batch_mode_batch_schema/
batch_choice.rs

1// ---------------- [ File: batch-mode-batch-schema/src/batch_choice.rs ]
2crate::ix!();
3
4#[derive(Getters,Builder,Clone,Debug,Serialize)]
5#[builder(setter(into))]
6#[getset(get="pub")]
7#[serde(deny_unknown_fields)]
8pub struct BatchChoice {
9    index:         u32,
10    message:       BatchMessage,
11    logprobs:      Option<serde_json::Value>,
12    finish_reason: FinishReason,
13}
14
15impl<'de> Deserialize<'de> for BatchChoice {
16
17    fn deserialize<D>(deserializer: D) -> Result<BatchChoice, D::Error>
18    where
19        D: Deserializer<'de>,
20    {
21        #[derive(Deserialize)]
22        struct BatchChoiceHelper {
23            index:         u32,
24            message:       BatchMessage,
25            logprobs:      Option<serde_json::Value>,
26            finish_reason: FinishReason,
27        }
28
29        let helper = BatchChoiceHelper::deserialize(deserializer)?;
30
31        // Validate `logprobs`
32        if let Some(ref logprobs) = helper.logprobs {
33            if !logprobs.is_object() && !logprobs.is_null() {
34                return Err(de::Error::custom("`logprobs` must be an object or null"));
35            }
36        }
37
38        Ok(BatchChoice {
39            index: helper.index,
40            message: helper.message,
41            logprobs: helper.logprobs,
42            finish_reason: helper.finish_reason,
43        })
44    }
45}
46
47#[cfg(test)]
48mod batch_choice_tests {
49    use super::*;
50
51    #[test]
52    fn test_invalid_index_type() {
53        let json = r#"{
54            "index": "invalid_index",
55            "message": {
56                "role": "assistant",
57                "content": "Invalid index type.",
58                "refusal": null
59            },
60            "logprobs": null,
61            "finish_reason": "stop"
62        }"#;
63
64        let result: Result<BatchChoice, _> = serde_json::from_str(json);
65        println!("Result: {:?}", result);
66        assert!(result.is_err());
67    }
68
69    // Test suite for BatchChoice
70    #[test]
71    fn test_batch_choice_deserialization() {
72        // Choice with all fields present
73        let json = r#"{
74            "index": 0,
75            "message": {
76                "role": "assistant",
77                "content": "This is the assistant's response.",
78                "refusal": null
79            },
80            "logprobs": null,
81            "finish_reason": "stop"
82        }"#;
83        let choice: BatchChoice = serde_json::from_str(json).unwrap();
84        pretty_assert_eq!(*choice.index(), 0);
85        pretty_assert_eq!(choice.message().role(), &MessageRole::Assistant);
86        pretty_assert_eq!(choice.finish_reason(), &FinishReason::Stop);
87        assert!(choice.logprobs.is_none());
88
89        // Choice with logprobs present
90        let json = r#"{
91            "index": 1,
92            "message": {
93                "role": "assistant",
94                "content": "Another response.",
95                "refusal": null
96            },
97            "logprobs": {
98                "tokens": ["This", "is", "a", "test"],
99                "token_logprobs": [-0.1, -0.2, -0.3, -0.4]
100            },
101            "finish_reason": "length"
102        }"#;
103        let choice: BatchChoice = serde_json::from_str(json).unwrap();
104        assert!(choice.logprobs.is_some());
105        pretty_assert_eq!(choice.finish_reason(), &FinishReason::Length);
106
107        // Choice with unknown finish_reason
108        let json = r#"{
109            "index": 2,
110            "message": {
111                "role": "assistant",
112                "content": "Response with unknown finish reason.",
113                "refusal": null
114            },
115            "logprobs": null,
116            "finish_reason": "unknown_reason"
117        }"#;
118        let choice: BatchChoice = serde_json::from_str(json).unwrap();
119        pretty_assert_eq!(
120            choice.finish_reason(),
121            &FinishReason::Unknown("unknown_reason".to_string())
122        );
123
124        // Choice with missing finish_reason field
125        let json = r#"{
126            "index": 3,
127            "message": {
128                "role": "assistant",
129                "content": "Response without finish_reason.",
130                "refusal": null
131            },
132            "logprobs": null
133        }"#;
134        let choice: BatchChoice = serde_json::from_str(json).unwrap();
135        pretty_assert_eq!(
136            choice.finish_reason(),
137            &FinishReason::Unknown("None".to_string())
138        );
139
140        // Choice with missing optional fields
141        let json = r#"{
142            "index": 4,
143            "message": {
144                "role": "assistant",
145                "content": "Response with minimal fields."
146            }
147        }"#;
148        let choice: BatchChoice = serde_json::from_str(json).unwrap();
149        pretty_assert_eq!(*choice.index(), 4);
150        pretty_assert_eq!(choice.message().content(), "Response with minimal fields.");
151        assert!(choice.logprobs.is_none());
152        pretty_assert_eq!(
153            choice.finish_reason(),
154            &FinishReason::Unknown("None".to_string())
155        );
156
157        // Choice with invalid index type
158        let json = r#"{
159            "index": "invalid_index",
160            "message": {
161                "role": "assistant",
162                "content": "Invalid index type.",
163                "refusal": null
164            },
165            "logprobs": null,
166            "finish_reason": "stop"
167        }"#;
168        let result: Result<BatchChoice, _> = serde_json::from_str(json);
169        assert!(result.is_err());
170
171        // Choice with missing message field
172        let json = r#"{
173            "index": 5,
174            "logprobs": null,
175            "finish_reason": "stop"
176        }"#;
177        let result: Result<BatchChoice, _> = serde_json::from_str(json);
178        assert!(result.is_err());
179
180        // Choice with invalid logprobs type
181        let json = r#"{
182            "index": 6,
183            "message": {
184                "role": "assistant",
185                "content": "Invalid logprobs.",
186                "refusal": null
187            },
188            "logprobs": "invalid_logprobs",
189            "finish_reason": "stop"
190        }"#;
191        let result: Result<BatchChoice, _> = serde_json::from_str(json);
192        assert!(result.is_err());
193    }
194}