1use std::{collections::{HashMap, VecDeque}, sync::Arc};
2
3use reqwest::{Client, Response};
4
5use crate::chat::api::WebSearchOptions;
6
7use super::{
8 api::{APIRequest, APIResponse, APIResponseHeaders},
9 err::ClientError,
10 function::{FunctionCall, FunctionDef, Tool, ToolDef},
11 prompt::{Message, MessageContext},
12};
13
14#[derive(Clone)]
16pub struct OpenAIClient {
17 pub client: Client,
19 pub end_point: String,
21 pub api_key: Option<String>,
23 pub tools: HashMap<String, (Arc<dyn Tool + Send + Sync>, bool)>,
25 pub model_config: Option<ModelConfig>,
27}
28
29#[derive(Debug, Clone)]
31pub struct ModelConfig {
32 pub model: String,
34 pub model_name: Option<String>,
36 pub top_p: Option<f64>,
38 pub parallel_tool_calls: Option<bool>,
41 pub temperature: Option<f64>,
43 pub max_completion_tokens: Option<u64>,
45 pub reasoning_effort: Option<String>,
51 pub presence_penalty: Option<f64>,
54 pub strict: Option<bool>,
58 pub web_search_options: Option<WebSearchOptions>,
60}
61
62#[derive(Debug, Clone)]
64pub struct APIResult {
65 pub response: APIResponse,
67 pub headers: APIResponseHeaders,
69}
70
71impl OpenAIClient {
72 pub fn new(end_point: &str, api_key: Option<&str>) -> Self {
79 Self {
80 client: Client::new(),
81 end_point: end_point.trim_end_matches('/').to_string(),
82 api_key: api_key.map(|s| s.to_string()),
83 tools: HashMap::new(),
84 model_config: None,
85 }
86 }
87
88 pub fn set_model_config(&mut self, model_config: &ModelConfig) {
94 self.model_config = Some(model_config.clone());
95 }
96
97 pub fn def_tool<T: Tool + Send + Sync + 'static>(&mut self, tool: Arc<T>) {
105 self.tools
106 .insert(tool.def_name().to_string(), (tool, true));
107 }
108
109 pub fn list_tools(&self) -> Vec<(String, String, bool)> {
115 let mut tools = Vec::new();
116 for (tool_name, (tool, enable)) in self.tools.iter() {
117 tools.push((
118 tool_name.to_string(),
119 tool.def_description().to_string(),
120 *enable,
121 ));
122 }
123 tools
124 }
125
126 pub fn switch_tool(&mut self, tool_name: &str, t_enable: bool) {
133 if let Some((_, enable)) = self.tools.get_mut(tool_name) {
134 *enable = t_enable;
135 }
136 }
137
138 pub fn export_tool_def(&self) -> Result<Vec<ToolDef>, ClientError> {
144 let mut defs = Vec::new();
145 for (tool_name, (tool, enable)) in self.tools.iter() {
146 if *enable {
147 defs.push(ToolDef {
148 tool_type: "function".to_string(),
149 function: FunctionDef {
150 name: tool_name.clone(),
151 description: tool.def_description().to_string(),
152 parameters: tool.def_parameters(),
153 strict: self.model_config.as_ref().ok_or(ClientError::ModelConfigNotSet)?.strict.unwrap_or(false),
154 },
155 });
156 }
157 }
158 Ok(defs)
159 }
160
161 pub async fn send(
172 &self,
173 prompt: &VecDeque<Message>,
174 model: Option<&ModelConfig>,
175 ) -> Result<APIResult, ClientError> {
176 match self
177 .call_api(
178 prompt,
179 Some(&serde_json::json!("none")),
180 model,
181 )
182 .await
183 {
184 Ok(res) => Ok(res),
185 Err(e) => Err(e),
186 }
187 }
188
189 pub async fn send_can_use_tool(
200 &self,
201 prompt: &VecDeque<Message>,
202 model: Option<&ModelConfig>,
203 ) -> Result<APIResult, ClientError> {
204 match self
205 .call_api(
206 prompt,
207 Some(&serde_json::json!("auto")),
208 model,
209 )
210 .await
211 {
212 Ok(res) => Ok(res),
213 Err(e) => Err(e),
214 }
215 }
216
217 pub async fn send_use_tool(
228 &self,
229 prompt: &VecDeque<Message>,
230 model: Option<&ModelConfig>,
231 ) -> Result<APIResult, ClientError> {
232 match self
233 .call_api(
234 prompt,
235 Some(&serde_json::json!("required")),
236 model,
237 )
238 .await
239 {
240 Ok(res) => Ok(res),
241 Err(e) => Err(e),
242 }
243 }
244
245 pub async fn send_with_tool(
257 &self,
258 prompt: &VecDeque<Message>,
259 tool_name: &str,
260 model: Option<&ModelConfig>,
261 ) -> Result<APIResult, ClientError> {
262 let function_call = serde_json::json!({"type": "function", "function": {"name": tool_name}});
263
264 match self
265 .call_api(
266 prompt,
267 Some(&function_call),
268 model,
269 )
270 .await
271 {
272 Ok(res) => Ok(res),
273 Err(e) => Err(e),
274 }
275 }
276
277 pub async fn call_api(
295 &self,
296 prompt: &VecDeque<Message>,
297 tool_choice: Option<&serde_json::Value>,
298 model_config: Option<&ModelConfig>,
299 ) -> Result<APIResult, ClientError> {
300 let url = format!("{}/chat/completions", self.end_point);
301 if !url.starts_with("https://") && !url.starts_with("http://") {
302 return Err(ClientError::InvalidEndpoint);
303 }
304
305 let model_config = model_config.unwrap_or(self.model_config.as_ref().ok_or(ClientError::ModelConfigNotSet)?);
306 let tools = self.export_tool_def()?;
307 let res = self.request_api(&self.end_point, self.api_key.as_deref(), model_config, prompt, &tools, tool_choice.unwrap_or(&serde_json::Value::Null)).await?;
308
309 let headers = APIResponseHeaders {
310 retry_after: res
311 .headers()
312 .get("Retry-After")
313 .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok())),
314 reset: res
315 .headers()
316 .get("X-RateLimit-Reset")
317 .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok())),
318 rate_limit: res
319 .headers()
320 .get("X-RateLimit-Remaining")
321 .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok())),
322 limit: res
323 .headers()
324 .get("X-RateLimit-Limit")
325 .and_then(|v| v.to_str().ok().and_then(|v| v.parse().ok())),
326 extra_other: res
327 .headers()
328 .iter()
329 .map(|(k, v)| {
330 (
331 k.as_str().to_string(),
332 v.to_str().unwrap_or("").to_string(),
333 )
334 })
335 .collect(),
336 };
337 let text = res.text().await.map_err(|_| ClientError::InvalidResponse)?;
338 log::debug!("Response: {}", text);
339 let response_body: APIResponse =
340 serde_json::from_str(&text).map_err(|_| {
341 ClientError::InvalidResponse
342 })?;
343
344 Ok(APIResult {
345 response: response_body,
346 headers,
347 })
348 }
349
350 pub async fn request_api(&self ,end_point: &str, api_key: Option<&str>, model_config: &ModelConfig ,message: &VecDeque<Message>, tools: &Vec<ToolDef>, tool_choice: &serde_json::Value) -> Result<Response, ClientError> {
351 let request = APIRequest {
352 model: model_config.model.clone(),
353 messages: message.clone(),
354 tools: tools.clone(),
355 tool_choice: tool_choice.clone(),
356 parallel_tool_calls: model_config.parallel_tool_calls,
357 temperature: model_config.temperature,
358 max_completion_tokens: model_config.max_completion_tokens,
359 top_p: model_config.top_p,
360 reasoning_effort: model_config.reasoning_effort.clone(),
361 presence_penalty: model_config.presence_penalty,
362 web_search_options: model_config.web_search_options.clone(),
363 };
364
365 let res = self
366 .client
367 .post(&format!("{}/chat/completions", end_point))
368 .header("Content-Type", "application/json")
369 .header(
370 "authorization",
371 format!("Bearer {}", api_key.as_deref().unwrap_or("")),
372 )
373 .json(&request)
374 .send()
375 .await
376 .map_err(|_| ClientError::NetworkError)?;
377
378 Ok(res)
379 }
380
381 pub fn create_prompt(&self) -> OpenAIClientState {
387 OpenAIClientState {
388 prompt: VecDeque::new(),
389 client: self.clone(),
390 entry_limit: None,
391 }
392 }
393}
394
395#[derive(Clone)]
397pub struct OpenAIClientState {
398 pub prompt: VecDeque<Message>,
400 pub client: OpenAIClient,
402 pub entry_limit: Option<u64>,
403}
404
405#[derive(Debug, Clone)]
406pub struct GenerateResponse {
407 pub has_content: bool,
408 pub has_tool_calls: bool,
409 pub content: Option<String>,
410 pub tool_calls: Option<Vec<FunctionCall>>,
411 pub api_result: APIResult,
412}
413
414impl<'a> OpenAIClientState {
415 pub async fn add(&mut self, messages: Vec<Message>) -> &mut Self {
425 if let Some(limit) = self.entry_limit {
426 while self.prompt.len() as u64 + messages.len() as u64 > limit {
427 self.prompt.pop_front();
428 }
429 }
430 self.prompt.extend(messages);
431 self
432 }
433
434 pub async fn add_last(&mut self, messages: Vec<Message>) -> &mut Self {
435 if let Some(limit) = self.entry_limit {
436 while self.prompt.len() as u64 + messages.len() as u64 > limit {
437 self.prompt.pop_front();
438 }
439 }
440 for msg in messages {
441 self.prompt.push_front(msg);
442 }
443 self
444 }
445
446 pub async fn set_entry_limit(&mut self, limit: u64) -> &mut Self {
456 self.entry_limit = Some(limit);
457 while self.prompt.len() as u64 > limit {
458 self.prompt.pop_front();
459 }
460 self
461 }
462
463 pub async fn clear(&mut self) -> &mut Self {
469 self.prompt.clear();
470 self
471 }
472
473 pub async fn last(&mut self) -> Option<&Message> {
479 self.prompt.back()
480 }
481
482 pub async fn generate(&mut self, model: Option<&ModelConfig>) -> Result<GenerateResponse, ClientError> {
495 let model = model.unwrap_or(
497 self.client
498 .model_config
499 .as_ref()
500 .ok_or(ClientError::ModelConfigNotSet)?
501 );
502
503 let result = self.client.send(&self.prompt, Some(model)).await?;
505 let choice = result
506 .response
507 .choices
508 .as_ref()
509 .and_then(|choices| choices.first())
510 .ok_or(ClientError::InvalidResponse)?;
511
512 let content = choice
514 .message
515 .content
516 .as_ref()
517 .ok_or(ClientError::UnknownError)?;
518
519 self.add(vec![Message::Assistant {
521 name: model.model_name.clone(),
522 content: vec![MessageContext::Text(content.clone())],
523 tool_calls: None,
524 }])
525 .await;
526
527 Ok(
528 GenerateResponse {
529 has_content: true,
530 has_tool_calls: false,
531 content: Some(content.clone()),
532 tool_calls: None,
533 api_result: result,
534 }
535 )
536 }
537
538 pub async fn generate_can_use_tool<F>(&mut self, model: Option<&ModelConfig>, show_call: Option<F>) -> Result<GenerateResponse, ClientError>
551 where F: Fn(&str, &serde_json::Value) {
552 let model = model.or(self.client.model_config.as_ref()).ok_or(ClientError::ModelConfigNotSet)?;
554
555 let result = self.client.send_can_use_tool(&self.prompt, Some(model)).await?;
557 let choices = result
558 .response
559 .choices
560 .as_ref()
561 .ok_or(ClientError::InvalidResponse)?;
562
563 let choice = choices.first().ok_or(ClientError::InvalidResponse)?;
564 let has_content = choice.message.content.is_some();
565 let has_tool_calls = choice.message.tool_calls.is_some();
566
567 if !has_content && !has_tool_calls {
569 return Err(ClientError::UnknownError);
570 }
571
572 self.add(vec![Message::Assistant {
574 name: model.model_name.clone(),
575 content: if has_content { vec![MessageContext::Text(choice.message.content.clone().unwrap())] } else { vec![] },
576 tool_calls: choice.message.tool_calls.clone(),
577 }]).await;
578
579 if let Some(tool_calls) = &choice.message.tool_calls {
581 for call in tool_calls {
582 let (tool, enabled) = self.client.tools
583 .get(&call.function.name)
584 .ok_or(ClientError::ToolNotFound)?;
585 if !*enabled {
586 return Err(ClientError::ToolNotFound);
587 }
588 if let Some(show_call) = &show_call {
589 show_call(&call.function.name, &call.function.arguments);
590 }
591 let result_text = tool
592 .run(call.function.arguments.clone())
593 .unwrap_or_else(|e| format!("Error: {}", e));
594 self.add(vec![Message::Tool {
595 tool_call_id: call.id.clone(),
596 content: vec![MessageContext::Text(result_text)],
597 }]).await;
598 }
599 }
600
601 Ok(GenerateResponse {
602 has_content,
603 has_tool_calls,
604 content: choice.message.content.clone(),
605 tool_calls: choice.message.tool_calls.clone(),
606 api_result: result,
607 })
608 }
609
610 pub async fn generate_use_tool<F>(&mut self, model: Option<&ModelConfig>, show_call: Option<F>) -> Result<GenerateResponse, ClientError>
624 where F: Fn(&str, &serde_json::Value) {
625 let model = model.unwrap_or(
626 self.client
627 .model_config
628 .as_ref()
629 .ok_or(ClientError::ModelConfigNotSet)?
630 );
631
632 let result = self.client.send_use_tool(&self.prompt, Some(model)).await?;
633 let choices = result
634 .response
635 .choices
636 .as_ref()
637 .ok_or(ClientError::InvalidResponse)?;
638
639 let choice = choices.first().ok_or(ClientError::InvalidResponse)?;
640 let content = choice.message.content.clone();
641 let tool_calls = choice.message.tool_calls.clone();
642
643 if tool_calls.is_none() {
645 return Err(ClientError::ToolNotFound);
646 }
647
648 let has_content = content.is_some();
649
650 self.add(vec![Message::Assistant {
652 name: model.model_name.clone(),
653 content: if has_content { vec![MessageContext::Text(choice.message.content.clone().unwrap())] } else { vec![] },
654 tool_calls: choice.message.tool_calls.clone(),
655 }]).await;
656
657 if let Some(calls) = tool_calls.clone() {
659 for call in calls {
660 let (tool, enabled) = self
661 .client
662 .tools
663 .get(&call.function.name)
664 .ok_or(ClientError::ToolNotFound)?;
665 if !*enabled {
666 return Err(ClientError::ToolNotFound);
667 }
668 if let Some(show_call) = &show_call {
669 show_call(&call.function.name, &call.function.arguments);
670 }
671 let result_text = match tool.run(call.function.arguments.clone()) {
672 Ok(res) => res,
673 Err(e) => format!("Error: {}", e),
674 };
675 self.add(vec![Message::Tool {
676 tool_call_id: call.id.clone(),
677 content: vec![MessageContext::Text(result_text)],
678 }]).await;
679 }
680 }
681
682 Ok(GenerateResponse {
683 has_content,
684 has_tool_calls: true,
685 content,
686 tool_calls,
687 api_result: result,
688 })
689 }
690
691 pub async fn generate_with_tool<F>(&mut self, model: Option<&ModelConfig>, tool_name: &str, show_call: Option<F>) -> Result<GenerateResponse, ClientError>
705 where F: Fn(&str, &serde_json::Value) {
706 let model = model.unwrap_or(
707 self.client.model_config.as_ref().ok_or(ClientError::ModelConfigNotSet)?
708 );
709
710 let result = self.client.send_with_tool(&self.prompt, tool_name, Some(model)).await?;
711 let choices = result
712 .response
713 .choices
714 .as_ref()
715 .ok_or(ClientError::InvalidResponse)?;
716
717 let choice = choices.first().ok_or(ClientError::InvalidResponse)?;
718 let content = choice.message.content.clone();
719 let tool_calls = choice.message.tool_calls.clone();
720
721 if tool_calls.is_none() {
723 return Err(ClientError::ToolNotFound);
724 }
725
726 let has_content = content.is_some();
727
728 self.add(vec![Message::Assistant {
730 name: model.model_name.clone(),
731 content: if has_content { vec![MessageContext::Text(choice.message.content.clone().unwrap())] } else { vec![] },
732 tool_calls: choice.message.tool_calls.clone(),
733 }]).await;
734
735 if let Some(calls) = tool_calls.clone() {
737 for call in calls {
738 let (tool, enabled) = self
739 .client
740 .tools
741 .get(&call.function.name)
742 .ok_or(ClientError::ToolNotFound)?;
743 if !*enabled {
744 return Err(ClientError::ToolNotFound);
745 }
746 if let Some(show_call) = &show_call {
747 show_call(&call.function.name, &call.function.arguments);
748 }
749 let result_text = match tool.run(call.function.arguments.clone()) {
750 Ok(res) => res,
751 Err(e) => format!("Error: {}", e),
752 };
753 self.add(vec![Message::Tool {
754 tool_call_id: call.id.clone(),
755 content: vec![MessageContext::Text(result_text)],
756 }]).await;
757 }
758 }
759
760 Ok(
761 GenerateResponse {
762 has_content,
763 has_tool_calls: true,
764 content,
765 tool_calls,
766 api_result: result,
767 }
768 )
769 }
770}
771
772pub struct ReasoningState<'a> {
773 pub state: &'a mut OpenAIClientState,
774 pub model: ModelConfig,
775 pub has_content: bool,
776 pub has_tool_calls: bool,
777 pub content: Option<String>,
778 pub tool_calls: Option<Vec<FunctionCall>>,
779 pub api_result: APIResult,
780}
781
782pub enum ToolMode {
783 Disable,
785 Auto,
787 Force(String)
789}
790
791impl<'a> OpenAIClientState {
793 pub async fn reasoning(&'a mut self, model: Option<&ModelConfig>, mode: &ToolMode) -> Result<ReasoningState<'a>, ClientError> {
794 let model = model.unwrap_or(
795 self.client.model_config.as_ref().ok_or(ClientError::ModelConfigNotSet)?
796 ).clone();
797
798 let result = match &mode {
799 ToolMode::Disable => self.client.send(&self.prompt, Some(&model)).await?,
800 ToolMode::Auto => self.client.send_can_use_tool(&self.prompt, Some(&model)).await?,
801 ToolMode::Force(tool_name) => self.client.send_with_tool(&self.prompt, &tool_name, Some(&model)).await?,
802 };
803
804 let choices = result.response.choices.as_ref().ok_or(ClientError::InvalidResponse)?;
805 let choice = choices.first().ok_or(ClientError::InvalidResponse)?;
806 let content = choice.message.content.clone();
807 let tool_calls = choice.message.tool_calls.clone();
808
809 let has_content = content.is_some();
810
811 self.add(vec![Message::Assistant {
813 name: model.model_name.clone(),
814 content: if has_content { vec![MessageContext::Text(choice.message.content.clone().unwrap())] } else { vec![] },
815 tool_calls: choice.message.tool_calls.clone(),
816 }]).await;
817
818 Ok(ReasoningState {
819 state: &mut *self,
820 model: model,
821 has_content,
822 has_tool_calls: tool_calls.is_some(),
823 content,
824 tool_calls,
825 api_result: result,
826 })
827 }
828}
829
830impl<'a> ReasoningState<'a> {
831 pub fn can_finish(&self) -> bool {
833 self.has_content && !self.has_tool_calls
834 }
835
836 pub fn show_tool_calls(&self) -> Vec<(&str, &serde_json::Value)> {
838 if let Some(tool_calls) = &self.tool_calls {
839 tool_calls.iter().map(|call| (call.function.name.as_str(), &call.function.arguments)).collect()
840 } else {
841 vec![]
842 }
843 }
844
845 pub async fn proceed(&mut self, mode: &ToolMode) -> Result<(), ClientError> {
854 if let Some(tool_calls) = &self.tool_calls {
855 for call in tool_calls {
856 let (tool, enabled) = self.state.client.tools
857 .get(&call.function.name)
858 .ok_or(ClientError::ToolNotFound)?;
859 if !*enabled {
860 return Err(ClientError::ToolNotFound);
861 }
862 let result_text = match tool.run(call.function.arguments.clone()) {
863 Ok(res) => res,
864 Err(e) => format!("Error: {}", e),
865 };
866 self.state.add(vec![Message::Tool {
867 tool_call_id: call.id.clone(),
868 content: vec![MessageContext::Text(result_text)],
869 }]).await;
870 }
871 }
872
873 let result = match mode {
874 ToolMode::Disable => self.state.client.send(&self.state.prompt, Some(&self.model)).await?,
875 ToolMode::Auto => self.state.client.send_can_use_tool(&self.state.prompt, Some(&self.model)).await?,
876 ToolMode::Force(tool_name) => self.state.client.send_with_tool(&self.state.prompt, tool_name, Some(&self.model)).await?,
877 };
878
879 let choices = result.response.choices.as_ref().ok_or(ClientError::InvalidResponse)?;
880 let choice = choices.first().ok_or(ClientError::InvalidResponse)?;
881 let content = choice.message.content.clone();
882 let tool_calls = choice.message.tool_calls.clone();
883
884 let has_content = content.is_some();
885
886
887 self.state.add(vec![Message::Assistant {
888 name: self.model.model_name.clone(),
889 content: if has_content { vec![MessageContext::Text(choice.message.content.clone().unwrap())] } else { vec![] },
890 tool_calls: choice.message.tool_calls.clone(),
891 }]).await;
892
893 self.has_content = has_content;
894 self.has_tool_calls = tool_calls.is_some();
895 self.content = content;
896 self.tool_calls = tool_calls;
897 self.api_result = result;
898 Ok(())
899 }
900}