1use crate::Ref;
2use crate::chat::{Completion, Request, ResponseContent, ResponseToolCalls, ToolCall};
3use crate::knowledge::Knowledge;
4use crate::mcp::MCPClient;
5use crate::memory::{Memory, Message};
6use crate::tool::Tool;
7use std::sync::Arc;
8
9pub struct Executor<M: Completion> {
11 model: Ref<M>,
12 knowledges: Arc<Vec<Box<dyn Knowledge>>>,
13 tools: Ref<Vec<Box<dyn Tool>>>,
14 memory: Option<Ref<dyn Memory>>,
15 mcp_clients: Ref<Vec<MCPClient>>,
17}
18
19impl<M: Completion> Executor<M> {
20 pub fn new(
22 model: Ref<M>,
23 knowledges: Arc<Vec<Box<dyn Knowledge>>>,
24 tools: Ref<Vec<Box<dyn Tool>>>,
25 memory: Option<Ref<dyn Memory>>,
26 mcp_clients: Ref<Vec<MCPClient>>,
27 ) -> Self {
28 Self {
29 model,
30 knowledges,
31 tools,
32 memory,
33 mcp_clients,
34 }
35 }
36
37 pub async fn invoke(&mut self, mut request: Request) -> anyhow::Result<String> {
39 request.knowledges = {
40 let mut enriched_knowledges = Vec::new();
41 for knowledge in self.knowledges.iter() {
42 let enriched = knowledge.enrich(&request.prompt)?;
43 enriched_knowledges.push(enriched);
44 }
45 enriched_knowledges
46 };
47 self.add_user_message(&request.prompt).await;
49 let mut model = self.model.write().await;
51 let response = model.completion(request.clone()).await?;
52
53 let mut responses = vec![response.content()];
54 self.add_ai_message(&responses[0]).await;
55
56 for call in response.toolcalls() {
58 let tool_call = self.execute_tool(call).await?;
59 self.add_ai_message_with_tool_call(&tool_call).await?;
60 responses.push(tool_call);
61 }
62
63 Ok(responses.join("\n"))
64 }
65
66 async fn add_user_message(&self, message: &dyn std::fmt::Display) {
68 if let Some(memory) = &self.memory {
69 let mut memory = memory.write().await;
70 memory.add_user_message(message);
71 }
72 }
73
74 async fn add_ai_message(&self, message: &dyn std::fmt::Display) {
76 if let Some(memory) = &self.memory {
77 let mut memory = memory.write().await;
78 memory.add_ai_message(message);
79 }
80 }
81
82 async fn add_ai_message_with_tool_call(
84 &self,
85 tool_call: &dyn std::fmt::Display,
86 ) -> anyhow::Result<()> {
87 if let Some(memory) = &self.memory {
88 let mut memory = memory.write().await;
89 let tool_call: serde_json::Value = serde_json::from_str(&format!("{tool_call}"))?;
90 memory.add_message(Message::new_ai_message("").with_tool_calls(tool_call));
91 }
92 Ok(())
93 }
94
95 async fn execute_tool(&self, call: ToolCall) -> anyhow::Result<String> {
97 let tools = self.tools.read().await;
98 if let Some(tool) = tools
99 .iter()
100 .find(|t| t.name().eq_ignore_ascii_case(&call.function.name))
101 {
102 Ok(tool.run(&call.function.arguments).await?)
103 } else {
104 let mcp_clients = self.mcp_clients.read().await;
105 if !mcp_clients.is_empty() {
106 for mcp_client in mcp_clients.iter() {
107 if mcp_client.tools.contains_key(&call.function.name) {
108 let arguments = serde_json::from_str(&call.function.arguments)?;
109 let response = mcp_client.call_tool(&call.function.name, arguments).await?;
110 if let Some(text) = response.content[0].as_text() {
111 return Ok(text.to_string());
112 } else {
113 return Ok("".to_string());
114 }
115 }
116 }
117 }
118 Err(anyhow::anyhow!("Tool not found: {}", call.function.name))
119 }
120 }
121}