oli_server/agent/
core.rs

1use crate::agent::executor::AgentExecutor;
2use crate::apis::anthropic::AnthropicClient;
3use crate::apis::api_client::{ApiClientEnum, DynApiClient, Message};
4use crate::apis::ollama::OllamaClient;
5use crate::apis::openai::OpenAIClient;
6use crate::prompts::DEFAULT_AGENT_PROMPT;
7use crate::tools::code::parser::CodeParser;
8use anyhow::{Context, Result};
9use std::sync::Arc;
10use tokio::sync::mpsc;
11
12#[derive(Clone)]
13pub enum LLMProvider {
14    Anthropic,
15    OpenAI,
16    Ollama,
17}
18
19#[derive(Clone)]
20pub struct Agent {
21    provider: LLMProvider,
22    model: Option<String>,
23    api_client: Option<DynApiClient>,
24    system_prompt: Option<String>,
25    progress_sender: Option<mpsc::Sender<String>>,
26    code_parser: Option<Arc<CodeParser>>,
27    // Store the conversation history
28    conversation_history: Vec<crate::apis::api_client::Message>,
29}
30
31impl Agent {
32    pub fn new(provider: LLMProvider) -> Self {
33        Self {
34            provider,
35            model: None,
36            api_client: None,
37            system_prompt: None,
38            progress_sender: None,
39            code_parser: None,
40            conversation_history: Vec::new(),
41        }
42    }
43
44    pub fn new_with_api_key(provider: LLMProvider, api_key: String) -> Self {
45        // Create a new agent with the given provider and API key
46        // The API key will be used during initialization
47        let mut agent = Self::new(provider);
48        // Store the API key as the model temporarily
49        // It will be handled properly in initialize_with_api_key
50        agent.model = Some(api_key);
51        agent
52    }
53
54    pub fn with_model(mut self, model: String) -> Self {
55        self.model = Some(model);
56        self
57    }
58
59    pub fn with_system_prompt(mut self, prompt: String) -> Self {
60        self.system_prompt = Some(prompt);
61        self
62    }
63
64    pub fn with_progress_sender(mut self, sender: mpsc::Sender<String>) -> Self {
65        self.progress_sender = Some(sender);
66        self
67    }
68
69    pub fn clear_history(&mut self) {
70        self.conversation_history.clear();
71    }
72
73    /// Add a message to the conversation history
74    pub fn add_message(&mut self, message: Message) {
75        self.conversation_history.push(message);
76    }
77
78    pub async fn initialize(&mut self) -> Result<()> {
79        // Create the API client based on provider and model
80        self.api_client = Some(match self.provider {
81            LLMProvider::Anthropic => {
82                let client = AnthropicClient::new(self.model.clone())?;
83                ApiClientEnum::Anthropic(Arc::new(client))
84            }
85            LLMProvider::OpenAI => {
86                let client = OpenAIClient::new(self.model.clone())?;
87                ApiClientEnum::OpenAi(Arc::new(client))
88            }
89            LLMProvider::Ollama => {
90                let client = OllamaClient::new(self.model.clone())?;
91                ApiClientEnum::Ollama(Arc::new(client))
92            }
93        });
94
95        // Initialize the code parser
96        let parser = CodeParser::new()?;
97        self.code_parser = Some(Arc::new(parser));
98
99        Ok(())
100    }
101
102    pub async fn initialize_with_api_key(&mut self, api_key: String) -> Result<()> {
103        // Create the API client based on provider and model, using the provided API key
104        self.api_client = Some(match self.provider {
105            LLMProvider::Anthropic => {
106                let client = AnthropicClient::with_api_key(api_key, self.model.clone())?;
107                ApiClientEnum::Anthropic(Arc::new(client))
108            }
109            LLMProvider::OpenAI => {
110                let client = OpenAIClient::with_api_key(api_key, self.model.clone())?;
111                ApiClientEnum::OpenAi(Arc::new(client))
112            }
113            LLMProvider::Ollama => {
114                // For Ollama, we'll use the api_key as the base URL if provided
115                // Otherwise, use the default localhost URL
116                let client = if api_key.trim().is_empty() {
117                    OllamaClient::new(self.model.clone())?
118                } else {
119                    // Treat the "API key" as the base URL for Ollama
120                    let model = self
121                        .model
122                        .clone()
123                        .unwrap_or_else(|| "qwen2.5-coder:14b".to_string());
124                    OllamaClient::with_base_url(model, api_key)?
125                };
126                ApiClientEnum::Ollama(Arc::new(client))
127            }
128        });
129
130        // Initialize the code parser
131        let parser = CodeParser::new()?;
132        self.code_parser = Some(Arc::new(parser));
133
134        Ok(())
135    }
136
137    pub async fn execute(&self, query: &str) -> Result<String> {
138        let api_client = self
139            .api_client
140            .as_ref()
141            .context("Agent not initialized. Call initialize() first.")?;
142
143        // Create and configure executor with persisted conversation history
144        let mut executor = AgentExecutor::new(api_client.clone());
145
146        // Add existing conversation history if any
147        if !self.conversation_history.is_empty() {
148            executor.set_conversation_history(self.conversation_history.clone());
149        }
150
151        // Log the conversation history we're passing to the executor only when debug is explicitly enabled
152        let is_debug_mode = std::env::var("RUST_LOG")
153            .map(|v| v.contains("debug"))
154            .unwrap_or(false);
155
156        if is_debug_mode {
157            if let Some(progress_sender) = &self.progress_sender {
158                let _ = progress_sender.try_send(format!(
159                    "[debug] Agent execute with history: {} messages",
160                    self.conversation_history.len()
161                ));
162                for (i, msg) in self.conversation_history.iter().enumerate() {
163                    let _ = progress_sender.try_send(format!(
164                        "[debug]   History message {}: role={}, preview={}",
165                        i,
166                        msg.role,
167                        if msg.content.len() > 30 {
168                            format!("{}...", &msg.content[..30])
169                        } else {
170                            msg.content.clone()
171                        }
172                    ));
173                }
174            }
175        }
176
177        // Add progress sender if available
178        if let Some(sender) = &self.progress_sender {
179            executor = executor.with_progress_sender(sender.clone());
180        }
181
182        // Always preserve system message at the beginning - if it doesn't exist
183        let has_system_message = self
184            .conversation_history
185            .iter()
186            .any(|msg| msg.role == "system");
187
188        // Add system prompt if it doesn't exist in history
189        if !has_system_message {
190            // Add system prompt if available
191            if let Some(system_prompt) = &self.system_prompt {
192                executor.add_system_message(system_prompt.clone());
193            } else {
194                // Use default system prompt
195                executor.add_system_message(DEFAULT_AGENT_PROMPT.to_string());
196            }
197        }
198
199        // Add the original user query
200        executor.add_user_message(query.to_string());
201
202        // Let the executor determine if codebase parsing is needed
203        // It will use the updated might_need_codebase_parsing method that relies on the LLM
204        // This happens within executor.execute() and adds a suggestion to use ParseCode tool
205        // when appropriate, rather than automatically parsing everything
206
207        // Execute and get result
208        let result = executor.execute().await?;
209
210        // Save updated conversation history for future calls
211        // We need to make sure we preserve the system message in the history
212        if let Some(mutable_self) = unsafe { (self as *const Self as *mut Self).as_mut() } {
213            // Get updated history from executor
214            let mut updated_history = executor.get_conversation_history();
215
216            // Make sure we have a system message, without it conversation history won't work properly
217            let has_system_in_updated = updated_history.iter().any(|msg| msg.role == "system");
218
219            // Always ensure we have a system message
220            if !has_system_in_updated {
221                // Get system message from original history or from system_prompt
222                let system_content = mutable_self
223                    .conversation_history
224                    .iter()
225                    .find(|msg| msg.role == "system")
226                    .map(|msg| msg.content.clone())
227                    .or_else(|| mutable_self.system_prompt.clone())
228                    .unwrap_or_else(|| DEFAULT_AGENT_PROMPT.to_string());
229
230                // Insert system message at the beginning
231                updated_history.insert(0, Message::system(system_content));
232            }
233
234            // Remove any duplicate system messages that might have been added
235            let mut seen_system = false;
236            updated_history.retain(|msg| {
237                if msg.role == "system" {
238                    if seen_system {
239                        return false; // Remove duplicate system messages
240                    }
241                    seen_system = true;
242                }
243                true
244            });
245
246            // Make sure the system message is at the beginning
247            updated_history.sort_by(|a, b| {
248                if a.role == "system" {
249                    std::cmp::Ordering::Less
250                } else if b.role == "system" {
251                    std::cmp::Ordering::Greater
252                } else {
253                    std::cmp::Ordering::Equal
254                }
255            });
256
257            // Update the history
258            mutable_self.conversation_history = updated_history;
259
260            // Debug: Log the updated conversation history only when debug is explicitly enabled
261            let is_debug_mode = std::env::var("RUST_LOG")
262                .map(|v| v.contains("debug"))
263                .unwrap_or(false);
264
265            if is_debug_mode {
266                if let Some(progress_sender) = &self.progress_sender {
267                    let _ = progress_sender.try_send(format!(
268                        "[debug] Updated conversation history: {} messages",
269                        mutable_self.conversation_history.len()
270                    ));
271                    for (i, msg) in mutable_self.conversation_history.iter().enumerate() {
272                        let _ = progress_sender.try_send(format!(
273                            "[debug]   Updated message {}: role={}, preview={}",
274                            i,
275                            msg.role,
276                            if msg.content.len() > 30 {
277                                format!("{}...", &msg.content[..30])
278                            } else {
279                                msg.content.clone()
280                            }
281                        ));
282                    }
283                }
284            }
285        }
286
287        Ok(result)
288    }
289}