1use crate::error::{Error, Result};
2use crate::message::{Content, ContentPart, Message};
3use crate::provider::HTTPProvider;
4use crate::{Chat, LlmToolInfo, OpenAi};
5use reqwest::{Method, Request, Url};
6use serde::{Deserialize, Serialize};
7use std::env;
8use tracing::{debug, error, info, instrument, trace, warn};
9
10#[derive(Debug, Clone)]
12pub struct OpenAIConfig {
13 pub api_key: String,
15 pub base_url: String,
17 pub organization: Option<String>,
19}
20
21impl Default for OpenAIConfig {
22 fn default() -> Self {
23 Self {
24 api_key: env::var("OPENAI_API_KEY").unwrap_or_default(),
25 base_url: "https://api.openai.com/v1".to_string(),
26 organization: env::var("OPENAI_ORGANIZATION").ok(),
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct OpenAIProvider {
34 config: OpenAIConfig,
36}
37
38impl OpenAIProvider {
39 #[instrument(level = "debug")]
51 pub fn new() -> Self {
52 info!("Creating new OpenAIProvider with default configuration");
53 let config = OpenAIConfig::default();
54 debug!("API key set: {}", !config.api_key.is_empty());
55 debug!("Base URL: {}", config.base_url);
56 debug!("Organization set: {}", config.organization.is_some());
57
58 Self { config }
59 }
60
61 #[instrument(skip(config), level = "debug")]
77 pub fn with_config(config: OpenAIConfig) -> Self {
78 info!("Creating new OpenAIProvider with custom configuration");
79 debug!("API key set: {}", !config.api_key.is_empty());
80 debug!("Base URL: {}", config.base_url);
81 debug!("Organization set: {}", config.organization.is_some());
82
83 Self { config }
84 }
85}
86
87impl Default for OpenAIProvider {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl HTTPProvider<OpenAi> for OpenAIProvider {
94 fn accept(&self, model: OpenAi, chat: &Chat) -> Result<Request> {
95 info!("Creating request for OpenAI model: {:?}", model);
96 debug!("Messages in chat history: {}", chat.history.len());
97
98 let url_str = format!("{}/chat/completions", self.config.base_url);
99 debug!("Parsing URL: {}", url_str);
100 let url = match Url::parse(&url_str) {
101 Ok(url) => {
102 debug!("URL parsed successfully: {}", url);
103 url
104 }
105 Err(e) => {
106 error!("Failed to parse URL '{}': {}", url_str, e);
107 return Err(e.into());
108 }
109 };
110
111 let mut request = Request::new(Method::POST, url);
112 debug!("Created request: {} {}", request.method(), request.url());
113
114 debug!("Setting request headers");
116
117 let auth_header = match format!("Bearer {}", self.config.api_key).parse() {
119 Ok(header) => header,
120 Err(e) => {
121 error!("Invalid API key format: {}", e);
122 return Err(Error::Authentication("Invalid API key format".into()));
123 }
124 };
125
126 let content_type_header = match "application/json".parse() {
127 Ok(header) => header,
128 Err(e) => {
129 error!("Failed to set content type: {}", e);
130 return Err(Error::Other("Failed to set content type".into()));
131 }
132 };
133
134 request.headers_mut().insert("Authorization", auth_header);
135 request
136 .headers_mut()
137 .insert("Content-Type", content_type_header);
138
139 if let Some(org) = &self.config.organization {
141 match org.parse() {
142 Ok(header) => {
143 request.headers_mut().insert("OpenAI-Organization", header);
144 debug!("Added organization header");
145 }
146 Err(e) => {
147 warn!("Failed to set organization header: {}", e);
148 }
150 }
151 }
152
153 trace!("Request headers set: {:#?}", request.headers());
154
155 debug!("Creating request payload");
157 let payload = match self.create_request_payload(model, chat) {
158 Ok(payload) => {
159 debug!("Request payload created successfully");
160 trace!("Model: {}", payload.model);
161 trace!("Max tokens: {:?}", payload.max_tokens);
162 trace!("Number of messages: {}", payload.messages.len());
163 payload
164 }
165 Err(e) => {
166 error!("Failed to create request payload: {}", e);
167 return Err(e);
168 }
169 };
170
171 debug!("Serializing request payload");
173 let body_bytes = match serde_json::to_vec(&payload) {
174 Ok(bytes) => {
175 debug!("Payload serialized successfully ({} bytes)", bytes.len());
176 bytes
177 }
178 Err(e) => {
179 error!("Failed to serialize payload: {}", e);
180 return Err(Error::Serialization(e));
181 }
182 };
183
184 *request.body_mut() = Some(body_bytes.into());
185 info!("Request created successfully");
186
187 Ok(request)
188 }
189
190 fn parse(&self, raw_response_text: String) -> Result<Message> {
191 info!("Parsing response from OpenAI API");
192 trace!("Raw response: {}", raw_response_text);
193
194 if let Ok(error_response) = serde_json::from_str::<OpenAIErrorResponse>(&raw_response_text)
196 {
197 if let Some(error) = error_response.error {
198 error!("OpenAI API returned an error: {}", error.message);
199 return Err(Error::ProviderUnavailable(error.message));
200 }
201 }
202
203 debug!("Deserializing response JSON");
205 let openai_response = match serde_json::from_str::<OpenAIResponse>(&raw_response_text) {
206 Ok(response) => {
207 debug!("Response deserialized successfully");
208 debug!("Response model: {}", response.model);
209 if !response.choices.is_empty() {
210 debug!("Number of choices: {}", response.choices.len());
211 debug!(
212 "First choice finish reason: {:?}",
213 response.choices[0].finish_reason
214 );
215 }
216 if let Some(usage) = &response.usage {
217 debug!(
218 "Token usage - prompt: {}, completion: {}, total: {}",
219 usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
220 );
221 }
222 response
223 }
224 Err(e) => {
225 error!("Failed to deserialize response: {}", e);
226 error!("Raw response: {}", raw_response_text);
227 return Err(Error::Serialization(e));
228 }
229 };
230
231 debug!("Converting OpenAI response to Message");
233 let message = Message::from(&openai_response);
234
235 info!("Response parsed successfully");
236 trace!("Response message processed");
237
238 Ok(message)
239 }
240}
241
242pub trait OpenAIModelInfo {
244 fn openai_model_id(&self) -> String;
245}
246
247impl OpenAIProvider {
248 #[instrument(skip(self, chat), level = "debug")]
253 fn create_request_payload(&self, model: OpenAi, chat: &Chat) -> Result<OpenAIRequest> {
254 info!("Creating request payload for chat with OpenAI model");
255 debug!("System prompt length: {}", chat.system_prompt.len());
256 debug!("Messages in history: {}", chat.history.len());
257 debug!("Max output tokens: {}", chat.max_output_tokens);
258
259 let model_id = model.openai_model_id();
260 debug!("Using model ID: {}", model_id);
261
262 debug!("Converting messages to OpenAI format");
264 let mut messages: Vec<OpenAIMessage> = Vec::new();
265
266 if !chat.system_prompt.is_empty() {
268 debug!("Adding system prompt");
269 messages.push(OpenAIMessage {
270 role: "system".to_string(),
271 content: Some(chat.system_prompt.clone()),
272 function_call: None,
273 name: None,
274 tool_calls: None,
275 tool_call_id: None,
276 });
277 }
278
279 for msg in &chat.history {
281 debug!("Converting message with role: {}", msg.role_str());
282 messages.push(OpenAIMessage::from(msg));
283 }
284
285 debug!("Converted {} messages for the request", messages.len());
304
305 let tools = chat
307 .tools
308 .as_ref()
309 .map(|tools| tools.iter().map(OpenAITool::from).collect());
310
311 let tool_choice = if let Some(choice) = &chat.tool_choice {
313 match choice {
315 crate::tool::ToolChoice::Auto => Some(serde_json::json!("auto")),
316 crate::tool::ToolChoice::Any => Some(serde_json::json!("required")),
318 crate::tool::ToolChoice::None => Some(serde_json::json!("none")),
319 crate::tool::ToolChoice::Specific(name) => {
320 Some(serde_json::json!({
322 "type": "function",
323 "function": {
324 "name": name
325 }
326 }))
327 }
328 }
329 } else if tools.is_some() {
330 Some(serde_json::json!("auto"))
332 } else {
333 None
334 };
335
336 debug!("Creating OpenAIRequest");
338
339 let is_o_series = model_id.starts_with("o");
341
342 let request = OpenAIRequest {
343 model: model_id,
344 messages,
345 temperature: None,
346 top_p: None,
347 n: None,
348 max_tokens: if is_o_series {
350 None
351 } else {
352 Some(chat.max_output_tokens)
353 },
354 max_completion_tokens: if is_o_series {
355 Some(chat.max_output_tokens)
356 } else {
357 None
358 },
359 presence_penalty: None,
360 frequency_penalty: None,
361 stream: None,
362 tools,
363 tool_choice,
364 };
365
366 info!("Request payload created successfully");
367 Ok(request)
368 }
369}
370
371#[derive(Debug, Clone, Serialize, Deserialize)]
373pub(crate) struct OpenAIMessage {
374 pub role: String,
376 #[serde(skip_serializing_if = "Option::is_none")]
378 pub content: Option<String>,
379 #[serde(skip_serializing_if = "Option::is_none")]
381 pub function_call: Option<OpenAIFunctionCall>,
382 #[serde(skip_serializing_if = "Option::is_none")]
384 pub name: Option<String>,
385 #[serde(skip_serializing_if = "Option::is_none")]
387 pub tool_calls: Option<Vec<OpenAIToolCall>>,
388 #[serde(skip_serializing_if = "Option::is_none")]
390 pub tool_call_id: Option<String>,
391}
392
393#[derive(Debug, Serialize, Deserialize)]
395pub(crate) struct OpenAIFunction {
396 pub name: String,
398 pub description: String,
400 pub parameters: serde_json::Value,
402}
403
404#[derive(Debug, Serialize, Deserialize)]
406pub(crate) struct OpenAITool {
407 pub r#type: String,
409 pub function: OpenAIFunction,
411}
412
413impl From<&LlmToolInfo> for OpenAITool {
414 fn from(value: &LlmToolInfo) -> Self {
415 OpenAITool {
416 r#type: "function".to_string(),
417 function: OpenAIFunction {
418 name: value.name.clone(),
419 description: value.description.clone(),
420 parameters: value.parameters.clone(),
421 },
422 }
423 }
424}
425
426#[derive(Debug, Clone, Serialize, Deserialize)]
428pub(crate) struct OpenAIFunctionCall {
429 pub name: String,
431 pub arguments: String,
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize)]
437pub(crate) struct OpenAIToolCall {
438 pub id: String,
440 pub r#type: String,
442 pub function: OpenAIFunctionCall,
444}
445
446#[derive(Debug, Serialize, Deserialize)]
448pub(crate) struct OpenAIRequest {
449 pub model: String,
451 pub messages: Vec<OpenAIMessage>,
453 #[serde(skip_serializing_if = "Option::is_none")]
455 pub temperature: Option<f32>,
456 #[serde(skip_serializing_if = "Option::is_none")]
458 pub top_p: Option<f32>,
459 #[serde(skip_serializing_if = "Option::is_none")]
461 pub n: Option<usize>,
462 #[serde(skip_serializing_if = "Option::is_none")]
464 pub max_tokens: Option<usize>,
465 #[serde(skip_serializing_if = "Option::is_none")]
467 pub max_completion_tokens: Option<usize>,
468 #[serde(skip_serializing_if = "Option::is_none")]
470 pub presence_penalty: Option<f32>,
471 #[serde(skip_serializing_if = "Option::is_none")]
473 pub frequency_penalty: Option<f32>,
474 #[serde(skip_serializing_if = "Option::is_none")]
476 pub stream: Option<bool>,
477 #[serde(skip_serializing_if = "Option::is_none")]
479 pub tools: Option<Vec<OpenAITool>>,
480 #[serde(skip_serializing_if = "Option::is_none")]
482 pub tool_choice: Option<serde_json::Value>,
483}
484
485#[derive(Debug, Serialize, Deserialize)]
487pub(crate) struct OpenAIResponse {
488 pub id: String,
490 pub object: String,
492 pub created: u64,
494 pub model: String,
496 pub choices: Vec<OpenAIChoice>,
498 pub usage: Option<OpenAIUsage>,
500}
501
502#[derive(Debug, Serialize, Deserialize)]
504pub(crate) struct OpenAIChoice {
505 pub index: usize,
507 pub message: OpenAIMessage,
509 pub finish_reason: Option<String>,
511}
512
513#[derive(Debug, Serialize, Deserialize)]
515pub(crate) struct OpenAIUsage {
516 pub prompt_tokens: u32,
518 pub completion_tokens: u32,
520 pub total_tokens: u32,
522}
523
524#[derive(Debug, Serialize, Deserialize)]
526pub(crate) struct OpenAIErrorResponse {
527 pub error: Option<OpenAIError>,
529}
530
531#[derive(Debug, Serialize, Deserialize)]
533pub(crate) struct OpenAIError {
534 pub message: String,
536 #[serde(rename = "type")]
538 pub error_type: String,
539 #[serde(skip_serializing_if = "Option::is_none")]
541 pub code: Option<String>,
542}
543
544impl From<&Message> for OpenAIMessage {
546 fn from(msg: &Message) -> Self {
547 let role = match msg {
548 Message::System { .. } => "system",
549 Message::User { .. } => "user",
550 Message::Assistant { .. } => "assistant",
551 Message::Tool { .. } => "tool",
552 }
553 .to_string();
554
555 let (content, name, function_call, tool_calls, tool_call_id) = match msg {
556 Message::System { content, .. } => (Some(content.clone()), None, None, None, None),
557 Message::User { content, name, .. } => {
558 let content_str = match content {
559 Content::Text(text) => Some(text.clone()),
560 Content::Parts(parts) => {
561 let combined_text = parts
563 .iter()
564 .filter_map(|part| match part {
565 ContentPart::Text { text } => Some(text.clone()),
566 _ => None,
567 })
568 .collect::<Vec<String>>()
569 .join("\n");
570
571 if combined_text.is_empty() {
572 None
573 } else {
574 Some(combined_text)
575 }
576 }
577 };
578 (content_str, name.clone(), None, None, None)
579 }
580 Message::Assistant {
581 content,
582 tool_calls,
583 ..
584 } => {
585 let content_str = match content {
586 Some(Content::Text(text)) => Some(text.clone()),
587 Some(Content::Parts(parts)) => {
588 let combined_text = parts
590 .iter()
591 .filter_map(|part| match part {
592 ContentPart::Text { text } => Some(text.clone()),
593 _ => None,
594 })
595 .collect::<Vec<String>>()
596 .join("\n");
597
598 if combined_text.is_empty() {
599 None
600 } else {
601 Some(combined_text)
602 }
603 }
604 None => None,
605 };
606
607 let openai_tool_calls = if !tool_calls.is_empty() {
609 let mut calls = Vec::with_capacity(tool_calls.len());
610
611 for tc in tool_calls {
612 calls.push(OpenAIToolCall {
613 id: tc.id.clone(),
614 r#type: tc.tool_type.clone(),
615 function: OpenAIFunctionCall {
616 name: tc.function.name.clone(),
617 arguments: tc.function.arguments.clone(),
618 },
619 });
620 }
621
622 Some(calls)
623 } else {
624 None
625 };
626
627 (content_str, None, None, openai_tool_calls, None)
628 }
629 Message::Tool {
630 tool_call_id,
631 content,
632 ..
633 } => (
634 Some(content.clone()),
635 None,
636 None,
637 None,
638 Some(tool_call_id.clone()),
639 ),
640 };
641
642 OpenAIMessage {
643 role,
644 content,
645 function_call,
646 name,
647 tool_calls,
648 tool_call_id,
649 }
650 }
651}
652
653impl From<&OpenAIResponse> for Message {
655 fn from(response: &OpenAIResponse) -> Self {
656 if response.choices.is_empty() {
658 return Message::assistant("No response generated");
659 }
660
661 let choice = &response.choices[0];
662 let message = &choice.message;
663
664 let mut msg = match message.role.as_str() {
666 "assistant" => {
667 let content = message
668 .content
669 .as_ref()
670 .map(|text| Content::Text(text.clone()));
671
672 if let Some(openai_tool_calls) = &message.tool_calls {
674 if !openai_tool_calls.is_empty() {
675 let mut tool_calls = Vec::with_capacity(openai_tool_calls.len());
676
677 for call in openai_tool_calls {
678 let tool_call = crate::message::ToolCall {
679 id: call.id.clone(),
680 tool_type: call.r#type.clone(),
681 function: crate::message::Function {
682 name: call.function.name.clone(),
683 arguments: call.function.arguments.clone(),
684 },
685 };
686 tool_calls.push(tool_call);
687 }
688
689 Message::Assistant {
690 content,
691 tool_calls,
692 metadata: Default::default(),
693 }
694 } else {
695 if let Some(Content::Text(text)) = content {
697 Message::assistant(text)
698 } else {
699 Message::Assistant {
700 content,
701 tool_calls: Vec::new(),
702 metadata: Default::default(),
703 }
704 }
705 }
706 } else if let Some(fc) = &message.function_call {
707 let tool_call = crate::message::ToolCall {
709 id: format!("legacy_function_{}", fc.name),
710 tool_type: "function".to_string(),
711 function: crate::message::Function {
712 name: fc.name.clone(),
713 arguments: fc.arguments.clone(),
714 },
715 };
716
717 Message::Assistant {
718 content,
719 tool_calls: vec![tool_call],
720 metadata: Default::default(),
721 }
722 } else {
723 if let Some(Content::Text(text)) = content {
725 Message::assistant(text)
726 } else {
727 Message::Assistant {
728 content,
729 tool_calls: Vec::new(),
730 metadata: Default::default(),
731 }
732 }
733 }
734 }
735 "user" => {
736 if let Some(name) = &message.name {
737 if let Some(content) = &message.content {
738 Message::user_with_name(name, content)
739 } else {
740 Message::user_with_name(name, "")
741 }
742 } else if let Some(content) = &message.content {
743 Message::user(content)
744 } else {
745 Message::user("")
746 }
747 }
748 "system" => {
749 if let Some(content) = &message.content {
750 Message::system(content)
751 } else {
752 Message::system("")
753 }
754 }
755 "tool" => {
756 if let Some(tool_call_id) = &message.tool_call_id {
757 if let Some(content) = &message.content {
758 Message::tool(tool_call_id, content)
759 } else {
760 Message::tool(tool_call_id, "")
761 }
762 } else {
763 if let Some(content) = &message.content {
765 Message::user(content)
766 } else {
767 Message::user("")
768 }
769 }
770 }
771 _ => {
772 if let Some(content) = &message.content {
774 Message::user(content)
775 } else {
776 Message::user("")
777 }
778 }
779 };
780
781 if let Some(usage) = &response.usage {
783 msg = msg.with_metadata(
784 "prompt_tokens",
785 serde_json::Value::Number(usage.prompt_tokens.into()),
786 );
787 msg = msg.with_metadata(
788 "completion_tokens",
789 serde_json::Value::Number(usage.completion_tokens.into()),
790 );
791 msg = msg.with_metadata(
792 "total_tokens",
793 serde_json::Value::Number(usage.total_tokens.into()),
794 );
795 }
796
797 msg
798 }
799}
800
801#[cfg(test)]
802mod tests {
803 use super::*;
804
805 #[test]
806 fn test_message_conversion() {
807 let msg = Message::user("Hello, world!");
809 let openai_msg = OpenAIMessage::from(&msg);
810
811 assert_eq!(openai_msg.role, "user");
812 assert_eq!(openai_msg.content, Some("Hello, world!".to_string()));
813
814 let msg = Message::system("You are a helpful assistant.");
816 let openai_msg = OpenAIMessage::from(&msg);
817
818 assert_eq!(openai_msg.role, "system");
819 assert_eq!(
820 openai_msg.content,
821 Some("You are a helpful assistant.".to_string())
822 );
823
824 let msg = Message::assistant("I can help with that.");
826 let openai_msg = OpenAIMessage::from(&msg);
827
828 assert_eq!(openai_msg.role, "assistant");
829 assert_eq!(
830 openai_msg.content,
831 Some("I can help with that.".to_string())
832 );
833
834 let tool_call = crate::message::ToolCall {
836 id: "tool_123".to_string(),
837 tool_type: "function".to_string(),
838 function: crate::message::Function {
839 name: "get_weather".to_string(),
840 arguments: "{\"location\":\"San Francisco\"}".to_string(),
841 },
842 };
843
844 let msg = Message::Assistant {
845 content: Some(Content::Text("I'll check the weather".to_string())),
846 tool_calls: vec![tool_call],
847 metadata: Default::default(),
848 };
849
850 let openai_msg = OpenAIMessage::from(&msg);
851
852 assert_eq!(openai_msg.role, "assistant");
853 assert_eq!(
854 openai_msg.content,
855 Some("I'll check the weather".to_string())
856 );
857 assert!(openai_msg.tool_calls.is_some());
858 let tool_calls = openai_msg.tool_calls.unwrap();
859 assert_eq!(tool_calls.len(), 1);
860 assert_eq!(tool_calls[0].id, "tool_123");
861 assert_eq!(tool_calls[0].function.name, "get_weather");
862 }
863
864 #[test]
865 fn test_error_response_parsing() {
866 let error_json = r#"{
867 "error": {
868 "message": "The model does not exist",
869 "type": "invalid_request_error",
870 "code": "model_not_found"
871 }
872 }"#;
873
874 let error_response: OpenAIErrorResponse = serde_json::from_str(error_json).unwrap();
875 assert!(error_response.error.is_some());
876 let error = error_response.error.unwrap();
877 assert_eq!(error.error_type, "invalid_request_error");
878 assert_eq!(error.code, Some("model_not_found".to_string()));
879 }
880
881 fn get_weather_tool_info() -> crate::tool::LlmToolInfo {
891 use serde_json::json;
892
893 crate::tool::LlmToolInfo {
894 name: "get_weather".to_string(),
895 description: "Get current temperature for a given location.".to_string(),
896 parameters: json!({
897 "type": "object",
898 "properties": {
899 "location": {
900 "type": "string",
901 "description": "City and country e.g. Bogotá, Colombia"
902 }
903 },
904 "required": ["location"],
905 "additionalProperties": false
906 }),
907 }
908 }
909
910 fn base_chat_with_tool() -> crate::Chat {
912 crate::Chat::default().with_tools(vec![get_weather_tool_info()])
913 }
914
915 #[test]
919 fn test_stage1_user_only_serialization() {
920 use crate::model::OpenAi;
921 use crate::message::Message;
922
923 let chat = base_chat_with_tool()
924 .add_message(Message::user("What is the weather like in Paris today?"));
925
926 let provider = OpenAIProvider::new();
927 let request = provider
928 .create_request_payload(OpenAi::GPT35Turbo, &chat)
929 .expect("payload generation failed");
930
931 assert_eq!(request.messages.len(), 1);
933 let msg = &request.messages[0];
934 assert_eq!(msg.role, "user");
935 assert_eq!(msg.content.as_deref(), Some("What is the weather like in Paris today?"));
936 assert!(msg.tool_calls.is_none());
937 assert!(msg.tool_call_id.is_none());
938
939 let tools = request.tools.expect("tools should be present");
941 assert!(tools.iter().any(|t| t.function.name == "get_weather"));
942
943 assert_eq!(request.tool_choice, Some(serde_json::json!("auto")));
945 }
946
947 #[test]
951 fn test_stage2_assistant_tool_call_serialization() {
952 use crate::model::OpenAi;
953 use crate::message::{Function, Message, ToolCall};
954
955 const CALL_ID: &str = "call_19InQqbLUTQIuc6MlV5QSogY";
956
957 let assistant_msg = Message::assistant_with_tool_calls(vec![ToolCall {
959 id: CALL_ID.to_string(),
960 tool_type: "function".to_string(),
961 function: Function {
962 name: "get_weather".to_string(),
963 arguments: "{\"location\":\"Paris, France\"}".to_string(),
964 },
965 }]);
966
967 let chat = base_chat_with_tool()
968 .add_message(Message::user("What is the weather like in Paris today?"))
969 .add_message(assistant_msg);
970
971 let provider = OpenAIProvider::new();
972 let request = provider
973 .create_request_payload(OpenAi::GPT35Turbo, &chat)
974 .expect("payload generation failed");
975
976 assert_eq!(request.messages.len(), 2);
977
978 assert_eq!(request.messages[0].role, "user");
980 let assistant = &request.messages[1];
981 assert_eq!(assistant.role, "assistant");
982 assert!(assistant.content.is_none());
983
984 let calls = assistant.tool_calls.as_ref().expect("tool_calls missing");
986 assert_eq!(calls.len(), 1);
987 let call = &calls[0];
988 assert_eq!(call.id, CALL_ID);
989 assert_eq!(call.r#type, "function");
990 assert_eq!(call.function.name, "get_weather");
991 assert_eq!(call.function.arguments, "{\"location\":\"Paris, France\"}");
992 }
993
994 #[test]
998 fn test_stage3_tool_response_serialization() {
999 use crate::model::OpenAi;
1000 use crate::message::{Function, Message, ToolCall};
1001
1002 const CALL_ID: &str = "call_19InQqbLUTQIuc6MlV5QSogY";
1003
1004 let assistant_msg = Message::assistant_with_tool_calls(vec![ToolCall {
1005 id: CALL_ID.to_string(),
1006 tool_type: "function".to_string(),
1007 function: Function {
1008 name: "get_weather".to_string(),
1009 arguments: "{\"location\":\"Paris, France\"}".to_string(),
1010 },
1011 }]);
1012
1013 let tool_msg = Message::tool(CALL_ID, "10C");
1014
1015 let chat = base_chat_with_tool()
1016 .add_message(Message::user("What is the weather like in Paris today?"))
1017 .add_message(assistant_msg)
1018 .add_message(tool_msg);
1019
1020 let provider = OpenAIProvider::new();
1021 let request = provider
1022 .create_request_payload(OpenAi::GPT35Turbo, &chat)
1023 .expect("payload generation failed");
1024
1025 assert_eq!(request.messages.len(), 3);
1027 assert_eq!(request.messages[0].role, "user");
1028 assert_eq!(request.messages[1].role, "assistant");
1029 assert_eq!(request.messages[2].role, "tool");
1030
1031 let tool = &request.messages[2];
1033 assert_eq!(tool.tool_call_id.as_deref(), Some(CALL_ID));
1034 assert_eq!(tool.content.as_deref(), Some("10C"));
1035 }
1036
1037 #[test]
1040 fn test_stage4_full_conversation_serialization() {
1041 use crate::model::OpenAi;
1042 use crate::message::{Function, Message, ToolCall};
1043
1044 const CALL_ID: &str = "call_19InQqbLUTQIuc6MlV5QSogY";
1045 let user_msg = Message::user("What is the weather like in Paris today?");
1046
1047 let assistant_call = Message::assistant_with_tool_calls(vec![ToolCall {
1048 id: CALL_ID.to_string(),
1049 tool_type: "function".to_string(),
1050 function: Function {
1051 name: "get_weather".to_string(),
1052 arguments: "{\"location\":\"Paris, France\"}".to_string(),
1053 },
1054 }]);
1055
1056 let tool_msg = Message::tool(CALL_ID, "10C");
1057
1058 let final_assistant = Message::assistant("The weather in Paris today is 10°C. Let me know if you need more details or the forecast for the coming days!");
1059
1060 let chat = base_chat_with_tool()
1061 .add_message(user_msg)
1062 .add_message(assistant_call)
1063 .add_message(tool_msg)
1064 .add_message(final_assistant);
1065
1066 let provider = OpenAIProvider::new();
1067 let request = provider
1068 .create_request_payload(OpenAi::GPT35Turbo, &chat)
1069 .expect("payload generation failed");
1070
1071 let roles: Vec<_> = request.messages.iter().map(|m| m.role.as_str()).collect();
1073 assert_eq!(roles, vec!["user", "assistant", "tool", "assistant"]);
1074
1075 let assistant_after_tool = &request.messages[3];
1076 assert_eq!(assistant_after_tool.role, "assistant");
1077 assert_eq!(assistant_after_tool.content.as_deref(), Some("The weather in Paris today is 10°C. Let me know if you need more details or the forecast for the coming days!"));
1078 assert!(assistant_after_tool.tool_calls.is_none());
1079 }
1080
1081 }