batch_mode_batch_scribe/
language_model_request_body.rs1crate::ix!();
3
4#[derive(Getters,Setters,Clone,Debug, Serialize, Deserialize)]
6#[getset(get="pub")]
7pub struct LanguageModelRequestBody {
8
9 #[serde(with = "model_type")]
11 model: LanguageModelType,
12
13 messages: Vec<LanguageModelMessage>,
15
16 max_completion_tokens: u32,
18}
19
20impl LanguageModelRequestBody {
21
22 pub fn mock() -> Self {
23 LanguageModelRequestBody {
24 model: LanguageModelType::Gpt4o,
25 messages: vec![],
26 max_completion_tokens: 128,
27 }
28 }
29
30 pub fn default_max_tokens() -> u32 {
31 8192
33 }
34
35 pub fn default_max_tokens_given_image(_image_b64: &str) -> u32 {
36 2048
38 }
39
40 pub fn new_basic(model: LanguageModelType, system_message: &str, user_message: &str) -> Self {
41 Self {
42 model,
43 messages: vec![
44 LanguageModelMessage::system_message(system_message),
45 LanguageModelMessage::user_message(user_message),
46 ],
47 max_completion_tokens: Self::default_max_tokens(),
48 }
49 }
50
51 pub fn new_with_image(model: LanguageModelType, system_message: &str, user_message: &str, image_b64: &str) -> Self {
52 Self {
53 model,
54 messages: vec![
55 LanguageModelMessage::system_message(system_message),
56 LanguageModelMessage::user_message_with_image(user_message,image_b64),
57 ],
58 max_completion_tokens: Self::default_max_tokens_given_image(image_b64),
59 }
60 }
61}
62
63#[cfg(test)]
64mod language_model_request_body_exhaustive_tests {
65 use super::*;
66
67 #[traced_test]
68 fn mock_produces_gpt4o_empty_messages_128_tokens() {
69 trace!("===== BEGIN TEST: mock_produces_gpt4o_empty_messages_128_tokens =====");
70 let body = LanguageModelRequestBody::mock();
71 debug!("Mock body: {:?}", body);
72
73 match body.model {
74 LanguageModelType::Gpt4o => trace!("Correct model: Gpt4o"),
75 _ => panic!("Expected LanguageModelType::Gpt4o"),
76 }
77
78 assert!(
79 body.messages.is_empty(),
80 "Mock body should have no messages"
81 );
82 pretty_assert_eq!(
83 body.max_completion_tokens, 128,
84 "Mock body should have max_completion_tokens=128"
85 );
86
87 trace!("===== END TEST: mock_produces_gpt4o_empty_messages_128_tokens =====");
88 }
89
90 #[traced_test]
91 fn default_max_tokens_returns_8192() {
92 trace!("===== BEGIN TEST: default_max_tokens_returns_8192 =====");
93 let tokens = LanguageModelRequestBody::default_max_tokens();
94 debug!("default_max_tokens: {}", tokens);
95 pretty_assert_eq!(tokens, 8192, "default_max_tokens should return 8192");
96 trace!("===== END TEST: default_max_tokens_returns_8192 =====");
97 }
98
99 #[traced_test]
100 fn default_max_tokens_given_image_returns_2048() {
101 trace!("===== BEGIN TEST: default_max_tokens_given_image_returns_2048 =====");
102 let image_b64 = "fake_base64_image_data";
103 let tokens = LanguageModelRequestBody::default_max_tokens_given_image(image_b64);
104 debug!("default_max_tokens_given_image: {}", tokens);
105 pretty_assert_eq!(
106 tokens, 2048,
107 "default_max_tokens_given_image should return 2048"
108 );
109 trace!("===== END TEST: default_max_tokens_given_image_returns_2048 =====");
110 }
111
112 #[traced_test]
113 fn new_basic_sets_provided_model_and_messages_and_uses_default_tokens() {
114 trace!("===== BEGIN TEST: new_basic_sets_provided_model_and_messages_and_uses_default_tokens =====");
115 let model = LanguageModelType::Gpt4o;
116 let system_message = "System says hello";
117 let user_message = "User says hi";
118 let body = LanguageModelRequestBody::new_basic(model.clone(), system_message, user_message);
119 debug!("Constructed body: {:?}", body);
120
121 match body.model {
122 LanguageModelType::Gpt4o => trace!("Model is Gpt4o as expected"),
123 _ => panic!("Expected LanguageModelType::Gpt4o"),
124 }
125 pretty_assert_eq!(body.messages.len(), 2, "Should have 2 messages total");
126 match &body.messages[0].content() {
127 ChatCompletionRequestUserMessageContent::Text(text) => {
128 pretty_assert_eq!(text, system_message, "System message mismatch");
129 },
130 _ => panic!("Expected text content for system message"),
131 }
132 match &body.messages[1].content() {
133 ChatCompletionRequestUserMessageContent::Text(text) => {
134 pretty_assert_eq!(text, user_message, "User message mismatch");
135 },
136 _ => panic!("Expected text content for user message"),
137 }
138
139 pretty_assert_eq!(
140 *body.max_completion_tokens(),
141 LanguageModelRequestBody::default_max_tokens(),
142 "max_completion_tokens should match default"
143 );
144
145 trace!("===== END TEST: new_basic_sets_provided_model_and_messages_and_uses_default_tokens =====");
146 }
147
148 #[traced_test]
149 fn new_with_image_sets_provided_model_and_messages_and_uses_image_default_tokens() {
150 trace!("===== BEGIN TEST: new_with_image_sets_provided_model_and_messages_and_uses_image_default_tokens =====");
151 let model = LanguageModelType::Gpt4o;
152 let system_message = "System with image instructions";
153 let user_message = "User requests image";
154 let image_b64 = "fake_image_b64";
155 let body = LanguageModelRequestBody::new_with_image(model.clone(), system_message, user_message, image_b64);
156 debug!("Constructed body with image: {:?}", body);
157
158 match body.model {
159 LanguageModelType::Gpt4o => trace!("Model is Gpt4o as expected"),
160 _ => panic!("Expected LanguageModelType::Gpt4o"),
161 }
162 pretty_assert_eq!(body.messages.len(), 2, "Should have 2 messages total");
163 match &body.messages[0].content() {
164 ChatCompletionRequestUserMessageContent::Text(text) => {
165 pretty_assert_eq!(text, system_message, "System message mismatch");
166 },
167 _ => panic!("Expected text content for system message"),
168 }
169
170 match &body.messages[1].content() {
171 ChatCompletionRequestUserMessageContent::Array(parts) => {
172 pretty_assert_eq!(parts.len(), 2, "Expected text + image parts");
173 },
174 _ => panic!("Expected array content for user message with image"),
175 }
176
177 pretty_assert_eq!(
178 body.max_completion_tokens,
179 LanguageModelRequestBody::default_max_tokens_given_image(image_b64),
180 "max_completion_tokens should match default for images"
181 );
182
183 trace!("===== END TEST: new_with_image_sets_provided_model_and_messages_and_uses_image_default_tokens =====");
184 }
185
186 #[traced_test]
187 fn serialization_and_deserialization_round_trip() {
188 trace!("===== BEGIN TEST: serialization_and_deserialization_round_trip =====");
189 let original = LanguageModelRequestBody::new_basic(
190 LanguageModelType::Gpt4o,
191 "System Info",
192 "User Query"
193 );
194 let serialized = serde_json::to_string(&original)
195 .expect("Failed to serialize LanguageModelRequestBody");
196 debug!("Serialized: {}", serialized);
197
198 let deserialized: LanguageModelRequestBody = serde_json::from_str(&serialized)
199 .expect("Failed to deserialize LanguageModelRequestBody");
200 debug!("Deserialized: {:?}", deserialized);
201
202 pretty_assert_eq!(format!("{:?}", original.model), format!("{:?}", deserialized.model));
204 pretty_assert_eq!(
205 original.messages.len(),
206 deserialized.messages.len(),
207 "Messages length mismatch after round-trip"
208 );
209 pretty_assert_eq!(
210 original.max_completion_tokens,
211 deserialized.max_completion_tokens,
212 "max_completion_tokens mismatch after round-trip"
213 );
214
215 trace!("===== END TEST: serialization_and_deserialization_round_trip =====");
216 }
217}