1use super::backend::openai_compat::{OpenAiCompatBackend, OpenAiCompatConfig};
4use super::backend::LlmBackend;
5use super::PawanAgent;
6use crate::config::{LlmProvider, PawanConfig};
7use crate::credentials;
8use crate::tools::ToolRegistry;
9use crate::{PawanError, Result};
10use std::path::PathBuf;
11
12pub(crate) fn probe_local_endpoint(url: &str) -> bool {
13 use std::net::TcpStream;
14 use std::time::Duration;
15
16 let hostport = url
18 .trim_start_matches("http://")
19 .trim_start_matches("https://")
20 .split('/')
21 .next()
22 .unwrap_or("");
23
24 let addr = if hostport.contains(':') {
26 hostport.to_string()
27 } else if url.starts_with("https://") {
28 format!("{hostport}:443")
29 } else {
30 format!("{hostport}:80")
31 };
32
33 let addr = addr.replace("localhost", "127.0.0.1");
36
37 let socket_addr = match addr.parse() {
38 Ok(a) => a,
39 Err(_) => return false,
40 };
41
42 TcpStream::connect_timeout(&socket_addr, Duration::from_millis(100)).is_ok()
43}
44
45pub(crate) fn get_api_key_with_secure_fallback(env_var: &str, key_name: &str) -> Option<String> {
53 if let Ok(key) = std::env::var(env_var) {
55 return Some(key);
56 }
57
58 match credentials::get_api_key(key_name) {
60 Ok(Some(key)) => {
61 std::env::set_var(env_var, &key);
63 Some(key)
64 }
65 Ok(None) => None,
66 Err(e) => {
67 tracing::warn!("Failed to retrieve {} from secure store: {}", key_name, e);
68 None
69 }
70 }
71}
72
73fn prompt_and_store_api_key(env_var: &str, key_name: &str, provider: &str) -> Option<String> {
82 eprintln!("\n🔑 {} API key not found.", provider);
83 eprintln!("You can set it via:");
84 eprintln!(" - Environment variable: export {}=<your-key>", env_var);
85 eprintln!(" - Interactive entry (recommended for security)");
86 eprintln!("\nEnter your {} API key:", provider);
87 eprintln!(" (Your key will be stored securely in the OS credential store)\n");
88
89 #[cfg(unix)]
91 let key = {
92 use std::io::{self, Write};
93
94 let mut stdout = io::stdout();
96 stdout.flush().ok();
97
98 rpassword::prompt_password("> ").ok()
100 };
101
102 #[cfg(windows)]
103 let key = {
104 use std::io::{self, Write};
105
106 let mut stdout = io::stdout();
107 stdout.flush().ok();
108
109 rpassword::prompt_password("> ").ok()
111 };
112
113 #[cfg(not(any(unix, windows)))]
114 let key = {
115 use std::io::{self, BufRead, Write};
116
117 let mut stdout = io::stdout();
118 let mut stdin = io::stdin();
119 stdout.flush().ok();
120 print!("> ");
121 stdout.flush().ok();
122
123 let mut input = String::new();
124 stdin.lock().read_line(&mut input).ok();
125 Some(input.trim().to_string())
126 };
127
128 match key {
129 Some(k) if !k.trim().is_empty() => {
130 let key = k.trim().to_string();
131
132 match credentials::store_api_key(key_name, &key) {
134 Ok(()) => {
135 tracing::info!("{} API key stored securely", provider);
136 std::env::set_var(env_var, &key);
137 Some(key)
138 }
139 Err(e) => {
140 tracing::warn!("Failed to store key securely: {}. Using session-only.", e);
141 std::env::set_var(env_var, &key);
142 Some(key)
143 }
144 }
145 }
146 _ => {
147 eprintln!(
148 "\n⚠️ No key entered. {} will not work until a key is set.",
149 provider
150 );
151 None
152 }
153 }
154}
155
156pub(crate) fn scan_context_file(content: &str, source: &str) -> Result<String> {
157 let suspicious = [
159 "IGNORE ALL PREVIOUS",
160 "DISREGARD ALL",
161 "OVERRIDE",
162 "You are now",
163 "Your new role",
164 "IMPORTANT: do not",
165 "<system-directive>",
166 "<role>",
167 "<contract>",
168 "\u{200B}",
170 "\u{200C}",
171 "\u{200D}",
172 "\u{FEFF}",
173 "\u{202E}",
174 "\u{2060}",
175 "\u{2061}",
176 "\u{2062}",
177 ];
178
179 let upper = content.to_uppercase();
180 let allow = source.ends_with("AGENTS.md") || source.ends_with("CLAUDE.md");
181
182 for pattern in &suspicious {
183 let hit = if pattern.is_ascii() {
184 upper.contains(&pattern.to_uppercase())
185 } else {
186 content.contains(pattern)
187 };
188
189 if hit {
190 tracing::warn!(source = %source, pattern = %pattern, "prompt injection pattern detected");
191 if allow {
192 continue;
193 }
194 return Err(PawanError::Config(format!(
195 "Suspicious content in {}: contains '{}'",
196 source, pattern
197 )));
198 }
199 }
200 Ok(content.to_string())
201}
202
203pub(crate) fn load_arch_context(workspace_root: &std::path::Path) -> Result<Option<String>> {
209 let path = workspace_root.join(".pawan").join("arch.md");
210 if !path.exists() {
211 return Ok(None);
212 }
213
214 let bytes = std::fs::read(&path).map_err(PawanError::Io)?;
215 let content = String::from_utf8(bytes).map_err(|_| {
216 PawanError::Config(
217 "Suspicious content in .pawan/arch.md: file is not valid UTF-8 (binary?)".to_string(),
218 )
219 })?;
220
221 if content.trim().is_empty() {
222 return Ok(None);
223 }
224
225 let content = scan_context_file(&content, ".pawan/arch.md")?;
226
227 const MAX_CHARS: usize = 2_000;
228 if content.len() > MAX_CHARS {
229 let boundary = content
231 .char_indices()
232 .map(|(i, _)| i)
233 .nth(MAX_CHARS)
234 .unwrap_or(content.len());
235 Ok(Some(format!("{}…(truncated)", &content[..boundary])))
236 } else {
237 Ok(Some(content))
238 }
239}
240
241impl PawanAgent {
242 pub fn new(config: PawanConfig, workspace_root: PathBuf) -> Self {
244 let tools = ToolRegistry::with_defaults(workspace_root.clone());
245 let system_prompt = config.get_system_prompt();
246 let backend = Self::create_backend(&config, &system_prompt);
247 let eruka = if config.eruka.enabled {
248 Some(crate::eruka_bridge::ErukaClient::new(config.eruka.clone()))
249 } else {
250 None
251 };
252 let (arch_context, arch_context_error) = match load_arch_context(&workspace_root) {
253 Ok(v) => (v, None),
254 Err(e) => (None, Some(e.to_string())),
255 };
256
257 Self {
258 config,
259 tools,
260 history: Vec::new(),
261 workspace_root,
262 backend,
263 context_tokens_estimate: 0,
264 eruka,
265 session_id: uuid::Uuid::new_v4().to_string(),
266 arch_context,
267 arch_context_error,
268 last_tool_call_time: None,
269 }
270 }
271
272 pub(crate) fn create_backend(config: &PawanConfig, system_prompt: &str) -> Box<dyn LlmBackend> {
279 if config.local_first {
282 let local_url = config
283 .local_endpoint
284 .clone()
285 .unwrap_or_else(|| "http://localhost:11434/v1".to_string());
286 if probe_local_endpoint(&local_url) {
287 tracing::info!(
288 url = %local_url,
289 model = %config.model,
290 "local_first: local server reachable, using local inference"
291 );
292 return Box::new(OpenAiCompatBackend::new(
293 super::backend::openai_compat::OpenAiCompatConfig {
294 api_url: local_url,
295 api_key: None,
296 model: config.model.clone(),
297 temperature: config.temperature,
298 top_p: config.top_p,
299 max_tokens: config.max_tokens,
300 system_prompt: system_prompt.to_string(),
301 use_thinking: false,
302 max_retries: config.max_retries,
303 fallback_models: Vec::new(),
304 cloud: None,
305 },
306 ));
307 }
308 tracing::info!(
309 url = %local_url,
310 "local_first: local server unreachable, falling back to cloud provider"
311 );
312 }
313
314 if config.use_ares_backend {
316 if let Some(backend) = Self::try_create_ares_backend(config, system_prompt) {
317 return backend;
318 }
319 tracing::warn!(
320 "use_ares_backend=true but ares backend creation failed; \
321 falling back to pawan's native backend"
322 );
323 }
324
325 match config.provider {
326 LlmProvider::Nvidia | LlmProvider::OpenAI | LlmProvider::Mlx => {
327 let (api_url, api_key) = match config.provider {
328 LlmProvider::Nvidia => {
329 let url = std::env::var("NVIDIA_API_URL")
330 .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
331
332 let key =
334 get_api_key_with_secure_fallback("NVIDIA_API_KEY", "nvidia_api_key");
335
336 let key = if key.is_some() {
338 key
339 } else if cfg!(test) {
340 Some("pawan-test-dummy-key".to_string())
341 } else {
342 prompt_and_store_api_key("NVIDIA_API_KEY", "nvidia_api_key", "NVIDIA")
343 };
344
345 if key.is_none() {
346 tracing::warn!("NVIDIA_API_KEY not set. Model calls will fail until a key is provided.");
347 }
348 (url, key)
349 }
350 LlmProvider::OpenAI => {
351 let url = config
352 .base_url
353 .clone()
354 .or_else(|| std::env::var("OPENAI_API_URL").ok())
355 .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
356
357 let key =
358 get_api_key_with_secure_fallback("OPENAI_API_KEY", "openai_api_key");
359 let key = if key.is_some() {
360 key
361 } else if cfg!(test) {
362 Some("pawan-test-dummy-key".to_string())
363 } else {
364 prompt_and_store_api_key("OPENAI_API_KEY", "openai_api_key", "OpenAI")
365 };
366
367 (url, key)
368 }
369 LlmProvider::Mlx => {
370 let url = config
372 .base_url
373 .clone()
374 .unwrap_or_else(|| "http://localhost:8080/v1".to_string());
375 tracing::info!(url = %url, "Using MLX LM server (Apple Silicon native)");
376 (url, None) }
378 _ => unreachable!(),
379 };
380
381 let cloud = config.cloud.as_ref().map(|c| {
383 let (cloud_url, cloud_key) = match c.provider {
384 LlmProvider::Nvidia => {
385 let url = std::env::var("NVIDIA_API_URL")
386 .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
387 let key = get_api_key_with_secure_fallback(
388 "NVIDIA_API_KEY",
389 "nvidia_api_key",
390 );
391 (url, key)
392 }
393 LlmProvider::OpenAI => {
394 let url = std::env::var("OPENAI_API_URL")
395 .unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
396 let key = get_api_key_with_secure_fallback(
397 "OPENAI_API_KEY",
398 "openai_api_key",
399 );
400 (url, key)
401 }
402 LlmProvider::Mlx => ("http://localhost:8080/v1".to_string(), None),
403 _ => {
404 tracing::warn!(
405 "Cloud fallback only supports nvidia/openai/mlx providers"
406 );
407 ("https://integrate.api.nvidia.com/v1".to_string(), None)
408 }
409 };
410 super::backend::openai_compat::CloudFallback {
411 api_url: cloud_url,
412 api_key: cloud_key,
413 model: c.model.clone(),
414 fallback_models: c.fallback_models.clone(),
415 }
416 });
417
418 Box::new(OpenAiCompatBackend::new(OpenAiCompatConfig {
419 api_url,
420 api_key,
421 model: config.model.clone(),
422 temperature: config.temperature,
423 top_p: config.top_p,
424 max_tokens: config.max_tokens,
425 system_prompt: system_prompt.to_string(),
426 use_thinking: config.thinking_budget == 0 && config.use_thinking_mode(),
429 max_retries: config.max_retries,
430 fallback_models: config.fallback_models.clone(),
431 cloud,
432 }))
433 }
434 LlmProvider::Ollama => {
435 let url = std::env::var("OLLAMA_URL")
436 .unwrap_or_else(|_| "http://localhost:11434".to_string());
437
438 Box::new(super::backend::ollama::OllamaBackend::new(
439 url,
440 config.model.clone(),
441 config.temperature,
442 system_prompt.to_string(),
443 ))
444 }
445 }
446 }
447
448 fn try_create_ares_backend(
453 config: &PawanConfig,
454 system_prompt: &str,
455 ) -> Option<Box<dyn LlmBackend>> {
456 use ares::llm::client::{ModelParams, Provider};
457
458 let params = ModelParams {
463 temperature: Some(config.temperature),
464 max_tokens: Some(config.max_tokens as u32),
465 top_p: Some(config.top_p),
466 frequency_penalty: None,
467 presence_penalty: None,
468 };
469
470 let provider = match config.provider {
471 LlmProvider::Nvidia => {
472 let api_base = std::env::var("NVIDIA_API_URL")
473 .unwrap_or_else(|_| crate::DEFAULT_NVIDIA_API_URL.to_string());
474 let api_key = std::env::var("NVIDIA_API_KEY").ok()?;
475 Provider::OpenAI {
476 api_key,
477 api_base,
478 model: config.model.clone(),
479 params,
480 }
481 }
482 LlmProvider::OpenAI => {
483 let api_base = config
484 .base_url
485 .clone()
486 .or_else(|| std::env::var("OPENAI_API_URL").ok())
487 .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
488 let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_default();
489 Provider::OpenAI {
490 api_key,
491 api_base,
492 model: config.model.clone(),
493 params,
494 }
495 }
496 LlmProvider::Mlx => {
497 let api_base = config
499 .base_url
500 .clone()
501 .unwrap_or_else(|| "http://localhost:8080/v1".to_string());
502 Provider::OpenAI {
503 api_key: String::new(),
504 api_base,
505 model: config.model.clone(),
506 params,
507 }
508 }
509 LlmProvider::Ollama => {
510 return None;
514 }
515 };
516
517 let client: Box<dyn ares::llm::LLMClient> = match provider {
520 Provider::OpenAI {
521 api_key,
522 api_base,
523 model,
524 params,
525 } => Box::new(ares::llm::openai::OpenAIClient::with_params(
526 api_key, api_base, model, params,
527 )),
528 _ => return None,
529 };
530
531 tracing::info!(
532 provider = ?config.provider,
533 model = %config.model,
534 "Using ares-backed LLM backend"
535 );
536
537 Some(Box::new(super::backend::ares_backend::AresBackend::new(
538 client,
539 system_prompt.to_string(),
540 )))
541 }
542
543 pub fn with_tools(mut self, tools: ToolRegistry) -> Self {
545 self.tools = tools;
546 self
547 }
548
549 pub fn tools_mut(&mut self) -> &mut ToolRegistry {
551 &mut self.tools
552 }
553
554 pub fn with_backend(mut self, backend: Box<dyn LlmBackend>) -> Self {
556 self.backend = backend;
557 self
558 }
559}