Skip to main content

localgpt_core/agent/tools/
mod.rs

1pub mod spawn_agent;
2pub mod ssrf;
3pub mod web_search;
4
5use anyhow::Result;
6use async_trait::async_trait;
7use once_cell::sync::Lazy;
8use readability::extractor;
9use regex::Regex;
10use serde_json::{Value, json};
11use std::fs;
12use std::io::Cursor;
13use std::path::PathBuf;
14use std::sync::Arc;
15use tracing::debug;
16
17use super::providers::ToolSchema;
18use crate::config::{Config, SearchProviderType};
19use crate::memory::MemoryManager;
20
21use spawn_agent::{SpawnAgentTool, SpawnContext};
22use web_search::{SearchRouter, WebSearchTool};
23
24#[derive(Debug, Clone)]
25pub struct ToolResult {
26    pub call_id: String,
27    pub output: String,
28}
29
30/// Permission level required to execute a tool.
31///
32/// Tools default to `Safe`. CLI dangerous tools (bash, file write, etc.) override
33/// to `Elevated`. Admin tools (config edit, key rotation) override to `Admin`.
34#[derive(
35    Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
36)]
37#[serde(rename_all = "lowercase")]
38pub enum PermissionLevel {
39    /// Read-only tools: memory search, web fetch, etc.
40    Safe = 0,
41    /// File write, shell exec, browser automation.
42    Elevated = 1,
43    /// Config changes, daemon control, encryption key rotation.
44    Admin = 2,
45}
46
47impl Default for PermissionLevel {
48    fn default() -> Self {
49        Self::Safe
50    }
51}
52
53impl std::fmt::Display for PermissionLevel {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            Self::Safe => f.write_str("safe"),
57            Self::Elevated => f.write_str("elevated"),
58            Self::Admin => f.write_str("admin"),
59        }
60    }
61}
62
63#[async_trait]
64pub trait Tool: Send + Sync {
65    fn name(&self) -> &str;
66    fn schema(&self) -> ToolSchema;
67    async fn execute(&self, arguments: &str) -> Result<String>;
68
69    /// Permission level required to execute this tool. Default: Safe.
70    fn permission_level(&self) -> PermissionLevel {
71        PermissionLevel::Safe
72    }
73
74    /// MCP tool annotations (readOnlyHint, destructiveHint, etc.).
75    ///
76    /// Returns `None` by default. Override to provide annotations per the MCP spec.
77    fn annotations(&self) -> Option<Value> {
78        None
79    }
80}
81
82/// Create the safe (mobile-compatible) tools: memory search, memory get, web fetch, web search.
83///
84/// Dangerous tools (bash, read_file, write_file, edit_file) are provided by the CLI crate.
85/// Use `Agent::new_with_tools()` to supply the full tool set.
86pub fn create_safe_tools(
87    config: &Config,
88    memory: Option<Arc<MemoryManager>>,
89) -> Result<Vec<Box<dyn Tool>>> {
90    use super::hardcoded_filters;
91    use super::tool_filters::CompiledToolFilter;
92
93    let workspace = config.workspace_path();
94
95    // Use indexed memory search if MemoryManager is provided, otherwise fallback to grep-based
96    let memory_search_tool: Box<dyn Tool> = if let Some(ref mem) = memory {
97        Box::new(MemorySearchToolWithIndex::new(Arc::clone(mem)))
98    } else {
99        Box::new(MemorySearchTool::new(workspace.clone()))
100    };
101
102    // Compile web_fetch filter from user config and merge small hardcoded
103    // fail-fast deny rules (authoritative SSRF protection is still handled by
104    // validate_web_fetch_url() with host parsing + DNS/IP checks).
105    let web_fetch_filter = config
106        .tools
107        .filters
108        .get("web_fetch")
109        .map(CompiledToolFilter::compile)
110        .unwrap_or_else(|| Ok(CompiledToolFilter::permissive()))?
111        .merge_hardcoded(
112            hardcoded_filters::WEB_FETCH_DENY_SUBSTRINGS,
113            hardcoded_filters::WEB_FETCH_DENY_PATTERNS,
114        )?;
115
116    let mut tools: Vec<Box<dyn Tool>> = vec![
117        memory_search_tool,
118        Box::new(MemoryGetTool::new(workspace.clone())),
119        Box::new(WebFetchTool::new(
120            config.tools.web_fetch_max_bytes,
121            web_fetch_filter,
122        )?),
123    ];
124
125    // Conditionally add web search tool
126    if let Some(ref ws_config) = config.tools.web_search
127        && !matches!(ws_config.provider, SearchProviderType::None)
128    {
129        match SearchRouter::from_config(ws_config) {
130            Ok(router) => tools.push(Box::new(WebSearchTool::new(Arc::new(router)))),
131            Err(e) => tracing::warn!("Web search init failed: {e}"),
132        }
133    }
134
135    // Document loader tool (always available — uses shell commands for extraction)
136    tools.push(Box::new(DocumentLoadTool::new(workspace, &config.tools)));
137
138    // Wiki tools (structured knowledge management)
139    if config.memory.wiki_enabled
140        && let Some(ref mem) = memory
141    {
142        match crate::memory::wiki::WikiStore::new(
143            mem.db_path(),
144            config.memory.wiki_fresh_days,
145            config.memory.wiki_stale_days,
146        ) {
147            Ok(store) => {
148                let store = Arc::new(store);
149                tools.push(Box::new(WikiAddTool::new(Arc::clone(&store))));
150                tools.push(Box::new(WikiSearchTool::new(Arc::clone(&store))));
151                tools.push(Box::new(WikiStatusTool::new(store)));
152            }
153            Err(e) => tracing::warn!("Wiki store init failed: {e}"),
154        }
155    }
156
157    // Audio transcription tool (only if STT providers are configured)
158    if let Some(ref stt_config) = config.tools.stt {
159        let env_vars: std::collections::HashMap<String, String> = std::env::vars().collect();
160        let registry = crate::media::SttRegistry::from_config(stt_config, &env_vars);
161        if registry.has_providers() {
162            let audio_cache = if config.tools.media_cache_enabled {
163                Some(crate::media::cache::MediaCache::new(
164                    config.workspace_path().join(".cache").join("media"),
165                    config.tools.media_cache_max_mb,
166                ))
167            } else {
168                None
169            };
170            tools.push(Box::new(AudioTranscribeTool::new(
171                Arc::new(registry),
172                config.workspace_path(),
173                audio_cache,
174            )));
175        } else {
176            tracing::debug!("STT configured but no providers available (missing API keys?)");
177        }
178    }
179
180    Ok(tools)
181}
182
183/// Create spawn_agent tool for hierarchical delegation.
184///
185/// This tool allows an agent to spawn specialist subagents for tasks like
186/// exploration, planning, implementation, or analysis.
187///
188/// # Arguments
189/// * `config` - Application configuration (cloned)
190/// * `memory` - Memory manager (shared with parent agent, required)
191///
192/// # Returns
193/// A boxed spawn_agent tool
194pub fn create_spawn_agent_tool(config: Config, memory: Arc<MemoryManager>) -> Box<dyn Tool> {
195    Box::new(SpawnAgentTool::from_config(config, memory))
196}
197
198/// Create spawn_agent tool with custom depth (for subagents).
199///
200/// Subagents get spawn_agent tool only if they're below the max depth.
201pub fn create_spawn_agent_tool_at_depth(
202    config: Config,
203    memory: Arc<MemoryManager>,
204    depth: u8,
205) -> Option<Box<dyn Tool>> {
206    let max_depth = config.agent.max_spawn_depth.unwrap_or(1);
207
208    if depth >= max_depth {
209        // At or past max depth, don't provide spawn_agent
210        return None;
211    }
212
213    let tool = SpawnAgentTool::new(SpawnContext {
214        depth,
215        config,
216        memory,
217        model: None,
218        max_depth,
219    });
220
221    Some(Box::new(tool))
222}
223
224// Memory Search Tool
225pub struct MemorySearchTool {
226    workspace: PathBuf,
227}
228
229impl MemorySearchTool {
230    pub fn new(workspace: PathBuf) -> Self {
231        Self { workspace }
232    }
233}
234
235#[async_trait]
236impl Tool for MemorySearchTool {
237    fn name(&self) -> &str {
238        "memory_search"
239    }
240
241    fn schema(&self) -> ToolSchema {
242        ToolSchema {
243            name: "memory_search".to_string(),
244            description: "Search the memory index for relevant information".to_string(),
245            parameters: json!({
246                "type": "object",
247                "properties": {
248                    "query": {
249                        "type": "string",
250                        "description": "The search query"
251                    },
252                    "limit": {
253                        "type": "integer",
254                        "description": "Maximum number of results (default: 5)"
255                    }
256                },
257                "required": ["query"]
258            }),
259        }
260    }
261
262    async fn execute(&self, arguments: &str) -> Result<String> {
263        let args: Value = serde_json::from_str(arguments)?;
264        let query = args["query"]
265            .as_str()
266            .ok_or_else(|| anyhow::anyhow!("Missing query"))?;
267        let limit = args["limit"].as_u64().unwrap_or(5) as usize;
268
269        debug!("Memory search: {} (limit: {})", query, limit);
270
271        // Simple grep-based search for now
272        // TODO: Use proper memory index
273        let mut results = Vec::new();
274
275        let memory_file = self.workspace.join("MEMORY.md");
276        if memory_file.exists()
277            && let Ok(content) = fs::read_to_string(&memory_file)
278        {
279            for (i, line) in content.lines().enumerate() {
280                if line.to_lowercase().contains(&query.to_lowercase()) {
281                    results.push(format!("MEMORY.md:{}: {}", i + 1, line));
282                    if results.len() >= limit {
283                        break;
284                    }
285                }
286            }
287        }
288
289        // Search daily logs
290        let memory_dir = self.workspace.join("memory");
291        if memory_dir.exists()
292            && let Ok(entries) = fs::read_dir(&memory_dir)
293        {
294            for entry in entries.filter_map(|e| e.ok()) {
295                if results.len() >= limit {
296                    break;
297                }
298
299                let path = entry.path();
300                if path.extension().map(|e| e == "md").unwrap_or(false)
301                    && let Ok(content) = fs::read_to_string(&path)
302                {
303                    let filename = path.file_name().unwrap().to_string_lossy();
304                    for (i, line) in content.lines().enumerate() {
305                        if line.to_lowercase().contains(&query.to_lowercase()) {
306                            results.push(format!("memory/{}:{}: {}", filename, i + 1, line));
307                            if results.len() >= limit {
308                                break;
309                            }
310                        }
311                    }
312                }
313            }
314        }
315
316        if results.is_empty() {
317            Ok("No results found".to_string())
318        } else {
319            Ok(results.join("\n"))
320        }
321    }
322}
323
324// Memory Search Tool with Index - uses MemoryManager for hybrid FTS+vector search
325pub struct MemorySearchToolWithIndex {
326    memory: Arc<MemoryManager>,
327}
328
329impl MemorySearchToolWithIndex {
330    pub fn new(memory: Arc<MemoryManager>) -> Self {
331        Self { memory }
332    }
333}
334
335#[async_trait]
336impl Tool for MemorySearchToolWithIndex {
337    fn name(&self) -> &str {
338        "memory_search"
339    }
340
341    fn schema(&self) -> ToolSchema {
342        let description = if self.memory.has_embeddings() {
343            "Search the memory index using hybrid semantic + keyword search for relevant information"
344        } else {
345            "Search the memory index for relevant information"
346        };
347
348        ToolSchema {
349            name: "memory_search".to_string(),
350            description: description.to_string(),
351            parameters: json!({
352                "type": "object",
353                "properties": {
354                    "query": {
355                        "type": "string",
356                        "description": "The search query"
357                    },
358                    "limit": {
359                        "type": "integer",
360                        "description": "Maximum number of results (default: 5)"
361                    }
362                },
363                "required": ["query"]
364            }),
365        }
366    }
367
368    async fn execute(&self, arguments: &str) -> Result<String> {
369        let args: Value = serde_json::from_str(arguments)?;
370        let query = args["query"]
371            .as_str()
372            .ok_or_else(|| anyhow::anyhow!("Missing query"))?;
373        let limit = args["limit"].as_u64().unwrap_or(5) as usize;
374
375        let search_type = if self.memory.has_embeddings() {
376            "hybrid"
377        } else {
378            "FTS"
379        };
380        debug!(
381            "Memory search ({}): {} (limit: {})",
382            search_type, query, limit
383        );
384
385        let results = self.memory.search(query, limit)?;
386
387        if results.is_empty() {
388            return Ok("No results found".to_string());
389        }
390
391        // Format results with citation-style references
392        let formatted: Vec<String> = results
393            .iter()
394            .enumerate()
395            .map(|(i, chunk)| {
396                let preview: String = chunk.content.chars().take(200).collect();
397                let preview = preview.replace('\n', " ");
398                format!(
399                    "{}. [{}:{}-{}] (score: {:.3})\n   {}{}",
400                    i + 1,
401                    chunk.file,
402                    chunk.line_start,
403                    chunk.line_end,
404                    chunk.score,
405                    preview,
406                    if chunk.content.len() > 200 { "..." } else { "" }
407                )
408            })
409            .collect();
410
411        Ok(formatted.join("\n\n"))
412    }
413}
414
415// Memory Get Tool - efficient snippet fetching after memory_search
416pub struct MemoryGetTool {
417    workspace: PathBuf,
418}
419
420impl MemoryGetTool {
421    pub fn new(workspace: PathBuf) -> Self {
422        Self { workspace }
423    }
424
425    fn resolve_path(&self, path: &str) -> PathBuf {
426        // Handle paths relative to workspace
427        if path.starts_with("memory/") || path == "MEMORY.md" || path == "HEARTBEAT.md" {
428            self.workspace.join(path)
429        } else {
430            PathBuf::from(shellexpand::tilde(path).to_string())
431        }
432    }
433
434    /// Validate that a resolved path stays within the workspace directory.
435    /// Checks the parent directory's canonical path if the file doesn't exist yet.
436    fn is_within_workspace(&self, resolved: &std::path::Path) -> bool {
437        let workspace_canonical = match self.workspace.canonicalize() {
438            Ok(p) => p,
439            Err(_) => return false,
440        };
441        // Try canonicalizing the file itself first
442        if let Ok(canonical) = resolved.canonicalize() {
443            return canonical.starts_with(&workspace_canonical);
444        }
445        // File doesn't exist — check the parent directory instead
446        if let Some(parent) = resolved.parent()
447            && let Ok(parent_canonical) = parent.canonicalize()
448        {
449            return parent_canonical.starts_with(&workspace_canonical);
450        }
451        false
452    }
453}
454
455#[async_trait]
456impl Tool for MemoryGetTool {
457    fn name(&self) -> &str {
458        "memory_get"
459    }
460
461    fn schema(&self) -> ToolSchema {
462        ToolSchema {
463            name: "memory_get".to_string(),
464            description: "Safe snippet read from MEMORY.md or memory/*.md with optional line range; use after memory_search to pull only the needed lines and keep context small.".to_string(),
465            parameters: json!({
466                "type": "object",
467                "properties": {
468                    "path": {
469                        "type": "string",
470                        "description": "Path to the file (e.g., 'MEMORY.md' or 'memory/2024-01-15.md')"
471                    },
472                    "from": {
473                        "type": "integer",
474                        "description": "Starting line number (1-indexed, default: 1)"
475                    },
476                    "lines": {
477                        "type": "integer",
478                        "description": "Number of lines to read (default: 50)"
479                    }
480                },
481                "required": ["path"]
482            }),
483        }
484    }
485
486    async fn execute(&self, arguments: &str) -> Result<String> {
487        let args: Value = serde_json::from_str(arguments)?;
488        let path = args["path"]
489            .as_str()
490            .ok_or_else(|| anyhow::anyhow!("Missing path"))?;
491
492        // Reject null bytes in raw input
493        if path.contains('\0') {
494            anyhow::bail!("Invalid path: null bytes not allowed");
495        }
496
497        let from = args["from"].as_u64().unwrap_or(1).max(1) as usize;
498        let lines_count = (args["lines"].as_u64().unwrap_or(50) as usize).min(10_000);
499
500        let resolved_path = self.resolve_path(path);
501
502        // Check for path traversal on the resolved path (catches .. after tilde expansion)
503        if resolved_path
504            .components()
505            .any(|c| matches!(c, std::path::Component::ParentDir))
506        {
507            anyhow::bail!("Invalid path: path traversal not allowed");
508        }
509
510        // Verify resolved path stays within workspace
511        if !self.is_within_workspace(&resolved_path) {
512            anyhow::bail!("Access denied: path is outside workspace");
513        }
514
515        debug!(
516            "Memory get: {} (from: {}, lines: {})",
517            resolved_path.display(),
518            from,
519            lines_count
520        );
521
522        if !resolved_path.exists() {
523            return Ok(format!("File not found: {}", path));
524        }
525
526        let content = fs::read_to_string(&resolved_path)?;
527        let lines: Vec<&str> = content.lines().collect();
528        let total_lines = lines.len();
529
530        // Convert from 1-indexed to 0-indexed
531        let start = (from - 1).min(total_lines);
532        let end = (start + lines_count).min(total_lines);
533
534        if start >= total_lines {
535            return Ok(format!(
536                "Line {} is past end of file ({} lines)",
537                from, total_lines
538            ));
539        }
540
541        let selected: Vec<String> = lines[start..end]
542            .iter()
543            .enumerate()
544            .map(|(i, line)| format!("{:4}\t{}", start + i + 1, line))
545            .collect();
546
547        let header = format!(
548            "# {} (lines {}-{} of {})\n",
549            path,
550            start + 1,
551            end,
552            total_lines
553        );
554        Ok(header + &selected.join("\n"))
555    }
556}
557
558// Document Load Tool — extracts text from PDF, DOCX, EPUB, HTML via shell commands
559pub struct DocumentLoadTool {
560    loaders: crate::media::DocumentLoaders,
561    workspace: PathBuf,
562    max_bytes: usize,
563    output_max_chars: usize,
564    cache: Option<crate::media::cache::MediaCache>,
565}
566
567impl DocumentLoadTool {
568    pub fn new(workspace: PathBuf, config: &crate::config::ToolsConfig) -> Self {
569        let loaders = match config.document_loaders {
570            Some(ref custom) => crate::media::DocumentLoaders::with_custom(custom),
571            None => crate::media::DocumentLoaders::new(),
572        };
573        let cache = if config.media_cache_enabled {
574            Some(crate::media::cache::MediaCache::new(
575                workspace.join(".cache").join("media"),
576                config.media_cache_max_mb,
577            ))
578        } else {
579            None
580        };
581        Self {
582            loaders,
583            workspace,
584            max_bytes: config.document_max_bytes,
585            output_max_chars: config.tool_output_max_chars,
586            cache,
587        }
588    }
589
590    fn validate_path(&self, path_str: &str) -> Result<PathBuf> {
591        if path_str.contains('\0') {
592            anyhow::bail!("Invalid path: null bytes not allowed");
593        }
594        let expanded = shellexpand::tilde(path_str).to_string();
595        let resolved = if std::path::Path::new(&expanded).is_absolute() {
596            PathBuf::from(expanded)
597        } else {
598            self.workspace.join(expanded)
599        };
600        if resolved
601            .components()
602            .any(|c| matches!(c, std::path::Component::ParentDir))
603        {
604            anyhow::bail!("Invalid path: path traversal not allowed");
605        }
606        Ok(resolved)
607    }
608}
609
610#[async_trait]
611impl Tool for DocumentLoadTool {
612    fn name(&self) -> &str {
613        "document_load"
614    }
615
616    fn schema(&self) -> ToolSchema {
617        ToolSchema {
618            name: "document_load".to_string(),
619            description: "Extract text content from PDF, DOCX, EPUB, or HTML documents. Returns the document text.".to_string(),
620            parameters: json!({
621                "type": "object",
622                "properties": {
623                    "path": {
624                        "type": "string",
625                        "description": "Path to the document file (relative to workspace or absolute)"
626                    }
627                },
628                "required": ["path"]
629            }),
630        }
631    }
632
633    async fn execute(&self, arguments: &str) -> Result<String> {
634        let args: Value = serde_json::from_str(arguments)?;
635        let path_str = args["path"]
636            .as_str()
637            .ok_or_else(|| anyhow::anyhow!("Missing path"))?;
638
639        let resolved = self.validate_path(path_str)?;
640
641        if !resolved.exists() {
642            anyhow::bail!("File not found: {}", path_str);
643        }
644
645        let metadata = fs::metadata(&resolved)?;
646        if metadata.len() as usize > self.max_bytes {
647            anyhow::bail!(
648                "File too large: {} bytes (max: {} bytes / {}MB)",
649                metadata.len(),
650                self.max_bytes,
651                self.max_bytes / 1_048_576
652            );
653        }
654
655        let ext = resolved.extension().and_then(|e| e.to_str()).unwrap_or("");
656        if !self.loaders.has_loader(ext) {
657            let supported = self.loaders.supported_extensions().join(", ");
658            anyhow::bail!("Unsupported format: .{}. Supported: {}", ext, supported);
659        }
660
661        // Check cache
662        if let Some(ref cache) = self.cache
663            && let Some(cached) = cache.get(&resolved)
664        {
665            return Ok(cached);
666        }
667
668        debug!("Loading document: {} ({})", resolved.display(), ext);
669        let text = self.loaders.extract_text(&resolved)?;
670
671        if let Some(ref cache) = self.cache {
672            let _ = cache.put(&resolved, &text);
673        }
674
675        if self.output_max_chars > 0 && text.len() > self.output_max_chars {
676            let truncated = truncate_on_char_boundary(&text, self.output_max_chars);
677            Ok(format!(
678                "{}\n\n[Truncated, {} chars total]",
679                truncated,
680                text.len()
681            ))
682        } else {
683            Ok(text)
684        }
685    }
686}
687
688// Audio Transcribe Tool — transcribes audio files via Groq/OpenAI/CLI
689pub struct AudioTranscribeTool {
690    registry: Arc<crate::media::SttRegistry>,
691    workspace: PathBuf,
692    cache: Option<crate::media::cache::MediaCache>,
693}
694
695impl AudioTranscribeTool {
696    pub fn new(
697        registry: Arc<crate::media::SttRegistry>,
698        workspace: PathBuf,
699        cache: Option<crate::media::cache::MediaCache>,
700    ) -> Self {
701        Self {
702            registry,
703            workspace,
704            cache,
705        }
706    }
707
708    fn validate_path(&self, path_str: &str) -> Result<PathBuf> {
709        if path_str.contains('\0') {
710            anyhow::bail!("Invalid path: null bytes not allowed");
711        }
712        let expanded = shellexpand::tilde(path_str).to_string();
713        let resolved = if std::path::Path::new(&expanded).is_absolute() {
714            PathBuf::from(expanded)
715        } else {
716            self.workspace.join(expanded)
717        };
718        if resolved
719            .components()
720            .any(|c| matches!(c, std::path::Component::ParentDir))
721        {
722            anyhow::bail!("Invalid path: path traversal not allowed");
723        }
724        Ok(resolved)
725    }
726}
727
728#[async_trait]
729impl Tool for AudioTranscribeTool {
730    fn name(&self) -> &str {
731        "transcribe_audio"
732    }
733
734    fn schema(&self) -> ToolSchema {
735        ToolSchema {
736            name: "transcribe_audio".to_string(),
737            description: "Transcribe audio files (MP3, M4A, WAV, OGG, FLAC, WEBM) to text using speech-to-text.".to_string(),
738            parameters: json!({
739                "type": "object",
740                "properties": {
741                    "path": {
742                        "type": "string",
743                        "description": "Path to the audio file"
744                    },
745                    "language": {
746                        "type": "string",
747                        "description": "Language hint (ISO 639-1, e.g., 'en', 'zh', 'ja'). Default: 'en'"
748                    }
749                },
750                "required": ["path"]
751            }),
752        }
753    }
754
755    async fn execute(&self, arguments: &str) -> Result<String> {
756        let args: Value = serde_json::from_str(arguments)?;
757        let path_str = args["path"]
758            .as_str()
759            .ok_or_else(|| anyhow::anyhow!("Missing path"))?;
760
761        let resolved = self.validate_path(path_str)?;
762
763        if !resolved.exists() {
764            anyhow::bail!("File not found: {}", path_str);
765        }
766
767        let mime_type = crate::media::audio::mime_type_from_path(&resolved);
768        if mime_type == "audio/octet-stream" {
769            let ext = resolved.extension().and_then(|e| e.to_str()).unwrap_or("?");
770            anyhow::bail!(
771                "Unsupported audio format: .{}. Supported: ogg, opus, mp3, m4a, wav, webm, flac",
772                ext
773            );
774        }
775
776        // Check cache
777        if let Some(ref cache) = self.cache
778            && let Some(cached) = cache.get(&resolved)
779        {
780            return Ok(cached);
781        }
782
783        let audio_data = fs::read(&resolved)?;
784        debug!(
785            "Transcribing audio: {} ({} bytes, {})",
786            resolved.display(),
787            audio_data.len(),
788            mime_type
789        );
790
791        let text = self.registry.transcribe(&audio_data, mime_type).await?;
792
793        if let Some(ref cache) = self.cache {
794            let _ = cache.put(&resolved, &text);
795        }
796
797        Ok(text)
798    }
799}
800
801fn truncate_on_char_boundary(s: &str, max_bytes: usize) -> &str {
802    &s[..s.floor_char_boundary(max_bytes)]
803}
804
805/// Delegates to [`ssrf::validate_url`] — validates URL for SSRF safety before
806/// any HTTP request is made (scheme, hostname, IP range, DNS pinning).
807async fn validate_web_fetch_url(url: &str) -> Result<reqwest::Url> {
808    ssrf::validate_url(url).await
809}
810
811const MAX_WEB_FETCH_REDIRECTS: usize = 10;
812
813fn should_follow_redirect(status: reqwest::StatusCode) -> bool {
814    matches!(
815        status,
816        reqwest::StatusCode::MOVED_PERMANENTLY
817            | reqwest::StatusCode::FOUND
818            | reqwest::StatusCode::SEE_OTHER
819            | reqwest::StatusCode::TEMPORARY_REDIRECT
820            | reqwest::StatusCode::PERMANENT_REDIRECT
821    )
822}
823
824async fn resolve_and_validate_redirect_target(
825    current: &reqwest::Url,
826    location: &str,
827) -> Result<reqwest::Url> {
828    let candidate = current
829        .join(location)
830        .map_err(|e| anyhow::anyhow!("Invalid redirect target '{}': {}", location, e))?;
831    validate_web_fetch_url(candidate.as_str()).await
832}
833
834fn extract_fallback_text(html: &str) -> String {
835    static SCRIPT_RE: Lazy<Regex> =
836        Lazy::new(|| Regex::new(r"(?is)<script[^>]*>.*?</script>").expect("valid script regex"));
837    static STYLE_RE: Lazy<Regex> =
838        Lazy::new(|| Regex::new(r"(?is)<style[^>]*>.*?</style>").expect("valid style regex"));
839    static TAG_RE: Lazy<Regex> =
840        Lazy::new(|| Regex::new(r"(?is)<[^>]+>").expect("valid tag regex"));
841    static WS_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"\s+").expect("valid whitespace regex"));
842
843    let no_scripts = SCRIPT_RE.replace_all(html, " ");
844    let no_styles = STYLE_RE.replace_all(&no_scripts, " ");
845    let no_tags = TAG_RE.replace_all(&no_styles, " ");
846    WS_RE.replace_all(no_tags.trim(), " ").to_string()
847}
848
849fn extract_readable_text(html: &str, url: &reqwest::Url) -> String {
850    let mut cursor = Cursor::new(html.as_bytes());
851    match extractor::extract(&mut cursor, url) {
852        Ok(product) => {
853            let text = product.text.trim();
854            if text.is_empty() {
855                return extract_fallback_text(html);
856            }
857
858            let title = product.title.trim();
859            if title.is_empty() {
860                text.to_string()
861            } else {
862                format!("# {}\n\n{}", title, text)
863            }
864        }
865        Err(e) => {
866            debug!("Readability extraction failed for {}: {}", url, e);
867            extract_fallback_text(html)
868        }
869    }
870}
871
872// Web Fetch Tool
873pub struct WebFetchTool {
874    client: reqwest::Client,
875    max_bytes: usize,
876    filter: super::tool_filters::CompiledToolFilter,
877}
878
879impl WebFetchTool {
880    pub fn new(max_bytes: usize, filter: super::tool_filters::CompiledToolFilter) -> Result<Self> {
881        let client = reqwest::Client::builder()
882            .redirect(reqwest::redirect::Policy::none())
883            .build()?;
884
885        Ok(Self {
886            client,
887            max_bytes,
888            filter,
889        })
890    }
891
892    async fn fetch_with_validated_redirects(
893        &self,
894        mut current_url: reqwest::Url,
895    ) -> Result<(reqwest::Response, reqwest::Url)> {
896        for redirect_count in 0..=MAX_WEB_FETCH_REDIRECTS {
897            let response = self
898                .client
899                .get(current_url.clone())
900                .header("User-Agent", "LocalGPT/0.1")
901                .send()
902                .await?;
903
904            if !should_follow_redirect(response.status()) {
905                return Ok((response, current_url));
906            }
907
908            if redirect_count == MAX_WEB_FETCH_REDIRECTS {
909                anyhow::bail!(
910                    "Too many redirects (>{}) while fetching {}",
911                    MAX_WEB_FETCH_REDIRECTS,
912                    current_url
913                );
914            }
915
916            let location = response
917                .headers()
918                .get(reqwest::header::LOCATION)
919                .ok_or_else(|| {
920                    anyhow::anyhow!(
921                        "Redirect response {} missing Location header",
922                        response.status()
923                    )
924                })?
925                .to_str()
926                .map_err(|_| anyhow::anyhow!("Redirect Location header is not valid UTF-8"))?;
927
928            let next_url = resolve_and_validate_redirect_target(&current_url, location).await?;
929            debug!(
930                "Following redirect {}: {} -> {}",
931                redirect_count + 1,
932                current_url,
933                next_url
934            );
935            current_url = next_url;
936        }
937
938        unreachable!("redirect loop should return or bail")
939    }
940}
941
942#[async_trait]
943impl Tool for WebFetchTool {
944    fn name(&self) -> &str {
945        "web_fetch"
946    }
947
948    fn schema(&self) -> ToolSchema {
949        ToolSchema {
950            name: "web_fetch".to_string(),
951            description: "Fetch content from a URL".to_string(),
952            parameters: json!({
953                "type": "object",
954                "properties": {
955                    "url": {
956                        "type": "string",
957                        "description": "The URL to fetch"
958                    }
959                },
960                "required": ["url"]
961            }),
962        }
963    }
964
965    async fn execute(&self, arguments: &str) -> Result<String> {
966        let args: Value = serde_json::from_str(arguments)?;
967        let url = args["url"]
968            .as_str()
969            .ok_or_else(|| anyhow::anyhow!("Missing url"))?;
970
971        // Check URL against SSRF deny filters (fast, static patterns)
972        self.filter.check(url, "web_fetch", "url")?;
973
974        let parsed_url = validate_web_fetch_url(url).await?;
975        debug!("Fetching URL: {}", parsed_url);
976
977        let (response, final_url) = self.fetch_with_validated_redirects(parsed_url).await?;
978
979        let status = response.status();
980        let content_type = response
981            .headers()
982            .get(reqwest::header::CONTENT_TYPE)
983            .and_then(|v| v.to_str().ok())
984            .unwrap_or("")
985            .to_string();
986
987        // Limit download size to prevent memory exhaustion from malicious servers.
988        // Allow up to 2x max_bytes raw download since extraction often shrinks content.
989        let download_limit = self.max_bytes * 2;
990
991        // Fast reject via Content-Length header when available
992        if let Some(content_length) = response.content_length()
993            && content_length as usize > download_limit
994        {
995            anyhow::bail!(
996                "Response too large ({} bytes, limit {})",
997                content_length,
998                download_limit
999            );
1000        }
1001
1002        // Stream response body with size cap (handles chunked/missing Content-Length)
1003        let mut body_bytes = Vec::new();
1004        let mut stream = response.bytes_stream();
1005        use futures::StreamExt;
1006        while let Some(chunk) = stream.next().await {
1007            let chunk = chunk?;
1008            body_bytes.extend_from_slice(&chunk);
1009            if body_bytes.len() > download_limit {
1010                anyhow::bail!(
1011                    "Response too large (>{} bytes), download aborted",
1012                    download_limit
1013                );
1014            }
1015        }
1016        let body = String::from_utf8_lossy(&body_bytes).to_string();
1017        let extracted =
1018            if content_type.contains("text/html") || content_type.contains("application/xhtml") {
1019                extract_readable_text(&body, &final_url)
1020            } else {
1021                body
1022            };
1023
1024        // Truncate if too long
1025        let truncated = if extracted.len() > self.max_bytes {
1026            let prefix = truncate_on_char_boundary(&extracted, self.max_bytes);
1027            format!(
1028                "{}...\n\n[Truncated, {} bytes total]",
1029                prefix,
1030                extracted.len()
1031            )
1032        } else {
1033            extracted
1034        };
1035
1036        Ok(format!(
1037            "Status: {}\nURL: {}\nContent-Type: {}\n\n{}",
1038            status, final_url, content_type, truncated
1039        ))
1040    }
1041}
1042
1043/// Extract relevant detail from tool arguments for display.
1044/// Returns a human-readable summary of the key argument (file path, command, query, URL).
1045pub fn extract_tool_detail(tool_name: &str, arguments: &str) -> Option<String> {
1046    let args: Value = serde_json::from_str(arguments).ok()?;
1047
1048    match tool_name {
1049        "edit_file" | "write_file" | "read_file" | "replace" => args
1050            .get("path")
1051            .or_else(|| args.get("file_path"))
1052            .and_then(|v| v.as_str())
1053            .map(|s| s.to_string()),
1054        "bash" | "run_shell_command" => args.get("command").and_then(|v| v.as_str()).map(|s| {
1055            if s.len() > 60 {
1056                format!("{}...", &s[..57])
1057            } else {
1058                s.to_string()
1059            }
1060        }),
1061        "memory_search" => args
1062            .get("query")
1063            .and_then(|v| v.as_str())
1064            .map(|s| format!("\"{}\"", s)),
1065        "web_fetch" => args
1066            .get("url")
1067            .or_else(|| args.get("prompt"))
1068            .and_then(|v| v.as_str())
1069            .map(|s| s.to_string()),
1070        "web_search" | "google_web_search" => args
1071            .get("query")
1072            .and_then(|v| v.as_str())
1073            .map(|s| format!("\"{}\"", s)),
1074        "grep_search" | "glob" => args
1075            .get("pattern")
1076            .and_then(|v| v.as_str())
1077            .map(|s| format!("\"{}\"", s)),
1078        "list_directory" => args
1079            .get("dir_path")
1080            .and_then(|v| v.as_str())
1081            .map(|s| s.to_string()),
1082        "codebase_investigator" => args
1083            .get("objective")
1084            .and_then(|v| v.as_str())
1085            .map(|s| s.to_string()),
1086        "document_load" => args
1087            .get("path")
1088            .and_then(|v| v.as_str())
1089            .map(|s| s.to_string()),
1090        "transcribe_audio" => args
1091            .get("path")
1092            .and_then(|v| v.as_str())
1093            .map(|s| s.to_string()),
1094
1095        // Gen tools - 3D scene manipulation
1096        "gen_spawn_primitive" => {
1097            let name = args.get("name").and_then(|v| v.as_str());
1098            let shape = args.get("shape").and_then(|v| v.as_str()).unwrap_or("?");
1099            name.map(|n| format!("{} ({})", n, shape))
1100        }
1101        "gen_spawn_batch" => args
1102            .get("entities")
1103            .and_then(|v| v.as_array())
1104            .map(|arr| format!("{} entities", arr.len())),
1105        "gen_modify_batch" => args
1106            .get("entities")
1107            .and_then(|v| v.as_array())
1108            .map(|arr| format!("{} entities", arr.len())),
1109        "gen_delete_batch" => args
1110            .get("names")
1111            .and_then(|v| v.as_array())
1112            .map(|arr| format!("{} entities", arr.len())),
1113        "gen_spawn_mesh" => args
1114            .get("name")
1115            .and_then(|v| v.as_str())
1116            .map(|s| s.to_string()),
1117        "gen_modify_entity" => args
1118            .get("name")
1119            .and_then(|v| v.as_str())
1120            .map(|s| s.to_string()),
1121        "gen_delete_entity" => args
1122            .get("name")
1123            .and_then(|v| v.as_str())
1124            .map(|s| s.to_string()),
1125        "gen_entity_info" => args
1126            .get("name")
1127            .and_then(|v| v.as_str())
1128            .map(|s| s.to_string()),
1129        "gen_set_light" => args
1130            .get("name")
1131            .and_then(|v| v.as_str())
1132            .map(|s| s.to_string()),
1133        "gen_load_gltf" => args
1134            .get("path")
1135            .and_then(|v| v.as_str())
1136            .map(|s| s.to_string()),
1137        "gen_export_screenshot" => args
1138            .get("path")
1139            .and_then(|v| v.as_str())
1140            .map(|s| s.to_string()),
1141        "gen_export_gltf" => args
1142            .get("path")
1143            .and_then(|v| v.as_str())
1144            .map(|s| s.to_string()),
1145        "gen_save_world" => args
1146            .get("name")
1147            .and_then(|v| v.as_str())
1148            .map(|s| format!("'{}'", s)),
1149        "gen_load_world" => args
1150            .get("path")
1151            .and_then(|v| v.as_str())
1152            .map(|s| s.to_string()),
1153        "gen_export_world" => args
1154            .get("format")
1155            .and_then(|v| v.as_str())
1156            .map(|f| format!("format: {}", f)),
1157
1158        // Gen tools - audio
1159        "gen_audio_emitter" => args
1160            .get("name")
1161            .and_then(|v| v.as_str())
1162            .map(|s| s.to_string()),
1163        "gen_modify_audio" => args
1164            .get("name")
1165            .and_then(|v| v.as_str())
1166            .map(|s| s.to_string()),
1167
1168        // Gen tools - behaviors
1169        "gen_add_behavior" => {
1170            let entity = args.get("entity").and_then(|v| v.as_str());
1171            let behavior_type = args
1172                .get("behavior")
1173                .and_then(|b| b.get("type"))
1174                .and_then(|v| v.as_str());
1175            match (entity, behavior_type) {
1176                (Some(e), Some(t)) => Some(format!("{} [{}]", e, t)),
1177                (Some(e), None) => Some(e.to_string()),
1178                _ => None,
1179            }
1180        }
1181        "gen_remove_behavior" => args
1182            .get("entity")
1183            .and_then(|v| v.as_str())
1184            .map(|s| s.to_string()),
1185        "gen_list_behaviors" => args
1186            .get("entity")
1187            .and_then(|v| v.as_str())
1188            .map(|s| s.to_string()),
1189
1190        // Gen tools with no meaningful detail
1191        "gen_scene_info"
1192        | "gen_screenshot"
1193        | "gen_set_camera"
1194        | "gen_set_environment"
1195        | "gen_set_ambience"
1196        | "gen_audio_info"
1197        | "gen_pause_behaviors"
1198        | "gen_clear_scene" => None,
1199
1200        _ => None,
1201    }
1202}
1203
1204// ── Wiki Tools ──────────────────────────────────────────────────────────
1205
1206/// wiki_add — Add or update a structured knowledge claim with evidence.
1207pub struct WikiAddTool {
1208    store: Arc<crate::memory::wiki::WikiStore>,
1209}
1210
1211impl WikiAddTool {
1212    pub fn new(store: Arc<crate::memory::wiki::WikiStore>) -> Self {
1213        Self { store }
1214    }
1215}
1216
1217#[async_trait]
1218impl Tool for WikiAddTool {
1219    fn name(&self) -> &str {
1220        "wiki_add"
1221    }
1222
1223    fn schema(&self) -> ToolSchema {
1224        ToolSchema {
1225            name: "wiki_add".to_string(),
1226            description: "Add or update a structured knowledge claim with optional evidence. Deduplicates similar claims automatically.".to_string(),
1227            parameters: json!({
1228                "type": "object",
1229                "properties": {
1230                    "text": {
1231                        "type": "string",
1232                        "description": "The claim text"
1233                    },
1234                    "category": {
1235                        "type": "string",
1236                        "enum": ["fact", "preference", "decision", "question"],
1237                        "description": "Claim category (default: fact)"
1238                    },
1239                    "confidence": {
1240                        "type": "number",
1241                        "description": "Confidence score 0.0-1.0 (default: 0.8)"
1242                    },
1243                    "evidence_source": {
1244                        "type": "string",
1245                        "description": "Source of evidence (file path, URL, session ID)"
1246                    },
1247                    "evidence_excerpt": {
1248                        "type": "string",
1249                        "description": "Relevant text excerpt from the source"
1250                    }
1251                },
1252                "required": ["text"]
1253            }),
1254        }
1255    }
1256
1257    async fn execute(&self, arguments: &str) -> Result<String> {
1258        let args: Value = serde_json::from_str(arguments)?;
1259        let text = args["text"]
1260            .as_str()
1261            .ok_or_else(|| anyhow::anyhow!("Missing text"))?;
1262
1263        let category = args["category"]
1264            .as_str()
1265            .map(crate::memory::wiki::ClaimCategory::parse)
1266            .transpose()?
1267            .unwrap_or(crate::memory::wiki::ClaimCategory::Fact);
1268
1269        let confidence = args["confidence"].as_f64().unwrap_or(0.8) as f32;
1270        let evidence_source = args["evidence_source"].as_str();
1271        let evidence_excerpt = args["evidence_excerpt"].as_str();
1272
1273        let id = self.store.add_claim(
1274            text,
1275            category,
1276            confidence,
1277            evidence_source,
1278            evidence_excerpt,
1279        )?;
1280
1281        Ok(format!("Claim stored (id: {}, category: {})", id, category))
1282    }
1283}
1284
1285/// wiki_search — Search structured knowledge claims.
1286pub struct WikiSearchTool {
1287    store: Arc<crate::memory::wiki::WikiStore>,
1288}
1289
1290impl WikiSearchTool {
1291    pub fn new(store: Arc<crate::memory::wiki::WikiStore>) -> Self {
1292        Self { store }
1293    }
1294}
1295
1296#[async_trait]
1297impl Tool for WikiSearchTool {
1298    fn name(&self) -> &str {
1299        "wiki_search"
1300    }
1301
1302    fn schema(&self) -> ToolSchema {
1303        ToolSchema {
1304            name: "wiki_search".to_string(),
1305            description: "Search structured knowledge claims by text, category, or freshness."
1306                .to_string(),
1307            parameters: json!({
1308                "type": "object",
1309                "properties": {
1310                    "query": {
1311                        "type": "string",
1312                        "description": "Search query"
1313                    },
1314                    "category": {
1315                        "type": "string",
1316                        "enum": ["fact", "preference", "decision", "question"],
1317                        "description": "Filter by category (optional)"
1318                    },
1319                    "include_stale": {
1320                        "type": "boolean",
1321                        "description": "Include stale claims (default: false)"
1322                    },
1323                    "limit": {
1324                        "type": "integer",
1325                        "description": "Maximum results (default: 10)"
1326                    }
1327                },
1328                "required": ["query"]
1329            }),
1330        }
1331    }
1332
1333    async fn execute(&self, arguments: &str) -> Result<String> {
1334        let args: Value = serde_json::from_str(arguments)?;
1335        let query = args["query"]
1336            .as_str()
1337            .ok_or_else(|| anyhow::anyhow!("Missing query"))?;
1338
1339        let category = args["category"]
1340            .as_str()
1341            .map(crate::memory::wiki::ClaimCategory::parse)
1342            .transpose()?;
1343
1344        let include_stale = args["include_stale"].as_bool().unwrap_or(false);
1345        let limit = args["limit"].as_u64().unwrap_or(10) as usize;
1346
1347        let claims = self.store.search(query, category, include_stale, limit)?;
1348
1349        if claims.is_empty() {
1350            return Ok("No claims found".to_string());
1351        }
1352
1353        let formatted: Vec<String> = claims
1354            .iter()
1355            .enumerate()
1356            .map(|(i, c)| {
1357                let freshness = self.store.freshness(c.updated_at);
1358                let evidence_summary = if c.evidence.is_empty() {
1359                    String::new()
1360                } else {
1361                    format!(
1362                        "\n   Evidence ({}):\n{}",
1363                        c.evidence.len(),
1364                        c.evidence
1365                            .iter()
1366                            .take(3)
1367                            .map(|e| format!(
1368                                "   - [{}] {}",
1369                                e.source,
1370                                e.excerpt.chars().take(80).collect::<String>()
1371                            ))
1372                            .collect::<Vec<_>>()
1373                            .join("\n")
1374                    )
1375                };
1376                format!(
1377                    "{}. [{}] ({}, {}, conf: {:.1}) {freshness}\n   {}{}",
1378                    i + 1,
1379                    c.id.chars().take(8).collect::<String>(),
1380                    c.category,
1381                    c.status,
1382                    c.confidence,
1383                    c.text,
1384                    evidence_summary,
1385                    freshness = freshness,
1386                )
1387            })
1388            .collect();
1389
1390        Ok(formatted.join("\n\n"))
1391    }
1392}
1393
1394/// wiki_status — Knowledge base health overview.
1395pub struct WikiStatusTool {
1396    store: Arc<crate::memory::wiki::WikiStore>,
1397}
1398
1399impl WikiStatusTool {
1400    pub fn new(store: Arc<crate::memory::wiki::WikiStore>) -> Self {
1401        Self { store }
1402    }
1403}
1404
1405#[async_trait]
1406impl Tool for WikiStatusTool {
1407    fn name(&self) -> &str {
1408        "wiki_status"
1409    }
1410
1411    fn schema(&self) -> ToolSchema {
1412        ToolSchema {
1413            name: "wiki_status".to_string(),
1414            description: "Get knowledge base health overview: total claims, breakdown by category/status/freshness, top stale claims.".to_string(),
1415            parameters: json!({
1416                "type": "object",
1417                "properties": {},
1418                "required": []
1419            }),
1420        }
1421    }
1422
1423    async fn execute(&self, _arguments: &str) -> Result<String> {
1424        let status = self.store.status()?;
1425
1426        let mut out = format!(
1427            "## Knowledge Base Status\n\nTotal claims: {}\n",
1428            status.total_claims
1429        );
1430
1431        if !status.by_category.is_empty() {
1432            out.push_str("\n**By category:**\n");
1433            for (cat, count) in &status.by_category {
1434                out.push_str(&format!("- {}: {}\n", cat, count));
1435            }
1436        }
1437
1438        if !status.by_status.is_empty() {
1439            out.push_str("\n**By status:**\n");
1440            for (st, count) in &status.by_status {
1441                out.push_str(&format!("- {}: {}\n", st, count));
1442            }
1443        }
1444
1445        out.push_str("\n**By freshness:**\n");
1446        for (freshness, count) in &status.by_freshness {
1447            out.push_str(&format!("- {}: {}\n", freshness, count));
1448        }
1449
1450        if !status.top_stale.is_empty() {
1451            out.push_str("\n**Top stale claims:**\n");
1452            for c in &status.top_stale {
1453                out.push_str(&format!(
1454                    "- [{}] {}\n",
1455                    c.id.chars().take(8).collect::<String>(),
1456                    c.text
1457                ));
1458            }
1459        }
1460
1461        Ok(out)
1462    }
1463}
1464
1465#[cfg(test)]
1466mod tests {
1467    use super::*;
1468
1469    // SSRF unit tests for is_private_ip and is_blocked_hostname are in ssrf.rs.
1470    // These integration tests verify the redirect validation path delegates correctly.
1471
1472    #[test]
1473    fn test_extract_readable_text_removes_html() {
1474        let html = r#"
1475            <html><head><style>.x{display:none}</style></head>
1476            <body><script>alert(1)</script><h1>Title</h1><p>Hello <b>world</b>.</p></body></html>
1477        "#;
1478        let url = reqwest::Url::parse("https://example.com/test").unwrap();
1479        let text = extract_readable_text(html, &url);
1480        assert!(text.contains("Hello world"));
1481        assert!(!text.contains("alert(1)"));
1482    }
1483
1484    #[tokio::test]
1485    async fn test_redirect_target_validation_blocks_private_ip() {
1486        let current = reqwest::Url::parse("https://93.184.216.34/start").unwrap();
1487        let err = resolve_and_validate_redirect_target(&current, "http://127.0.0.1/admin").await;
1488        assert!(err.is_err());
1489        let msg = err.unwrap_err().to_string();
1490        assert!(
1491            msg.contains("private/reserved IP"),
1492            "expected SSRF block message, got: {msg}"
1493        );
1494    }
1495
1496    #[tokio::test]
1497    async fn test_redirect_target_validation_allows_relative_public_ip_target() {
1498        let current = reqwest::Url::parse("https://93.184.216.34/start").unwrap();
1499        let next = resolve_and_validate_redirect_target(&current, "/next")
1500            .await
1501            .unwrap();
1502        assert_eq!(next.as_str(), "https://93.184.216.34/next");
1503    }
1504
1505    #[tokio::test]
1506    async fn test_redirect_target_validation_blocks_non_http_scheme() {
1507        let current = reqwest::Url::parse("https://93.184.216.34/start").unwrap();
1508        let err = resolve_and_validate_redirect_target(&current, "file:///etc/passwd").await;
1509        assert!(err.is_err());
1510        let msg = err.unwrap_err().to_string();
1511        assert!(msg.contains("Only http/https"));
1512    }
1513
1514    #[tokio::test]
1515    async fn test_memory_get_rejects_path_traversal() {
1516        let workspace = std::env::temp_dir().join("localgpt_test_workspace");
1517        let _ = std::fs::create_dir_all(&workspace);
1518        let tool = MemoryGetTool::new(workspace);
1519
1520        // Path with .. should be rejected
1521        let args = r#"{"path": "memory/../../../etc/passwd"}"#;
1522        let result = tool.execute(args).await;
1523        assert!(result.is_err());
1524        let msg = result.unwrap_err().to_string();
1525        assert!(msg.contains("path traversal"));
1526    }
1527
1528    #[tokio::test]
1529    async fn test_memory_get_rejects_null_bytes() {
1530        let workspace = std::env::temp_dir().join("localgpt_test_workspace");
1531        let _ = std::fs::create_dir_all(&workspace);
1532        let tool = MemoryGetTool::new(workspace);
1533
1534        let args = r#"{"path": "memory/\u0000evil.md"}"#;
1535        let result = tool.execute(args).await;
1536        assert!(result.is_err());
1537    }
1538
1539    #[tokio::test]
1540    async fn test_memory_get_caps_lines_parameter() {
1541        let workspace = std::env::temp_dir().join("localgpt_test_mg_lines");
1542        let _ = std::fs::create_dir_all(workspace.join("memory"));
1543        // Create a small test file
1544        std::fs::write(workspace.join("MEMORY.md"), "line1\nline2\nline3\n").unwrap();
1545        let tool = MemoryGetTool::new(workspace.clone());
1546
1547        // Even with a huge lines value, it should be capped and work normally
1548        let args = r#"{"path": "MEMORY.md", "lines": 999999999}"#;
1549        let result = tool.execute(args).await.unwrap();
1550        assert!(result.contains("line1"));
1551        // Cleanup
1552        let _ = std::fs::remove_dir_all(&workspace);
1553    }
1554
1555    // --- DocumentLoadTool tests ---
1556
1557    fn test_tools_config() -> crate::config::ToolsConfig {
1558        crate::config::ToolsConfig::default()
1559    }
1560
1561    #[test]
1562    fn test_document_load_tool_schema() {
1563        let workspace = std::env::temp_dir().join("localgpt_test_doc_schema");
1564        let tool = DocumentLoadTool::new(workspace, &test_tools_config());
1565        assert_eq!(tool.name(), "document_load");
1566        let schema = tool.schema();
1567        assert_eq!(schema.name, "document_load");
1568        let params = &schema.parameters;
1569        assert!(params["properties"]["path"].is_object());
1570        assert_eq!(params["required"][0], "path");
1571    }
1572
1573    #[tokio::test]
1574    async fn test_document_load_rejects_path_traversal() {
1575        let workspace = std::env::temp_dir().join("localgpt_test_doc_traversal");
1576        let _ = std::fs::create_dir_all(&workspace);
1577        let tool = DocumentLoadTool::new(workspace, &test_tools_config());
1578
1579        let args = r#"{"path": "../../../etc/passwd"}"#;
1580        let result = tool.execute(args).await;
1581        assert!(result.is_err());
1582        assert!(result.unwrap_err().to_string().contains("path traversal"));
1583    }
1584
1585    #[tokio::test]
1586    async fn test_document_load_rejects_unsupported_format() {
1587        let workspace = std::env::temp_dir().join("localgpt_test_doc_format");
1588        let _ = std::fs::create_dir_all(&workspace);
1589        std::fs::write(workspace.join("test.xyz"), "content").unwrap();
1590        let tool = DocumentLoadTool::new(workspace.clone(), &test_tools_config());
1591
1592        let args = r#"{"path": "test.xyz"}"#;
1593        let result = tool.execute(args).await;
1594        assert!(result.is_err());
1595        let msg = result.unwrap_err().to_string();
1596        assert!(msg.contains("Unsupported format"));
1597        assert!(msg.contains("pdf"));
1598        let _ = std::fs::remove_dir_all(&workspace);
1599    }
1600
1601    #[tokio::test]
1602    async fn test_document_load_rejects_too_large() {
1603        let workspace = std::env::temp_dir().join("localgpt_test_doc_large");
1604        let _ = std::fs::create_dir_all(&workspace);
1605        std::fs::write(workspace.join("big.pdf"), vec![0u8; 100]).unwrap();
1606
1607        let mut config = test_tools_config();
1608        config.document_max_bytes = 50; // 50 bytes limit
1609        let tool = DocumentLoadTool::new(workspace.clone(), &config);
1610
1611        let args = r#"{"path": "big.pdf"}"#;
1612        let result = tool.execute(args).await;
1613        assert!(result.is_err());
1614        assert!(result.unwrap_err().to_string().contains("too large"));
1615        let _ = std::fs::remove_dir_all(&workspace);
1616    }
1617
1618    #[tokio::test]
1619    async fn test_document_load_file_not_found() {
1620        let workspace = std::env::temp_dir().join("localgpt_test_doc_notfound");
1621        let _ = std::fs::create_dir_all(&workspace);
1622        let tool = DocumentLoadTool::new(workspace, &test_tools_config());
1623
1624        let args = r#"{"path": "nonexistent.pdf"}"#;
1625        let result = tool.execute(args).await;
1626        assert!(result.is_err());
1627        assert!(result.unwrap_err().to_string().contains("not found"));
1628    }
1629
1630    // --- AudioTranscribeTool tests ---
1631
1632    #[test]
1633    fn test_audio_transcribe_tool_schema() {
1634        let workspace = std::env::temp_dir().join("localgpt_test_audio_schema");
1635        let registry = Arc::new(crate::media::SttRegistry::new(
1636            crate::media::SttConfig::default(),
1637        ));
1638        let tool = AudioTranscribeTool::new(registry, workspace, None);
1639        assert_eq!(tool.name(), "transcribe_audio");
1640        let schema = tool.schema();
1641        assert_eq!(schema.name, "transcribe_audio");
1642        let params = &schema.parameters;
1643        assert!(params["properties"]["path"].is_object());
1644        assert!(params["properties"]["language"].is_object());
1645        assert_eq!(params["required"][0], "path");
1646    }
1647
1648    #[tokio::test]
1649    async fn test_audio_transcribe_rejects_path_traversal() {
1650        let workspace = std::env::temp_dir().join("localgpt_test_audio_traversal");
1651        let _ = std::fs::create_dir_all(&workspace);
1652        let registry = Arc::new(crate::media::SttRegistry::new(
1653            crate::media::SttConfig::default(),
1654        ));
1655        let tool = AudioTranscribeTool::new(registry, workspace, None);
1656
1657        let args = r#"{"path": "../../../etc/passwd.mp3"}"#;
1658        let result = tool.execute(args).await;
1659        assert!(result.is_err());
1660        assert!(result.unwrap_err().to_string().contains("path traversal"));
1661    }
1662
1663    #[tokio::test]
1664    async fn test_audio_transcribe_rejects_unsupported_format() {
1665        let workspace = std::env::temp_dir().join("localgpt_test_audio_format");
1666        let _ = std::fs::create_dir_all(&workspace);
1667        std::fs::write(workspace.join("test.txt"), "not audio").unwrap();
1668        let registry = Arc::new(crate::media::SttRegistry::new(
1669            crate::media::SttConfig::default(),
1670        ));
1671        let tool = AudioTranscribeTool::new(registry, workspace.clone(), None);
1672
1673        let args = r#"{"path": "test.txt"}"#;
1674        let result = tool.execute(args).await;
1675        assert!(result.is_err());
1676        assert!(
1677            result
1678                .unwrap_err()
1679                .to_string()
1680                .contains("Unsupported audio")
1681        );
1682        let _ = std::fs::remove_dir_all(&workspace);
1683    }
1684
1685    #[tokio::test]
1686    async fn test_audio_transcribe_file_not_found() {
1687        let workspace = std::env::temp_dir().join("localgpt_test_audio_notfound");
1688        let _ = std::fs::create_dir_all(&workspace);
1689        let registry = Arc::new(crate::media::SttRegistry::new(
1690            crate::media::SttConfig::default(),
1691        ));
1692        let tool = AudioTranscribeTool::new(registry, workspace, None);
1693
1694        let args = r#"{"path": "nonexistent.mp3"}"#;
1695        let result = tool.execute(args).await;
1696        assert!(result.is_err());
1697        assert!(result.unwrap_err().to_string().contains("not found"));
1698    }
1699}