1use crate::agent::executor::AgentExecutor;
2use crate::apis::anthropic::AnthropicClient;
3use crate::apis::api_client::{ApiClientEnum, DynApiClient, Message};
4use crate::apis::ollama::OllamaClient;
5use crate::apis::openai::OpenAIClient;
6use crate::prompts::DEFAULT_AGENT_PROMPT;
7use crate::tools::code::parser::CodeParser;
8use anyhow::{Context, Result};
9use std::sync::Arc;
10use tokio::sync::mpsc;
11
12#[derive(Clone)]
13pub enum LLMProvider {
14 Anthropic,
15 OpenAI,
16 Ollama,
17}
18
19#[derive(Clone)]
20pub struct Agent {
21 provider: LLMProvider,
22 model: Option<String>,
23 api_client: Option<DynApiClient>,
24 system_prompt: Option<String>,
25 progress_sender: Option<mpsc::Sender<String>>,
26 code_parser: Option<Arc<CodeParser>>,
27 conversation_history: Vec<crate::apis::api_client::Message>,
29}
30
31impl Agent {
32 pub fn new(provider: LLMProvider) -> Self {
33 Self {
34 provider,
35 model: None,
36 api_client: None,
37 system_prompt: None,
38 progress_sender: None,
39 code_parser: None,
40 conversation_history: Vec::new(),
41 }
42 }
43
44 pub fn new_with_api_key(provider: LLMProvider, api_key: String) -> Self {
45 let mut agent = Self::new(provider);
48 agent.model = Some(api_key);
51 agent
52 }
53
54 pub fn with_model(mut self, model: String) -> Self {
55 self.model = Some(model);
56 self
57 }
58
59 pub fn with_system_prompt(mut self, prompt: String) -> Self {
60 self.system_prompt = Some(prompt);
61 self
62 }
63
64 pub fn with_progress_sender(mut self, sender: mpsc::Sender<String>) -> Self {
65 self.progress_sender = Some(sender);
66 self
67 }
68
69 pub fn clear_history(&mut self) {
70 self.conversation_history.clear();
71 }
72
73 pub fn add_message(&mut self, message: Message) {
75 self.conversation_history.push(message);
76 }
77
78 pub async fn initialize(&mut self) -> Result<()> {
79 self.api_client = Some(match self.provider {
81 LLMProvider::Anthropic => {
82 let client = AnthropicClient::new(self.model.clone())?;
83 ApiClientEnum::Anthropic(Arc::new(client))
84 }
85 LLMProvider::OpenAI => {
86 let client = OpenAIClient::new(self.model.clone())?;
87 ApiClientEnum::OpenAi(Arc::new(client))
88 }
89 LLMProvider::Ollama => {
90 let client = OllamaClient::new(self.model.clone())?;
91 ApiClientEnum::Ollama(Arc::new(client))
92 }
93 });
94
95 let parser = CodeParser::new()?;
97 self.code_parser = Some(Arc::new(parser));
98
99 Ok(())
100 }
101
102 pub async fn initialize_with_api_key(&mut self, api_key: String) -> Result<()> {
103 self.api_client = Some(match self.provider {
105 LLMProvider::Anthropic => {
106 let client = AnthropicClient::with_api_key(api_key, self.model.clone())?;
107 ApiClientEnum::Anthropic(Arc::new(client))
108 }
109 LLMProvider::OpenAI => {
110 let client = OpenAIClient::with_api_key(api_key, self.model.clone())?;
111 ApiClientEnum::OpenAi(Arc::new(client))
112 }
113 LLMProvider::Ollama => {
114 let client = if api_key.trim().is_empty() {
117 OllamaClient::new(self.model.clone())?
118 } else {
119 let model = self
121 .model
122 .clone()
123 .unwrap_or_else(|| "qwen2.5-coder:14b".to_string());
124 OllamaClient::with_base_url(model, api_key)?
125 };
126 ApiClientEnum::Ollama(Arc::new(client))
127 }
128 });
129
130 let parser = CodeParser::new()?;
132 self.code_parser = Some(Arc::new(parser));
133
134 Ok(())
135 }
136
137 pub async fn execute(&self, query: &str) -> Result<String> {
138 let api_client = self
139 .api_client
140 .as_ref()
141 .context("Agent not initialized. Call initialize() first.")?;
142
143 let mut executor = AgentExecutor::new(api_client.clone());
145
146 if !self.conversation_history.is_empty() {
148 executor.set_conversation_history(self.conversation_history.clone());
149 }
150
151 let is_debug_mode = std::env::var("RUST_LOG")
153 .map(|v| v.contains("debug"))
154 .unwrap_or(false);
155
156 if is_debug_mode {
157 if let Some(progress_sender) = &self.progress_sender {
158 let _ = progress_sender.try_send(format!(
159 "[debug] Agent execute with history: {} messages",
160 self.conversation_history.len()
161 ));
162 for (i, msg) in self.conversation_history.iter().enumerate() {
163 let _ = progress_sender.try_send(format!(
164 "[debug] History message {}: role={}, preview={}",
165 i,
166 msg.role,
167 if msg.content.len() > 30 {
168 format!("{}...", &msg.content[..30])
169 } else {
170 msg.content.clone()
171 }
172 ));
173 }
174 }
175 }
176
177 if let Some(sender) = &self.progress_sender {
179 executor = executor.with_progress_sender(sender.clone());
180 }
181
182 let has_system_message = self
184 .conversation_history
185 .iter()
186 .any(|msg| msg.role == "system");
187
188 if !has_system_message {
190 if let Some(system_prompt) = &self.system_prompt {
192 executor.add_system_message(system_prompt.clone());
193 } else {
194 executor.add_system_message(DEFAULT_AGENT_PROMPT.to_string());
196 }
197 }
198
199 executor.add_user_message(query.to_string());
201
202 let result = executor.execute().await?;
209
210 if let Some(mutable_self) = unsafe { (self as *const Self as *mut Self).as_mut() } {
213 let mut updated_history = executor.get_conversation_history();
215
216 let has_system_in_updated = updated_history.iter().any(|msg| msg.role == "system");
218
219 if !has_system_in_updated {
221 let system_content = mutable_self
223 .conversation_history
224 .iter()
225 .find(|msg| msg.role == "system")
226 .map(|msg| msg.content.clone())
227 .or_else(|| mutable_self.system_prompt.clone())
228 .unwrap_or_else(|| DEFAULT_AGENT_PROMPT.to_string());
229
230 updated_history.insert(0, Message::system(system_content));
232 }
233
234 let mut seen_system = false;
236 updated_history.retain(|msg| {
237 if msg.role == "system" {
238 if seen_system {
239 return false; }
241 seen_system = true;
242 }
243 true
244 });
245
246 updated_history.sort_by(|a, b| {
248 if a.role == "system" {
249 std::cmp::Ordering::Less
250 } else if b.role == "system" {
251 std::cmp::Ordering::Greater
252 } else {
253 std::cmp::Ordering::Equal
254 }
255 });
256
257 mutable_self.conversation_history = updated_history;
259
260 let is_debug_mode = std::env::var("RUST_LOG")
262 .map(|v| v.contains("debug"))
263 .unwrap_or(false);
264
265 if is_debug_mode {
266 if let Some(progress_sender) = &self.progress_sender {
267 let _ = progress_sender.try_send(format!(
268 "[debug] Updated conversation history: {} messages",
269 mutable_self.conversation_history.len()
270 ));
271 for (i, msg) in mutable_self.conversation_history.iter().enumerate() {
272 let _ = progress_sender.try_send(format!(
273 "[debug] Updated message {}: role={}, preview={}",
274 i,
275 msg.role,
276 if msg.content.len() > 30 {
277 format!("{}...", &msg.content[..30])
278 } else {
279 msg.content.clone()
280 }
281 ));
282 }
283 }
284 }
285 }
286
287 Ok(result)
288 }
289}