batch_mode_batch_schema/
batch_choice.rs1crate::ix!();
3
4#[derive(Debug,Serialize)]
5#[serde(deny_unknown_fields)]
6pub struct BatchChoice {
7 index: u32,
8 message: BatchMessage,
9 logprobs: Option<serde_json::Value>,
10 finish_reason: FinishReason,
11}
12
13impl<'de> Deserialize<'de> for BatchChoice {
14
15 fn deserialize<D>(deserializer: D) -> Result<BatchChoice, D::Error>
16 where
17 D: Deserializer<'de>,
18 {
19 #[derive(Deserialize)]
20 struct BatchChoiceHelper {
21 index: u32,
22 message: BatchMessage,
23 logprobs: Option<serde_json::Value>,
24 finish_reason: FinishReason,
25 }
26
27 let helper = BatchChoiceHelper::deserialize(deserializer)?;
28
29 if let Some(ref logprobs) = helper.logprobs {
31 if !logprobs.is_object() && !logprobs.is_null() {
32 return Err(de::Error::custom("`logprobs` must be an object or null"));
33 }
34 }
35
36 Ok(BatchChoice {
37 index: helper.index,
38 message: helper.message,
39 logprobs: helper.logprobs,
40 finish_reason: helper.finish_reason,
41 })
42 }
43}
44
45impl BatchChoice {
46 pub fn index(&self) -> u32 {
47 self.index
48 }
49
50 pub fn message(&self) -> &BatchMessage {
51 &self.message
52 }
53
54 pub fn finish_reason(&self) -> &FinishReason {
55 &self.finish_reason
56 }
57}
58
59#[cfg(test)]
60mod batch_choice_tests {
61 use super::*;
62 use serde_json::json;
63
64 #[test]
65 fn test_invalid_index_type() {
66 let json = r#"{
67 "index": "invalid_index",
68 "message": {
69 "role": "assistant",
70 "content": "Invalid index type.",
71 "refusal": null
72 },
73 "logprobs": null,
74 "finish_reason": "stop"
75 }"#;
76
77 let result: Result<BatchChoice, _> = serde_json::from_str(json);
78 println!("Result: {:?}", result);
79 assert!(result.is_err());
80 }
81
82 #[test]
84 fn test_batch_choice_deserialization() {
85 let json = r#"{
87 "index": 0,
88 "message": {
89 "role": "assistant",
90 "content": "This is the assistant's response.",
91 "refusal": null
92 },
93 "logprobs": null,
94 "finish_reason": "stop"
95 }"#;
96 let choice: BatchChoice = serde_json::from_str(json).unwrap();
97 assert_eq!(choice.index(), 0);
98 assert_eq!(choice.message().role(), &MessageRole::Assistant);
99 assert_eq!(choice.finish_reason(), &FinishReason::Stop);
100 assert!(choice.logprobs.is_none());
101
102 let json = r#"{
104 "index": 1,
105 "message": {
106 "role": "assistant",
107 "content": "Another response.",
108 "refusal": null
109 },
110 "logprobs": {
111 "tokens": ["This", "is", "a", "test"],
112 "token_logprobs": [-0.1, -0.2, -0.3, -0.4]
113 },
114 "finish_reason": "length"
115 }"#;
116 let choice: BatchChoice = serde_json::from_str(json).unwrap();
117 assert!(choice.logprobs.is_some());
118 assert_eq!(choice.finish_reason(), &FinishReason::Length);
119
120 let json = r#"{
122 "index": 2,
123 "message": {
124 "role": "assistant",
125 "content": "Response with unknown finish reason.",
126 "refusal": null
127 },
128 "logprobs": null,
129 "finish_reason": "unknown_reason"
130 }"#;
131 let choice: BatchChoice = serde_json::from_str(json).unwrap();
132 assert_eq!(
133 choice.finish_reason(),
134 &FinishReason::Unknown("unknown_reason".to_string())
135 );
136
137 let json = r#"{
139 "index": 3,
140 "message": {
141 "role": "assistant",
142 "content": "Response without finish_reason.",
143 "refusal": null
144 },
145 "logprobs": null
146 }"#;
147 let choice: BatchChoice = serde_json::from_str(json).unwrap();
148 assert_eq!(
149 choice.finish_reason(),
150 &FinishReason::Unknown("None".to_string())
151 );
152
153 let json = r#"{
155 "index": 4,
156 "message": {
157 "role": "assistant",
158 "content": "Response with minimal fields."
159 }
160 }"#;
161 let choice: BatchChoice = serde_json::from_str(json).unwrap();
162 assert_eq!(choice.index(), 4);
163 assert_eq!(choice.message().content(), "Response with minimal fields.");
164 assert!(choice.logprobs.is_none());
165 assert_eq!(
166 choice.finish_reason(),
167 &FinishReason::Unknown("None".to_string())
168 );
169
170 let json = r#"{
172 "index": "invalid_index",
173 "message": {
174 "role": "assistant",
175 "content": "Invalid index type.",
176 "refusal": null
177 },
178 "logprobs": null,
179 "finish_reason": "stop"
180 }"#;
181 let result: Result<BatchChoice, _> = serde_json::from_str(json);
182 assert!(result.is_err());
183
184 let json = r#"{
186 "index": 5,
187 "logprobs": null,
188 "finish_reason": "stop"
189 }"#;
190 let result: Result<BatchChoice, _> = serde_json::from_str(json);
191 assert!(result.is_err());
192
193 let json = r#"{
195 "index": 6,
196 "message": {
197 "role": "assistant",
198 "content": "Invalid logprobs.",
199 "refusal": null
200 },
201 "logprobs": "invalid_logprobs",
202 "finish_reason": "stop"
203 }"#;
204 let result: Result<BatchChoice, _> = serde_json::from_str(json);
205 assert!(result.is_err());
206 }
207}