batch_mode_batch_schema/
batch_choice.rs1crate::ix!();
3
4#[derive(Getters,Builder,Clone,Debug,Serialize)]
5#[builder(setter(into))]
6#[getset(get="pub")]
7#[serde(deny_unknown_fields)]
8pub struct BatchChoice {
9 index: u32,
10 message: BatchMessage,
11 logprobs: Option<serde_json::Value>,
12 finish_reason: FinishReason,
13}
14
15impl<'de> Deserialize<'de> for BatchChoice {
16
17 fn deserialize<D>(deserializer: D) -> Result<BatchChoice, D::Error>
18 where
19 D: Deserializer<'de>,
20 {
21 #[derive(Deserialize)]
22 struct BatchChoiceHelper {
23 index: u32,
24 message: BatchMessage,
25 logprobs: Option<serde_json::Value>,
26 finish_reason: FinishReason,
27 }
28
29 let helper = BatchChoiceHelper::deserialize(deserializer)?;
30
31 if let Some(ref logprobs) = helper.logprobs {
33 if !logprobs.is_object() && !logprobs.is_null() {
34 return Err(de::Error::custom("`logprobs` must be an object or null"));
35 }
36 }
37
38 Ok(BatchChoice {
39 index: helper.index,
40 message: helper.message,
41 logprobs: helper.logprobs,
42 finish_reason: helper.finish_reason,
43 })
44 }
45}
46
47#[cfg(test)]
48mod batch_choice_tests {
49 use super::*;
50
51 #[test]
52 fn test_invalid_index_type() {
53 let json = r#"{
54 "index": "invalid_index",
55 "message": {
56 "role": "assistant",
57 "content": "Invalid index type.",
58 "refusal": null
59 },
60 "logprobs": null,
61 "finish_reason": "stop"
62 }"#;
63
64 let result: Result<BatchChoice, _> = serde_json::from_str(json);
65 println!("Result: {:?}", result);
66 assert!(result.is_err());
67 }
68
69 #[test]
71 fn test_batch_choice_deserialization() {
72 let json = r#"{
74 "index": 0,
75 "message": {
76 "role": "assistant",
77 "content": "This is the assistant's response.",
78 "refusal": null
79 },
80 "logprobs": null,
81 "finish_reason": "stop"
82 }"#;
83 let choice: BatchChoice = serde_json::from_str(json).unwrap();
84 pretty_assert_eq!(*choice.index(), 0);
85 pretty_assert_eq!(choice.message().role(), &MessageRole::Assistant);
86 pretty_assert_eq!(choice.finish_reason(), &FinishReason::Stop);
87 assert!(choice.logprobs.is_none());
88
89 let json = r#"{
91 "index": 1,
92 "message": {
93 "role": "assistant",
94 "content": "Another response.",
95 "refusal": null
96 },
97 "logprobs": {
98 "tokens": ["This", "is", "a", "test"],
99 "token_logprobs": [-0.1, -0.2, -0.3, -0.4]
100 },
101 "finish_reason": "length"
102 }"#;
103 let choice: BatchChoice = serde_json::from_str(json).unwrap();
104 assert!(choice.logprobs.is_some());
105 pretty_assert_eq!(choice.finish_reason(), &FinishReason::Length);
106
107 let json = r#"{
109 "index": 2,
110 "message": {
111 "role": "assistant",
112 "content": "Response with unknown finish reason.",
113 "refusal": null
114 },
115 "logprobs": null,
116 "finish_reason": "unknown_reason"
117 }"#;
118 let choice: BatchChoice = serde_json::from_str(json).unwrap();
119 pretty_assert_eq!(
120 choice.finish_reason(),
121 &FinishReason::Unknown("unknown_reason".to_string())
122 );
123
124 let json = r#"{
126 "index": 3,
127 "message": {
128 "role": "assistant",
129 "content": "Response without finish_reason.",
130 "refusal": null
131 },
132 "logprobs": null
133 }"#;
134 let choice: BatchChoice = serde_json::from_str(json).unwrap();
135 pretty_assert_eq!(
136 choice.finish_reason(),
137 &FinishReason::Unknown("None".to_string())
138 );
139
140 let json = r#"{
142 "index": 4,
143 "message": {
144 "role": "assistant",
145 "content": "Response with minimal fields."
146 }
147 }"#;
148 let choice: BatchChoice = serde_json::from_str(json).unwrap();
149 pretty_assert_eq!(*choice.index(), 4);
150 pretty_assert_eq!(choice.message().content(), "Response with minimal fields.");
151 assert!(choice.logprobs.is_none());
152 pretty_assert_eq!(
153 choice.finish_reason(),
154 &FinishReason::Unknown("None".to_string())
155 );
156
157 let json = r#"{
159 "index": "invalid_index",
160 "message": {
161 "role": "assistant",
162 "content": "Invalid index type.",
163 "refusal": null
164 },
165 "logprobs": null,
166 "finish_reason": "stop"
167 }"#;
168 let result: Result<BatchChoice, _> = serde_json::from_str(json);
169 assert!(result.is_err());
170
171 let json = r#"{
173 "index": 5,
174 "logprobs": null,
175 "finish_reason": "stop"
176 }"#;
177 let result: Result<BatchChoice, _> = serde_json::from_str(json);
178 assert!(result.is_err());
179
180 let json = r#"{
182 "index": 6,
183 "message": {
184 "role": "assistant",
185 "content": "Invalid logprobs.",
186 "refusal": null
187 },
188 "logprobs": "invalid_logprobs",
189 "finish_reason": "stop"
190 }"#;
191 let result: Result<BatchChoice, _> = serde_json::from_str(json);
192 assert!(result.is_err());
193 }
194}