batch_mode_batch_schema/
batch_choice.rs

1// ---------------- [ File: src/batch_choice.rs ]
2crate::ix!();
3
4#[derive(Debug,Serialize)]
5#[serde(deny_unknown_fields)]
6pub struct BatchChoice {
7    index:         u32,
8    message:       BatchMessage,
9    logprobs:      Option<serde_json::Value>,
10    finish_reason: FinishReason,
11}
12
13impl<'de> Deserialize<'de> for BatchChoice {
14
15    fn deserialize<D>(deserializer: D) -> Result<BatchChoice, D::Error>
16    where
17        D: Deserializer<'de>,
18    {
19        #[derive(Deserialize)]
20        struct BatchChoiceHelper {
21            index:         u32,
22            message:       BatchMessage,
23            logprobs:      Option<serde_json::Value>,
24            finish_reason: FinishReason,
25        }
26
27        let helper = BatchChoiceHelper::deserialize(deserializer)?;
28
29        // Validate `logprobs`
30        if let Some(ref logprobs) = helper.logprobs {
31            if !logprobs.is_object() && !logprobs.is_null() {
32                return Err(de::Error::custom("`logprobs` must be an object or null"));
33            }
34        }
35
36        Ok(BatchChoice {
37            index: helper.index,
38            message: helper.message,
39            logprobs: helper.logprobs,
40            finish_reason: helper.finish_reason,
41        })
42    }
43}
44
45impl BatchChoice {
46    pub fn index(&self) -> u32 {
47        self.index
48    }
49
50    pub fn message(&self) -> &BatchMessage {
51        &self.message
52    }
53
54    pub fn finish_reason(&self) -> &FinishReason {
55        &self.finish_reason
56    }
57}
58
59#[cfg(test)]
60mod batch_choice_tests {
61    use super::*;
62    use serde_json::json;
63
64    #[test]
65    fn test_invalid_index_type() {
66        let json = r#"{
67            "index": "invalid_index",
68            "message": {
69                "role": "assistant",
70                "content": "Invalid index type.",
71                "refusal": null
72            },
73            "logprobs": null,
74            "finish_reason": "stop"
75        }"#;
76
77        let result: Result<BatchChoice, _> = serde_json::from_str(json);
78        println!("Result: {:?}", result);
79        assert!(result.is_err());
80    }
81
82    // Test suite for BatchChoice
83    #[test]
84    fn test_batch_choice_deserialization() {
85        // Choice with all fields present
86        let json = r#"{
87            "index": 0,
88            "message": {
89                "role": "assistant",
90                "content": "This is the assistant's response.",
91                "refusal": null
92            },
93            "logprobs": null,
94            "finish_reason": "stop"
95        }"#;
96        let choice: BatchChoice = serde_json::from_str(json).unwrap();
97        assert_eq!(choice.index(), 0);
98        assert_eq!(choice.message().role(), &MessageRole::Assistant);
99        assert_eq!(choice.finish_reason(), &FinishReason::Stop);
100        assert!(choice.logprobs.is_none());
101
102        // Choice with logprobs present
103        let json = r#"{
104            "index": 1,
105            "message": {
106                "role": "assistant",
107                "content": "Another response.",
108                "refusal": null
109            },
110            "logprobs": {
111                "tokens": ["This", "is", "a", "test"],
112                "token_logprobs": [-0.1, -0.2, -0.3, -0.4]
113            },
114            "finish_reason": "length"
115        }"#;
116        let choice: BatchChoice = serde_json::from_str(json).unwrap();
117        assert!(choice.logprobs.is_some());
118        assert_eq!(choice.finish_reason(), &FinishReason::Length);
119
120        // Choice with unknown finish_reason
121        let json = r#"{
122            "index": 2,
123            "message": {
124                "role": "assistant",
125                "content": "Response with unknown finish reason.",
126                "refusal": null
127            },
128            "logprobs": null,
129            "finish_reason": "unknown_reason"
130        }"#;
131        let choice: BatchChoice = serde_json::from_str(json).unwrap();
132        assert_eq!(
133            choice.finish_reason(),
134            &FinishReason::Unknown("unknown_reason".to_string())
135        );
136
137        // Choice with missing finish_reason field
138        let json = r#"{
139            "index": 3,
140            "message": {
141                "role": "assistant",
142                "content": "Response without finish_reason.",
143                "refusal": null
144            },
145            "logprobs": null
146        }"#;
147        let choice: BatchChoice = serde_json::from_str(json).unwrap();
148        assert_eq!(
149            choice.finish_reason(),
150            &FinishReason::Unknown("None".to_string())
151        );
152
153        // Choice with missing optional fields
154        let json = r#"{
155            "index": 4,
156            "message": {
157                "role": "assistant",
158                "content": "Response with minimal fields."
159            }
160        }"#;
161        let choice: BatchChoice = serde_json::from_str(json).unwrap();
162        assert_eq!(choice.index(), 4);
163        assert_eq!(choice.message().content(), "Response with minimal fields.");
164        assert!(choice.logprobs.is_none());
165        assert_eq!(
166            choice.finish_reason(),
167            &FinishReason::Unknown("None".to_string())
168        );
169
170        // Choice with invalid index type
171        let json = r#"{
172            "index": "invalid_index",
173            "message": {
174                "role": "assistant",
175                "content": "Invalid index type.",
176                "refusal": null
177            },
178            "logprobs": null,
179            "finish_reason": "stop"
180        }"#;
181        let result: Result<BatchChoice, _> = serde_json::from_str(json);
182        assert!(result.is_err());
183
184        // Choice with missing message field
185        let json = r#"{
186            "index": 5,
187            "logprobs": null,
188            "finish_reason": "stop"
189        }"#;
190        let result: Result<BatchChoice, _> = serde_json::from_str(json);
191        assert!(result.is_err());
192
193        // Choice with invalid logprobs type
194        let json = r#"{
195            "index": 6,
196            "message": {
197                "role": "assistant",
198                "content": "Invalid logprobs.",
199                "refusal": null
200            },
201            "logprobs": "invalid_logprobs",
202            "finish_reason": "stop"
203        }"#;
204        let result: Result<BatchChoice, _> = serde_json::from_str(json);
205        assert!(result.is_err());
206    }
207}