1pub mod backend;
10mod preflight;
11pub mod session;
12pub mod git_session;
13
14use crate::config::{LlmProvider, PawanConfig};
15use crate::tools::{ToolDefinition, ToolRegistry};
16use crate::{PawanError, Result};
17use backend::openai_compat::{OpenAiCompatBackend, OpenAiCompatConfig};
18use backend::LlmBackend;
19use serde::{Deserialize, Serialize};
20use serde_json::{json, Value};
21use std::path::PathBuf;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Message {
26 pub role: Role,
28 pub content: String,
30 #[serde(default)]
32 pub tool_calls: Vec<ToolCallRequest>,
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub tool_result: Option<ToolResultMessage>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40#[serde(rename_all = "lowercase")]
41pub enum Role {
42 System,
43 User,
44 Assistant,
45 Tool,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ToolCallRequest {
51 pub id: String,
53 pub name: String,
55 pub arguments: Value,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ToolResultMessage {
62 pub tool_call_id: String,
64 pub content: Value,
66 pub success: bool,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ToolCallRecord {
73 pub id: String,
75 pub name: String,
77 pub arguments: Value,
79 pub result: Value,
81 pub success: bool,
83 pub duration_ms: u64,
85}
86
87#[derive(Debug, Clone, Default, Serialize, Deserialize)]
89pub struct TokenUsage {
90 pub prompt_tokens: u64,
91 pub completion_tokens: u64,
92 pub total_tokens: u64,
93 pub reasoning_tokens: u64,
95 pub action_tokens: u64,
97}
98
99#[derive(Debug, Clone)]
101pub struct LLMResponse {
102 pub content: String,
104 pub reasoning: Option<String>,
106 pub tool_calls: Vec<ToolCallRequest>,
108 pub finish_reason: String,
110 pub usage: Option<TokenUsage>,
112}
113
114#[derive(Debug)]
116pub struct AgentResponse {
117 pub content: String,
119 pub tool_calls: Vec<ToolCallRecord>,
121 pub iterations: usize,
123 pub usage: TokenUsage,
125}
126
127pub type TokenCallback = Box<dyn Fn(&str) + Send + Sync>;
129
130pub type ToolCallback = Box<dyn Fn(&ToolCallRecord) + Send + Sync>;
132
133pub type ToolStartCallback = Box<dyn Fn(&str) + Send + Sync>;
135
136pub struct PawanAgent {
147 config: PawanConfig,
149 tools: ToolRegistry,
151 history: Vec<Message>,
153 workspace_root: PathBuf,
155 backend: Box<dyn LlmBackend>,
157
158 context_tokens_estimate: usize,
160
161 eruka: Option<crate::eruka_bridge::ErukaClient>,
163}
164
165impl PawanAgent {
166 pub fn new(config: PawanConfig, workspace_root: PathBuf) -> Self {
168 let tools = ToolRegistry::with_defaults(workspace_root.clone());
169 let system_prompt = config.get_system_prompt();
170 let backend = Self::create_backend(&config, &system_prompt);
171 let eruka = if config.eruka.enabled {
172 Some(crate::eruka_bridge::ErukaClient::new(config.eruka.clone()))
173 } else {
174 None
175 };
176
177 Self {
178 config,
179 tools,
180 history: Vec::new(),
181 workspace_root,
182 backend,
183 context_tokens_estimate: 0,
184 eruka,
185 }
186 }
187
188 fn create_backend(config: &PawanConfig, system_prompt: &str) -> Box<dyn LlmBackend> {
190 match config.provider {
191 LlmProvider::Nvidia | LlmProvider::OpenAI | LlmProvider::Mlx => {
192 let (api_url, api_key) = match config.provider {
193 LlmProvider::Nvidia => {
194 let url = std::env::var("NVIDIA_API_URL")
195 .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
196 let key = std::env::var("NVIDIA_API_KEY").ok();
197 if key.is_none() {
198 tracing::warn!("NVIDIA_API_KEY not set. Add it to .env or export it.");
199 }
200 (url, key)
201 },
202 LlmProvider::OpenAI => {
203 let url = config.base_url.clone()
204 .or_else(|| std::env::var("OPENAI_API_URL").ok())
205 .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
206 let key = std::env::var("OPENAI_API_KEY").ok();
207 (url, key)
208 },
209 LlmProvider::Mlx => {
210 let url = config.base_url.clone()
212 .unwrap_or_else(|| "http://localhost:8080/v1".to_string());
213 tracing::info!(url = %url, "Using MLX LM server (Apple Silicon native)");
214 (url, None) },
216 _ => unreachable!(),
217 };
218
219 let cloud = config.cloud.as_ref().map(|c| {
221 let (cloud_url, cloud_key) = match c.provider {
222 LlmProvider::Nvidia => {
223 let url = std::env::var("NVIDIA_API_URL")
224 .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
225 let key = std::env::var("NVIDIA_API_KEY").ok();
226 (url, key)
227 },
228 LlmProvider::OpenAI => {
229 let url = std::env::var("OPENAI_API_URL")
230 .unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
231 let key = std::env::var("OPENAI_API_KEY").ok();
232 (url, key)
233 },
234 LlmProvider::Mlx => {
235 ("http://localhost:8080/v1".to_string(), None)
236 },
237 _ => {
238 tracing::warn!("Cloud fallback only supports nvidia/openai/mlx providers");
239 ("https://integrate.api.nvidia.com/v1".to_string(), None)
240 }
241 };
242 backend::openai_compat::CloudFallback {
243 api_url: cloud_url,
244 api_key: cloud_key,
245 model: c.model.clone(),
246 fallback_models: c.fallback_models.clone(),
247 }
248 });
249
250 Box::new(OpenAiCompatBackend::new(OpenAiCompatConfig {
251 api_url,
252 api_key,
253 model: config.model.clone(),
254 temperature: config.temperature,
255 top_p: config.top_p,
256 max_tokens: config.max_tokens,
257 system_prompt: system_prompt.to_string(),
258 use_thinking: config.thinking_budget == 0 && config.use_thinking_mode(),
261 max_retries: config.max_retries,
262 fallback_models: config.fallback_models.clone(),
263 cloud,
264 }))
265 }
266 LlmProvider::Ollama => {
267 let url = std::env::var("OLLAMA_URL")
268 .unwrap_or_else(|_| "http://localhost:11434".to_string());
269
270 Box::new(backend::ollama::OllamaBackend::new(
271 url,
272 config.model.clone(),
273 config.temperature,
274 system_prompt.to_string(),
275 ))
276 }
277 }
278 }
279
280 pub fn with_tools(mut self, tools: ToolRegistry) -> Self {
282 self.tools = tools;
283 self
284 }
285
286 pub fn tools_mut(&mut self) -> &mut ToolRegistry {
288 &mut self.tools
289 }
290
291 pub fn with_backend(mut self, backend: Box<dyn LlmBackend>) -> Self {
293 self.backend = backend;
294 self
295 }
296
297 pub fn history(&self) -> &[Message] {
299 &self.history
300 }
301
302 pub fn save_session(&self) -> Result<String> {
304 let mut session = session::Session::new(&self.config.model);
305 session.messages = self.history.clone();
306 session.total_tokens = self.context_tokens_estimate as u64;
307 session.save()?;
308 Ok(session.id)
309 }
310
311 pub fn resume_session(&mut self, session_id: &str) -> Result<()> {
313 let session = session::Session::load(session_id)?;
314 self.history = session.messages;
315 self.context_tokens_estimate = session.total_tokens as usize;
316 Ok(())
317 }
318
319 pub fn config(&self) -> &PawanConfig {
321 &self.config
322 }
323
324 pub fn clear_history(&mut self) {
326 self.history.clear();
327 }
328 fn prune_history(&mut self) {
332 let len = self.history.len();
333 if len <= 5 {
334 return; }
336
337 let keep_end = 4;
338 let start = 1; let end = len - keep_end;
340 let pruned_count = end - start;
341
342 let mut summary = String::new();
344 for msg in &self.history[start..end] {
345 let chunk = if msg.content.len() > 200 {
346 &msg.content[..200]
347 } else {
348 &msg.content
349 };
350 summary.push_str(chunk);
351 summary.push('\n');
352 if summary.len() > 2000 {
353 summary.truncate(2000);
354 break;
355 }
356 }
357
358 let summary_msg = Message {
359 role: Role::System,
360 content: format!("Previous conversation summary (pruned): {}", summary),
361 tool_calls: vec![],
362 tool_result: None,
363 };
364
365 let first = self.history[0].clone();
367 let tail: Vec<Message> = self.history[len - keep_end..].to_vec();
368
369 self.history.clear();
370 self.history.push(first);
371 self.history.push(summary_msg);
372 self.history.extend(tail);
373
374 tracing::info!(pruned = pruned_count, context_estimate = self.context_tokens_estimate, "Pruned messages from history");
375 }
376
377 pub fn add_message(&mut self, message: Message) {
379 self.history.push(message);
380 }
381
382 pub fn switch_model(&mut self, model: &str) {
384 self.config.model = model.to_string();
385 let system_prompt = self.config.get_system_prompt();
386 self.backend = Self::create_backend(&self.config, &system_prompt);
387 tracing::info!(model = model, "Model switched at runtime");
388 }
389
390 pub fn model_name(&self) -> &str {
392 &self.config.model
393 }
394
395 pub fn get_tool_definitions(&self) -> Vec<ToolDefinition> {
397 self.tools.get_definitions()
398 }
399
400 pub async fn execute(&mut self, user_prompt: &str) -> Result<AgentResponse> {
402 self.execute_with_callbacks(user_prompt, None, None, None)
403 .await
404 }
405
406 pub async fn execute_with_callbacks(
408 &mut self,
409 user_prompt: &str,
410 on_token: Option<TokenCallback>,
411 on_tool: Option<ToolCallback>,
412 on_tool_start: Option<ToolStartCallback>,
413 ) -> Result<AgentResponse> {
414 if let Some(eruka) = &self.eruka {
416 if let Err(e) = eruka.inject_core_memory(&mut self.history).await {
417 tracing::warn!("Eruka memory injection failed (non-fatal): {}", e);
418 }
419 }
420
421 self.history.push(Message {
422 role: Role::User,
423 content: user_prompt.to_string(),
424 tool_calls: vec![],
425 tool_result: None,
426 });
427
428 let mut all_tool_calls = Vec::new();
429 let mut total_usage = TokenUsage::default();
430 let mut iterations = 0;
431 let max_iterations = self.config.max_tool_iterations;
432
433 loop {
434 iterations += 1;
435 if iterations > max_iterations {
436 return Err(PawanError::Agent(format!(
437 "Max tool iterations ({}) exceeded",
438 max_iterations
439 )));
440 }
441
442 let remaining = max_iterations.saturating_sub(iterations);
444 if remaining == 3 && iterations > 1 {
445 self.history.push(Message {
446 role: Role::User,
447 content: format!(
448 "[SYSTEM] You have {} tool iterations remaining. \
449 Stop exploring and write the most important output now. \
450 If you have code to write, write it immediately.",
451 remaining
452 ),
453 tool_calls: vec![],
454 tool_result: None,
455 });
456 }
457 self.context_tokens_estimate = self.history.iter().map(|m| m.content.len()).sum::<usize>() / 4;
459 if self.context_tokens_estimate > self.config.max_context_tokens {
460 self.prune_history();
461 }
462
463 let latest_query = self.history.iter().rev()
466 .find(|m| m.role == Role::User)
467 .map(|m| m.content.as_str())
468 .unwrap_or("");
469 let tool_defs = self.tools.select_for_query(latest_query, 12);
470 if iterations == 1 {
471 let tool_names: Vec<&str> = tool_defs.iter().map(|t| t.name.as_str()).collect();
472 tracing::info!(tools = ?tool_names, count = tool_defs.len(), "Selected tools for query");
473 }
474
475 let response = {
477 #[allow(unused_assignments)]
478 let mut last_err = None;
479 let max_llm_retries = 3;
480 let mut attempt = 0;
481 loop {
482 attempt += 1;
483 match self.backend.generate(&self.history, &tool_defs, on_token.as_ref()).await {
484 Ok(resp) => break resp,
485 Err(e) => {
486 let err_str = e.to_string();
487 let is_transient = err_str.contains("timeout")
488 || err_str.contains("connection")
489 || err_str.contains("429")
490 || err_str.contains("500")
491 || err_str.contains("502")
492 || err_str.contains("503")
493 || err_str.contains("504")
494 || err_str.contains("reset")
495 || err_str.contains("broken pipe");
496
497 if is_transient && attempt <= max_llm_retries {
498 let delay = std::time::Duration::from_secs(2u64.pow(attempt as u32));
499 tracing::warn!(
500 attempt = attempt,
501 delay_secs = delay.as_secs(),
502 error = err_str.as_str(),
503 "LLM call failed (transient) — retrying"
504 );
505 tokio::time::sleep(delay).await;
506
507 if err_str.contains("context") || err_str.contains("token") {
509 tracing::info!("Pruning history before retry (possible context overflow)");
510 self.prune_history();
511 }
512 continue;
513 }
514
515 last_err = Some(e);
517 break {
518 tracing::error!(
520 attempt = attempt,
521 error = last_err.as_ref().map(|e| e.to_string()).unwrap_or_default().as_str(),
522 "LLM call failed permanently — returning error as content"
523 );
524 LLMResponse {
525 content: format!(
526 "LLM error after {} attempts: {}. The task could not be completed.",
527 attempt,
528 last_err.as_ref().map(|e| e.to_string()).unwrap_or_default()
529 ),
530 reasoning: None,
531 tool_calls: vec![],
532 finish_reason: "error".to_string(),
533 usage: None,
534 }
535 };
536 }
537 }
538 }
539 };
540
541 if let Some(ref usage) = response.usage {
543 total_usage.prompt_tokens += usage.prompt_tokens;
544 total_usage.completion_tokens += usage.completion_tokens;
545 total_usage.total_tokens += usage.total_tokens;
546 total_usage.reasoning_tokens += usage.reasoning_tokens;
547 total_usage.action_tokens += usage.action_tokens;
548
549 if usage.reasoning_tokens > 0 {
551 tracing::info!(
552 iteration = iterations,
553 think = usage.reasoning_tokens,
554 act = usage.action_tokens,
555 total = usage.completion_tokens,
556 "Token budget: think:{} act:{} (total:{})",
557 usage.reasoning_tokens, usage.action_tokens, usage.completion_tokens
558 );
559 }
560
561 let thinking_budget = self.config.thinking_budget;
563 if thinking_budget > 0 && usage.reasoning_tokens > thinking_budget as u64 {
564 tracing::warn!(
565 budget = thinking_budget,
566 actual = usage.reasoning_tokens,
567 "Thinking budget exceeded ({}/{} tokens)",
568 usage.reasoning_tokens, thinking_budget
569 );
570 }
571 }
572
573 let clean_content = {
575 let mut s = response.content.clone();
576 loop {
577 let lower = s.to_lowercase();
578 let open = lower.find("<think>");
579 let close = lower.find("</think>");
580 match (open, close) {
581 (Some(i), Some(j)) if j > i => {
582 let before = s[..i].trim_end().to_string();
583 let after = if s.len() > j + 8 { s[j + 8..].trim_start().to_string() } else { String::new() };
584 s = if before.is_empty() { after } else if after.is_empty() { before } else { format!("{}\n{}", before, after) };
585 }
586 _ => break,
587 }
588 }
589 s
590 };
591
592 if response.tool_calls.is_empty() {
593 let has_tools = !tool_defs.is_empty();
596 let lower = clean_content.to_lowercase();
597 let planning_prefix = lower.starts_with("let me")
598 || lower.starts_with("i'll help")
599 || lower.starts_with("i will help")
600 || lower.starts_with("sure, i")
601 || lower.starts_with("okay, i");
602 let looks_like_planning = clean_content.len() > 200 || (planning_prefix && clean_content.len() > 50);
603 if has_tools && looks_like_planning && iterations == 1 && iterations < max_iterations && response.finish_reason != "error" {
604 tracing::warn!(
605 "No tool calls at iteration {} (content: {}B) — nudging model to use tools",
606 iterations, clean_content.len()
607 );
608 self.history.push(Message {
609 role: Role::Assistant,
610 content: clean_content.clone(),
611 tool_calls: vec![],
612 tool_result: None,
613 });
614 self.history.push(Message {
615 role: Role::User,
616 content: "You must use tools to complete this task. Do NOT just describe what you would do — actually call the tools. Start with bash or read_file.".to_string(),
617 tool_calls: vec![],
618 tool_result: None,
619 });
620 continue;
621 }
622
623 if iterations > 1 {
625 let prev_assistant = self.history.iter().rev()
626 .find(|m| m.role == Role::Assistant && !m.content.is_empty());
627 if let Some(prev) = prev_assistant {
628 if prev.content.trim() == clean_content.trim() && iterations < max_iterations {
629 tracing::warn!("Repeated response detected at iteration {} — injecting correction", iterations);
630 self.history.push(Message {
631 role: Role::Assistant,
632 content: clean_content.clone(),
633 tool_calls: vec![],
634 tool_result: None,
635 });
636 self.history.push(Message {
637 role: Role::User,
638 content: "You gave the same response as before. Try a different approach. Use anchor_text in edit_file_lines, or use insert_after, or use bash with sed.".to_string(),
639 tool_calls: vec![],
640 tool_result: None,
641 });
642 continue;
643 }
644 }
645 }
646
647 self.history.push(Message {
648 role: Role::Assistant,
649 content: clean_content.clone(),
650 tool_calls: vec![],
651 tool_result: None,
652 });
653
654 return Ok(AgentResponse {
655 content: clean_content,
656 tool_calls: all_tool_calls,
657 iterations,
658 usage: total_usage,
659 });
660 }
661
662 self.history.push(Message {
663 role: Role::Assistant,
664 content: response.content.clone(),
665 tool_calls: response.tool_calls.clone(),
666 tool_result: None,
667 });
668
669 for tool_call in &response.tool_calls {
670 self.tools.activate(&tool_call.name);
672
673 if let Some(crate::config::ToolPermission::Deny) =
675 self.config.permissions.get(&tool_call.name)
676 {
677 let record = ToolCallRecord {
678 id: tool_call.id.clone(),
679 name: tool_call.name.clone(),
680 arguments: tool_call.arguments.clone(),
681 result: json!({"error": "Tool denied by permission policy"}),
682 success: false,
683 duration_ms: 0,
684 };
685
686 if let Some(ref callback) = on_tool {
687 callback(&record);
688 }
689 all_tool_calls.push(record);
690
691 self.history.push(Message {
692 role: Role::Tool,
693 content: "{\"error\": \"Tool denied by permission policy\"}".to_string(),
694 tool_calls: vec![],
695 tool_result: Some(ToolResultMessage {
696 tool_call_id: tool_call.id.clone(),
697 content: json!({"error": "Tool denied by permission policy"}),
698 success: false,
699 }),
700 });
701 continue;
702 }
703
704 if let Some(ref callback) = on_tool_start {
706 callback(&tool_call.name);
707 }
708
709 tracing::debug!(
711 tool = tool_call.name.as_str(),
712 args_len = serde_json::to_string(&tool_call.arguments).unwrap_or_default().len(),
713 "Tool call: {}({})",
714 tool_call.name,
715 serde_json::to_string(&tool_call.arguments)
716 .unwrap_or_default()
717 .chars()
718 .take(200)
719 .collect::<String>()
720 );
721
722 let start = std::time::Instant::now();
723
724 let result = {
726 let tool_future = self.tools.execute(&tool_call.name, tool_call.arguments.clone());
727 let timeout_dur = if tool_call.name == "bash" {
729 std::time::Duration::from_secs(self.config.bash_timeout_secs)
730 } else {
731 std::time::Duration::from_secs(30)
732 };
733 match tokio::time::timeout(timeout_dur, tool_future).await {
734 Ok(inner) => inner,
735 Err(_) => Err(PawanError::Tool(format!(
736 "Tool '{}' timed out after {}s", tool_call.name, timeout_dur.as_secs()
737 ))),
738 }
739 };
740 let duration_ms = start.elapsed().as_millis() as u64;
741
742 let (result_value, success) = match result {
743 Ok(v) => (v, true),
744 Err(e) => {
745 tracing::warn!(tool = tool_call.name.as_str(), error = %e, "Tool execution failed");
746 (json!({"error": e.to_string(), "tool": tool_call.name, "hint": "Try a different approach or tool"}), false)
747 }
748 };
749
750 let max_result_chars = self.config.max_result_chars;
752 let result_value = {
753 let result_str = serde_json::to_string(&result_value).unwrap_or_default();
754 if result_str.len() > max_result_chars {
755 let truncated: String = result_str.chars().take(max_result_chars).collect();
757 let truncated = truncated.as_str();
758 serde_json::from_str(truncated).unwrap_or_else(|_| {
759 json!({"content": format!("{}...[truncated from {} chars]", truncated, result_str.len())})
760 })
761 } else {
762 result_value
763 }
764 };
765
766
767 let record = ToolCallRecord {
768 id: tool_call.id.clone(),
769 name: tool_call.name.clone(),
770 arguments: tool_call.arguments.clone(),
771 result: result_value.clone(),
772 success,
773 duration_ms,
774 };
775
776 if let Some(ref callback) = on_tool {
777 callback(&record);
778 }
779
780 all_tool_calls.push(record);
781
782 self.history.push(Message {
783 role: Role::Tool,
784 content: serde_json::to_string(&result_value).unwrap_or_default(),
785 tool_calls: vec![],
786 tool_result: Some(ToolResultMessage {
787 tool_call_id: tool_call.id.clone(),
788 content: result_value,
789 success,
790 }),
791 });
792
793 if success && tool_call.name == "write_file" {
796 let wrote_rs = tool_call.arguments.get("path")
797 .and_then(|p| p.as_str())
798 .map(|p| p.ends_with(".rs"))
799 .unwrap_or(false);
800 if wrote_rs {
801 let ws = self.workspace_root.clone();
802 let check_result = tokio::process::Command::new("cargo")
803 .arg("check")
804 .arg("--message-format=short")
805 .current_dir(&ws)
806 .output()
807 .await;
808 match check_result {
809 Ok(output) if !output.status.success() => {
810 let stderr = String::from_utf8_lossy(&output.stderr);
811 let err_msg: String = stderr.chars().take(1500).collect();
813 tracing::info!("Compile-gate: cargo check failed after write_file, injecting errors");
814 self.history.push(Message {
815 role: Role::User,
816 content: format!(
817 "[SYSTEM] cargo check failed after your write_file. Fix the errors:\n```\n{}\n```",
818 err_msg
819 ),
820 tool_calls: vec![],
821 tool_result: None,
822 });
823 }
824 Ok(_) => {
825 tracing::debug!("Compile-gate: cargo check passed");
826 }
827 Err(e) => {
828 tracing::warn!("Compile-gate: cargo check failed to run: {}", e);
829 }
830 }
831 }
832 }
833 }
834 }
835 }
836
837 pub async fn heal(&mut self) -> Result<AgentResponse> {
839 let healer = crate::healing::Healer::new(
840 self.workspace_root.clone(),
841 self.config.healing.clone(),
842 );
843
844 let diagnostics = healer.get_diagnostics().await?;
845 let failed_tests = healer.get_failed_tests().await?;
846
847 let mut prompt = format!(
848 "I need you to heal this Rust project at: {}
849
850",
851 self.workspace_root.display()
852 );
853
854 if !diagnostics.is_empty() {
855 prompt.push_str(&format!(
856 "## Compilation Issues ({} found)
857{}
858",
859 diagnostics.len(),
860 healer.format_diagnostics_for_prompt(&diagnostics)
861 ));
862 }
863
864 if !failed_tests.is_empty() {
865 prompt.push_str(&format!(
866 "## Failed Tests ({} found)
867{}
868",
869 failed_tests.len(),
870 healer.format_tests_for_prompt(&failed_tests)
871 ));
872 }
873
874 if diagnostics.is_empty() && failed_tests.is_empty() {
875 prompt.push_str("No issues found! Run cargo check and cargo test to verify.
876");
877 }
878
879 prompt.push_str("
880Fix each issue one at a time. Verify with cargo check after each fix.");
881
882 self.execute(&prompt).await
883 }
884 pub async fn heal_with_retries(&mut self, max_attempts: usize) -> Result<AgentResponse> {
886 let mut last_response = self.heal().await?;
887
888 for attempt in 1..max_attempts {
889 let fixer = crate::healing::CompilerFixer::new(self.workspace_root.clone());
890 let remaining = fixer.check().await?;
891 let errors: Vec<_> = remaining.iter().filter(|d| d.kind == crate::healing::DiagnosticKind::Error).collect();
892
893 if errors.is_empty() {
894 tracing::info!(attempts = attempt, "Healing complete");
895 return Ok(last_response);
896 }
897
898 tracing::warn!(errors = errors.len(), attempt = attempt, "Errors remain after heal attempt, retrying");
899 last_response = self.heal().await?;
900 }
901
902 tracing::info!(attempts = max_attempts, "Healing finished (may still have errors)");
903 Ok(last_response)
904 }
905 pub async fn task(&mut self, task_description: &str) -> Result<AgentResponse> {
907 let prompt = format!(
908 r#"I need you to complete the following coding task:
909
910{}
911
912The workspace is at: {}
913
914Please:
9151. First explore the codebase to understand the relevant code
9162. Make the necessary changes
9173. Verify the changes compile with `cargo check`
9184. Run relevant tests if applicable
919
920Explain your changes as you go."#,
921 task_description,
922 self.workspace_root.display()
923 );
924
925 self.execute(&prompt).await
926 }
927
928 pub async fn generate_commit_message(&mut self) -> Result<String> {
930 let prompt = r#"Please:
9311. Run `git status` to see what files are changed
9322. Run `git diff --cached` to see staged changes (or `git diff` for unstaged)
9333. Generate a concise, descriptive commit message following conventional commits format
934
935Only output the suggested commit message, nothing else."#;
936
937 let response = self.execute(prompt).await?;
938 Ok(response.content)
939 }
940}
941
942#[cfg(test)]
943mod tests {
944 use super::*;
945
946 #[test]
947 fn test_message_serialization() {
948 let msg = Message {
949 role: Role::User,
950 content: "Hello".to_string(),
951 tool_calls: vec![],
952 tool_result: None,
953 };
954
955 let json = serde_json::to_string(&msg).expect("Serialization failed");
956 assert!(json.contains("user"));
957 assert!(json.contains("Hello"));
958 }
959
960 #[test]
961 fn test_tool_call_request() {
962 let tc = ToolCallRequest {
963 id: "123".to_string(),
964 name: "read_file".to_string(),
965 arguments: json!({"path": "test.txt"}),
966 };
967
968 let json = serde_json::to_string(&tc).expect("Serialization failed");
969 assert!(json.contains("read_file"));
970 assert!(json.contains("test.txt"));
971 }
972}