Skip to main content

alith_core/
executor.rs

1use crate::Ref;
2use crate::chat::{Completion, Request, ResponseContent, ResponseToolCalls, ToolCall};
3use crate::knowledge::Knowledge;
4use crate::mcp::MCPClient;
5use crate::memory::{Memory, Message};
6use crate::tool::Tool;
7use std::sync::Arc;
8
9/// Manages the execution of tasks using an LLM, tools, and (optionally) memory components.
10pub struct Executor<M: Completion> {
11    model: Ref<M>,
12    knowledges: Arc<Vec<Box<dyn Knowledge>>>,
13    tools: Ref<Vec<Box<dyn Tool>>>,
14    memory: Option<Ref<dyn Memory>>,
15    /// The MCP client used to communicate with the MCP server
16    mcp_clients: Ref<Vec<MCPClient>>,
17}
18
19impl<M: Completion> Executor<M> {
20    /// Creates a new `Executor` instance.
21    pub fn new(
22        model: Ref<M>,
23        knowledges: Arc<Vec<Box<dyn Knowledge>>>,
24        tools: Ref<Vec<Box<dyn Tool>>>,
25        memory: Option<Ref<dyn Memory>>,
26        mcp_clients: Ref<Vec<MCPClient>>,
27    ) -> Self {
28        Self {
29            model,
30            knowledges,
31            tools,
32            memory,
33            mcp_clients,
34        }
35    }
36
37    /// Executes the task by managing interactions between the LLM and tools.
38    pub async fn invoke(&mut self, mut request: Request) -> anyhow::Result<String> {
39        request.knowledges = {
40            let mut enriched_knowledges = Vec::new();
41            for knowledge in self.knowledges.iter() {
42                let enriched = knowledge.enrich(&request.prompt)?;
43                enriched_knowledges.push(enriched);
44            }
45            enriched_knowledges
46        };
47        // Add user memory
48        self.add_user_message(&request.prompt).await;
49        // Interact with the LLM to get a response.
50        let mut model = self.model.write().await;
51        let response = model.completion(request.clone()).await?;
52
53        let mut responses = vec![response.content()];
54        self.add_ai_message(&responses[0]).await;
55
56        // Attempt to parse and execute a tool action.
57        for call in response.toolcalls() {
58            let tool_call = self.execute_tool(call).await?;
59            self.add_ai_message_with_tool_call(&tool_call).await?;
60            responses.push(tool_call);
61        }
62
63        Ok(responses.join("\n"))
64    }
65
66    /// Add a user message into the memory if the memory has been set.
67    async fn add_user_message(&self, message: &dyn std::fmt::Display) {
68        if let Some(memory) = &self.memory {
69            let mut memory = memory.write().await;
70            memory.add_user_message(message);
71        }
72    }
73
74    /// Add an AI message into the memory if the memory has been set.
75    async fn add_ai_message(&self, message: &dyn std::fmt::Display) {
76        if let Some(memory) = &self.memory {
77            let mut memory = memory.write().await;
78            memory.add_ai_message(message);
79        }
80    }
81
82    /// Add an AI message into the memory if the memory has been set.
83    async fn add_ai_message_with_tool_call(
84        &self,
85        tool_call: &dyn std::fmt::Display,
86    ) -> anyhow::Result<()> {
87        if let Some(memory) = &self.memory {
88            let mut memory = memory.write().await;
89            let tool_call: serde_json::Value = serde_json::from_str(&format!("{tool_call}"))?;
90            memory.add_message(Message::new_ai_message("").with_tool_calls(tool_call));
91        }
92        Ok(())
93    }
94
95    /// Executes a tool action and returns the result.
96    async fn execute_tool(&self, call: ToolCall) -> anyhow::Result<String> {
97        let tools = self.tools.read().await;
98        if let Some(tool) = tools
99            .iter()
100            .find(|t| t.name().eq_ignore_ascii_case(&call.function.name))
101        {
102            Ok(tool.run(&call.function.arguments).await?)
103        } else {
104            let mcp_clients = self.mcp_clients.read().await;
105            if !mcp_clients.is_empty() {
106                for mcp_client in mcp_clients.iter() {
107                    if mcp_client.tools.contains_key(&call.function.name) {
108                        let arguments = serde_json::from_str(&call.function.arguments)?;
109                        let response = mcp_client.call_tool(&call.function.name, arguments).await?;
110                        if let Some(text) = response.content[0].as_text() {
111                            return Ok(text.to_string());
112                        } else {
113                            return Ok("".to_string());
114                        }
115                    }
116                }
117            }
118            Err(anyhow::anyhow!("Tool not found: {}", call.function.name))
119        }
120    }
121}