Skip to main content

pawan/agent/
construction.rs

1//! PawanAgent construction — factory, backend selection, and builder methods.
2
3use super::backend::openai_compat::{OpenAiCompatBackend, OpenAiCompatConfig};
4use super::backend::LlmBackend;
5use super::PawanAgent;
6use crate::config::{LlmProvider, PawanConfig};
7use crate::credentials;
8use crate::tools::ToolRegistry;
9use crate::{PawanError, Result};
10use std::path::PathBuf;
11
12pub(crate) fn probe_local_endpoint(url: &str) -> bool {
13    use std::net::TcpStream;
14    use std::time::Duration;
15
16    // Strip scheme and path — we only need host:port
17    let hostport = url
18        .trim_start_matches("http://")
19        .trim_start_matches("https://")
20        .split('/')
21        .next()
22        .unwrap_or("");
23
24    // Ensure port is present; default http → 80, https → 443
25    let addr = if hostport.contains(':') {
26        hostport.to_string()
27    } else if url.starts_with("https://") {
28        format!("{hostport}:443")
29    } else {
30        format!("{hostport}:80")
31    };
32
33    // Normalise "localhost" → "127.0.0.1" so we don't accidentally resolve
34    // to ::1 (IPv6) when the listener is bound only to IPv4.
35    let addr = addr.replace("localhost", "127.0.0.1");
36
37    let socket_addr = match addr.parse() {
38        Ok(a) => a,
39        Err(_) => return false,
40    };
41
42    TcpStream::connect_timeout(&socket_addr, Duration::from_millis(100)).is_ok()
43}
44
45/// Retrieve an API key with fallback chain:
46/// 1. Environment variable
47/// 2. Secure credential store
48/// 3. Return None (caller should prompt user)
49///
50/// If the key is found in the secure store, it's also set as an env var
51/// for subsequent calls.
52pub(crate) fn get_api_key_with_secure_fallback(env_var: &str, key_name: &str) -> Option<String> {
53    // First, check environment variable
54    if let Ok(key) = std::env::var(env_var) {
55        return Some(key);
56    }
57
58    // Second, try secure credential store
59    match credentials::get_api_key(key_name) {
60        Ok(Some(key)) => {
61            // Cache in env var for subsequent calls
62            std::env::set_var(env_var, &key);
63            Some(key)
64        }
65        Ok(None) => None,
66        Err(e) => {
67            tracing::warn!("Failed to retrieve {} from secure store: {}", key_name, e);
68            None
69        }
70    }
71}
72
73/// Prompt user to enter an API key and store it securely.
74///
75/// This function:
76/// 1. Prompts the user to enter the API key
77/// 2. Stores it in the secure credential store
78/// 3. Sets it as an environment variable for the current session
79///
80/// Returns the entered key on success, or None if the user cancels.
81fn prompt_and_store_api_key(env_var: &str, key_name: &str, provider: &str) -> Option<String> {
82    eprintln!("\n🔑 {} API key not found.", provider);
83    eprintln!("You can set it via:");
84    eprintln!("  - Environment variable: export {}=<your-key>", env_var);
85    eprintln!("  - Interactive entry (recommended for security)");
86    eprintln!("\nEnter your {} API key:", provider);
87    eprintln!("  (Your key will be stored securely in the OS credential store)\n");
88
89    // Read input securely (no echo)
90    #[cfg(unix)]
91    let key = {
92        use std::io::{self, Write};
93
94        // Use termios to disable echo on Unix
95        let mut stdout = io::stdout();
96        stdout.flush().ok();
97
98        // Read password without echo
99        rpassword::prompt_password("> ").ok()
100    };
101
102    #[cfg(windows)]
103    let key = {
104        use std::io::{self, Write};
105
106        let mut stdout = io::stdout();
107        stdout.flush().ok();
108
109        // On Windows, use a simple prompt (rpassword handles this)
110        rpassword::prompt_password("> ").ok()
111    };
112
113    #[cfg(not(any(unix, windows)))]
114    let key = {
115        use std::io::{self, BufRead, Write};
116
117        let mut stdout = io::stdout();
118        let mut stdin = io::stdin();
119        stdout.flush().ok();
120        print!("> ");
121        stdout.flush().ok();
122
123        let mut input = String::new();
124        stdin.lock().read_line(&mut input).ok();
125        Some(input.trim().to_string())
126    };
127
128    match key {
129        Some(k) if !k.trim().is_empty() => {
130            let key = k.trim().to_string();
131
132            // Store in secure credential store
133            match credentials::store_api_key(key_name, &key) {
134                Ok(()) => {
135                    tracing::info!("{} API key stored securely", provider);
136                    std::env::set_var(env_var, &key);
137                    Some(key)
138                }
139                Err(e) => {
140                    tracing::warn!("Failed to store key securely: {}. Using session-only.", e);
141                    std::env::set_var(env_var, &key);
142                    Some(key)
143                }
144            }
145        }
146        _ => {
147            eprintln!(
148                "\n⚠️  No key entered. {} will not work until a key is set.",
149                provider
150            );
151            None
152        }
153    }
154}
155
156pub(crate) fn scan_context_file(content: &str, source: &str) -> Result<String> {
157    // Check for suspicious patterns
158    let suspicious = [
159        "IGNORE ALL PREVIOUS",
160        "DISREGARD ALL",
161        "OVERRIDE",
162        "You are now",
163        "Your new role",
164        "IMPORTANT: do not",
165        "<system-directive>",
166        "<role>",
167        "<contract>",
168        // Invisible unicode
169        "\u{200B}",
170        "\u{200C}",
171        "\u{200D}",
172        "\u{FEFF}",
173        "\u{202E}",
174        "\u{2060}",
175        "\u{2061}",
176        "\u{2062}",
177    ];
178
179    let upper = content.to_uppercase();
180    let allow = source.ends_with("AGENTS.md") || source.ends_with("CLAUDE.md");
181
182    for pattern in &suspicious {
183        let hit = if pattern.is_ascii() {
184            upper.contains(&pattern.to_uppercase())
185        } else {
186            content.contains(pattern)
187        };
188
189        if hit {
190            tracing::warn!(source = %source, pattern = %pattern, "prompt injection pattern detected");
191            if allow {
192                continue;
193            }
194            return Err(PawanError::Config(format!(
195                "Suspicious content in {}: contains '{}'",
196                source, pattern
197            )));
198        }
199    }
200    Ok(content.to_string())
201}
202
203/// Load per-turn architecture context from `<workspace_root>/.pawan/arch.md`.
204///
205/// Returns `None` if the file is absent or empty.
206/// Caps content at 2 000 chars to avoid context bloat from large files;
207/// an ellipsis marker is appended when truncation occurs.
208pub(crate) fn load_arch_context(workspace_root: &std::path::Path) -> Result<Option<String>> {
209    let path = workspace_root.join(".pawan").join("arch.md");
210    if !path.exists() {
211        return Ok(None);
212    }
213
214    let bytes = std::fs::read(&path).map_err(PawanError::Io)?;
215    let content = String::from_utf8(bytes).map_err(|_| {
216        PawanError::Config(
217            "Suspicious content in .pawan/arch.md: file is not valid UTF-8 (binary?)".to_string(),
218        )
219    })?;
220
221    if content.trim().is_empty() {
222        return Ok(None);
223    }
224
225    let content = scan_context_file(&content, ".pawan/arch.md")?;
226
227    const MAX_CHARS: usize = 2_000;
228    if content.len() > MAX_CHARS {
229        // Truncate on a char boundary
230        let boundary = content
231            .char_indices()
232            .map(|(i, _)| i)
233            .nth(MAX_CHARS)
234            .unwrap_or(content.len());
235        Ok(Some(format!("{}…(truncated)", &content[..boundary])))
236    } else {
237        Ok(Some(content))
238    }
239}
240
241impl PawanAgent {
242    /// Create a new PawanAgent with auto-selected backend
243    pub fn new(config: PawanConfig, workspace_root: PathBuf) -> Self {
244        let tools = ToolRegistry::with_defaults(workspace_root.clone());
245        let system_prompt = config.get_system_prompt();
246        let backend = Self::create_backend(&config, &system_prompt);
247        let eruka = if config.eruka.enabled {
248            Some(crate::eruka_bridge::ErukaClient::new(config.eruka.clone()))
249        } else {
250            None
251        };
252        let (arch_context, arch_context_error) = match load_arch_context(&workspace_root) {
253            Ok(v) => (v, None),
254            Err(e) => (None, Some(e.to_string())),
255        };
256
257        Self {
258            config,
259            tools,
260            history: Vec::new(),
261            workspace_root,
262            backend,
263            context_tokens_estimate: 0,
264            eruka,
265            session_id: uuid::Uuid::new_v4().to_string(),
266            arch_context,
267            arch_context_error,
268            last_tool_call_time: None,
269        }
270    }
271
272    /// Create the appropriate backend based on config.
273    ///
274    /// If `use_ares_backend` is true and the `ares` feature is compiled in,
275    /// delegates to ares-server's LLMClient (unified provider abstraction with
276    /// connection pooling). Otherwise uses pawan's built-in OpenAI-compatible
277    /// backend (the original path).
278    pub(crate) fn create_backend(config: &PawanConfig, system_prompt: &str) -> Box<dyn LlmBackend> {
279        // Local-inference-first cost guard: if enabled and the local server
280        // responds within 100 ms, route all traffic there instead of cloud.
281        if config.local_first {
282            let local_url = config
283                .local_endpoint
284                .clone()
285                .unwrap_or_else(|| "http://localhost:11434/v1".to_string());
286            if probe_local_endpoint(&local_url) {
287                tracing::info!(
288                    url = %local_url,
289                    model = %config.model,
290                    "local_first: local server reachable, using local inference"
291                );
292                return Box::new(OpenAiCompatBackend::new(
293                    super::backend::openai_compat::OpenAiCompatConfig {
294                        api_url: local_url,
295                        api_key: None,
296                        model: config.model.clone(),
297                        temperature: config.temperature,
298                        top_p: config.top_p,
299                        max_tokens: config.max_tokens,
300                        system_prompt: system_prompt.to_string(),
301                        use_thinking: false,
302                        max_retries: config.max_retries,
303                        fallback_models: Vec::new(),
304                        cloud: None,
305                    },
306                ));
307            }
308            tracing::info!(
309                url = %local_url,
310                "local_first: local server unreachable, falling back to cloud provider"
311            );
312        }
313
314        // Try ares backend first if requested
315        if config.use_ares_backend {
316            if let Some(backend) = Self::try_create_ares_backend(config, system_prompt) {
317                return backend;
318            }
319            tracing::warn!(
320                "use_ares_backend=true but ares backend creation failed; \
321                 falling back to pawan's native backend"
322            );
323        }
324
325        match config.provider {
326            LlmProvider::Nvidia | LlmProvider::OpenAI | LlmProvider::Mlx => {
327                let (api_url, api_key) = match config.provider {
328                    LlmProvider::Nvidia => {
329                        let url = std::env::var("NVIDIA_API_URL")
330                            .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
331
332                        // Try to get key from env or secure store
333                        let key =
334                            get_api_key_with_secure_fallback("NVIDIA_API_KEY", "nvidia_api_key");
335
336                        // If no key found, prompt user (skip interactive prompts in unit tests)
337                        let key = if key.is_some() {
338                            key
339                        } else if cfg!(test) {
340                            Some("pawan-test-dummy-key".to_string())
341                        } else {
342                            prompt_and_store_api_key("NVIDIA_API_KEY", "nvidia_api_key", "NVIDIA")
343                        };
344
345                        if key.is_none() {
346                            tracing::warn!("NVIDIA_API_KEY not set. Model calls will fail until a key is provided.");
347                        }
348                        (url, key)
349                    }
350                    LlmProvider::OpenAI => {
351                        let url = config
352                            .base_url
353                            .clone()
354                            .or_else(|| std::env::var("OPENAI_API_URL").ok())
355                            .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
356
357                        let key =
358                            get_api_key_with_secure_fallback("OPENAI_API_KEY", "openai_api_key");
359                        let key = if key.is_some() {
360                            key
361                        } else if cfg!(test) {
362                            Some("pawan-test-dummy-key".to_string())
363                        } else {
364                            prompt_and_store_api_key("OPENAI_API_KEY", "openai_api_key", "OpenAI")
365                        };
366
367                        (url, key)
368                    }
369                    LlmProvider::Mlx => {
370                        // MLX LM server — Apple Silicon native, always local
371                        let url = config
372                            .base_url
373                            .clone()
374                            .unwrap_or_else(|| "http://localhost:8080/v1".to_string());
375                        tracing::info!(url = %url, "Using MLX LM server (Apple Silicon native)");
376                        (url, None) // mlx_lm.server requires no API key
377                    }
378                    _ => unreachable!(),
379                };
380
381                // Build cloud fallback if configured
382                let cloud = config.cloud.as_ref().map(|c| {
383                    let (cloud_url, cloud_key) = match c.provider {
384                        LlmProvider::Nvidia => {
385                            let url = std::env::var("NVIDIA_API_URL")
386                                .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
387                            let key = get_api_key_with_secure_fallback(
388                                "NVIDIA_API_KEY",
389                                "nvidia_api_key",
390                            );
391                            (url, key)
392                        }
393                        LlmProvider::OpenAI => {
394                            let url = std::env::var("OPENAI_API_URL")
395                                .unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
396                            let key = get_api_key_with_secure_fallback(
397                                "OPENAI_API_KEY",
398                                "openai_api_key",
399                            );
400                            (url, key)
401                        }
402                        LlmProvider::Mlx => ("http://localhost:8080/v1".to_string(), None),
403                        _ => {
404                            tracing::warn!(
405                                "Cloud fallback only supports nvidia/openai/mlx providers"
406                            );
407                            ("https://integrate.api.nvidia.com/v1".to_string(), None)
408                        }
409                    };
410                    super::backend::openai_compat::CloudFallback {
411                        api_url: cloud_url,
412                        api_key: cloud_key,
413                        model: c.model.clone(),
414                        fallback_models: c.fallback_models.clone(),
415                    }
416                });
417
418                Box::new(OpenAiCompatBackend::new(OpenAiCompatConfig {
419                    api_url,
420                    api_key,
421                    model: config.model.clone(),
422                    temperature: config.temperature,
423                    top_p: config.top_p,
424                    max_tokens: config.max_tokens,
425                    system_prompt: system_prompt.to_string(),
426                    // Enforce thinking budget: if set, disable thinking entirely
427                    // and give all tokens to action output
428                    use_thinking: config.thinking_budget == 0 && config.use_thinking_mode(),
429                    max_retries: config.max_retries,
430                    fallback_models: config.fallback_models.clone(),
431                    cloud,
432                }))
433            }
434            LlmProvider::Ollama => {
435                let url = std::env::var("OLLAMA_URL")
436                    .unwrap_or_else(|_| "http://localhost:11434".to_string());
437
438                Box::new(super::backend::ollama::OllamaBackend::new(
439                    url,
440                    config.model.clone(),
441                    config.temperature,
442                    system_prompt.to_string(),
443                ))
444            }
445        }
446    }
447
448    /// Try to construct an ares-backed LLM backend from pawan config.
449    /// Returns `None` if the provider isn't supported by ares or required
450    /// credentials are missing — the caller should fall back to pawan's
451    /// native backend.
452    fn try_create_ares_backend(
453        config: &PawanConfig,
454        system_prompt: &str,
455    ) -> Option<Box<dyn LlmBackend>> {
456        use ares::llm::client::{ModelParams, Provider};
457
458        // Map pawan LlmProvider → ares Provider variants.
459        // ares supports: OpenAI (with custom base_url), Ollama, LlamaCpp, Anthropic.
460        // Pawan's Nvidia/OpenAI/Mlx all use OpenAI-compatible endpoints, so they
461        // all map to ares Provider::OpenAI with different base URLs.
462        let params = ModelParams {
463            temperature: Some(config.temperature),
464            max_tokens: Some(config.max_tokens as u32),
465            top_p: Some(config.top_p),
466            frequency_penalty: None,
467            presence_penalty: None,
468        };
469
470        let provider = match config.provider {
471            LlmProvider::Nvidia => {
472                let api_base = std::env::var("NVIDIA_API_URL")
473                    .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
474                let api_key = std::env::var("NVIDIA_API_KEY").ok()?;
475                Provider::OpenAI {
476                    api_key,
477                    api_base,
478                    model: config.model.clone(),
479                    params,
480                }
481            }
482            LlmProvider::OpenAI => {
483                let api_base = config
484                    .base_url
485                    .clone()
486                    .or_else(|| std::env::var("OPENAI_API_URL").ok())
487                    .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
488                let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
489                Provider::OpenAI {
490                    api_key,
491                    api_base,
492                    model: config.model.clone(),
493                    params,
494                }
495            }
496            LlmProvider::Mlx => {
497                // MLX LM server is OpenAI-compatible, no API key needed
498                let api_base = config
499                    .base_url
500                    .clone()
501                    .unwrap_or_else(|| "http://localhost:8080/v1".to_string());
502                Provider::OpenAI {
503                    api_key: String::new(),
504                    api_base,
505                    model: config.model.clone(),
506                    params,
507                }
508            }
509            LlmProvider::Ollama => {
510                // Ares Ollama client is async-constructed (async with_params),
511                // which doesn't fit pawan's sync PawanAgent::new path.
512                // Fall back to pawan's native OllamaBackend for now.
513                return None;
514            }
515        };
516
517        // OpenAI variants construct synchronously — we skip the async
518        // Provider::create_client() entirely for sync construction.
519        let client: Box<dyn ares::llm::LLMClient> = match provider {
520            Provider::OpenAI {
521                api_key,
522                api_base,
523                model,
524                params,
525            } => Box::new(ares::llm::openai::OpenAIClient::with_params(
526                api_key, api_base, model, params,
527            )),
528            _ => return None,
529        };
530
531        tracing::info!(
532            provider = ?config.provider,
533            model = %config.model,
534            "Using ares-backed LLM backend"
535        );
536
537        Some(Box::new(super::backend::ares_backend::AresBackend::new(
538            client,
539            system_prompt.to_string(),
540        )))
541    }
542
543    /// Create with a specific tool registry
544    pub fn with_tools(mut self, tools: ToolRegistry) -> Self {
545        self.tools = tools;
546        self
547    }
548
549    /// Get mutable access to the tool registry (for registering MCP tools)
550    pub fn tools_mut(&mut self) -> &mut ToolRegistry {
551        &mut self.tools
552    }
553
554    /// Create with a custom backend
555    pub fn with_backend(mut self, backend: Box<dyn LlmBackend>) -> Self {
556        self.backend = backend;
557        self
558    }
559}