batch_mode_batch_scribe/
language_model_request_body.rs

1// ---------------- [ File: batch-mode-batch-scribe/src/language_model_request_body.rs ]
2crate::ix!();
3
4/// Body details of the API request.
5#[derive(Getters,Setters,Clone,Debug, Serialize, Deserialize)]
6#[getset(get="pub")]
7pub struct LanguageModelRequestBody {
8
9    /// Model used for the request.
10    #[serde(with = "model_type")]
11    model: LanguageModelType,
12
13    /// Array of messages exchanged in the request.
14    messages: Vec<LanguageModelMessage>,
15
16    /// Maximum number of tokens to be used by the model.
17    max_completion_tokens: u32,
18}
19
20impl LanguageModelRequestBody {
21
22    pub fn mock() -> Self {
23        LanguageModelRequestBody {
24            model:                 LanguageModelType::Gpt4o,
25            messages:              vec![],
26            max_completion_tokens: 128,
27        }
28    }
29
30    pub fn default_max_tokens() -> u32 {
31        //1024 
32        8192
33    }
34
35    pub fn default_max_tokens_given_image(_image_b64: &str) -> u32 {
36        //TODO: is this the right value?
37        2048
38    }
39
40    pub fn new_basic(model: LanguageModelType, system_message: &str, user_message: &str) -> Self {
41        Self {
42            model,
43            messages: vec![
44                LanguageModelMessage::system_message(system_message),
45                LanguageModelMessage::user_message(user_message),
46            ],
47            max_completion_tokens: Self::default_max_tokens(),
48        }
49    }
50
51    pub fn new_with_image(model: LanguageModelType, system_message: &str, user_message: &str, image_b64: &str) -> Self {
52        Self {
53            model,
54            messages: vec![
55                LanguageModelMessage::system_message(system_message),
56                LanguageModelMessage::user_message_with_image(user_message,image_b64),
57            ],
58            max_completion_tokens: Self::default_max_tokens_given_image(image_b64),
59        }
60    }
61}
62
63#[cfg(test)]
64mod language_model_request_body_exhaustive_tests {
65    use super::*;
66
67    #[traced_test]
68    fn mock_produces_gpt4o_empty_messages_128_tokens() {
69        trace!("===== BEGIN TEST: mock_produces_gpt4o_empty_messages_128_tokens =====");
70        let body = LanguageModelRequestBody::mock();
71        debug!("Mock body: {:?}", body);
72
73        match body.model {
74            LanguageModelType::Gpt4o => trace!("Correct model: Gpt4o"),
75            _ => panic!("Expected LanguageModelType::Gpt4o"),
76        }
77
78        assert!(
79            body.messages.is_empty(),
80            "Mock body should have no messages"
81        );
82        pretty_assert_eq!(
83            body.max_completion_tokens, 128,
84            "Mock body should have max_completion_tokens=128"
85        );
86
87        trace!("===== END TEST: mock_produces_gpt4o_empty_messages_128_tokens =====");
88    }
89
90    #[traced_test]
91    fn default_max_tokens_returns_8192() {
92        trace!("===== BEGIN TEST: default_max_tokens_returns_8192 =====");
93        let tokens = LanguageModelRequestBody::default_max_tokens();
94        debug!("default_max_tokens: {}", tokens);
95        pretty_assert_eq!(tokens, 8192, "default_max_tokens should return 8192");
96        trace!("===== END TEST: default_max_tokens_returns_8192 =====");
97    }
98
99    #[traced_test]
100    fn default_max_tokens_given_image_returns_2048() {
101        trace!("===== BEGIN TEST: default_max_tokens_given_image_returns_2048 =====");
102        let image_b64 = "fake_base64_image_data";
103        let tokens = LanguageModelRequestBody::default_max_tokens_given_image(image_b64);
104        debug!("default_max_tokens_given_image: {}", tokens);
105        pretty_assert_eq!(
106            tokens, 2048,
107            "default_max_tokens_given_image should return 2048"
108        );
109        trace!("===== END TEST: default_max_tokens_given_image_returns_2048 =====");
110    }
111
112    #[traced_test]
113    fn new_basic_sets_provided_model_and_messages_and_uses_default_tokens() {
114        trace!("===== BEGIN TEST: new_basic_sets_provided_model_and_messages_and_uses_default_tokens =====");
115        let model = LanguageModelType::Gpt4o;
116        let system_message = "System says hello";
117        let user_message = "User says hi";
118        let body = LanguageModelRequestBody::new_basic(model.clone(), system_message, user_message);
119        debug!("Constructed body: {:?}", body);
120
121        match body.model {
122            LanguageModelType::Gpt4o => trace!("Model is Gpt4o as expected"),
123            _ => panic!("Expected LanguageModelType::Gpt4o"),
124        }
125        pretty_assert_eq!(body.messages.len(), 2, "Should have 2 messages total");
126        match &body.messages[0].content() {
127            ChatCompletionRequestUserMessageContent::Text(text) => {
128                pretty_assert_eq!(text, system_message, "System message mismatch");
129            },
130            _ => panic!("Expected text content for system message"),
131        }
132        match &body.messages[1].content() {
133            ChatCompletionRequestUserMessageContent::Text(text) => {
134                pretty_assert_eq!(text, user_message, "User message mismatch");
135            },
136            _ => panic!("Expected text content for user message"),
137        }
138
139        pretty_assert_eq!(
140            *body.max_completion_tokens(),
141            LanguageModelRequestBody::default_max_tokens(),
142            "max_completion_tokens should match default"
143        );
144
145        trace!("===== END TEST: new_basic_sets_provided_model_and_messages_and_uses_default_tokens =====");
146    }
147
148    #[traced_test]
149    fn new_with_image_sets_provided_model_and_messages_and_uses_image_default_tokens() {
150        trace!("===== BEGIN TEST: new_with_image_sets_provided_model_and_messages_and_uses_image_default_tokens =====");
151        let model = LanguageModelType::Gpt4o;
152        let system_message = "System with image instructions";
153        let user_message = "User requests image";
154        let image_b64 = "fake_image_b64";
155        let body = LanguageModelRequestBody::new_with_image(model.clone(), system_message, user_message, image_b64);
156        debug!("Constructed body with image: {:?}", body);
157
158        match body.model {
159            LanguageModelType::Gpt4o => trace!("Model is Gpt4o as expected"),
160            _ => panic!("Expected LanguageModelType::Gpt4o"),
161        }
162        pretty_assert_eq!(body.messages.len(), 2, "Should have 2 messages total");
163        match &body.messages[0].content() {
164            ChatCompletionRequestUserMessageContent::Text(text) => {
165                pretty_assert_eq!(text, system_message, "System message mismatch");
166            },
167            _ => panic!("Expected text content for system message"),
168        }
169
170        match &body.messages[1].content() {
171            ChatCompletionRequestUserMessageContent::Array(parts) => {
172                pretty_assert_eq!(parts.len(), 2, "Expected text + image parts");
173            },
174            _ => panic!("Expected array content for user message with image"),
175        }
176
177        pretty_assert_eq!(
178            body.max_completion_tokens,
179            LanguageModelRequestBody::default_max_tokens_given_image(image_b64),
180            "max_completion_tokens should match default for images"
181        );
182
183        trace!("===== END TEST: new_with_image_sets_provided_model_and_messages_and_uses_image_default_tokens =====");
184    }
185
186    #[traced_test]
187    fn serialization_and_deserialization_round_trip() {
188        trace!("===== BEGIN TEST: serialization_and_deserialization_round_trip =====");
189        let original = LanguageModelRequestBody::new_basic(
190            LanguageModelType::Gpt4o,
191            "System Info",
192            "User Query"
193        );
194        let serialized = serde_json::to_string(&original)
195            .expect("Failed to serialize LanguageModelRequestBody");
196        debug!("Serialized: {}", serialized);
197
198        let deserialized: LanguageModelRequestBody = serde_json::from_str(&serialized)
199            .expect("Failed to deserialize LanguageModelRequestBody");
200        debug!("Deserialized: {:?}", deserialized);
201
202        // Compare essential fields
203        pretty_assert_eq!(format!("{:?}", original.model), format!("{:?}", deserialized.model));
204        pretty_assert_eq!(
205            original.messages.len(),
206            deserialized.messages.len(),
207            "Messages length mismatch after round-trip"
208        );
209        pretty_assert_eq!(
210            original.max_completion_tokens,
211            deserialized.max_completion_tokens,
212            "max_completion_tokens mismatch after round-trip"
213        );
214
215        trace!("===== END TEST: serialization_and_deserialization_round_trip =====");
216    }
217}