1pub mod backend;
10mod preflight;
11pub mod session;
12pub mod git_session;
13
14use crate::config::{LlmProvider, PawanConfig};
15use crate::tools::{ToolDefinition, ToolRegistry};
16use crate::{PawanError, Result};
17use backend::openai_compat::{OpenAiCompatBackend, OpenAiCompatConfig};
18use backend::LlmBackend;
19use serde::{Deserialize, Serialize};
20use serde_json::{json, Value};
21use std::path::PathBuf;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Message {
26 pub role: Role,
28 pub content: String,
30 #[serde(default)]
32 pub tool_calls: Vec<ToolCallRequest>,
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub tool_result: Option<ToolResultMessage>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
40#[serde(rename_all = "lowercase")]
41pub enum Role {
42 System,
43 User,
44 Assistant,
45 Tool,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ToolCallRequest {
51 pub id: String,
53 pub name: String,
55 pub arguments: Value,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ToolResultMessage {
62 pub tool_call_id: String,
64 pub content: Value,
66 pub success: bool,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ToolCallRecord {
73 pub id: String,
75 pub name: String,
77 pub arguments: Value,
79 pub result: Value,
81 pub success: bool,
83 pub duration_ms: u64,
85}
86
87#[derive(Debug, Clone, Default, Serialize, Deserialize)]
89pub struct TokenUsage {
90 pub prompt_tokens: u64,
91 pub completion_tokens: u64,
92 pub total_tokens: u64,
93}
94
95#[derive(Debug, Clone)]
97pub struct LLMResponse {
98 pub content: String,
100 pub tool_calls: Vec<ToolCallRequest>,
102 pub finish_reason: String,
104 pub usage: Option<TokenUsage>,
106}
107
108#[derive(Debug)]
110pub struct AgentResponse {
111 pub content: String,
113 pub tool_calls: Vec<ToolCallRecord>,
115 pub iterations: usize,
117 pub usage: TokenUsage,
119}
120
121pub type TokenCallback = Box<dyn Fn(&str) + Send + Sync>;
123
124pub type ToolCallback = Box<dyn Fn(&ToolCallRecord) + Send + Sync>;
126
127pub type ToolStartCallback = Box<dyn Fn(&str) + Send + Sync>;
129
130pub struct PawanAgent {
132 config: PawanConfig,
134 tools: ToolRegistry,
136 history: Vec<Message>,
138 workspace_root: PathBuf,
140 backend: Box<dyn LlmBackend>,
142
143 context_tokens_estimate: usize,
145
146 eruka: Option<crate::eruka_bridge::ErukaClient>,
148}
149
150impl PawanAgent {
151 pub fn new(config: PawanConfig, workspace_root: PathBuf) -> Self {
153 let tools = ToolRegistry::with_defaults(workspace_root.clone());
154 let system_prompt = config.get_system_prompt();
155 let backend = Self::create_backend(&config, &system_prompt);
156 let eruka = if config.eruka.enabled {
157 Some(crate::eruka_bridge::ErukaClient::new(config.eruka.clone()))
158 } else {
159 None
160 };
161
162 Self {
163 config,
164 tools,
165 history: Vec::new(),
166 workspace_root,
167 backend,
168 context_tokens_estimate: 0,
169 eruka,
170 }
171 }
172
173 fn create_backend(config: &PawanConfig, system_prompt: &str) -> Box<dyn LlmBackend> {
175 match config.provider {
176 LlmProvider::Nvidia | LlmProvider::OpenAI => {
177 let (api_url, api_key) = match config.provider {
178 LlmProvider::Nvidia => {
179 let url = std::env::var("NVIDIA_API_URL")
180 .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
181 let key = std::env::var("NVIDIA_API_KEY").ok();
182 if key.is_none() {
183 tracing::warn!("NVIDIA_API_KEY not set. Add it to .env or export it.");
184 }
185 (url, key)
186 },
187 LlmProvider::OpenAI => {
188 let url = std::env::var("OPENAI_API_URL")
189 .unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
190 let key = std::env::var("OPENAI_API_KEY").ok();
191 if key.is_none() {
192 tracing::warn!("OPENAI_API_KEY not set. Add it to .env or export it.");
193 }
194 (url, key)
195 },
196 _ => unreachable!(),
197 };
198
199 let cloud = config.cloud.as_ref().map(|c| {
201 let (cloud_url, cloud_key) = match c.provider {
202 LlmProvider::Nvidia => {
203 let url = std::env::var("NVIDIA_API_URL")
204 .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
205 let key = std::env::var("NVIDIA_API_KEY").ok();
206 (url, key)
207 },
208 LlmProvider::OpenAI => {
209 let url = std::env::var("OPENAI_API_URL")
210 .unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
211 let key = std::env::var("OPENAI_API_KEY").ok();
212 (url, key)
213 },
214 _ => {
215 tracing::warn!("Cloud fallback only supports nvidia/openai providers");
216 ("https://integrate.api.nvidia.com/v1".to_string(), None)
217 }
218 };
219 backend::openai_compat::CloudFallback {
220 api_url: cloud_url,
221 api_key: cloud_key,
222 model: c.model.clone(),
223 fallback_models: c.fallback_models.clone(),
224 }
225 });
226
227 Box::new(OpenAiCompatBackend::new(OpenAiCompatConfig {
228 api_url,
229 api_key,
230 model: config.model.clone(),
231 temperature: config.temperature,
232 top_p: config.top_p,
233 max_tokens: config.max_tokens,
234 system_prompt: system_prompt.to_string(),
235 use_thinking: config.use_thinking_mode(),
236 max_retries: config.max_retries,
237 fallback_models: config.fallback_models.clone(),
238 cloud,
239 }))
240 }
241 LlmProvider::Ollama => {
242 let url = std::env::var("OLLAMA_URL")
243 .unwrap_or_else(|_| "http://localhost:11434".to_string());
244
245 Box::new(backend::ollama::OllamaBackend::new(
246 url,
247 config.model.clone(),
248 config.temperature,
249 system_prompt.to_string(),
250 ))
251 }
252 }
253 }
254
255 pub fn with_tools(mut self, tools: ToolRegistry) -> Self {
257 self.tools = tools;
258 self
259 }
260
261 pub fn tools_mut(&mut self) -> &mut ToolRegistry {
263 &mut self.tools
264 }
265
266 pub fn with_backend(mut self, backend: Box<dyn LlmBackend>) -> Self {
268 self.backend = backend;
269 self
270 }
271
272 pub fn history(&self) -> &[Message] {
274 &self.history
275 }
276
277 pub fn save_session(&self) -> Result<String> {
279 let mut session = session::Session::new(&self.config.model);
280 session.messages = self.history.clone();
281 session.total_tokens = self.context_tokens_estimate as u64;
282 session.save()?;
283 Ok(session.id)
284 }
285
286 pub fn resume_session(&mut self, session_id: &str) -> Result<()> {
288 let session = session::Session::load(session_id)?;
289 self.history = session.messages;
290 self.context_tokens_estimate = session.total_tokens as usize;
291 Ok(())
292 }
293
294 pub fn config(&self) -> &PawanConfig {
296 &self.config
297 }
298
299 pub fn clear_history(&mut self) {
301 self.history.clear();
302 }
303 fn prune_history(&mut self) {
307 let len = self.history.len();
308 if len <= 5 {
309 return; }
311
312 let keep_end = 4;
313 let start = 1; let end = len - keep_end;
315 let pruned_count = end - start;
316
317 let mut summary = String::new();
319 for msg in &self.history[start..end] {
320 let chunk = if msg.content.len() > 200 {
321 &msg.content[..200]
322 } else {
323 &msg.content
324 };
325 summary.push_str(chunk);
326 summary.push('\n');
327 if summary.len() > 2000 {
328 summary.truncate(2000);
329 break;
330 }
331 }
332
333 let summary_msg = Message {
334 role: Role::System,
335 content: format!("Previous conversation summary (pruned): {}", summary),
336 tool_calls: vec![],
337 tool_result: None,
338 };
339
340 let first = self.history[0].clone();
342 let tail: Vec<Message> = self.history[len - keep_end..].to_vec();
343
344 self.history.clear();
345 self.history.push(first);
346 self.history.push(summary_msg);
347 self.history.extend(tail);
348
349 tracing::info!(pruned = pruned_count, context_estimate = self.context_tokens_estimate, "Pruned messages from history");
350 }
351
352 pub fn add_message(&mut self, message: Message) {
354 self.history.push(message);
355 }
356
357 pub fn get_tool_definitions(&self) -> Vec<ToolDefinition> {
359 self.tools.get_definitions()
360 }
361
362 pub async fn execute(&mut self, user_prompt: &str) -> Result<AgentResponse> {
364 self.execute_with_callbacks(user_prompt, None, None, None)
365 .await
366 }
367
368 pub async fn execute_with_callbacks(
370 &mut self,
371 user_prompt: &str,
372 on_token: Option<TokenCallback>,
373 on_tool: Option<ToolCallback>,
374 on_tool_start: Option<ToolStartCallback>,
375 ) -> Result<AgentResponse> {
376 if let Some(eruka) = &self.eruka {
378 if let Err(e) = eruka.inject_core_memory(&mut self.history).await {
379 tracing::warn!("Eruka memory injection failed (non-fatal): {}", e);
380 }
381 }
382
383 self.history.push(Message {
384 role: Role::User,
385 content: user_prompt.to_string(),
386 tool_calls: vec![],
387 tool_result: None,
388 });
389
390 let mut all_tool_calls = Vec::new();
391 let mut total_usage = TokenUsage::default();
392 let mut iterations = 0;
393 let max_iterations = self.config.max_tool_iterations;
394
395 loop {
396 iterations += 1;
397 if iterations > max_iterations {
398 return Err(PawanError::Agent(format!(
399 "Max tool iterations ({}) exceeded",
400 max_iterations
401 )));
402 }
403 self.context_tokens_estimate = self.history.iter().map(|m| m.content.len()).sum::<usize>() / 4;
405 if self.context_tokens_estimate > self.config.max_context_tokens {
406 self.prune_history();
407 }
408
409 let tool_defs = self.tools.get_definitions();
410
411 let response = {
413 #[allow(unused_assignments)]
414 let mut last_err = None;
415 let max_llm_retries = 3;
416 let mut attempt = 0;
417 loop {
418 attempt += 1;
419 match self.backend.generate(&self.history, &tool_defs, on_token.as_ref()).await {
420 Ok(resp) => break resp,
421 Err(e) => {
422 let err_str = e.to_string();
423 let is_transient = err_str.contains("timeout")
424 || err_str.contains("connection")
425 || err_str.contains("429")
426 || err_str.contains("500")
427 || err_str.contains("502")
428 || err_str.contains("503")
429 || err_str.contains("504")
430 || err_str.contains("reset")
431 || err_str.contains("broken pipe");
432
433 if is_transient && attempt <= max_llm_retries {
434 let delay = std::time::Duration::from_secs(2u64.pow(attempt as u32));
435 tracing::warn!(
436 attempt = attempt,
437 delay_secs = delay.as_secs(),
438 error = err_str.as_str(),
439 "LLM call failed (transient) — retrying"
440 );
441 tokio::time::sleep(delay).await;
442
443 if err_str.contains("context") || err_str.contains("token") {
445 tracing::info!("Pruning history before retry (possible context overflow)");
446 self.prune_history();
447 }
448 continue;
449 }
450
451 last_err = Some(e);
453 break {
454 tracing::error!(
456 attempt = attempt,
457 error = last_err.as_ref().map(|e| e.to_string()).unwrap_or_default().as_str(),
458 "LLM call failed permanently — returning error as content"
459 );
460 LLMResponse {
461 content: format!(
462 "LLM error after {} attempts: {}. The task could not be completed.",
463 attempt,
464 last_err.as_ref().map(|e| e.to_string()).unwrap_or_default()
465 ),
466 tool_calls: vec![],
467 finish_reason: "error".to_string(),
468 usage: None,
469 }
470 };
471 }
472 }
473 }
474 };
475
476 if let Some(ref usage) = response.usage {
478 total_usage.prompt_tokens += usage.prompt_tokens;
479 total_usage.completion_tokens += usage.completion_tokens;
480 total_usage.total_tokens += usage.total_tokens;
481 }
482
483 let clean_content = {
485 let mut s = response.content.clone();
486 loop {
487 let lower = s.to_lowercase();
488 let open = lower.find("<think>");
489 let close = lower.find("</think>");
490 match (open, close) {
491 (Some(i), Some(j)) if j > i => {
492 let before = s[..i].trim_end().to_string();
493 let after = if s.len() > j + 8 { s[j + 8..].trim_start().to_string() } else { String::new() };
494 s = if before.is_empty() { after } else if after.is_empty() { before } else { format!("{}\n{}", before, after) };
495 }
496 _ => break,
497 }
498 }
499 s
500 };
501
502 if response.tool_calls.is_empty() {
503 let has_tools = !tool_defs.is_empty();
506 let lower = clean_content.to_lowercase();
507 let planning_prefix = lower.starts_with("let me")
508 || lower.starts_with("i'll help")
509 || lower.starts_with("i will help")
510 || lower.starts_with("sure, i")
511 || lower.starts_with("okay, i");
512 let looks_like_planning = clean_content.len() > 200 || (planning_prefix && clean_content.len() > 50);
513 if has_tools && looks_like_planning && iterations == 1 && iterations < max_iterations && response.finish_reason != "error" {
514 tracing::warn!(
515 "No tool calls at iteration {} (content: {}B) — nudging model to use tools",
516 iterations, clean_content.len()
517 );
518 self.history.push(Message {
519 role: Role::Assistant,
520 content: clean_content.clone(),
521 tool_calls: vec![],
522 tool_result: None,
523 });
524 self.history.push(Message {
525 role: Role::User,
526 content: "You must use tools to complete this task. Do NOT just describe what you would do — actually call the tools. Start with bash or read_file.".to_string(),
527 tool_calls: vec![],
528 tool_result: None,
529 });
530 continue;
531 }
532
533 if iterations > 1 {
535 let prev_assistant = self.history.iter().rev()
536 .find(|m| m.role == Role::Assistant && !m.content.is_empty());
537 if let Some(prev) = prev_assistant {
538 if prev.content.trim() == clean_content.trim() && iterations < max_iterations {
539 tracing::warn!("Repeated response detected at iteration {} — injecting correction", iterations);
540 self.history.push(Message {
541 role: Role::Assistant,
542 content: clean_content.clone(),
543 tool_calls: vec![],
544 tool_result: None,
545 });
546 self.history.push(Message {
547 role: Role::User,
548 content: "You gave the same response as before. Try a different approach. Use anchor_text in edit_file_lines, or use insert_after, or use bash with sed.".to_string(),
549 tool_calls: vec![],
550 tool_result: None,
551 });
552 continue;
553 }
554 }
555 }
556
557 self.history.push(Message {
558 role: Role::Assistant,
559 content: clean_content.clone(),
560 tool_calls: vec![],
561 tool_result: None,
562 });
563
564 return Ok(AgentResponse {
565 content: clean_content,
566 tool_calls: all_tool_calls,
567 iterations,
568 usage: total_usage,
569 });
570 }
571
572 self.history.push(Message {
573 role: Role::Assistant,
574 content: response.content.clone(),
575 tool_calls: response.tool_calls.clone(),
576 tool_result: None,
577 });
578
579 for tool_call in &response.tool_calls {
580 if let Some(crate::config::ToolPermission::Deny) =
582 self.config.permissions.get(&tool_call.name)
583 {
584 let record = ToolCallRecord {
585 id: tool_call.id.clone(),
586 name: tool_call.name.clone(),
587 arguments: tool_call.arguments.clone(),
588 result: json!({"error": "Tool denied by permission policy"}),
589 success: false,
590 duration_ms: 0,
591 };
592
593 if let Some(ref callback) = on_tool {
594 callback(&record);
595 }
596 all_tool_calls.push(record);
597
598 self.history.push(Message {
599 role: Role::Tool,
600 content: "{\"error\": \"Tool denied by permission policy\"}".to_string(),
601 tool_calls: vec![],
602 tool_result: Some(ToolResultMessage {
603 tool_call_id: tool_call.id.clone(),
604 content: json!({"error": "Tool denied by permission policy"}),
605 success: false,
606 }),
607 });
608 continue;
609 }
610
611 if let Some(ref callback) = on_tool_start {
613 callback(&tool_call.name);
614 }
615
616 tracing::debug!(
618 tool = tool_call.name.as_str(),
619 args_len = serde_json::to_string(&tool_call.arguments).unwrap_or_default().len(),
620 "Tool call: {}({})",
621 tool_call.name,
622 serde_json::to_string(&tool_call.arguments)
623 .unwrap_or_default()
624 .chars()
625 .take(200)
626 .collect::<String>()
627 );
628
629 let start = std::time::Instant::now();
630
631 let result = {
633 let tool_future = self.tools.execute(&tool_call.name, tool_call.arguments.clone());
634 let timeout_dur = if tool_call.name == "bash" {
636 std::time::Duration::from_secs(self.config.bash_timeout_secs)
637 } else {
638 std::time::Duration::from_secs(30)
639 };
640 match tokio::time::timeout(timeout_dur, tool_future).await {
641 Ok(inner) => inner,
642 Err(_) => Err(PawanError::Tool(format!(
643 "Tool '{}' timed out after {}s", tool_call.name, timeout_dur.as_secs()
644 ))),
645 }
646 };
647 let duration_ms = start.elapsed().as_millis() as u64;
648
649 let (result_value, success) = match result {
650 Ok(v) => (v, true),
651 Err(e) => {
652 tracing::warn!(tool = tool_call.name.as_str(), error = %e, "Tool execution failed");
653 (json!({"error": e.to_string(), "tool": tool_call.name, "hint": "Try a different approach or tool"}), false)
654 }
655 };
656
657 let max_result_chars = self.config.max_result_chars;
659 let result_value = {
660 let result_str = serde_json::to_string(&result_value).unwrap_or_default();
661 if result_str.len() > max_result_chars {
662 let truncated: String = result_str.chars().take(max_result_chars).collect();
664 let truncated = truncated.as_str();
665 serde_json::from_str(truncated).unwrap_or_else(|_| {
666 json!({"content": format!("{}...[truncated from {} chars]", truncated, result_str.len())})
667 })
668 } else {
669 result_value
670 }
671 };
672
673
674 let record = ToolCallRecord {
675 id: tool_call.id.clone(),
676 name: tool_call.name.clone(),
677 arguments: tool_call.arguments.clone(),
678 result: result_value.clone(),
679 success,
680 duration_ms,
681 };
682
683 if let Some(ref callback) = on_tool {
684 callback(&record);
685 }
686
687 all_tool_calls.push(record);
688
689 self.history.push(Message {
690 role: Role::Tool,
691 content: serde_json::to_string(&result_value).unwrap_or_default(),
692 tool_calls: vec![],
693 tool_result: Some(ToolResultMessage {
694 tool_call_id: tool_call.id.clone(),
695 content: result_value,
696 success,
697 }),
698 });
699 }
700 }
701 }
702
703 pub async fn heal(&mut self) -> Result<AgentResponse> {
705 let healer = crate::healing::Healer::new(
706 self.workspace_root.clone(),
707 self.config.healing.clone(),
708 );
709
710 let diagnostics = healer.get_diagnostics().await?;
711 let failed_tests = healer.get_failed_tests().await?;
712
713 let mut prompt = format!(
714 "I need you to heal this Rust project at: {}
715
716",
717 self.workspace_root.display()
718 );
719
720 if !diagnostics.is_empty() {
721 prompt.push_str(&format!(
722 "## Compilation Issues ({} found)
723{}
724",
725 diagnostics.len(),
726 healer.format_diagnostics_for_prompt(&diagnostics)
727 ));
728 }
729
730 if !failed_tests.is_empty() {
731 prompt.push_str(&format!(
732 "## Failed Tests ({} found)
733{}
734",
735 failed_tests.len(),
736 healer.format_tests_for_prompt(&failed_tests)
737 ));
738 }
739
740 if diagnostics.is_empty() && failed_tests.is_empty() {
741 prompt.push_str("No issues found! Run cargo check and cargo test to verify.
742");
743 }
744
745 prompt.push_str("
746Fix each issue one at a time. Verify with cargo check after each fix.");
747
748 self.execute(&prompt).await
749 }
750 pub async fn heal_with_retries(&mut self, max_attempts: usize) -> Result<AgentResponse> {
752 let mut last_response = self.heal().await?;
753
754 for attempt in 1..max_attempts {
755 let fixer = crate::healing::CompilerFixer::new(self.workspace_root.clone());
756 let remaining = fixer.check().await?;
757 let errors: Vec<_> = remaining.iter().filter(|d| d.kind == crate::healing::DiagnosticKind::Error).collect();
758
759 if errors.is_empty() {
760 tracing::info!(attempts = attempt, "Healing complete");
761 return Ok(last_response);
762 }
763
764 tracing::warn!(errors = errors.len(), attempt = attempt, "Errors remain after heal attempt, retrying");
765 last_response = self.heal().await?;
766 }
767
768 tracing::info!(attempts = max_attempts, "Healing finished (may still have errors)");
769 Ok(last_response)
770 }
771 pub async fn task(&mut self, task_description: &str) -> Result<AgentResponse> {
773 let prompt = format!(
774 r#"I need you to complete the following coding task:
775
776{}
777
778The workspace is at: {}
779
780Please:
7811. First explore the codebase to understand the relevant code
7822. Make the necessary changes
7833. Verify the changes compile with `cargo check`
7844. Run relevant tests if applicable
785
786Explain your changes as you go."#,
787 task_description,
788 self.workspace_root.display()
789 );
790
791 self.execute(&prompt).await
792 }
793
794 pub async fn generate_commit_message(&mut self) -> Result<String> {
796 let prompt = r#"Please:
7971. Run `git status` to see what files are changed
7982. Run `git diff --cached` to see staged changes (or `git diff` for unstaged)
7993. Generate a concise, descriptive commit message following conventional commits format
800
801Only output the suggested commit message, nothing else."#;
802
803 let response = self.execute(prompt).await?;
804 Ok(response.content)
805 }
806}
807
808#[cfg(test)]
809mod tests {
810 use super::*;
811
812 #[test]
813 fn test_message_serialization() {
814 let msg = Message {
815 role: Role::User,
816 content: "Hello".to_string(),
817 tool_calls: vec![],
818 tool_result: None,
819 };
820
821 let json = serde_json::to_string(&msg).unwrap();
822 assert!(json.contains("user"));
823 assert!(json.contains("Hello"));
824 }
825
826 #[test]
827 fn test_tool_call_request() {
828 let tc = ToolCallRequest {
829 id: "123".to_string(),
830 name: "read_file".to_string(),
831 arguments: json!({"path": "test.txt"}),
832 };
833
834 let json = serde_json::to_string(&tc).unwrap();
835 assert!(json.contains("read_file"));
836 assert!(json.contains("test.txt"));
837 }
838}