Skip to main content

pawan/agent/
mod.rs

1//! Pawan Agent - The core agent that handles tool-calling loops
2//!
3//! This module provides the main `PawanAgent` which:
4//! - Manages conversation history
5//! - Coordinates tool calling with the LLM via pluggable backends
6//! - Provides streaming responses
7//! - Supports multiple LLM backends (NVIDIA API, Ollama, OpenAI)
8
9pub mod backend;
10mod preflight;
11pub mod session;
12pub mod git_session;
13
14use crate::config::{LlmProvider, PawanConfig};
15use crate::tools::{ToolDefinition, ToolRegistry};
16use crate::{PawanError, Result};
17use backend::openai_compat::{OpenAiCompatBackend, OpenAiCompatConfig};
18use backend::LlmBackend;
19use serde::{Deserialize, Serialize};
20use serde_json::{json, Value};
21use std::path::PathBuf;
22
23/// A message in the conversation
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Message {
26    /// Role of the message sender
27    pub role: Role,
28    /// Content of the message
29    pub content: String,
30    /// Tool calls (if any)
31    #[serde(default)]
32    pub tool_calls: Vec<ToolCallRequest>,
33    /// Tool results (if this is a tool result message)
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub tool_result: Option<ToolResultMessage>,
36}
37
38/// Role of a message sender
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40#[serde(rename_all = "lowercase")]
41pub enum Role {
42    System,
43    User,
44    Assistant,
45    Tool,
46}
47
48/// A request to call a tool
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ToolCallRequest {
51    /// Unique ID for this tool call
52    pub id: String,
53    /// Name of the tool to call
54    pub name: String,
55    /// Arguments for the tool
56    pub arguments: Value,
57}
58
59/// Result from a tool execution
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ToolResultMessage {
62    /// ID of the tool call this result is for
63    pub tool_call_id: String,
64    /// The result content
65    pub content: Value,
66    /// Whether the tool executed successfully
67    pub success: bool,
68}
69
70/// Record of a tool call execution
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ToolCallRecord {
73    /// Unique ID for this tool call
74    pub id: String,
75    /// Name of the tool
76    pub name: String,
77    /// Arguments passed to the tool
78    pub arguments: Value,
79    /// Result from the tool
80    pub result: Value,
81    /// Whether execution was successful
82    pub success: bool,
83    /// Duration in milliseconds
84    pub duration_ms: u64,
85}
86
87/// Token usage from an LLM response
88#[derive(Debug, Clone, Default, Serialize, Deserialize)]
89pub struct TokenUsage {
90    pub prompt_tokens: u64,
91    pub completion_tokens: u64,
92    pub total_tokens: u64,
93}
94
95/// LLM response from a generation request
96#[derive(Debug, Clone)]
97pub struct LLMResponse {
98    /// Text content of the response
99    pub content: String,
100    /// Tool calls requested by the model
101    pub tool_calls: Vec<ToolCallRequest>,
102    /// Reason the response finished
103    pub finish_reason: String,
104    /// Token usage (if available)
105    pub usage: Option<TokenUsage>,
106}
107
108/// Result from a complete agent execution
109#[derive(Debug)]
110pub struct AgentResponse {
111    /// Final text response
112    pub content: String,
113    /// All tool calls made during execution
114    pub tool_calls: Vec<ToolCallRecord>,
115    /// Number of iterations taken
116    pub iterations: usize,
117    /// Cumulative token usage across all iterations
118    pub usage: TokenUsage,
119}
120
121/// Callback for receiving streaming tokens
122pub type TokenCallback = Box<dyn Fn(&str) + Send + Sync>;
123
124/// Callback for receiving tool call updates
125pub type ToolCallback = Box<dyn Fn(&ToolCallRecord) + Send + Sync>;
126
127/// Callback for tool call start notifications
128pub type ToolStartCallback = Box<dyn Fn(&str) + Send + Sync>;
129
130/// The main Pawan agent
131pub struct PawanAgent {
132    /// Configuration
133    config: PawanConfig,
134    /// Tool registry
135    tools: ToolRegistry,
136    /// Conversation history
137    history: Vec<Message>,
138    /// Workspace root
139    workspace_root: PathBuf,
140    /// LLM backend
141    backend: Box<dyn LlmBackend>,
142
143    /// Estimated token count for current context
144    context_tokens_estimate: usize,
145
146    /// Eruka bridge for 3-tier memory injection
147    eruka: Option<crate::eruka_bridge::ErukaClient>,
148}
149
150impl PawanAgent {
151    /// Create a new PawanAgent with auto-selected backend
152    pub fn new(config: PawanConfig, workspace_root: PathBuf) -> Self {
153        let tools = ToolRegistry::with_defaults(workspace_root.clone());
154        let system_prompt = config.get_system_prompt();
155        let backend = Self::create_backend(&config, &system_prompt);
156        let eruka = if config.eruka.enabled {
157            Some(crate::eruka_bridge::ErukaClient::new(config.eruka.clone()))
158        } else {
159            None
160        };
161
162        Self {
163            config,
164            tools,
165            history: Vec::new(),
166            workspace_root,
167            backend,
168            context_tokens_estimate: 0,
169            eruka,
170        }
171    }
172
173    /// Create the appropriate backend based on config
174    fn create_backend(config: &PawanConfig, system_prompt: &str) -> Box<dyn LlmBackend> {
175        match config.provider {
176            LlmProvider::Nvidia | LlmProvider::OpenAI => {
177                let (api_url, api_key) = match config.provider {
178                    LlmProvider::Nvidia => {
179                        let url = std::env::var("NVIDIA_API_URL")
180                            .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
181                        let key = std::env::var("NVIDIA_API_KEY").ok();
182                        if key.is_none() {
183                            tracing::warn!("NVIDIA_API_KEY not set. Add it to .env or export it.");
184                        }
185                        (url, key)
186                    },
187                    LlmProvider::OpenAI => {
188                        let url = std::env::var("OPENAI_API_URL")
189                            .unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
190                        let key = std::env::var("OPENAI_API_KEY").ok();
191                        if key.is_none() {
192                            tracing::warn!("OPENAI_API_KEY not set. Add it to .env or export it.");
193                        }
194                        (url, key)
195                    },
196                    _ => unreachable!(),
197                };
198                
199                // Build cloud fallback if configured
200                let cloud = config.cloud.as_ref().map(|c| {
201                    let (cloud_url, cloud_key) = match c.provider {
202                        LlmProvider::Nvidia => {
203                            let url = std::env::var("NVIDIA_API_URL")
204                                .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
205                            let key = std::env::var("NVIDIA_API_KEY").ok();
206                            (url, key)
207                        },
208                        LlmProvider::OpenAI => {
209                            let url = std::env::var("OPENAI_API_URL")
210                                .unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
211                            let key = std::env::var("OPENAI_API_KEY").ok();
212                            (url, key)
213                        },
214                        _ => {
215                            tracing::warn!("Cloud fallback only supports nvidia/openai providers");
216                            ("https://integrate.api.nvidia.com/v1".to_string(), None)
217                        }
218                    };
219                    backend::openai_compat::CloudFallback {
220                        api_url: cloud_url,
221                        api_key: cloud_key,
222                        model: c.model.clone(),
223                        fallback_models: c.fallback_models.clone(),
224                    }
225                });
226
227                Box::new(OpenAiCompatBackend::new(OpenAiCompatConfig {
228                    api_url,
229                    api_key,
230                    model: config.model.clone(),
231                    temperature: config.temperature,
232                    top_p: config.top_p,
233                    max_tokens: config.max_tokens,
234                    system_prompt: system_prompt.to_string(),
235                    use_thinking: config.use_thinking_mode(),
236                    max_retries: config.max_retries,
237                    fallback_models: config.fallback_models.clone(),
238                    cloud,
239                }))
240            }
241            LlmProvider::Ollama => {
242                let url = std::env::var("OLLAMA_URL")
243                    .unwrap_or_else(|_| "http://localhost:11434".to_string());
244
245                Box::new(backend::ollama::OllamaBackend::new(
246                    url,
247                    config.model.clone(),
248                    config.temperature,
249                    system_prompt.to_string(),
250                ))
251            }
252        }
253    }
254
255    /// Create with a specific tool registry
256    pub fn with_tools(mut self, tools: ToolRegistry) -> Self {
257        self.tools = tools;
258        self
259    }
260
261    /// Get mutable access to the tool registry (for registering MCP tools)
262    pub fn tools_mut(&mut self) -> &mut ToolRegistry {
263        &mut self.tools
264    }
265
266    /// Create with a custom backend
267    pub fn with_backend(mut self, backend: Box<dyn LlmBackend>) -> Self {
268        self.backend = backend;
269        self
270    }
271
272    /// Get the current conversation history
273    pub fn history(&self) -> &[Message] {
274        &self.history
275    }
276
277    /// Save current conversation as a session, returns session ID
278    pub fn save_session(&self) -> Result<String> {
279        let mut session = session::Session::new(&self.config.model);
280        session.messages = self.history.clone();
281        session.total_tokens = self.context_tokens_estimate as u64;
282        session.save()?;
283        Ok(session.id)
284    }
285
286    /// Resume a saved session by ID
287    pub fn resume_session(&mut self, session_id: &str) -> Result<()> {
288        let session = session::Session::load(session_id)?;
289        self.history = session.messages;
290        self.context_tokens_estimate = session.total_tokens as usize;
291        Ok(())
292    }
293
294    /// Get the configuration
295    pub fn config(&self) -> &PawanConfig {
296        &self.config
297    }
298
299    /// Clear the conversation history
300    pub fn clear_history(&mut self) {
301        self.history.clear();
302    }
303    /// Prune conversation history to reduce context size.
304    /// Keeps the first message (system prompt) and last 4 messages,
305    /// replaces everything in between with a summary message.
306    fn prune_history(&mut self) {
307        let len = self.history.len();
308        if len <= 5 {
309            return; // Nothing to prune
310        }
311
312        let keep_end = 4;
313        let start = 1; // Skip system prompt at index 0
314        let end = len - keep_end;
315        let pruned_count = end - start;
316
317        // Build summary from middle messages
318        let mut summary = String::new();
319        for msg in &self.history[start..end] {
320            let chunk = if msg.content.len() > 200 {
321                &msg.content[..200]
322            } else {
323                &msg.content
324            };
325            summary.push_str(chunk);
326            summary.push('\n');
327            if summary.len() > 2000 {
328                summary.truncate(2000);
329                break;
330            }
331        }
332
333        let summary_msg = Message {
334            role: Role::System,
335            content: format!("Previous conversation summary (pruned): {}", summary),
336            tool_calls: vec![],
337            tool_result: None,
338        };
339
340        // Keep first message, insert summary, then last 4
341        let first = self.history[0].clone();
342        let tail: Vec<Message> = self.history[len - keep_end..].to_vec();
343
344        self.history.clear();
345        self.history.push(first);
346        self.history.push(summary_msg);
347        self.history.extend(tail);
348
349        tracing::info!(pruned = pruned_count, context_estimate = self.context_tokens_estimate, "Pruned messages from history");
350    }
351
352    /// Add a message to history
353    pub fn add_message(&mut self, message: Message) {
354        self.history.push(message);
355    }
356
357    /// Get tool definitions for the LLM
358    pub fn get_tool_definitions(&self) -> Vec<ToolDefinition> {
359        self.tools.get_definitions()
360    }
361
362    /// Execute a single prompt with tool calling support
363    pub async fn execute(&mut self, user_prompt: &str) -> Result<AgentResponse> {
364        self.execute_with_callbacks(user_prompt, None, None, None)
365            .await
366    }
367
368    /// Execute with optional callbacks for streaming
369    pub async fn execute_with_callbacks(
370        &mut self,
371        user_prompt: &str,
372        on_token: Option<TokenCallback>,
373        on_tool: Option<ToolCallback>,
374        on_tool_start: Option<ToolStartCallback>,
375    ) -> Result<AgentResponse> {
376        // Inject Eruka core memory before first LLM call
377        if let Some(eruka) = &self.eruka {
378            if let Err(e) = eruka.inject_core_memory(&mut self.history).await {
379                tracing::warn!("Eruka memory injection failed (non-fatal): {}", e);
380            }
381        }
382
383        self.history.push(Message {
384            role: Role::User,
385            content: user_prompt.to_string(),
386            tool_calls: vec![],
387            tool_result: None,
388        });
389
390        let mut all_tool_calls = Vec::new();
391        let mut total_usage = TokenUsage::default();
392        let mut iterations = 0;
393        let max_iterations = self.config.max_tool_iterations;
394
395        loop {
396            iterations += 1;
397            if iterations > max_iterations {
398                return Err(PawanError::Agent(format!(
399                    "Max tool iterations ({}) exceeded",
400                    max_iterations
401                )));
402            }
403            // Estimate context tokens
404            self.context_tokens_estimate = self.history.iter().map(|m| m.content.len()).sum::<usize>() / 4;
405            if self.context_tokens_estimate > self.config.max_context_tokens {
406                self.prune_history();
407            }
408
409            let tool_defs = self.tools.get_definitions();
410
411            // --- Resilient LLM call: retry on transient failures instead of crashing ---
412            let response = {
413                #[allow(unused_assignments)]
414                let mut last_err = None;
415                let max_llm_retries = 3;
416                let mut attempt = 0;
417                loop {
418                    attempt += 1;
419                    match self.backend.generate(&self.history, &tool_defs, on_token.as_ref()).await {
420                        Ok(resp) => break resp,
421                        Err(e) => {
422                            let err_str = e.to_string();
423                            let is_transient = err_str.contains("timeout")
424                                || err_str.contains("connection")
425                                || err_str.contains("429")
426                                || err_str.contains("500")
427                                || err_str.contains("502")
428                                || err_str.contains("503")
429                                || err_str.contains("504")
430                                || err_str.contains("reset")
431                                || err_str.contains("broken pipe");
432
433                            if is_transient && attempt <= max_llm_retries {
434                                let delay = std::time::Duration::from_secs(2u64.pow(attempt as u32));
435                                tracing::warn!(
436                                    attempt = attempt,
437                                    delay_secs = delay.as_secs(),
438                                    error = err_str.as_str(),
439                                    "LLM call failed (transient) — retrying"
440                                );
441                                tokio::time::sleep(delay).await;
442
443                                // If context is too large, prune before retry
444                                if err_str.contains("context") || err_str.contains("token") {
445                                    tracing::info!("Pruning history before retry (possible context overflow)");
446                                    self.prune_history();
447                                }
448                                continue;
449                            }
450
451                            // Non-transient or max retries exhausted
452                            last_err = Some(e);
453                            break {
454                                // Return a synthetic "give up" response instead of crashing
455                                tracing::error!(
456                                    attempt = attempt,
457                                    error = last_err.as_ref().map(|e| e.to_string()).unwrap_or_default().as_str(),
458                                    "LLM call failed permanently — returning error as content"
459                                );
460                                LLMResponse {
461                                    content: format!(
462                                        "LLM error after {} attempts: {}. The task could not be completed.",
463                                        attempt,
464                                        last_err.as_ref().map(|e| e.to_string()).unwrap_or_default()
465                                    ),
466                                    tool_calls: vec![],
467                                    finish_reason: "error".to_string(),
468                                    usage: None,
469                                }
470                            };
471                        }
472                    }
473                }
474            };
475
476            // Accumulate token usage
477            if let Some(ref usage) = response.usage {
478                total_usage.prompt_tokens += usage.prompt_tokens;
479                total_usage.completion_tokens += usage.completion_tokens;
480                total_usage.total_tokens += usage.total_tokens;
481            }
482
483            // --- Guardrail: strip thinking blocks from content ---
484            let clean_content = {
485                let mut s = response.content.clone();
486                loop {
487                    let lower = s.to_lowercase();
488                    let open = lower.find("<think>");
489                    let close = lower.find("</think>");
490                    match (open, close) {
491                        (Some(i), Some(j)) if j > i => {
492                            let before = s[..i].trim_end().to_string();
493                            let after = if s.len() > j + 8 { s[j + 8..].trim_start().to_string() } else { String::new() };
494                            s = if before.is_empty() { after } else if after.is_empty() { before } else { format!("{}\n{}", before, after) };
495                        }
496                        _ => break,
497                    }
498                }
499                s
500            };
501
502            if response.tool_calls.is_empty() {
503                // --- Guardrail: detect chatty no-op (content but no tools on early iterations) ---
504                // Only nudge if tools are available AND response looks like planning text (not a real answer)
505                let has_tools = !tool_defs.is_empty();
506                let lower = clean_content.to_lowercase();
507                let planning_prefix = lower.starts_with("let me")
508                    || lower.starts_with("i'll help")
509                    || lower.starts_with("i will help")
510                    || lower.starts_with("sure, i")
511                    || lower.starts_with("okay, i");
512                let looks_like_planning = clean_content.len() > 200 || (planning_prefix && clean_content.len() > 50);
513                if has_tools && looks_like_planning && iterations == 1 && iterations < max_iterations && response.finish_reason != "error" {
514                    tracing::warn!(
515                        "No tool calls at iteration {} (content: {}B) — nudging model to use tools",
516                        iterations, clean_content.len()
517                    );
518                    self.history.push(Message {
519                        role: Role::Assistant,
520                        content: clean_content.clone(),
521                        tool_calls: vec![],
522                        tool_result: None,
523                    });
524                    self.history.push(Message {
525                        role: Role::User,
526                        content: "You must use tools to complete this task. Do NOT just describe what you would do — actually call the tools. Start with bash or read_file.".to_string(),
527                        tool_calls: vec![],
528                        tool_result: None,
529                    });
530                    continue;
531                }
532
533                // --- Guardrail: detect repeated responses ---
534                if iterations > 1 {
535                    let prev_assistant = self.history.iter().rev()
536                        .find(|m| m.role == Role::Assistant && !m.content.is_empty());
537                    if let Some(prev) = prev_assistant {
538                        if prev.content.trim() == clean_content.trim() && iterations < max_iterations {
539                            tracing::warn!("Repeated response detected at iteration {} — injecting correction", iterations);
540                            self.history.push(Message {
541                                role: Role::Assistant,
542                                content: clean_content.clone(),
543                                tool_calls: vec![],
544                                tool_result: None,
545                            });
546                            self.history.push(Message {
547                                role: Role::User,
548                                content: "You gave the same response as before. Try a different approach. Use anchor_text in edit_file_lines, or use insert_after, or use bash with sed.".to_string(),
549                                tool_calls: vec![],
550                                tool_result: None,
551                            });
552                            continue;
553                        }
554                    }
555                }
556
557                self.history.push(Message {
558                    role: Role::Assistant,
559                    content: clean_content.clone(),
560                    tool_calls: vec![],
561                    tool_result: None,
562                });
563
564                return Ok(AgentResponse {
565                    content: clean_content,
566                    tool_calls: all_tool_calls,
567                    iterations,
568                    usage: total_usage,
569                });
570            }
571
572            self.history.push(Message {
573                role: Role::Assistant,
574                content: response.content.clone(),
575                tool_calls: response.tool_calls.clone(),
576                tool_result: None,
577            });
578
579            for tool_call in &response.tool_calls {
580                // Check permission
581                if let Some(crate::config::ToolPermission::Deny) =
582                    self.config.permissions.get(&tool_call.name)
583                {
584                    let record = ToolCallRecord {
585                        id: tool_call.id.clone(),
586                        name: tool_call.name.clone(),
587                        arguments: tool_call.arguments.clone(),
588                        result: json!({"error": "Tool denied by permission policy"}),
589                        success: false,
590                        duration_ms: 0,
591                    };
592
593                    if let Some(ref callback) = on_tool {
594                        callback(&record);
595                    }
596                    all_tool_calls.push(record);
597
598                    self.history.push(Message {
599                        role: Role::Tool,
600                        content: "{\"error\": \"Tool denied by permission policy\"}".to_string(),
601                        tool_calls: vec![],
602                        tool_result: Some(ToolResultMessage {
603                            tool_call_id: tool_call.id.clone(),
604                            content: json!({"error": "Tool denied by permission policy"}),
605                            success: false,
606                        }),
607                    });
608                    continue;
609                }
610
611                // Notify tool start
612                if let Some(ref callback) = on_tool_start {
613                    callback(&tool_call.name);
614                }
615
616                // Debug: log tool call args for diagnosis
617                tracing::debug!(
618                    tool = tool_call.name.as_str(),
619                    args_len = serde_json::to_string(&tool_call.arguments).unwrap_or_default().len(),
620                    "Tool call: {}({})",
621                    tool_call.name,
622                    serde_json::to_string(&tool_call.arguments)
623                        .unwrap_or_default()
624                        .chars()
625                        .take(200)
626                        .collect::<String>()
627                );
628
629                let start = std::time::Instant::now();
630
631                // Resilient tool execution: catch panics + errors
632                let result = {
633                    let tool_future = self.tools.execute(&tool_call.name, tool_call.arguments.clone());
634                    // Timeout individual tool calls (prevent hangs)
635                    let timeout_dur = if tool_call.name == "bash" {
636                        std::time::Duration::from_secs(self.config.bash_timeout_secs)
637                    } else {
638                        std::time::Duration::from_secs(30)
639                    };
640                    match tokio::time::timeout(timeout_dur, tool_future).await {
641                        Ok(inner) => inner,
642                        Err(_) => Err(PawanError::Tool(format!(
643                            "Tool '{}' timed out after {}s", tool_call.name, timeout_dur.as_secs()
644                        ))),
645                    }
646                };
647                let duration_ms = start.elapsed().as_millis() as u64;
648
649                let (result_value, success) = match result {
650                    Ok(v) => (v, true),
651                    Err(e) => {
652                        tracing::warn!(tool = tool_call.name.as_str(), error = %e, "Tool execution failed");
653                        (json!({"error": e.to_string(), "tool": tool_call.name, "hint": "Try a different approach or tool"}), false)
654                    }
655                };
656
657                // Truncate tool results that exceed max chars to prevent context bloat
658                let max_result_chars = self.config.max_result_chars;
659                let result_value = {
660                    let result_str = serde_json::to_string(&result_value).unwrap_or_default();
661                    if result_str.len() > max_result_chars {
662                        // UTF-8 safe truncation
663                        let truncated: String = result_str.chars().take(max_result_chars).collect();
664                        let truncated = truncated.as_str();
665                        serde_json::from_str(truncated).unwrap_or_else(|_| {
666                            json!({"content": format!("{}...[truncated from {} chars]", truncated, result_str.len())})
667                        })
668                    } else {
669                        result_value
670                    }
671                };
672
673
674                let record = ToolCallRecord {
675                    id: tool_call.id.clone(),
676                    name: tool_call.name.clone(),
677                    arguments: tool_call.arguments.clone(),
678                    result: result_value.clone(),
679                    success,
680                    duration_ms,
681                };
682
683                if let Some(ref callback) = on_tool {
684                    callback(&record);
685                }
686
687                all_tool_calls.push(record);
688
689                self.history.push(Message {
690                    role: Role::Tool,
691                    content: serde_json::to_string(&result_value).unwrap_or_default(),
692                    tool_calls: vec![],
693                    tool_result: Some(ToolResultMessage {
694                        tool_call_id: tool_call.id.clone(),
695                        content: result_value,
696                        success,
697                    }),
698                });
699            }
700        }
701    }
702
703    /// Execute a healing task with real diagnostics
704    pub async fn heal(&mut self) -> Result<AgentResponse> {
705        let healer = crate::healing::Healer::new(
706            self.workspace_root.clone(),
707            self.config.healing.clone(),
708        );
709
710        let diagnostics = healer.get_diagnostics().await?;
711        let failed_tests = healer.get_failed_tests().await?;
712
713        let mut prompt = format!(
714            "I need you to heal this Rust project at: {}
715
716",
717            self.workspace_root.display()
718        );
719
720        if !diagnostics.is_empty() {
721            prompt.push_str(&format!(
722                "## Compilation Issues ({} found)
723{}
724",
725                diagnostics.len(),
726                healer.format_diagnostics_for_prompt(&diagnostics)
727            ));
728        }
729
730        if !failed_tests.is_empty() {
731            prompt.push_str(&format!(
732                "## Failed Tests ({} found)
733{}
734",
735                failed_tests.len(),
736                healer.format_tests_for_prompt(&failed_tests)
737            ));
738        }
739
740        if diagnostics.is_empty() && failed_tests.is_empty() {
741            prompt.push_str("No issues found! Run cargo check and cargo test to verify.
742");
743        }
744
745        prompt.push_str("
746Fix each issue one at a time. Verify with cargo check after each fix.");
747
748        self.execute(&prompt).await
749    }
750    /// Execute healing with retries — calls heal(), checks for remaining errors, retries if needed
751    pub async fn heal_with_retries(&mut self, max_attempts: usize) -> Result<AgentResponse> {
752        let mut last_response = self.heal().await?;
753
754        for attempt in 1..max_attempts {
755            let fixer = crate::healing::CompilerFixer::new(self.workspace_root.clone());
756            let remaining = fixer.check().await?;
757            let errors: Vec<_> = remaining.iter().filter(|d| d.kind == crate::healing::DiagnosticKind::Error).collect();
758
759            if errors.is_empty() {
760                tracing::info!(attempts = attempt, "Healing complete");
761                return Ok(last_response);
762            }
763
764            tracing::warn!(errors = errors.len(), attempt = attempt, "Errors remain after heal attempt, retrying");
765            last_response = self.heal().await?;
766        }
767
768        tracing::info!(attempts = max_attempts, "Healing finished (may still have errors)");
769        Ok(last_response)
770    }
771    /// Execute a task with a specific prompt
772    pub async fn task(&mut self, task_description: &str) -> Result<AgentResponse> {
773        let prompt = format!(
774            r#"I need you to complete the following coding task:
775
776{}
777
778The workspace is at: {}
779
780Please:
7811. First explore the codebase to understand the relevant code
7822. Make the necessary changes
7833. Verify the changes compile with `cargo check`
7844. Run relevant tests if applicable
785
786Explain your changes as you go."#,
787            task_description,
788            self.workspace_root.display()
789        );
790
791        self.execute(&prompt).await
792    }
793
794    /// Generate a commit message for current changes
795    pub async fn generate_commit_message(&mut self) -> Result<String> {
796        let prompt = r#"Please:
7971. Run `git status` to see what files are changed
7982. Run `git diff --cached` to see staged changes (or `git diff` for unstaged)
7993. Generate a concise, descriptive commit message following conventional commits format
800
801Only output the suggested commit message, nothing else."#;
802
803        let response = self.execute(prompt).await?;
804        Ok(response.content)
805    }
806}
807
808#[cfg(test)]
809mod tests {
810    use super::*;
811
812    #[test]
813    fn test_message_serialization() {
814        let msg = Message {
815            role: Role::User,
816            content: "Hello".to_string(),
817            tool_calls: vec![],
818            tool_result: None,
819        };
820
821        let json = serde_json::to_string(&msg).unwrap();
822        assert!(json.contains("user"));
823        assert!(json.contains("Hello"));
824    }
825
826    #[test]
827    fn test_tool_call_request() {
828        let tc = ToolCallRequest {
829            id: "123".to_string(),
830            name: "read_file".to_string(),
831            arguments: json!({"path": "test.txt"}),
832        };
833
834        let json = serde_json::to_string(&tc).unwrap();
835        assert!(json.contains("read_file"));
836        assert!(json.contains("test.txt"));
837    }
838}