1use std::{borrow::Cow, collections::HashMap};
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)]
6pub struct RequestBody {
7 pub messages: Vec<Message>, pub model: String,
13
14 #[serde(skip_serializing_if = "Option::is_none")]
18 pub frequency_penalty: Option<f32>, #[serde(skip_serializing_if = "Option::is_none")]
27 pub logit_bias: Option<HashMap<String, serde_json::Value>>, #[serde(skip_serializing_if = "Option::is_none")]
31 pub logprobs: Option<bool>,
32
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub top_logprobs: Option<u8>, #[deprecated(note = "Use max_completion_tokens instead")]
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub max_tokens: Option<u32>,
46
47 #[serde(skip_serializing_if = "Option::is_none")]
49 pub max_completion_tokens: Option<u32>,
50
51 #[serde(skip_serializing_if = "Option::is_none")]
56 pub reasoning_effort: Option<ReasoningEffort>,
57
58 #[serde(skip_serializing_if = "Option::is_none")]
60 pub n: Option<u8>, #[serde(skip_serializing_if = "Option::is_none")]
66 pub presence_penalty: Option<f32>, #[serde(skip_serializing_if = "Option::is_none")]
74 pub response_format: Option<ResponseFormat>,
75
76 #[serde(skip_serializing_if = "Option::is_none")]
81 pub seed: Option<i64>,
82
83 #[serde(skip_serializing_if = "Option::is_none")]
85 pub stop: Option<Stop>,
86
87 #[serde(skip_serializing_if = "Option::is_none")]
91 pub stream: Option<bool>,
92
93 #[serde(skip_serializing_if = "Option::is_none")]
98 pub temperature: Option<f32>, #[serde(skip_serializing_if = "Option::is_none")]
106 pub top_p: Option<f32>, #[serde(skip_serializing_if = "Option::is_none")]
111 pub tools: Option<Vec<Tool>>,
112
113 #[serde(skip_serializing_if = "Option::is_none")]
114 pub tool_choice: Option<ToolChoice>,
115
116 #[serde(skip_serializing_if = "Option::is_none")]
118 pub user: Option<String>,
119
120 #[serde(skip_serializing_if = "Option::is_none")]
122 pub store: Option<bool>, #[serde(skip_serializing_if = "Option::is_none")]
126 pub metadata: Option<HashMap<String, serde_json::Value>>,
127
128 #[serde(skip_serializing_if = "Option::is_none")]
130 pub parallel_tool_calls: Option<bool>, #[serde(skip_serializing_if = "Option::is_none")]
135 pub modalities: Option<Vec<String>>,
136
137 #[serde(skip_serializing_if = "Option::is_none")]
140 pub prediction: Option<PredictionConfig>,
141
142 #[serde(skip_serializing_if = "Option::is_none")]
144 pub audio: Option<AudioConfig>,
145
146 #[serde(skip_serializing_if = "Option::is_none")]
149 pub service_tier: Option<String>, #[serde(skip_serializing_if = "Option::is_none")]
153 pub stream_options: Option<StreamOptions>,
154
155 #[serde(skip_serializing_if = "Option::is_none")]
157 pub web_search_options: Option<WebSearchOptions>,
158
159 #[serde(skip_serializing_if = "Option::is_none")]
162 pub reasoning: Option<OpenRouterReasoning>,
163
164 #[serde(skip_serializing_if = "Option::is_none")]
167 pub safety_identifier: Option<String>,
168
169 #[serde(skip_serializing_if = "Option::is_none")]
172 pub prompt_cache_key: Option<String>,
173
174 #[serde(skip_serializing_if = "Option::is_none")]
178 pub truncation: Option<Truncation>,
179
180 #[serde(skip_serializing_if = "Option::is_none")]
184 pub verbosity: Option<Verbosity>,
185
186 #[serde(skip_serializing_if = "Option::is_none")]
188 pub text: Option<serde_json::Value>,
189}
190
191impl RequestBody {
192 pub fn first_user_message(&self) -> Option<&Message> {
193 self.messages
194 .iter()
195 .find(|message| matches!(message, Message::User(_)))
196 }
197
198 pub fn first_user_message_text(&self) -> Option<String> {
199 self.messages
200 .iter()
201 .find(|message| matches!(message, Message::User(_)))
202 .and_then(|message| match message {
203 Message::User(user_message) => match &user_message.content {
204 Content::Text(text) => Some(text.clone()),
205 Content::Array(array) => Some(
206 array
207 .iter()
208 .filter_map(|content_part| match content_part {
209 ContentPart::Text(text_part) => Some(text_part.text.clone()),
210 _ => None,
211 })
212 .collect::<Vec<String>>()
213 .join(""),
214 ),
215 },
216 _ => None,
217 })
218 }
219
220 pub fn first_system_message(&self) -> Option<&Message> {
221 self.messages
222 .iter()
223 .find(|message| matches!(message, Message::System(_)))
224 }
225
226 pub fn first_system_message_text(&self) -> Option<String> {
227 self.first_system_message()
228 .and_then(|message| match message {
229 Message::System(system_message) => Some(system_message.content.clone()),
230 _ => None,
231 })
232 }
233
234 pub fn last_user_message(&self) -> Option<&Message> {
235 self.messages
236 .iter()
237 .rev()
238 .find(|message| matches!(message, Message::User(_)))
239 }
240
241 pub fn last_user_message_text(&self) -> Option<String> {
242 self.messages
243 .iter()
244 .rev()
245 .find(|message| matches!(message, Message::User(_)))
246 .and_then(|message| match message {
247 Message::User(user_message) => match &user_message.content {
248 Content::Text(text) => Some(text.clone()),
249 Content::Array(array) => Some(
250 array
251 .iter()
252 .filter_map(|content_part| match content_part {
253 ContentPart::Text(text_part) => Some(text_part.text.clone()),
254 _ => None,
255 })
256 .collect::<Vec<String>>()
257 .join(" "),
258 ),
259 },
260 _ => None,
261 })
262 }
263
264 pub fn last_system_message(&self) -> Option<&Message> {
265 self.messages
266 .iter()
267 .rev()
268 .find(|message| matches!(message, Message::System(_)))
269 }
270
271 pub fn last_system_message_text(&self) -> Option<String> {
272 self.last_system_message()
273 .and_then(|message| match message {
274 Message::System(system_message) => Some(system_message.content.clone()),
275 _ => None,
276 })
277 }
278}
279
280pub struct RequestBodyBuilder {
281 inner: RequestBody,
282}
283
284impl RequestBodyBuilder {
285 pub fn new() -> Self {
286 RequestBodyBuilder {
287 inner: RequestBody::default(),
288 }
289 }
290
291 pub fn model(mut self, model: impl Into<String>) -> Self {
292 self.inner.model = model.into();
293 self
294 }
295
296 pub fn messages(mut self, messages: Vec<Message>) -> Self {
297 self.inner.messages = messages;
298 self
299 }
300
301 pub fn push_user_message(mut self, message: impl Into<String>) -> Self {
302 self.inner.messages.push(Message::User(UserMessage {
303 content: Content::Text(message.into()),
304 name: None,
305 }));
306 self
307 }
308
309 pub fn push_system_message(mut self, message: impl Into<String>) -> Self {
310 self.inner.messages.push(Message::System(SystemMessage {
311 content: message.into(),
312 ..Default::default()
313 }));
314 self
315 }
316
317 pub fn prepend_system_message(mut self, message: impl Into<String>) -> Self {
318 self.inner.messages.insert(
319 0,
320 Message::System(SystemMessage {
321 content: message.into(),
322 ..Default::default()
323 }),
324 );
325 self
326 }
327
328 pub fn frequency_penalty(mut self, frequency_penalty: f32) -> Self {
329 self.inner.frequency_penalty = Some(frequency_penalty);
330 self
331 }
332
333 pub fn logit_bias(mut self, logit_bias: HashMap<String, serde_json::Value>) -> Self {
334 self.inner.logit_bias = Some(logit_bias);
335 self
336 }
337
338 pub fn logprobs(mut self, logprobs: bool) -> Self {
339 self.inner.logprobs = Some(logprobs);
340 self
341 }
342
343 pub fn top_logprobs(mut self, top_logprobs: u8) -> Self {
344 self.inner.top_logprobs = Some(top_logprobs);
345 self
346 }
347
348 pub fn max_completion_tokens(mut self, max_completion_tokens: u32) -> Self {
349 self.inner.max_completion_tokens = Some(max_completion_tokens);
350 self
351 }
352
353 pub fn reasoning_effort(mut self, reasoning_effort: ReasoningEffort) -> Self {
354 self.inner.reasoning_effort = Some(reasoning_effort);
355 self
356 }
357
358 pub fn n(mut self, n: u8) -> Self {
359 self.inner.n = Some(n);
360 self
361 }
362
363 pub fn presence_penalty(mut self, presence_penalty: f32) -> Self {
364 self.inner.presence_penalty = Some(presence_penalty);
365 self
366 }
367
368 pub fn response_format(mut self, response_format: ResponseFormat) -> Self {
369 self.inner.response_format = Some(response_format);
370 self
371 }
372
373 pub fn seed(mut self, seed: i64) -> Self {
374 self.inner.seed = Some(seed);
375 self
376 }
377
378 pub fn stop(mut self, stop: Stop) -> Self {
379 self.inner.stop = Some(stop);
380 self
381 }
382
383 pub fn stream(mut self, stream: bool) -> Self {
384 self.inner.stream = Some(stream);
385 self
386 }
387
388 pub fn temperature(mut self, temperature: f32) -> Self {
389 self.inner.temperature = Some(temperature);
390 self
391 }
392
393 pub fn top_p(mut self, top_p: f32) -> Self {
394 self.inner.top_p = Some(top_p);
395 self
396 }
397
398 pub fn tools(mut self, tools: Vec<Tool>) -> Self {
399 self.inner.tools = Some(tools);
400 self
401 }
402
403 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
404 self.inner.tool_choice = Some(tool_choice);
405 self
406 }
407
408 pub fn user(mut self, user: impl Into<String>) -> Self {
409 self.inner.user = Some(user.into());
410 self
411 }
412
413 pub fn store(mut self, store: bool) -> Self {
414 self.inner.store = Some(store);
415 self
416 }
417
418 pub fn metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
419 self.inner.metadata = Some(metadata);
420 self
421 }
422
423 pub fn parallel_tool_calls(mut self, parallel_tool_calls: bool) -> Self {
424 self.inner.parallel_tool_calls = Some(parallel_tool_calls);
425 self
426 }
427
428 pub fn modalities(mut self, modalities: Vec<String>) -> Self {
429 self.inner.modalities = Some(modalities);
430 self
431 }
432
433 pub fn prediction(mut self, prediction: PredictionConfig) -> Self {
434 self.inner.prediction = Some(prediction);
435 self
436 }
437
438 pub fn audio(mut self, audio: AudioConfig) -> Self {
439 self.inner.audio = Some(audio);
440 self
441 }
442
443 pub fn service_tier(mut self, service_tier: String) -> Self {
444 self.inner.service_tier = Some(service_tier);
445 self
446 }
447
448 pub fn stream_options(mut self, stream_options: StreamOptions) -> Self {
449 self.inner.stream_options = Some(stream_options);
450 self
451 }
452
453 pub fn web_search_options(mut self, web_search_options: WebSearchOptions) -> Self {
454 self.inner.web_search_options = Some(web_search_options);
455 self
456 }
457
458 pub fn safety_identifier(mut self, safety_identifier: impl Into<String>) -> Self {
459 self.inner.safety_identifier = Some(safety_identifier.into());
460 self
461 }
462
463 pub fn prompt_cache_key(mut self, prompt_cache_key: impl Into<String>) -> Self {
464 self.inner.prompt_cache_key = Some(prompt_cache_key.into());
465 self
466 }
467
468 pub fn truncation(mut self, truncation: Truncation) -> Self {
469 self.inner.truncation = Some(truncation);
470 self
471 }
472
473 pub fn verbosity(mut self, verbosity: Verbosity) -> Self {
474 self.inner.verbosity = Some(verbosity);
475 self
476 }
477
478 pub fn text(mut self, text: serde_json::Value) -> Self {
479 self.inner.text = Some(text);
480 self
481 }
482
483 pub fn build(self) -> RequestBody {
484 self.inner
485 }
486}
487
488impl Default for RequestBodyBuilder {
489 fn default() -> Self {
490 Self::new()
491 }
492}
493
494#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
495#[serde(untagged)]
496pub enum Stop {
497 String(String),
498 Array(Vec<String>), }
500
501#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)]
502pub struct Tool {
503 pub r#type: ToolType,
504 pub function: FunctionTool,
505}
506
507#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)]
508#[serde(rename_all = "lowercase")]
509pub struct FunctionTool {
510 pub name: Cow<'static, str>,
512 #[serde(skip_serializing_if = "Option::is_none")]
514 pub description: Option<Cow<'static, str>>,
515 #[serde(skip_serializing_if = "Option::is_none")]
519 pub parameters: Option<serde_json::Value>,
520}
521
522#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)]
523#[serde(rename_all = "lowercase")]
524pub enum ToolType {
525 #[default]
526 Function,
527}
528
529#[derive(Debug, Serialize, Deserialize, Default, Clone, Copy, PartialEq)]
530#[serde(rename_all = "snake_case")]
531pub enum FinishReason {
532 #[default]
533 Stop,
534 Length,
535 ToolCalls,
536 ContentFilter,
537}
538
539#[derive(Debug, Default, Serialize, Deserialize, Clone, PartialEq)]
540pub struct TopLogprobs {
541 pub token: String,
543 pub logprob: f32,
545 pub bytes: Option<Vec<u8>>,
547}
548
549#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
550#[serde(tag = "role")]
551pub enum Message {
552 #[serde(rename = "system")]
553 System(SystemMessage),
554 #[serde(rename = "user")]
555 User(UserMessage),
556 #[serde(rename = "assistant")]
557 Assistant(AssistantMessage),
558 #[serde(rename = "tool")]
559 Tool(ToolMessage),
560}
561
562#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
563pub struct SystemMessage {
564 pub content: String,
566 #[serde(skip_serializing_if = "Option::is_none")]
568 pub name: Option<String>,
569}
570
571#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
572pub struct UserMessage {
573 pub content: Content,
575 #[serde(skip_serializing_if = "Option::is_none")]
577 pub name: Option<String>,
578}
579
580#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
581#[serde(untagged)]
582pub enum Content {
583 Text(String),
585 Array(Vec<ContentPart>),
589}
590
591#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
592#[serde(tag = "type")]
593pub enum ContentPart {
594 #[serde(rename = "text")]
595 Text(TextContentPart),
596 #[serde(rename = "image_url")]
597 Image(ImageContentPart),
598
599 #[serde(untagged)]
602 Other(serde_json::Value),
603}
604
605#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
606pub struct TextContentPart {
607 pub text: String,
608}
609
610#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
611#[serde(rename_all = "lowercase")]
612pub enum ImageUrlDetail {
613 #[default]
614 Auto,
615 Low,
616 High,
617}
618
619#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
620pub struct ImageUrl {
621 pub url: String,
623 pub detail: Option<ImageUrlDetail>,
625}
626
627#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
628pub struct ImageContentPart {
629 pub image_url: ImageUrl,
630 #[serde(skip_serializing)]
632 pub dimensions: Option<(u32, u32)>,
633}
634
635#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
636pub struct AssistantMessage {
637 pub content: Option<String>,
639 #[serde(skip_serializing_if = "Option::is_none")]
641 pub name: Option<String>,
642 #[serde(skip_serializing_if = "Option::is_none")]
643 pub tool_calls: Option<Vec<ToolCall>>,
644}
645
646#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
647#[serde(tag = "type")]
648pub enum ToolCall {
649 #[serde(rename = "function")]
650 Function(ToolCallFunction),
651}
652
653impl ToolCall {
654 pub fn id(&self) -> &str {
655 match self {
656 ToolCall::Function(f) => &f.id,
657 }
658 }
659 pub fn name(&self) -> &str {
660 match self {
661 ToolCall::Function(f) => &f.function.name,
662 }
663 }
664}
665
666#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)]
667pub struct ToolCallFunction {
668 pub id: String,
670 pub function: ToolCallFunctionObj,
672}
673
674#[derive(Debug, Deserialize, Default, Serialize, Clone, PartialEq)]
675pub struct ToolCallFunctionObj {
676 pub name: String,
678 pub arguments: String,
680}
681
682#[derive(Debug, Serialize, Deserialize, Default, Clone, PartialEq)]
683pub struct ToolMessage {
684 pub content: String,
686 pub tool_call_id: String,
688}
689
690#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)]
697#[serde(rename_all = "lowercase")]
698pub enum ToolChoice {
699 #[default]
700 None,
701 Auto,
702 #[serde(untagged)]
703 Function(ToolChoiceFunction),
704}
705
706#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)]
708pub struct ToolChoiceFunction {
709 pub r#type: ToolType,
711 pub function: FunctionName,
712}
713
714#[derive(Clone, Serialize, Default, Debug, Deserialize, PartialEq)]
715pub struct FunctionName {
716 pub name: String,
718}
719
720#[derive(Debug, Deserialize, Default, Serialize, Clone, PartialEq)]
721#[serde(rename_all = "snake_case")]
722#[serde(tag = "type")]
723pub enum ResponseFormat {
724 #[default]
725 Text,
726 JsonObject,
727 JsonSchema {
728 description: Option<String>,
729 properties: Option<serde_json::Value>,
730 name: String,
731 strict: Option<bool>,
732 },
733}
734
735#[derive(Debug, Deserialize, Default, Serialize, Clone, PartialEq)]
736pub struct StreamOptions {
737 #[serde(skip_serializing_if = "Option::is_none")]
742 pub include_usage: Option<bool>,
743}
744
745#[derive(Debug, Deserialize, Default, Serialize, Clone, PartialEq)]
746pub struct PredictionConfig {
747 pub text: String,
749 #[serde(skip_serializing_if = "Option::is_none")]
751 pub logprobs: Option<Vec<f32>>,
752}
753
754#[derive(Debug, Deserialize, Default, Serialize, Clone, PartialEq)]
755pub struct AudioConfig {
756 pub voice: String,
759 #[serde(skip_serializing_if = "Option::is_none")]
762 pub format: Option<String>,
763 #[serde(skip_serializing_if = "Option::is_none")]
766 pub speed: Option<f32>,
767}
768
769#[derive(Debug, Deserialize, Default, Serialize, Clone, PartialEq)]
770pub struct WebSearchOptions {
771 #[serde(skip_serializing_if = "Option::is_none")]
774 pub search_context_size: Option<String>,
775
776 #[serde(skip_serializing_if = "Option::is_none")]
778 pub user_location: Option<WebSearchUserLocation>,
779}
780
781#[derive(Debug, Deserialize, Default, Serialize, Clone, PartialEq)]
782pub struct WebSearchUserLocation {
783 pub r#type: String,
785
786 pub approximate: WebSearchUserLocationApproximate,
788}
789
790#[derive(Debug, Deserialize, Default, Serialize, Clone, PartialEq)]
791pub struct WebSearchUserLocationApproximate {
792 #[serde(skip_serializing_if = "Option::is_none")]
794 pub city: Option<String>,
795
796 #[serde(skip_serializing_if = "Option::is_none")]
798 pub country: Option<String>,
799
800 #[serde(skip_serializing_if = "Option::is_none")]
802 pub region: Option<String>,
803
804 #[serde(skip_serializing_if = "Option::is_none")]
806 pub timezone: Option<String>,
807}
808
809#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
810#[serde(rename_all = "lowercase")]
811pub enum ReasoningEffort {
812 High,
813 Medium,
814 Low,
815 Minimal,
816}
817
818#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
819#[serde(rename_all = "lowercase")]
820pub enum Truncation {
821 Auto,
822 Disabled,
823}
824
825#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
826#[serde(rename_all = "lowercase")]
827pub enum Verbosity {
828 Low,
829 Medium,
830 High,
831}
832
833#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
834pub struct OpenRouterReasoning {
835 effort: ReasoningEffort,
836 exclude: bool,
837}
838
839#[cfg(test)]
840mod tests {
841 use super::*;
842
843 #[test]
844 fn serde() {
845 let tests = vec![(
846 "default",
847 r#"{"model":"gpt-3.5-turbo","messages":[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Hello!"}]}"#,
848 RequestBody {
849 model: "gpt-3.5-turbo".to_string(),
850 messages: vec![
851 Message::System(SystemMessage {
852 content: "You are a helpful assistant.".to_string(),
853 ..Default::default()
854 }),
855 Message::User(UserMessage {
856 content: Content::Text("Hello!".to_string()),
857 name: None,
858 }),
859 ],
860 ..Default::default()
861 },
862 ),
863 (
864 "image input",
865 r#"{"model": "gpt-4-vision-preview","messages": [{"role": "user","content": [{"type": "text","text": "What's in this image?"},{"type": "image_url","image_url": {"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"}}]}],"max_completion_tokens": 300}"#,
866 RequestBody{
867 model: "gpt-4-vision-preview".to_string(),
868 messages:vec![
869 Message::User(UserMessage{
870 content: Content::Array(vec![
871 ContentPart::Text(TextContentPart{
872 text: "What's in this image?".to_string()
873 }),
874 ContentPart::Image(ImageContentPart{
875 dimensions: None,
876 image_url: ImageUrl{
877 url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg".to_string(),
878 detail: None
879 }
880 })
881 ]),
882 name: None
883 })
884 ],
885 max_completion_tokens: Some(300),
886 ..Default::default()
887 }
888 ),
889 (
890 "streaming",
891 r#"{"model": "gpt-3.5-turbo","messages": [{"role": "system","content": "You are a helpful assistant."},{"role": "user","content": "Hello!"}],"stream": true}"#,
892 RequestBody{
893 model: "gpt-3.5-turbo".to_string(),
894 messages:vec![
895 Message::System(SystemMessage {
896 content: "You are a helpful assistant.".to_string(),
897 ..Default::default()
898 }),
899 Message::User(UserMessage {
900 content: Content::Text("Hello!".to_string()),
901 name: None,
902 }),
903 ],
904 stream: Some(true),
905 ..Default::default()
906 }
907 ),
908 (
909 "functions",
910 r#"{"model": "gpt-3.5-turbo","messages": [{"role": "user","content": "What is the weather like in Boston?"}],"tools": [{"type": "function","function": {"name": "get_current_weather","description": "Get the current weather in a given location","parameters": {"type": "object","properties": {"location": {"type": "string","description": "The city and state, e.g. San Francisco, CA"},"unit": {"type": "string","enum": ["celsius", "fahrenheit"]}},"required": ["location"]}}}],"tool_choice": "auto"}"#,
911 RequestBody{
912 model: "gpt-3.5-turbo".to_string(),
913 messages:vec![
914 Message::User(UserMessage {
915 content: Content::Text("What is the weather like in Boston?".to_string()),
916 name: None,
917 }),
918 ],
919 tools:Some(vec![Tool{
920 r#type:ToolType::Function,
921 function: FunctionTool{
922 name:"get_current_weather".to_string().into(),
923 description: Some("Get the current weather in a given location".to_string().into()),
924 parameters: Some(serde_json::json!({
925 "type": "object",
926 "properties": {
927 "location": {
928 "type": "string",
929 "description": "The city and state, e.g. San Francisco, CA"
930 },
931 "unit": {
932 "type": "string",
933 "enum": ["celsius", "fahrenheit"]
934 }
935 },
936 "required": ["location"]
937 }))
938 }
939 }]),
940 tool_choice: Some(ToolChoice::Auto),
941 ..Default::default()
942 }
943 ),
944 (
945 "logprobs",
946 r#"{"model": "gpt-3.5-turbo","messages": [{"role": "user","content": "Hello!"}],"logprobs": true,"top_logprobs": 2}"#,
947 RequestBody{
948 model: "gpt-3.5-turbo".to_string(),
949 messages:vec![
950 Message::User(UserMessage {
951 content: Content::Text("Hello!".to_string()),
952 name: None,
953 }),
954 ],
955 logprobs: Some(true),
956 top_logprobs: Some(2),
957 ..Default::default()
958 }
959 ),
960 ];
961 for (name, json, expected) in tests {
962 let actual: RequestBody = serde_json::from_str(json).unwrap();
964 assert_eq!(actual, expected, "deserialize test failed: {}", name);
965 let serialized = serde_json::to_string(&expected).unwrap();
967 let actual: RequestBody = serde_json::from_str(&serialized).unwrap();
968 assert_eq!(actual, expected, "serialize test failed: {}", name);
969 }
970 }
971
972 #[test]
973 fn test_first_user_message_text() {
974 let request_body = RequestBody {
976 model: "gpt-3.5-turbo".to_string(),
977 messages: vec![
978 Message::System(SystemMessage {
979 content: "You are a helpful assistant.".to_string(),
980 ..Default::default()
981 }),
982 Message::User(UserMessage {
983 content: Content::Text("Hello, how are you?".to_string()),
984 name: None,
985 }),
986 ],
987 ..Default::default()
988 };
989 assert_eq!(
990 request_body.first_user_message_text(),
991 Some("Hello, how are you?".to_string())
992 );
993
994 let request_body_with_array = RequestBody {
996 model: "gpt-4-vision-preview".to_string(),
997 messages: vec![Message::User(UserMessage {
998 content: Content::Array(vec![
999 ContentPart::Text(TextContentPart {
1000 text: "What's in this image?".to_string(),
1001 }),
1002 ContentPart::Text(TextContentPart {
1003 text: " Please describe it.".to_string(),
1004 }),
1005 ContentPart::Image(ImageContentPart {
1006 dimensions: None,
1007 image_url: ImageUrl {
1008 url: "https://example.com/image.jpg".to_string(),
1009 detail: None,
1010 },
1011 }),
1012 ]),
1013 name: None,
1014 })],
1015 ..Default::default()
1016 };
1017 assert_eq!(
1018 request_body_with_array.first_user_message_text(),
1019 Some("What's in this image? Please describe it.".to_string())
1020 );
1021
1022 let request_body_no_user = RequestBody {
1024 model: "gpt-3.5-turbo".to_string(),
1025 messages: vec![Message::System(SystemMessage {
1026 content: "You are a helpful assistant.".to_string(),
1027 ..Default::default()
1028 })],
1029 ..Default::default()
1030 };
1031 assert_eq!(request_body_no_user.first_user_message_text(), None);
1032 }
1033
1034 #[test]
1035 fn test_other_content_part() {
1036 let audio_json = r#"{"type":"audio","url":"https://example.com/audio.mp3","name":"recording.mp3"}"#;
1040 let audio_content: ContentPart = serde_json::from_str(audio_json).unwrap();
1041
1042 match &audio_content {
1043 ContentPart::Other(value) => {
1044 assert_eq!(value["type"], "audio");
1045 assert_eq!(value["url"], "https://example.com/audio.mp3");
1046 assert_eq!(value["name"], "recording.mp3");
1047 }
1048 _ => panic!("Expected Other variant for audio content"),
1049 }
1050
1051 let serialized = serde_json::to_string(&audio_content).unwrap();
1053 let deserialized: ContentPart = serde_json::from_str(&serialized).unwrap();
1054 assert_eq!(audio_content, deserialized);
1055
1056 let doc_json = r#"{"type":"document","url":"https://example.com/doc.pdf","mime_type":"application/pdf"}"#;
1058 let doc_content: ContentPart = serde_json::from_str(doc_json).unwrap();
1059
1060 match &doc_content {
1061 ContentPart::Other(value) => {
1062 assert_eq!(value["type"], "document");
1063 assert_eq!(value["url"], "https://example.com/doc.pdf");
1064 assert_eq!(value["mime_type"], "application/pdf");
1065 }
1066 _ => panic!("Expected Other variant for document content"),
1067 }
1068
1069 let video_json = r#"{"type":"video","url":"https://example.com/video.mp4","thumbnail":"https://example.com/thumb.jpg"}"#;
1071 let video_content: ContentPart = serde_json::from_str(video_json).unwrap();
1072
1073 match &video_content {
1074 ContentPart::Other(value) => {
1075 assert_eq!(value["type"], "video");
1076 assert_eq!(value["url"], "https://example.com/video.mp4");
1077 assert_eq!(value["thumbnail"], "https://example.com/thumb.jpg");
1078 }
1079 _ => panic!("Expected Other variant for video content"),
1080 }
1081
1082 let request_body = RequestBody {
1084 model: "gpt-4".to_string(),
1085 messages: vec![Message::User(UserMessage {
1086 content: Content::Array(vec![
1087 ContentPart::Text(TextContentPart {
1088 text: "Please analyze this audio file:".to_string(),
1089 }),
1090 audio_content.clone(),
1091 ]),
1092 name: None,
1093 })],
1094 ..Default::default()
1095 };
1096
1097 let serialized = serde_json::to_string(&request_body).unwrap();
1099 let deserialized: RequestBody = serde_json::from_str(&serialized).unwrap();
1100 assert_eq!(request_body, deserialized);
1101
1102 assert_eq!(
1104 request_body.first_user_message_text(),
1105 Some("Please analyze this audio file:".to_string())
1106 );
1107
1108 let text_json = r#"{"type":"text","text":"Hello world"}"#;
1110 let text_content: ContentPart = serde_json::from_str(text_json).unwrap();
1111 match text_content {
1112 ContentPart::Text(text_part) => {
1113 assert_eq!(text_part.text, "Hello world");
1114 }
1115 _ => panic!("Expected Text variant"),
1116 }
1117
1118 let image_json = r#"{"type":"image_url","image_url":{"url":"https://example.com/image.jpg"}}"#;
1119 let image_content: ContentPart = serde_json::from_str(image_json).unwrap();
1120 match image_content {
1121 ContentPart::Image(image_part) => {
1122 assert_eq!(image_part.image_url.url, "https://example.com/image.jpg");
1123 }
1124 _ => panic!("Expected Image variant"),
1125 }
1126 }
1127
1128 }