1pub mod backend;
10mod preflight;
11pub mod session;
12pub mod git_session;
13
14use crate::config::{LlmProvider, PawanConfig};
15use crate::tools::{ToolDefinition, ToolRegistry};
16use crate::{PawanError, Result};
17use backend::openai_compat::{OpenAiCompatBackend, OpenAiCompatConfig};
18use backend::LlmBackend;
19use serde::{Deserialize, Serialize};
20use serde_json::{json, Value};
21use std::path::PathBuf;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Message {
26 pub role: Role,
28 pub content: String,
30 #[serde(default)]
32 pub tool_calls: Vec<ToolCallRequest>,
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub tool_result: Option<ToolResultMessage>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40#[serde(rename_all = "lowercase")]
41pub enum Role {
42 System,
43 User,
44 Assistant,
45 Tool,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ToolCallRequest {
51 pub id: String,
53 pub name: String,
55 pub arguments: Value,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ToolResultMessage {
62 pub tool_call_id: String,
64 pub content: Value,
66 pub success: bool,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ToolCallRecord {
73 pub id: String,
75 pub name: String,
77 pub arguments: Value,
79 pub result: Value,
81 pub success: bool,
83 pub duration_ms: u64,
85}
86
87#[derive(Debug, Clone, Default, Serialize, Deserialize)]
89pub struct TokenUsage {
90 pub prompt_tokens: u64,
91 pub completion_tokens: u64,
92 pub total_tokens: u64,
93 pub reasoning_tokens: u64,
95 pub action_tokens: u64,
97}
98
99#[derive(Debug, Clone)]
101pub struct LLMResponse {
102 pub content: String,
104 pub reasoning: Option<String>,
106 pub tool_calls: Vec<ToolCallRequest>,
108 pub finish_reason: String,
110 pub usage: Option<TokenUsage>,
112}
113
114#[derive(Debug)]
116pub struct AgentResponse {
117 pub content: String,
119 pub tool_calls: Vec<ToolCallRecord>,
121 pub iterations: usize,
123 pub usage: TokenUsage,
125}
126
127pub type TokenCallback = Box<dyn Fn(&str) + Send + Sync>;
129
130pub type ToolCallback = Box<dyn Fn(&ToolCallRecord) + Send + Sync>;
132
133pub type ToolStartCallback = Box<dyn Fn(&str) + Send + Sync>;
135
136#[derive(Debug, Clone)]
138pub struct PermissionRequest {
139 pub tool_name: String,
141 pub args_summary: String,
143}
144
145pub type PermissionCallback =
148 Box<dyn Fn(PermissionRequest) -> tokio::sync::oneshot::Receiver<bool> + Send + Sync>;
149
150pub struct PawanAgent {
160 config: PawanConfig,
162 tools: ToolRegistry,
164 history: Vec<Message>,
166 workspace_root: PathBuf,
168 backend: Box<dyn LlmBackend>,
170
171 context_tokens_estimate: usize,
173
174 eruka: Option<crate::eruka_bridge::ErukaClient>,
176}
177
178impl PawanAgent {
179 pub fn new(config: PawanConfig, workspace_root: PathBuf) -> Self {
181 let tools = ToolRegistry::with_defaults(workspace_root.clone());
182 let system_prompt = config.get_system_prompt();
183 let backend = Self::create_backend(&config, &system_prompt);
184 let eruka = if config.eruka.enabled {
185 Some(crate::eruka_bridge::ErukaClient::new(config.eruka.clone()))
186 } else {
187 None
188 };
189
190 Self {
191 config,
192 tools,
193 history: Vec::new(),
194 workspace_root,
195 backend,
196 context_tokens_estimate: 0,
197 eruka,
198 }
199 }
200
201 fn create_backend(config: &PawanConfig, system_prompt: &str) -> Box<dyn LlmBackend> {
203 match config.provider {
204 LlmProvider::Nvidia | LlmProvider::OpenAI | LlmProvider::Mlx => {
205 let (api_url, api_key) = match config.provider {
206 LlmProvider::Nvidia => {
207 let url = std::env::var("NVIDIA_API_URL")
208 .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
209 let key = std::env::var("NVIDIA_API_KEY").ok();
210 if key.is_none() {
211 tracing::warn!("NVIDIA_API_KEY not set. Add it to .env or export it.");
212 }
213 (url, key)
214 },
215 LlmProvider::OpenAI => {
216 let url = config.base_url.clone()
217 .or_else(|| std::env::var("OPENAI_API_URL").ok())
218 .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
219 let key = std::env::var("OPENAI_API_KEY").ok();
220 (url, key)
221 },
222 LlmProvider::Mlx => {
223 let url = config.base_url.clone()
225 .unwrap_or_else(|| "http://localhost:8080/v1".to_string());
226 tracing::info!(url = %url, "Using MLX LM server (Apple Silicon native)");
227 (url, None) },
229 _ => unreachable!(),
230 };
231
232 let cloud = config.cloud.as_ref().map(|c| {
234 let (cloud_url, cloud_key) = match c.provider {
235 LlmProvider::Nvidia => {
236 let url = std::env::var("NVIDIA_API_URL")
237 .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
238 let key = std::env::var("NVIDIA_API_KEY").ok();
239 (url, key)
240 },
241 LlmProvider::OpenAI => {
242 let url = std::env::var("OPENAI_API_URL")
243 .unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
244 let key = std::env::var("OPENAI_API_KEY").ok();
245 (url, key)
246 },
247 LlmProvider::Mlx => {
248 ("http://localhost:8080/v1".to_string(), None)
249 },
250 _ => {
251 tracing::warn!("Cloud fallback only supports nvidia/openai/mlx providers");
252 ("https://integrate.api.nvidia.com/v1".to_string(), None)
253 }
254 };
255 backend::openai_compat::CloudFallback {
256 api_url: cloud_url,
257 api_key: cloud_key,
258 model: c.model.clone(),
259 fallback_models: c.fallback_models.clone(),
260 }
261 });
262
263 Box::new(OpenAiCompatBackend::new(OpenAiCompatConfig {
264 api_url,
265 api_key,
266 model: config.model.clone(),
267 temperature: config.temperature,
268 top_p: config.top_p,
269 max_tokens: config.max_tokens,
270 system_prompt: system_prompt.to_string(),
271 use_thinking: config.thinking_budget == 0 && config.use_thinking_mode(),
274 max_retries: config.max_retries,
275 fallback_models: config.fallback_models.clone(),
276 cloud,
277 }))
278 }
279 LlmProvider::Ollama => {
280 let url = std::env::var("OLLAMA_URL")
281 .unwrap_or_else(|_| "http://localhost:11434".to_string());
282
283 Box::new(backend::ollama::OllamaBackend::new(
284 url,
285 config.model.clone(),
286 config.temperature,
287 system_prompt.to_string(),
288 ))
289 }
290 }
291 }
292
293 pub fn with_tools(mut self, tools: ToolRegistry) -> Self {
295 self.tools = tools;
296 self
297 }
298
299 pub fn tools_mut(&mut self) -> &mut ToolRegistry {
301 &mut self.tools
302 }
303
304 pub fn with_backend(mut self, backend: Box<dyn LlmBackend>) -> Self {
306 self.backend = backend;
307 self
308 }
309
310 pub fn history(&self) -> &[Message] {
312 &self.history
313 }
314
315 pub fn save_session(&self) -> Result<String> {
317 let mut session = session::Session::new(&self.config.model);
318 session.messages = self.history.clone();
319 session.total_tokens = self.context_tokens_estimate as u64;
320 session.save()?;
321 Ok(session.id)
322 }
323
324 pub fn resume_session(&mut self, session_id: &str) -> Result<()> {
326 let session = session::Session::load(session_id)?;
327 self.history = session.messages;
328 self.context_tokens_estimate = session.total_tokens as usize;
329 Ok(())
330 }
331
332 pub fn config(&self) -> &PawanConfig {
334 &self.config
335 }
336
337 pub fn clear_history(&mut self) {
339 self.history.clear();
340 }
341 fn prune_history(&mut self) {
349 let len = self.history.len();
350 if len <= 5 {
351 return; }
353
354 let keep_end = 4;
355 let start = 1; let end = len - keep_end;
357 let pruned_count = end - start;
358
359 let mut scored: Vec<(f32, &Message)> = self.history[start..end]
361 .iter()
362 .map(|msg| {
363 let score = Self::message_importance(msg);
364 (score, msg)
365 })
366 .collect();
367 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
368
369 let mut summary = String::with_capacity(2048);
371 for (score, msg) in &scored {
372 let prefix = match msg.role {
373 Role::User => "User: ",
374 Role::Assistant => "Assistant: ",
375 Role::Tool => if *score > 0.7 { "Tool error: " } else { "Tool: " },
376 Role::System => "System: ",
377 };
378 let chunk: String = msg.content.chars().take(200).collect();
379 summary.push_str(prefix);
380 summary.push_str(&chunk);
381 summary.push('\n');
382 if summary.len() > 2000 {
383 let safe_end = summary.char_indices()
384 .take_while(|(i, _)| *i <= 2000)
385 .last()
386 .map(|(i, c)| i + c.len_utf8())
387 .unwrap_or(0);
388 summary.truncate(safe_end);
389 break;
390 }
391 }
392
393 let summary_msg = Message {
394 role: Role::System,
395 content: format!("Previous conversation summary (pruned {} messages, importance-ranked): {}", pruned_count, summary),
396 tool_calls: vec![],
397 tool_result: None,
398 };
399
400 self.history.drain(start..end);
401 self.history.insert(start, summary_msg);
402
403 tracing::info!(pruned = pruned_count, context_estimate = self.context_tokens_estimate, "Pruned messages from history (importance-ranked)");
404 }
405
406 fn message_importance(msg: &Message) -> f32 {
409 match msg.role {
410 Role::User => 0.6, Role::System => 0.3, Role::Assistant => {
413 if msg.content.contains("error") || msg.content.contains("Error") { 0.8 }
414 else { 0.4 }
415 }
416 Role::Tool => {
417 if let Some(ref result) = msg.tool_result {
418 if !result.success { 0.9 } else { 0.2 } } else {
421 0.3
422 }
423 }
424 }
425 }
426
427 pub fn add_message(&mut self, message: Message) {
429 self.history.push(message);
430 }
431
432 pub fn switch_model(&mut self, model: &str) {
434 self.config.model = model.to_string();
435 let system_prompt = self.config.get_system_prompt();
436 self.backend = Self::create_backend(&self.config, &system_prompt);
437 tracing::info!(model = model, "Model switched at runtime");
438 }
439
440 pub fn model_name(&self) -> &str {
442 &self.config.model
443 }
444
445 pub fn get_tool_definitions(&self) -> Vec<ToolDefinition> {
447 self.tools.get_definitions()
448 }
449
450 pub async fn execute(&mut self, user_prompt: &str) -> Result<AgentResponse> {
452 self.execute_with_callbacks(user_prompt, None, None, None)
453 .await
454 }
455
456 pub async fn execute_with_callbacks(
458 &mut self,
459 user_prompt: &str,
460 on_token: Option<TokenCallback>,
461 on_tool: Option<ToolCallback>,
462 on_tool_start: Option<ToolStartCallback>,
463 ) -> Result<AgentResponse> {
464 self.execute_with_all_callbacks(user_prompt, on_token, on_tool, on_tool_start, None)
465 .await
466 }
467
468 pub async fn execute_with_all_callbacks(
470 &mut self,
471 user_prompt: &str,
472 on_token: Option<TokenCallback>,
473 on_tool: Option<ToolCallback>,
474 on_tool_start: Option<ToolStartCallback>,
475 on_permission: Option<PermissionCallback>,
476 ) -> Result<AgentResponse> {
477 if let Some(eruka) = &self.eruka {
479 if let Err(e) = eruka.inject_core_memory(&mut self.history).await {
480 tracing::warn!("Eruka memory injection failed (non-fatal): {}", e);
481 }
482 }
483
484 self.history.push(Message {
485 role: Role::User,
486 content: user_prompt.to_string(),
487 tool_calls: vec![],
488 tool_result: None,
489 });
490
491 let mut all_tool_calls = Vec::new();
492 let mut total_usage = TokenUsage::default();
493 let mut iterations = 0;
494 let max_iterations = self.config.max_tool_iterations;
495
496 loop {
497 iterations += 1;
498 if iterations > max_iterations {
499 return Err(PawanError::Agent(format!(
500 "Max tool iterations ({}) exceeded",
501 max_iterations
502 )));
503 }
504
505 let remaining = max_iterations.saturating_sub(iterations);
507 if remaining == 3 && iterations > 1 {
508 self.history.push(Message {
509 role: Role::User,
510 content: format!(
511 "[SYSTEM] You have {} tool iterations remaining. \
512 Stop exploring and write the most important output now. \
513 If you have code to write, write it immediately.",
514 remaining
515 ),
516 tool_calls: vec![],
517 tool_result: None,
518 });
519 }
520 self.context_tokens_estimate = self.history.iter().map(|m| m.content.len()).sum::<usize>() / 4;
522 if self.context_tokens_estimate > self.config.max_context_tokens {
523 self.prune_history();
524 }
525
526 let latest_query = self.history.iter().rev()
529 .find(|m| m.role == Role::User)
530 .map(|m| m.content.as_str())
531 .unwrap_or("");
532 let tool_defs = self.tools.select_for_query(latest_query, 12);
533 if iterations == 1 {
534 let tool_names: Vec<&str> = tool_defs.iter().map(|t| t.name.as_str()).collect();
535 tracing::info!(tools = ?tool_names, count = tool_defs.len(), "Selected tools for query");
536 }
537
538 let response = {
540 #[allow(unused_assignments)]
541 let mut last_err = None;
542 let max_llm_retries = 3;
543 let mut attempt = 0;
544 loop {
545 attempt += 1;
546 match self.backend.generate(&self.history, &tool_defs, on_token.as_ref()).await {
547 Ok(resp) => break resp,
548 Err(e) => {
549 let err_str = e.to_string();
550 let is_transient = err_str.contains("timeout")
551 || err_str.contains("connection")
552 || err_str.contains("429")
553 || err_str.contains("500")
554 || err_str.contains("502")
555 || err_str.contains("503")
556 || err_str.contains("504")
557 || err_str.contains("reset")
558 || err_str.contains("broken pipe");
559
560 if is_transient && attempt <= max_llm_retries {
561 let delay = std::time::Duration::from_secs(2u64.pow(attempt as u32));
562 tracing::warn!(
563 attempt = attempt,
564 delay_secs = delay.as_secs(),
565 error = err_str.as_str(),
566 "LLM call failed (transient) — retrying"
567 );
568 tokio::time::sleep(delay).await;
569
570 if err_str.contains("context") || err_str.contains("token") {
572 tracing::info!("Pruning history before retry (possible context overflow)");
573 self.prune_history();
574 }
575 continue;
576 }
577
578 last_err = Some(e);
580 break {
581 tracing::error!(
583 attempt = attempt,
584 error = last_err.as_ref().map(|e| e.to_string()).unwrap_or_default().as_str(),
585 "LLM call failed permanently — returning error as content"
586 );
587 LLMResponse {
588 content: format!(
589 "LLM error after {} attempts: {}. The task could not be completed.",
590 attempt,
591 last_err.as_ref().map(|e| e.to_string()).unwrap_or_default()
592 ),
593 reasoning: None,
594 tool_calls: vec![],
595 finish_reason: "error".to_string(),
596 usage: None,
597 }
598 };
599 }
600 }
601 }
602 };
603
604 if let Some(ref usage) = response.usage {
606 total_usage.prompt_tokens += usage.prompt_tokens;
607 total_usage.completion_tokens += usage.completion_tokens;
608 total_usage.total_tokens += usage.total_tokens;
609 total_usage.reasoning_tokens += usage.reasoning_tokens;
610 total_usage.action_tokens += usage.action_tokens;
611
612 if usage.reasoning_tokens > 0 {
614 tracing::info!(
615 iteration = iterations,
616 think = usage.reasoning_tokens,
617 act = usage.action_tokens,
618 total = usage.completion_tokens,
619 "Token budget: think:{} act:{} (total:{})",
620 usage.reasoning_tokens, usage.action_tokens, usage.completion_tokens
621 );
622 }
623
624 let thinking_budget = self.config.thinking_budget;
626 if thinking_budget > 0 && usage.reasoning_tokens > thinking_budget as u64 {
627 tracing::warn!(
628 budget = thinking_budget,
629 actual = usage.reasoning_tokens,
630 "Thinking budget exceeded ({}/{} tokens)",
631 usage.reasoning_tokens, thinking_budget
632 );
633 }
634 }
635
636 let clean_content = {
638 let mut s = response.content.clone();
639 loop {
640 let lower = s.to_lowercase();
641 let open = lower.find("<think>");
642 let close = lower.find("</think>");
643 match (open, close) {
644 (Some(i), Some(j)) if j > i => {
645 let before = s[..i].trim_end().to_string();
646 let after = if s.len() > j + 8 { s[j + 8..].trim_start().to_string() } else { String::new() };
647 s = if before.is_empty() { after } else if after.is_empty() { before } else { format!("{}\n{}", before, after) };
648 }
649 _ => break,
650 }
651 }
652 s
653 };
654
655 if response.tool_calls.is_empty() {
656 let has_tools = !tool_defs.is_empty();
659 let lower = clean_content.to_lowercase();
660 let planning_prefix = lower.starts_with("let me")
661 || lower.starts_with("i'll help")
662 || lower.starts_with("i will help")
663 || lower.starts_with("sure, i")
664 || lower.starts_with("okay, i");
665 let looks_like_planning = clean_content.len() > 200 || (planning_prefix && clean_content.len() > 50);
666 if has_tools && looks_like_planning && iterations == 1 && iterations < max_iterations && response.finish_reason != "error" {
667 tracing::warn!(
668 "No tool calls at iteration {} (content: {}B) — nudging model to use tools",
669 iterations, clean_content.len()
670 );
671 self.history.push(Message {
672 role: Role::Assistant,
673 content: clean_content.clone(),
674 tool_calls: vec![],
675 tool_result: None,
676 });
677 self.history.push(Message {
678 role: Role::User,
679 content: "You must use tools to complete this task. Do NOT just describe what you would do — actually call the tools. Start with bash or read_file.".to_string(),
680 tool_calls: vec![],
681 tool_result: None,
682 });
683 continue;
684 }
685
686 if iterations > 1 {
688 let prev_assistant = self.history.iter().rev()
689 .find(|m| m.role == Role::Assistant && !m.content.is_empty());
690 if let Some(prev) = prev_assistant {
691 if prev.content.trim() == clean_content.trim() && iterations < max_iterations {
692 tracing::warn!("Repeated response detected at iteration {} — injecting correction", iterations);
693 self.history.push(Message {
694 role: Role::Assistant,
695 content: clean_content.clone(),
696 tool_calls: vec![],
697 tool_result: None,
698 });
699 self.history.push(Message {
700 role: Role::User,
701 content: "You gave the same response as before. Try a different approach. Use anchor_text in edit_file_lines, or use insert_after, or use bash with sed.".to_string(),
702 tool_calls: vec![],
703 tool_result: None,
704 });
705 continue;
706 }
707 }
708 }
709
710 self.history.push(Message {
711 role: Role::Assistant,
712 content: clean_content.clone(),
713 tool_calls: vec![],
714 tool_result: None,
715 });
716
717 return Ok(AgentResponse {
718 content: clean_content,
719 tool_calls: all_tool_calls,
720 iterations,
721 usage: total_usage,
722 });
723 }
724
725 self.history.push(Message {
726 role: Role::Assistant,
727 content: response.content.clone(),
728 tool_calls: response.tool_calls.clone(),
729 tool_result: None,
730 });
731
732 for tool_call in &response.tool_calls {
733 self.tools.activate(&tool_call.name);
735
736 let perm = crate::config::ToolPermission::resolve(
738 &tool_call.name, &self.config.permissions
739 );
740 let denied = match perm {
741 crate::config::ToolPermission::Deny => Some("Tool denied by permission policy"),
742 crate::config::ToolPermission::Prompt => {
743 if tool_call.name == "bash" {
745 if let Some(cmd) = tool_call.arguments.get("command").and_then(|v| v.as_str()) {
746 if crate::tools::bash::is_read_only(cmd) {
747 tracing::debug!(command = cmd, "Auto-allowing read-only bash command under Prompt permission");
748 None
749 } else if let Some(ref perm_cb) = on_permission {
750 let args_summary = cmd.chars().take(120).collect::<String>();
752 let rx = perm_cb(PermissionRequest {
753 tool_name: tool_call.name.clone(),
754 args_summary,
755 });
756 match rx.await {
757 Ok(true) => None,
758 _ => Some("User denied tool execution"),
759 }
760 } else {
761 Some("Bash command requires user approval (read-only commands auto-allowed)")
762 }
763 } else {
764 Some("Tool requires user approval")
765 }
766 } else if let Some(ref perm_cb) = on_permission {
767 let args_summary = tool_call.arguments.to_string().chars().take(120).collect::<String>();
769 let rx = perm_cb(PermissionRequest {
770 tool_name: tool_call.name.clone(),
771 args_summary,
772 });
773 match rx.await {
774 Ok(true) => None,
775 _ => Some("User denied tool execution"),
776 }
777 } else {
778 Some("Tool requires user approval (set permission to 'allow' or use TUI mode)")
780 }
781 }
782 crate::config::ToolPermission::Allow => None,
783 };
784 if let Some(reason) = denied {
785 let record = ToolCallRecord {
786 id: tool_call.id.clone(),
787 name: tool_call.name.clone(),
788 arguments: tool_call.arguments.clone(),
789 result: json!({"error": reason}),
790 success: false,
791 duration_ms: 0,
792 };
793
794 if let Some(ref callback) = on_tool {
795 callback(&record);
796 }
797 all_tool_calls.push(record);
798
799 self.history.push(Message {
800 role: Role::Tool,
801 content: format!("{{\"error\": \"{}\"}}", reason),
802 tool_calls: vec![],
803 tool_result: Some(ToolResultMessage {
804 tool_call_id: tool_call.id.clone(),
805 content: json!({"error": reason}),
806 success: false,
807 }),
808 });
809 continue;
810 }
811
812 if let Some(ref callback) = on_tool_start {
814 callback(&tool_call.name);
815 }
816
817 tracing::debug!(
819 tool = tool_call.name.as_str(),
820 args_len = serde_json::to_string(&tool_call.arguments).unwrap_or_default().len(),
821 "Tool call: {}({})",
822 tool_call.name,
823 serde_json::to_string(&tool_call.arguments)
824 .unwrap_or_default()
825 .chars()
826 .take(200)
827 .collect::<String>()
828 );
829
830 if let Some(tool) = self.tools.get(&tool_call.name) {
832 let schema = tool.parameters_schema();
833 if let Ok(params) = thulp_core::ToolDefinition::parse_mcp_input_schema(&schema) {
834 let thulp_def = thulp_core::ToolDefinition {
835 name: tool_call.name.clone(),
836 description: String::new(),
837 parameters: params,
838 };
839 if let Err(e) = thulp_def.validate_args(&tool_call.arguments) {
840 tracing::warn!(
841 tool = tool_call.name.as_str(),
842 error = %e,
843 "Tool argument validation failed (continuing anyway)"
844 );
845 }
846 }
847 }
848
849 let start = std::time::Instant::now();
850
851 let result = {
853 let tool_future = self.tools.execute(&tool_call.name, tool_call.arguments.clone());
854 let timeout_dur = if tool_call.name == "bash" {
856 std::time::Duration::from_secs(self.config.bash_timeout_secs)
857 } else {
858 std::time::Duration::from_secs(30)
859 };
860 match tokio::time::timeout(timeout_dur, tool_future).await {
861 Ok(inner) => inner,
862 Err(_) => Err(PawanError::Tool(format!(
863 "Tool '{}' timed out after {}s", tool_call.name, timeout_dur.as_secs()
864 ))),
865 }
866 };
867 let duration_ms = start.elapsed().as_millis() as u64;
868
869 let (result_value, success) = match result {
870 Ok(v) => (v, true),
871 Err(e) => {
872 tracing::warn!(tool = tool_call.name.as_str(), error = %e, "Tool execution failed");
873 (json!({"error": e.to_string(), "tool": tool_call.name, "hint": "Try a different approach or tool"}), false)
874 }
875 };
876
877 let max_result_chars = self.config.max_result_chars;
879 let result_value = truncate_tool_result(result_value, max_result_chars);
880
881
882 let record = ToolCallRecord {
883 id: tool_call.id.clone(),
884 name: tool_call.name.clone(),
885 arguments: tool_call.arguments.clone(),
886 result: result_value.clone(),
887 success,
888 duration_ms,
889 };
890
891 if let Some(ref callback) = on_tool {
892 callback(&record);
893 }
894
895 all_tool_calls.push(record);
896
897 self.history.push(Message {
898 role: Role::Tool,
899 content: serde_json::to_string(&result_value).unwrap_or_default(),
900 tool_calls: vec![],
901 tool_result: Some(ToolResultMessage {
902 tool_call_id: tool_call.id.clone(),
903 content: result_value,
904 success,
905 }),
906 });
907
908 if success && tool_call.name == "write_file" {
911 let wrote_rs = tool_call.arguments.get("path")
912 .and_then(|p| p.as_str())
913 .map(|p| p.ends_with(".rs"))
914 .unwrap_or(false);
915 if wrote_rs {
916 let ws = self.workspace_root.clone();
917 let check_result = tokio::process::Command::new("cargo")
918 .arg("check")
919 .arg("--message-format=short")
920 .current_dir(&ws)
921 .output()
922 .await;
923 match check_result {
924 Ok(output) if !output.status.success() => {
925 let stderr = String::from_utf8_lossy(&output.stderr);
926 let err_msg: String = stderr.chars().take(1500).collect();
928 tracing::info!("Compile-gate: cargo check failed after write_file, injecting errors");
929 self.history.push(Message {
930 role: Role::User,
931 content: format!(
932 "[SYSTEM] cargo check failed after your write_file. Fix the errors:\n```\n{}\n```",
933 err_msg
934 ),
935 tool_calls: vec![],
936 tool_result: None,
937 });
938 }
939 Ok(_) => {
940 tracing::debug!("Compile-gate: cargo check passed");
941 }
942 Err(e) => {
943 tracing::warn!("Compile-gate: cargo check failed to run: {}", e);
944 }
945 }
946 }
947 }
948 }
949 }
950 }
951
952 pub async fn heal(&mut self) -> Result<AgentResponse> {
954 let healer = crate::healing::Healer::new(
955 self.workspace_root.clone(),
956 self.config.healing.clone(),
957 );
958
959 let diagnostics = healer.get_diagnostics().await?;
960 let failed_tests = healer.get_failed_tests().await?;
961
962 let mut prompt = format!(
963 "I need you to heal this Rust project at: {}
964
965",
966 self.workspace_root.display()
967 );
968
969 if !diagnostics.is_empty() {
970 prompt.push_str(&format!(
971 "## Compilation Issues ({} found)
972{}
973",
974 diagnostics.len(),
975 healer.format_diagnostics_for_prompt(&diagnostics)
976 ));
977 }
978
979 if !failed_tests.is_empty() {
980 prompt.push_str(&format!(
981 "## Failed Tests ({} found)
982{}
983",
984 failed_tests.len(),
985 healer.format_tests_for_prompt(&failed_tests)
986 ));
987 }
988
989 if diagnostics.is_empty() && failed_tests.is_empty() {
990 prompt.push_str("No issues found! Run cargo check and cargo test to verify.
991");
992 }
993
994 prompt.push_str("
995Fix each issue one at a time. Verify with cargo check after each fix.");
996
997 self.execute(&prompt).await
998 }
999 pub async fn heal_with_retries(&mut self, max_attempts: usize) -> Result<AgentResponse> {
1001 let mut last_response = self.heal().await?;
1002
1003 for attempt in 1..max_attempts {
1004 let fixer = crate::healing::CompilerFixer::new(self.workspace_root.clone());
1005 let remaining = fixer.check().await?;
1006 let errors: Vec<_> = remaining.iter().filter(|d| d.kind == crate::healing::DiagnosticKind::Error).collect();
1007
1008 if errors.is_empty() {
1009 tracing::info!(attempts = attempt, "Healing complete");
1010 return Ok(last_response);
1011 }
1012
1013 tracing::warn!(errors = errors.len(), attempt = attempt, "Errors remain after heal attempt, retrying");
1014 last_response = self.heal().await?;
1015 }
1016
1017 tracing::info!(attempts = max_attempts, "Healing finished (may still have errors)");
1018 Ok(last_response)
1019 }
1020 pub async fn task(&mut self, task_description: &str) -> Result<AgentResponse> {
1022 let prompt = format!(
1023 r#"I need you to complete the following coding task:
1024
1025{}
1026
1027The workspace is at: {}
1028
1029Please:
10301. First explore the codebase to understand the relevant code
10312. Make the necessary changes
10323. Verify the changes compile with `cargo check`
10334. Run relevant tests if applicable
1034
1035Explain your changes as you go."#,
1036 task_description,
1037 self.workspace_root.display()
1038 );
1039
1040 self.execute(&prompt).await
1041 }
1042
1043 pub async fn generate_commit_message(&mut self) -> Result<String> {
1045 let prompt = r#"Please:
10461. Run `git status` to see what files are changed
10472. Run `git diff --cached` to see staged changes (or `git diff` for unstaged)
10483. Generate a concise, descriptive commit message following conventional commits format
1049
1050Only output the suggested commit message, nothing else."#;
1051
1052 let response = self.execute(prompt).await?;
1053 Ok(response.content)
1054 }
1055}
1056
1057fn truncate_tool_result(value: Value, max_chars: usize) -> Value {
1061 let serialized = serde_json::to_string(&value).unwrap_or_default();
1062 if serialized.len() <= max_chars {
1063 return value;
1064 }
1065
1066 match value {
1068 Value::Object(map) => {
1069 let mut result = serde_json::Map::new();
1070 let total = serialized.len();
1071 for (k, v) in map {
1072 if let Value::String(s) = &v {
1073 if s.len() > 500 {
1074 let target = s.len() * max_chars / total;
1076 let target = target.max(200); let truncated: String = s.chars().take(target).collect();
1078 result.insert(k, json!(format!("{}...[truncated from {} chars]", truncated, s.len())));
1079 continue;
1080 }
1081 }
1082 result.insert(k, truncate_tool_result(v, max_chars));
1084 }
1085 Value::Object(result)
1086 }
1087 Value::String(s) if s.len() > max_chars => {
1088 let truncated: String = s.chars().take(max_chars).collect();
1089 json!(format!("{}...[truncated from {} chars]", truncated, s.len()))
1090 }
1091 Value::Array(arr) if serialized.len() > max_chars => {
1092 let mut result = Vec::new();
1094 let mut running_len = 2; for item in arr {
1096 let item_str = serde_json::to_string(&item).unwrap_or_default();
1097 running_len += item_str.len() + 1; if running_len > max_chars {
1099 result.push(json!(format!("...[{} more items truncated]", 0)));
1100 break;
1101 }
1102 result.push(item);
1103 }
1104 Value::Array(result)
1105 }
1106 other => other,
1107 }
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112 use super::*;
1113
1114 #[test]
1115 fn test_message_serialization() {
1116 let msg = Message {
1117 role: Role::User,
1118 content: "Hello".to_string(),
1119 tool_calls: vec![],
1120 tool_result: None,
1121 };
1122
1123 let json = serde_json::to_string(&msg).expect("Serialization failed");
1124 assert!(json.contains("user"));
1125 assert!(json.contains("Hello"));
1126 }
1127
1128 #[test]
1129 fn test_tool_call_request() {
1130 let tc = ToolCallRequest {
1131 id: "123".to_string(),
1132 name: "read_file".to_string(),
1133 arguments: json!({"path": "test.txt"}),
1134 };
1135
1136 let json = serde_json::to_string(&tc).expect("Serialization failed");
1137 assert!(json.contains("read_file"));
1138 assert!(json.contains("test.txt"));
1139 }
1140
1141 fn agent_with_messages(n: usize) -> PawanAgent {
1144 let config = PawanConfig::default();
1145 let mut agent = PawanAgent::new(config, PathBuf::from("."));
1146 agent.add_message(Message {
1148 role: Role::System,
1149 content: "System prompt".to_string(),
1150 tool_calls: vec![],
1151 tool_result: None,
1152 });
1153 for i in 1..n {
1154 agent.add_message(Message {
1155 role: if i % 2 == 1 { Role::User } else { Role::Assistant },
1156 content: format!("Message {}", i),
1157 tool_calls: vec![],
1158 tool_result: None,
1159 });
1160 }
1161 assert_eq!(agent.history().len(), n);
1162 agent
1163 }
1164
1165 #[test]
1166 fn test_prune_history_no_op_when_small() {
1167 let mut agent = agent_with_messages(5);
1168 agent.prune_history();
1169 assert_eq!(agent.history().len(), 5, "Should not prune <= 5 messages");
1170 }
1171
1172 #[test]
1173 fn test_prune_history_reduces_messages() {
1174 let mut agent = agent_with_messages(12);
1175 assert_eq!(agent.history().len(), 12);
1176 agent.prune_history();
1177 assert_eq!(agent.history().len(), 6);
1179 }
1180
1181 #[test]
1182 fn test_prune_history_preserves_system_prompt() {
1183 let mut agent = agent_with_messages(10);
1184 let original_system = agent.history()[0].content.clone();
1185 agent.prune_history();
1186 assert_eq!(agent.history()[0].content, original_system, "System prompt must survive pruning");
1187 }
1188
1189 #[test]
1190 fn test_prune_history_preserves_last_messages() {
1191 let mut agent = agent_with_messages(10);
1192 let last4: Vec<String> = agent.history()[6..10].iter().map(|m| m.content.clone()).collect();
1194 agent.prune_history();
1195 let after_last4: Vec<String> = agent.history()[2..6].iter().map(|m| m.content.clone()).collect();
1197 assert_eq!(last4, after_last4, "Last 4 messages must be preserved after pruning");
1198 }
1199
1200 #[test]
1201 fn test_prune_history_inserts_summary() {
1202 let mut agent = agent_with_messages(10);
1203 agent.prune_history();
1204 assert_eq!(agent.history()[1].role, Role::System);
1205 assert!(agent.history()[1].content.contains("summary"), "Summary message should contain 'summary'");
1206 }
1207
1208 #[test]
1209 fn test_prune_history_utf8_safe() {
1210 let config = PawanConfig::default();
1211 let mut agent = PawanAgent::new(config, PathBuf::from("."));
1212 agent.add_message(Message {
1214 role: Role::System, content: "sys".into(), tool_calls: vec![], tool_result: None,
1215 });
1216 for _ in 0..10 {
1217 agent.add_message(Message {
1218 role: Role::User,
1219 content: "こんにちは世界 🌍 ".repeat(50),
1220 tool_calls: vec![],
1221 tool_result: None,
1222 });
1223 }
1224 agent.prune_history();
1226 assert!(agent.history().len() < 11, "Should have pruned");
1227 let summary = &agent.history()[1].content;
1229 assert!(summary.is_char_boundary(0));
1230 }
1231
1232 #[test]
1233 fn test_prune_history_exactly_6_messages() {
1234 let mut agent = agent_with_messages(6);
1236 agent.prune_history();
1237 assert_eq!(agent.history().len(), 6);
1239 }
1240
1241 #[test]
1242 fn test_message_role_roundtrip() {
1243 for role in [Role::User, Role::Assistant, Role::System, Role::Tool] {
1244 let json = serde_json::to_string(&role).unwrap();
1245 let back: Role = serde_json::from_str(&json).unwrap();
1246 assert_eq!(role, back);
1247 }
1248 }
1249
1250 #[test]
1251 fn test_agent_response_construction() {
1252 let resp = AgentResponse {
1253 content: String::new(),
1254 tool_calls: vec![],
1255 iterations: 3,
1256 usage: TokenUsage::default(),
1257 };
1258 assert!(resp.content.is_empty());
1259 assert!(resp.tool_calls.is_empty());
1260 assert_eq!(resp.iterations, 3);
1261 }
1262
1263 #[test]
1266 fn test_truncate_small_result_unchanged() {
1267 let val = json!({"success": true, "output": "hello"});
1268 let result = truncate_tool_result(val.clone(), 8000);
1269 assert_eq!(result, val);
1270 }
1271
1272 #[test]
1273 fn test_truncate_large_string_value() {
1274 let big = "x".repeat(10000);
1275 let val = json!({"stdout": big, "success": true});
1276 let result = truncate_tool_result(val, 2000);
1277 let stdout = result["stdout"].as_str().unwrap();
1278 assert!(stdout.len() < 10000, "Should be truncated");
1279 assert!(stdout.contains("truncated"), "Should indicate truncation");
1280 }
1281
1282 #[test]
1283 fn test_truncate_preserves_valid_json() {
1284 let big = "x".repeat(20000);
1285 let val = json!({"data": big, "meta": "keep"});
1286 let result = truncate_tool_result(val, 5000);
1287 let serialized = serde_json::to_string(&result).unwrap();
1289 let _reparsed: Value = serde_json::from_str(&serialized).unwrap();
1290 assert_eq!(result["meta"], "keep");
1292 }
1293
1294 #[test]
1295 fn test_truncate_bare_string() {
1296 let big = json!("x".repeat(10000));
1297 let result = truncate_tool_result(big, 500);
1298 let s = result.as_str().unwrap();
1299 assert!(s.len() <= 600); assert!(s.contains("truncated"));
1301 }
1302
1303 #[test]
1304 fn test_truncate_array() {
1305 let items: Vec<Value> = (0..1000).map(|i| json!(format!("item_{}", i))).collect();
1306 let val = Value::Array(items);
1307 let result = truncate_tool_result(val, 500);
1308 let arr = result.as_array().unwrap();
1309 assert!(arr.len() < 1000, "Array should be truncated");
1310 }
1311
1312 #[test]
1315 fn test_importance_failed_tool_highest() {
1316 let msg = Message {
1317 role: Role::Tool,
1318 content: "error".into(),
1319 tool_calls: vec![],
1320 tool_result: Some(ToolResultMessage {
1321 tool_call_id: "1".into(),
1322 content: json!({"error": "failed"}),
1323 success: false,
1324 }),
1325 };
1326 assert!(PawanAgent::message_importance(&msg) > 0.8, "Failed tools should be high importance");
1327 }
1328
1329 #[test]
1330 fn test_importance_successful_tool_lowest() {
1331 let msg = Message {
1332 role: Role::Tool,
1333 content: "ok".into(),
1334 tool_calls: vec![],
1335 tool_result: Some(ToolResultMessage {
1336 tool_call_id: "1".into(),
1337 content: json!({"success": true}),
1338 success: true,
1339 }),
1340 };
1341 assert!(PawanAgent::message_importance(&msg) < 0.3, "Successful tools should be low importance");
1342 }
1343
1344 #[test]
1345 fn test_importance_user_medium() {
1346 let msg = Message { role: Role::User, content: "hello".into(), tool_calls: vec![], tool_result: None };
1347 let score = PawanAgent::message_importance(&msg);
1348 assert!(score > 0.4 && score < 0.8, "User messages should be medium: {}", score);
1349 }
1350
1351 #[test]
1352 fn test_importance_error_assistant_high() {
1353 let msg = Message { role: Role::Assistant, content: "Error: something failed".into(), tool_calls: vec![], tool_result: None };
1354 assert!(PawanAgent::message_importance(&msg) > 0.7, "Error assistant messages should be high importance");
1355 }
1356
1357 #[test]
1358 fn test_importance_ordering() {
1359 let failed_tool = Message { role: Role::Tool, content: "err".into(), tool_calls: vec![], tool_result: Some(ToolResultMessage { tool_call_id: "1".into(), content: json!({}), success: false }) };
1360 let user = Message { role: Role::User, content: "hi".into(), tool_calls: vec![], tool_result: None };
1361 let ok_tool = Message { role: Role::Tool, content: "ok".into(), tool_calls: vec![], tool_result: Some(ToolResultMessage { tool_call_id: "2".into(), content: json!({}), success: true }) };
1362
1363 let f = PawanAgent::message_importance(&failed_tool);
1364 let u = PawanAgent::message_importance(&user);
1365 let s = PawanAgent::message_importance(&ok_tool);
1366 assert!(f > u && u > s, "Ordering should be: failed({}) > user({}) > success({})", f, u, s);
1367 }
1368}