call_agent/chat/
client.rs

1use std::{collections::{HashMap, VecDeque}, sync::Arc};
2
3use reqwest::{Client, Response};
4
5use crate::chat::api::WebSearchOptions;
6
7use super::{
8    api::{APIRequest, APIResponse, APIResponseHeaders},
9    err::ClientError,
10    function::{FunctionCall, FunctionDef, Tool, ToolDef},
11    prompt::{Message, MessageContext},
12};
13
14/// Main client structure for interacting with the OpenAI API.
15#[derive(Clone)]
16pub struct OpenAIClient {
17    /// HTTP client
18    pub client: Client,
19    /// API endpoint
20    pub end_point: String,
21    /// Optional API key
22    pub api_key: Option<String>,
23    /// Registered tools: key is the tool name, value is a tuple (tool, is_enabled)
24    pub tools: HashMap<String, (Arc<dyn Tool + Send + Sync>, bool)>,
25    /// Configuration for the model request.
26    pub model_config: Option<ModelConfig>,
27}
28
29/// Configuration for the model request.
30#[derive(Debug, Clone)]
31pub struct ModelConfig {
32    /// Model name.
33    pub model: String,
34    /// Optional model name.
35    pub model_name: Option<String>,
36    /// Top-p sampling parameter.
37    pub top_p: Option<f64>,
38    /// Specifies whether to perform parallel ToolCalls.
39    /// default: true
40    pub parallel_tool_calls: Option<bool>,
41    /// Specifies the diversity of tokens generated by the model.
42    pub temperature: Option<f64>,
43    /// Specifies the maximum number of tokens generated by the model.
44    pub max_completion_tokens: Option<u64>,
45    /// Specifies the level of effort for reasoning in the inference model:
46    /// - "low": Low effort
47    /// - "medium": Medium effort
48    /// - "high": High effort
49    /// default: "medium"
50    pub reasoning_effort: Option<String>,
51    /// Specifies whether to apply a presence penalty to the model.
52    /// Range: 2.0..-2.0
53    pub presence_penalty: Option<f64>,
54    /// Strictly structured
55    /// default: false
56    /// Forced disabled in parallel ToolCalls
57    pub strict: Option<bool>,
58    /// Options for performing web search with available models.
59    pub web_search_options: Option<WebSearchOptions>,
60}
61
62/// Contains the API response and its headers.
63#[derive(Debug, Clone)]
64pub struct APIResult {
65    /// The parsed API response.
66    pub response: APIResponse,
67    /// Headers returned by the API.
68    pub headers: APIResponseHeaders,
69}
70
71impl OpenAIClient {
72    /// Create a new OpenAIClient.
73    ///
74    /// # Arguments
75    ///
76    /// * `end_point` - The endpoint of the OpenAI API.
77    /// * `api_key` - Optional API key.
78    pub fn new(end_point: &str, api_key: Option<&str>) -> Self {
79        Self {
80            client: Client::new(),
81            end_point: end_point.trim_end_matches('/').to_string(),
82            api_key: api_key.map(|s| s.to_string()),
83            tools: HashMap::new(),
84            model_config: None,
85        }
86    }
87
88    /// Set the default model configuration.
89    /// 
90    /// # Arguments
91    /// 
92    /// * `model_config` - The model configuration.
93    pub fn set_model_config(&mut self, model_config: &ModelConfig) {
94        self.model_config = Some(model_config.clone());
95    }
96
97    /// Register a tool.
98    ///
99    /// If a tool with the same name already exists, it will be overwritten.
100    ///
101    /// # Arguments
102    ///
103    /// * `tool` - Reference-counted tool implementing the Tool trait.
104    pub fn def_tool<T: Tool + Send + Sync + 'static>(&mut self, tool: Arc<T>) {
105        self.tools
106            .insert(tool.def_name().to_string(), (tool, true));
107    }
108
109    /// List all registered tools.
110    ///
111    /// # Returns
112    ///
113    /// A list of tuples containing (tool name, tool description, enabled flag).
114    pub fn list_tools(&self) -> Vec<(String, String, bool)> {
115        let mut tools = Vec::new();
116        for (tool_name, (tool, enable)) in self.tools.iter() {
117            tools.push((
118                tool_name.to_string(),
119                tool.def_description().to_string(),
120                *enable,
121            ));
122        }
123        tools
124    }
125
126    /// Switch the enable/disable state of a tool.
127    ///
128    /// # Arguments
129    ///
130    /// * `tool_name` - The name of the tool.
131    /// * `t_enable` - True to enable, false to disable.
132    pub fn switch_tool(&mut self, tool_name: &str, t_enable: bool) {
133        if let Some((_, enable)) = self.tools.get_mut(tool_name) {
134            *enable = t_enable;
135        }
136    }
137
138    /// Export the definitions of all enabled tools.
139    ///
140    /// # Returns
141    ///
142    /// A vector of function definitions.
143    pub fn export_tool_def(&self) -> Result<Vec<ToolDef>, ClientError> {
144        let mut defs = Vec::new();
145        for (tool_name, (tool, enable)) in self.tools.iter() {
146            if *enable {
147                defs.push(ToolDef {
148                    tool_type: "function".to_string(),
149                    function: FunctionDef {
150                        name: tool_name.clone(),
151                        description: tool.def_description().to_string(),
152                        parameters: tool.def_parameters(),
153                        strict: self.model_config.as_ref().ok_or(ClientError::ModelConfigNotSet)?.strict.unwrap_or(false),
154                    },
155                });
156            }
157        }
158        Ok(defs)
159    }
160
161    /// Send a chat request to the API.
162    ///
163    /// # Arguments
164    ///
165    /// * `model` - The model configuration.
166    /// * `prompt` - A vector of user and system messages.
167    ///
168    /// # Returns
169    ///
170    /// The API result or a ClientError.
171    pub async fn send(
172        &self,
173        prompt: &VecDeque<Message>,
174        model: Option<&ModelConfig>,
175    ) -> Result<APIResult, ClientError> {
176        match self
177            .call_api(
178                prompt,
179                Some(&serde_json::json!("none")),
180                model,
181            )
182            .await
183        {
184            Ok(res) => Ok(res),
185            Err(e) => Err(e),
186        }
187    }
188
189    /// Send a chat request with tool auto-selection.
190    ///
191    /// # Arguments
192    ///
193    /// * `model` - The model configuration.
194    /// * `prompt` - A vector of messages.
195    ///
196    /// # Returns
197    ///
198    /// The API result or a ClientError.
199    pub async fn send_can_use_tool(
200        &self,
201        prompt: &VecDeque<Message>,
202        model: Option<&ModelConfig>,
203    ) -> Result<APIResult, ClientError> {
204        match self
205            .call_api(
206                prompt,
207                Some(&serde_json::json!("auto")),
208                model,
209            )
210            .await
211        {
212            Ok(res) => Ok(res),
213            Err(e) => Err(e),
214        }
215    }
216
217    /// Send a chat request requiring the use of a tool.
218    /// 
219    /// # Arguments
220    /// 
221    /// * `model` - The model configuration.
222    /// * `prompt` - A vector of messages.
223    /// 
224    /// # Returns
225    /// 
226    /// The API result or a ClientError.
227    pub async fn send_use_tool(
228        &self,
229        prompt: &VecDeque<Message>,
230        model: Option<&ModelConfig>,
231    ) -> Result<APIResult, ClientError> {
232        match self
233            .call_api(
234                prompt,
235                Some(&serde_json::json!("required")),
236                model,
237            )
238            .await
239        {
240            Ok(res) => Ok(res),
241            Err(e) => Err(e),
242        }
243    }
244
245    /// Send a chat request forcing the use of a specific tool.
246    ///
247    /// # Arguments
248    ///
249    /// * `model` - The model configuration.
250    /// * `prompt` - A vector of messages.
251    /// * `tool_name` - The name of the tool to force.
252    ///
253    /// # Returns
254    ///
255    /// The API result or a ClientError.
256    pub async fn send_with_tool(
257        &self,
258        prompt: &VecDeque<Message>,
259        tool_name: &str,
260        model: Option<&ModelConfig>,
261    ) -> Result<APIResult, ClientError> {
262        let function_call = serde_json::json!({"type": "function", "function": {"name": tool_name}});
263
264        match self
265            .call_api(
266                prompt,
267                Some(&function_call),
268                model,
269            )
270            .await
271        {
272            Ok(res) => Ok(res),
273            Err(e) => Err(e),
274        }
275    }
276
277    /// Calls the OpenAI chat completions API.
278    ///
279    /// # Arguments
280    ///
281    /// * `model` - The model name; e.g. "GPT-4o".
282    /// * `prompt` - The list of messages.
283    /// * `function_call` - Indicates function call mode:
284    ///   - "auto"
285    ///   - "none"
286    ///   - { "name": "get_weather" }
287    /// * `temp` - Temperature parameter.
288    /// * `max_token` - Maximum tokens parameter.
289    /// * `top_p` - Top-p sampling parameter.
290    ///
291    /// # Returns
292    ///
293    /// An APIResult on success or a ClientError on failure.
294    pub async fn call_api(
295        &self,
296        prompt: &VecDeque<Message>,
297        tool_choice: Option<&serde_json::Value>,
298        model_config: Option<&ModelConfig>,
299    ) -> Result<APIResult, ClientError> {
300        let url = format!("{}/chat/completions", self.end_point);
301        if !url.starts_with("https://") && !url.starts_with("http://") {
302            return Err(ClientError::InvalidEndpoint);
303        }
304
305        let model_config = model_config.unwrap_or(self.model_config.as_ref().ok_or(ClientError::ModelConfigNotSet)?);
306        let tools = self.export_tool_def()?;
307        let res = self.request_api(&self.end_point, self.api_key.as_deref(), model_config, prompt, &tools, tool_choice.unwrap_or(&serde_json::Value::Null)).await?;
308
309        let headers = APIResponseHeaders {
310            retry_after: res
311                .headers()
312                .get("Retry-After")
313                .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok())),
314            reset: res
315                .headers()
316                .get("X-RateLimit-Reset")
317                .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok())),
318            rate_limit: res
319                .headers()
320                .get("X-RateLimit-Remaining")
321                .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok())),
322            limit: res
323                .headers()
324                .get("X-RateLimit-Limit")
325                .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok())),
326            extra_other: res
327                .headers()
328                .iter()
329                .map(|(k, v)| {
330                    (
331                        k.as_str().to_string(),
332                        v.to_str().unwrap_or("").to_string(),
333                    )
334                })
335                .collect(),
336        };
337        let text = res.text().await.map_err(|_| ClientError::InvalidResponse)?;
338        log::debug!("Response: {}", text);
339        let response_body: APIResponse =
340            serde_json::from_str(&text).map_err(|_| {
341            ClientError::InvalidResponse
342            })?;
343
344        Ok(APIResult {
345            response: response_body,
346            headers,
347        })
348    }
349
350    pub async fn request_api(&self ,end_point: &str, api_key: Option<&str>, model_config: &ModelConfig ,message: &VecDeque<Message>, tools: &Vec<ToolDef>, tool_choice: &serde_json::Value) -> Result<Response, ClientError> {
351        let request = APIRequest {
352            model:                  model_config.model.clone(),
353            messages:               message.clone(),
354            tools:                  tools.clone(),
355            tool_choice:            tool_choice.clone(),
356            parallel_tool_calls:    model_config.parallel_tool_calls,
357            temperature:            model_config.temperature,
358            max_completion_tokens:  model_config.max_completion_tokens,
359            top_p:                  model_config.top_p,
360            reasoning_effort:       model_config.reasoning_effort.clone(),
361            presence_penalty:       model_config.presence_penalty,
362            web_search_options:     model_config.web_search_options.clone(),
363        };
364
365        let res = self
366            .client
367            .post(&format!("{}/chat/completions", end_point))
368            .header("Content-Type", "application/json")
369            .header(
370                "authorization",
371                format!("Bearer {}", api_key.as_deref().unwrap_or("")),
372            )
373            .json(&request)
374            .send()
375            .await
376            .map_err(|_| ClientError::NetworkError)?;
377
378        Ok(res)
379    }
380
381    /// Create a new prompt conversation.
382    ///
383    /// # Returns
384    ///
385    /// A new OpenAIClientState with an empty message history.
386    pub fn create_prompt(&self) -> OpenAIClientState {
387        OpenAIClientState {
388            prompt: VecDeque::new(),
389            client: self.clone(),
390            entry_limit: None,
391        }
392    }
393}
394
395/// Represents a client state with a prompt history.
396#[derive(Clone)]
397pub struct OpenAIClientState {
398    /// Conversation history messages.
399    pub prompt: VecDeque<Message>,
400    /// Reference to the OpenAIClient.
401    pub client: OpenAIClient,
402    pub entry_limit: Option<u64>,
403}
404
405#[derive(Debug, Clone)]
406pub struct GenerateResponse {
407    pub has_content: bool,
408    pub has_tool_calls: bool,
409    pub content: Option<String>,
410    pub tool_calls: Option<Vec<FunctionCall>>,
411    pub api_result: APIResult,
412}
413
414impl<'a> OpenAIClientState {
415    /// Add messages to the conversation prompt.
416    ///
417    /// # Arguments
418    ///
419    /// * `messages` - A vector of messages to add.
420    ///
421    /// # Returns
422    ///
423    /// A mutable reference to self.
424    pub async fn add(&mut self, messages: Vec<Message>) -> &mut Self {
425        if let Some(limit) = self.entry_limit {
426            while self.prompt.len() as u64 + messages.len() as u64 > limit {
427                self.prompt.pop_front();
428            }
429        }
430        self.prompt.extend(messages);
431        self
432    }
433
434    pub async fn add_last(&mut self, messages: Vec<Message>) -> &mut Self {
435        if let Some(limit) = self.entry_limit {
436            while self.prompt.len() as u64 + messages.len() as u64 > limit {
437                self.prompt.pop_front();
438            }
439        }
440        for msg in messages {
441            self.prompt.push_front(msg);
442        }
443        self
444    }
445
446    /// Set the maximum number of entries in the conversation prompt.
447    ///    
448    /// # Arguments
449    /// 
450    /// * `limit` - The maximum number of entries.
451    /// 
452    /// # Returns
453    /// 
454    /// A mutable reference to self.
455    pub async fn set_entry_limit(&mut self, limit: u64) -> &mut Self {
456        self.entry_limit = Some(limit);
457        while self.prompt.len() as u64 > limit {
458            self.prompt.pop_front();
459        }
460        self
461    }
462
463    /// Clear all messages from the conversation prompt.
464    ///
465    /// # Returns
466    ///
467    /// A mutable reference to self.
468    pub async fn clear(&mut self) -> &mut Self {
469        self.prompt.clear();
470        self
471    }
472
473    /// Retrieve the last message in the prompt.
474    ///
475    /// # Returns
476    ///
477    /// An Option containing a reference to the last Message.
478    pub async fn last(&mut self) -> Option<&Message> {
479        self.prompt.back()
480    }
481
482    /// Generate an AI response.
483    ///
484    /// This method sends the prompt to the API and, upon successful response,
485    /// adds the assistant's message to the prompt.
486    ///
487    /// # Arguments
488    ///
489    /// * `model` - The model configuration.
490    ///
491    /// # Returns
492    ///
493    /// An APIResult with the API response or a ClientError.
494    pub async fn generate(&mut self, model: Option<&ModelConfig>) -> Result<GenerateResponse, ClientError> {
495        // Retrieve model configuration: use provided model or fallback to the client's config.
496        let model = model.unwrap_or(
497            self.client
498                .model_config
499                .as_ref()
500                .ok_or(ClientError::ModelConfigNotSet)?
501        );
502
503        // Send the request and extract the first choice.
504        let result = self.client.send(&self.prompt, Some(model)).await?;
505        let choice = result
506            .response
507            .choices
508            .as_ref()
509            .and_then(|choices| choices.first())
510            .ok_or(ClientError::InvalidResponse)?;
511
512        // Ensure there is content in the assistant's reply.
513        let content = choice
514            .message
515            .content
516            .as_ref()
517            .ok_or(ClientError::UnknownError)?;
518
519        // Add the assistant's message to the conversation.
520        self.add(vec![Message::Assistant {
521            name: model.model_name.clone(),
522            content: vec![MessageContext::Text(content.clone())],
523            tool_calls: None,
524        }])
525        .await;
526
527        Ok(
528            GenerateResponse {
529                has_content: true,
530                has_tool_calls: false,
531                content: Some(content.clone()),
532                tool_calls: None,
533                api_result: result,
534            }
535        )
536    }
537
538    /// Generate an AI response, possibly calling a tool.
539    ///
540    /// If the API response includes a function call, it will run the corresponding tool.
541    ///
542    /// # Arguments
543    ///
544    /// * `model` - The model configuration.
545    /// * `show_call` - Optional callback function to show the tool call.(eg, `show_call("tool_name", "args")`)
546    ///
547    /// # Returns
548    ///
549    /// An APIResult with the API response or a ClientError.
550    pub async fn generate_can_use_tool<F>(&mut self, model: Option<&ModelConfig>, show_call: Option<F>) -> Result<GenerateResponse, ClientError>
551    where F: Fn(&str, &serde_json::Value) { 
552        // Use the provided model configuration or fallback to the client's configuration.
553        let model = model.or(self.client.model_config.as_ref()).ok_or(ClientError::ModelConfigNotSet)?;
554
555        // Send the request with "can use tool" mode.
556        let result = self.client.send_can_use_tool(&self.prompt, Some(model)).await?;
557        let choices = result
558            .response
559            .choices
560            .as_ref()
561            .ok_or(ClientError::InvalidResponse)?;
562        
563        let choice = choices.first().ok_or(ClientError::InvalidResponse)?;
564        let has_content = choice.message.content.is_some();
565        let has_tool_calls = choice.message.tool_calls.is_some();
566
567        // Ensure that there is either content or a tool call.
568        if !has_content && !has_tool_calls {
569            return Err(ClientError::UnknownError);
570        }
571
572        // If content is returned, add the assistant message.
573        self.add(vec![Message::Assistant {
574            name: model.model_name.clone(),
575            content: if has_content { vec![MessageContext::Text(choice.message.content.clone().unwrap())] } else { vec![] },
576            tool_calls: choice.message.tool_calls.clone(),
577        }]).await;
578
579        // Process any tool calls.
580        if let Some(tool_calls) = &choice.message.tool_calls {
581            for call in tool_calls {
582                let (tool, enabled) = self.client.tools
583                    .get(&call.function.name)
584                    .ok_or(ClientError::ToolNotFound)?;
585                if !*enabled {
586                    return Err(ClientError::ToolNotFound);
587                }
588                if let Some(show_call) = &show_call {
589                    show_call(&call.function.name, &call.function.arguments);
590                }
591                let result_text = tool
592                    .run(call.function.arguments.clone())
593                    .unwrap_or_else(|e| format!("Error: {}", e));
594                self.add(vec![Message::Tool {
595                    tool_call_id: call.id.clone(),
596                    content: vec![MessageContext::Text(result_text)],
597                }]).await;
598            }
599        }
600
601        Ok(GenerateResponse {
602            has_content,
603            has_tool_calls,
604            content: choice.message.content.clone(),
605            tool_calls: choice.message.tool_calls.clone(),
606            api_result: result,
607        })
608    }
609
610    /// Generate an AI response while forcing the use of a specific tool.
611    /// 
612    /// If the response includes a function call, the specified tool will be executed
613    /// 
614    /// # Arguments
615    /// 
616    /// * `model` - The model configuration.
617    /// * `tool_name` - The name of the tool to use.
618    /// * `show_call` - Optional callback function to show the tool call.(eg, `show_call("tool_name", "args")`)
619    /// 
620    /// # Returns
621    /// 
622    /// An APIResult with the API response or a ClientError.
623    pub async fn generate_use_tool<F>(&mut self, model: Option<&ModelConfig>, show_call: Option<F>) -> Result<GenerateResponse, ClientError>
624    where F: Fn(&str, &serde_json::Value) {
625        let model = model.unwrap_or(
626            self.client
627                .model_config
628                .as_ref()
629                .ok_or(ClientError::ModelConfigNotSet)?
630        );
631
632        let result = self.client.send_use_tool(&self.prompt, Some(model)).await?;
633        let choices = result
634            .response
635            .choices
636            .as_ref()
637            .ok_or(ClientError::InvalidResponse)?;
638
639        let choice = choices.first().ok_or(ClientError::InvalidResponse)?;
640        let content = choice.message.content.clone();
641        let tool_calls = choice.message.tool_calls.clone();
642
643        // If there is no tool call, return an error.
644        if tool_calls.is_none() {
645            return Err(ClientError::ToolNotFound);
646        }
647
648        let has_content = content.is_some();
649
650        // Add the assistant's reply to the conversation.
651        self.add(vec![Message::Assistant {
652            name: model.model_name.clone(),
653            content: if has_content { vec![MessageContext::Text(choice.message.content.clone().unwrap())] } else { vec![] },
654            tool_calls: choice.message.tool_calls.clone(),
655        }]).await;
656
657        // Process any tool calls.
658        if let Some(calls) = tool_calls.clone() {
659            for call in calls {
660                let (tool, enabled) = self
661                    .client
662                    .tools
663                    .get(&call.function.name)
664                    .ok_or(ClientError::ToolNotFound)?;
665                if !*enabled {
666                    return Err(ClientError::ToolNotFound);
667                }
668                if let Some(show_call) = &show_call {
669                    show_call(&call.function.name, &call.function.arguments);
670                }
671                let result_text = match tool.run(call.function.arguments.clone()) {
672                    Ok(res) => res,
673                    Err(e) => format!("Error: {}", e),
674                };
675                self.add(vec![Message::Tool {
676                    tool_call_id: call.id.clone(),
677                    content: vec![MessageContext::Text(result_text)],
678                }]).await;
679            }
680        }
681
682        Ok(GenerateResponse {
683            has_content,
684            has_tool_calls: true,
685            content,
686            tool_calls,
687            api_result: result,
688        })
689    }
690
691    /// Generate an AI response while forcing the use of a specific tool.
692    ///
693    /// If the response includes a function call, the specified tool will be executed.
694    ///
695    /// # Arguments
696    ///
697    /// * `model` - The model configuration.
698    /// * `tool_name` - The name of the tool to use.
699    /// * `show_call` - Optional callback function to show the tool call.(eg, `show_call("tool_name", "args")`)
700    ///
701    /// # Returns
702    ///
703    /// An APIResult with the API response or a ClientError.
704    pub async fn generate_with_tool<F>(&mut self, model: Option<&ModelConfig>, tool_name: &str, show_call: Option<F>) -> Result<GenerateResponse, ClientError>
705    where F: Fn(&str, &serde_json::Value) {
706        let model = model.unwrap_or(
707            self.client.model_config.as_ref().ok_or(ClientError::ModelConfigNotSet)?
708        );
709
710        let result = self.client.send_with_tool(&self.prompt, tool_name, Some(model)).await?;
711        let choices = result
712            .response
713            .choices
714            .as_ref()
715            .ok_or(ClientError::InvalidResponse)?;
716
717        let choice = choices.first().ok_or(ClientError::InvalidResponse)?;
718        let content = choice.message.content.clone();
719        let tool_calls = choice.message.tool_calls.clone();
720
721        // If there is no tool call, return an error.
722        if tool_calls.is_none() {
723            return Err(ClientError::ToolNotFound);
724        }
725
726        let has_content = content.is_some();
727
728        // Add the assistant's reply to the conversation.
729        self.add(vec![Message::Assistant {
730            name: model.model_name.clone(),
731            content: if has_content { vec![MessageContext::Text(choice.message.content.clone().unwrap())] } else { vec![] },
732            tool_calls: choice.message.tool_calls.clone(),
733        }]).await;
734
735        // Process any tool calls.
736        if let Some(calls) = tool_calls.clone() {
737            for call in calls {
738                let (tool, enabled) = self
739                    .client
740                    .tools
741                    .get(&call.function.name)
742                    .ok_or(ClientError::ToolNotFound)?;
743                if !*enabled {
744                    return Err(ClientError::ToolNotFound);
745                }
746                if let Some(show_call) = &show_call {
747                    show_call(&call.function.name, &call.function.arguments);
748                }
749                let result_text = match tool.run(call.function.arguments.clone()) {
750                    Ok(res) => res,
751                    Err(e) => format!("Error: {}", e),
752                };
753                self.add(vec![Message::Tool {
754                    tool_call_id: call.id.clone(),
755                    content: vec![MessageContext::Text(result_text)],
756                }]).await;
757            }
758        }
759
760        Ok(
761            GenerateResponse {
762                has_content,
763                has_tool_calls: true,
764                content,
765                tool_calls,
766                api_result: result,
767            }
768        )
769    }
770}
771
772pub struct ReasoningState<'a> {
773    pub state: &'a mut OpenAIClientState,
774    pub model: ModelConfig,
775    pub has_content: bool,
776    pub has_tool_calls: bool,
777    pub content: Option<String>,
778    pub tool_calls: Option<Vec<FunctionCall>>,
779    pub api_result: APIResult,
780}
781
782pub enum ToolMode {
783    /// Disable the tool
784    Disable,
785    /// Can use the tool
786    Auto,
787    /// Must use the tool
788    Force(String)
789}
790
791/// new api after v.1.4.0
792impl<'a> OpenAIClientState {
793    pub async fn reasoning(&'a mut self, model: Option<&ModelConfig>, mode: &ToolMode) -> Result<ReasoningState<'a>, ClientError> {
794        let model = model.unwrap_or(
795            self.client.model_config.as_ref().ok_or(ClientError::ModelConfigNotSet)?
796        ).clone();
797
798        let result = match &mode {
799            ToolMode::Disable => self.client.send(&self.prompt, Some(&model)).await?,
800            ToolMode::Auto => self.client.send_can_use_tool(&self.prompt, Some(&model)).await?,
801            ToolMode::Force(tool_name) => self.client.send_with_tool(&self.prompt, &tool_name, Some(&model)).await?,
802        };
803
804        let choices = result.response.choices.as_ref().ok_or(ClientError::InvalidResponse)?;
805        let choice = choices.first().ok_or(ClientError::InvalidResponse)?;
806        let content = choice.message.content.clone();
807        let tool_calls = choice.message.tool_calls.clone();
808
809        let has_content = content.is_some();
810
811        // Add the assistant's reply to the conversation.
812        self.add(vec![Message::Assistant {
813            name: model.model_name.clone(),
814            content: if has_content { vec![MessageContext::Text(choice.message.content.clone().unwrap())] } else { vec![] },
815            tool_calls: choice.message.tool_calls.clone(),
816        }]).await;
817
818        Ok(ReasoningState {
819            state: &mut *self,
820            model: model,
821            has_content,
822            has_tool_calls: tool_calls.is_some(),
823            content,
824            tool_calls,
825            api_result: result,
826        })
827    }
828}
829
830impl<'a> ReasoningState<'a> {
831    /// Check if the reasoning state can proceed.
832    pub fn can_finish(&self) -> bool {
833        self.has_content && !self.has_tool_calls
834    }
835
836    /// Check if the reasoning state has tool calls.
837    pub fn show_tool_calls(&self) -> Vec<(&str, &serde_json::Value)> {
838        if let Some(tool_calls) = &self.tool_calls {
839            tool_calls.iter().map(|call| (call.function.name.as_str(), &call.function.arguments)).collect()
840        } else {
841            vec![]
842        }
843    }
844
845    /// step by step proceed
846    /// Proceed with the reasoning state.
847    /// 
848    /// # Arguments
849    /// - `mode` - The tool mode to use.
850    /// 
851    /// # Returns
852    /// - A Result indicating success or failure.
853    pub async fn proceed(&mut self, mode: &ToolMode) -> Result<(), ClientError> {
854        if let Some(tool_calls) = &self.tool_calls {
855            for call in tool_calls {
856                let (tool, enabled) = self.state.client.tools
857                    .get(&call.function.name)
858                    .ok_or(ClientError::ToolNotFound)?;
859                if !*enabled {
860                    return Err(ClientError::ToolNotFound);
861                }
862                let result_text = match tool.run(call.function.arguments.clone()) {
863                    Ok(res) => res,
864                    Err(e) => format!("Error: {}", e),
865                };
866                self.state.add(vec![Message::Tool {
867                    tool_call_id: call.id.clone(),
868                    content: vec![MessageContext::Text(result_text)],
869                }]).await;
870            }
871        }
872
873        let result = match mode {
874            ToolMode::Disable => self.state.client.send(&self.state.prompt, Some(&self.model)).await?,
875            ToolMode::Auto => self.state.client.send_can_use_tool(&self.state.prompt, Some(&self.model)).await?,
876            ToolMode::Force(tool_name) => self.state.client.send_with_tool(&self.state.prompt, tool_name, Some(&self.model)).await?,
877        };
878
879        let choices = result.response.choices.as_ref().ok_or(ClientError::InvalidResponse)?;
880        let choice = choices.first().ok_or(ClientError::InvalidResponse)?;
881        let content = choice.message.content.clone();
882        let tool_calls = choice.message.tool_calls.clone();
883
884        let has_content = content.is_some();
885
886
887        self.state.add(vec![Message::Assistant {
888            name: self.model.model_name.clone(),
889            content: if has_content { vec![MessageContext::Text(choice.message.content.clone().unwrap())] } else { vec![] },
890            tool_calls: choice.message.tool_calls.clone(),
891        }]).await;
892
893        self.has_content = has_content;
894        self.has_tool_calls = tool_calls.is_some();
895        self.content = content;
896        self.tool_calls = tool_calls;
897        self.api_result = result;
898        Ok(())
899    }
900}