1pub mod spawn_agent;
2pub mod ssrf;
3pub mod web_search;
4
5use anyhow::Result;
6use async_trait::async_trait;
7use once_cell::sync::Lazy;
8use readability::extractor;
9use regex::Regex;
10use serde_json::{Value, json};
11use std::fs;
12use std::io::Cursor;
13use std::path::PathBuf;
14use std::sync::Arc;
15use tracing::debug;
16
17use super::providers::ToolSchema;
18use crate::config::{Config, SearchProviderType};
19use crate::memory::MemoryManager;
20
21use spawn_agent::{SpawnAgentTool, SpawnContext};
22use web_search::{SearchRouter, WebSearchTool};
23
24#[derive(Debug, Clone)]
25pub struct ToolResult {
26 pub call_id: String,
27 pub output: String,
28}
29
30#[derive(
35 Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
36)]
37#[serde(rename_all = "lowercase")]
38pub enum PermissionLevel {
39 Safe = 0,
41 Elevated = 1,
43 Admin = 2,
45}
46
47impl Default for PermissionLevel {
48 fn default() -> Self {
49 Self::Safe
50 }
51}
52
53impl std::fmt::Display for PermissionLevel {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 Self::Safe => f.write_str("safe"),
57 Self::Elevated => f.write_str("elevated"),
58 Self::Admin => f.write_str("admin"),
59 }
60 }
61}
62
63#[async_trait]
64pub trait Tool: Send + Sync {
65 fn name(&self) -> &str;
66 fn schema(&self) -> ToolSchema;
67 async fn execute(&self, arguments: &str) -> Result<String>;
68
69 fn permission_level(&self) -> PermissionLevel {
71 PermissionLevel::Safe
72 }
73
74 fn annotations(&self) -> Option<Value> {
78 None
79 }
80}
81
82pub fn create_safe_tools(
87 config: &Config,
88 memory: Option<Arc<MemoryManager>>,
89) -> Result<Vec<Box<dyn Tool>>> {
90 use super::hardcoded_filters;
91 use super::tool_filters::CompiledToolFilter;
92
93 let workspace = config.workspace_path();
94
95 let memory_search_tool: Box<dyn Tool> = if let Some(ref mem) = memory {
97 Box::new(MemorySearchToolWithIndex::new(Arc::clone(mem)))
98 } else {
99 Box::new(MemorySearchTool::new(workspace.clone()))
100 };
101
102 let web_fetch_filter = config
106 .tools
107 .filters
108 .get("web_fetch")
109 .map(CompiledToolFilter::compile)
110 .unwrap_or_else(|| Ok(CompiledToolFilter::permissive()))?
111 .merge_hardcoded(
112 hardcoded_filters::WEB_FETCH_DENY_SUBSTRINGS,
113 hardcoded_filters::WEB_FETCH_DENY_PATTERNS,
114 )?;
115
116 let mut tools: Vec<Box<dyn Tool>> = vec![
117 memory_search_tool,
118 Box::new(MemoryGetTool::new(workspace.clone())),
119 Box::new(WebFetchTool::new(
120 config.tools.web_fetch_max_bytes,
121 web_fetch_filter,
122 )?),
123 ];
124
125 if let Some(ref ws_config) = config.tools.web_search
127 && !matches!(ws_config.provider, SearchProviderType::None)
128 {
129 match SearchRouter::from_config(ws_config) {
130 Ok(router) => tools.push(Box::new(WebSearchTool::new(Arc::new(router)))),
131 Err(e) => tracing::warn!("Web search init failed: {e}"),
132 }
133 }
134
135 tools.push(Box::new(DocumentLoadTool::new(workspace, &config.tools)));
137
138 if config.memory.wiki_enabled
140 && let Some(ref mem) = memory
141 {
142 match crate::memory::wiki::WikiStore::new(
143 mem.db_path(),
144 config.memory.wiki_fresh_days,
145 config.memory.wiki_stale_days,
146 ) {
147 Ok(store) => {
148 let store = Arc::new(store);
149 tools.push(Box::new(WikiAddTool::new(Arc::clone(&store))));
150 tools.push(Box::new(WikiSearchTool::new(Arc::clone(&store))));
151 tools.push(Box::new(WikiStatusTool::new(store)));
152 }
153 Err(e) => tracing::warn!("Wiki store init failed: {e}"),
154 }
155 }
156
157 if let Some(ref stt_config) = config.tools.stt {
159 let env_vars: std::collections::HashMap<String, String> = std::env::vars().collect();
160 let registry = crate::media::SttRegistry::from_config(stt_config, &env_vars);
161 if registry.has_providers() {
162 let audio_cache = if config.tools.media_cache_enabled {
163 Some(crate::media::cache::MediaCache::new(
164 config.workspace_path().join(".cache").join("media"),
165 config.tools.media_cache_max_mb,
166 ))
167 } else {
168 None
169 };
170 tools.push(Box::new(AudioTranscribeTool::new(
171 Arc::new(registry),
172 config.workspace_path(),
173 audio_cache,
174 )));
175 } else {
176 tracing::debug!("STT configured but no providers available (missing API keys?)");
177 }
178 }
179
180 Ok(tools)
181}
182
183pub fn create_spawn_agent_tool(config: Config, memory: Arc<MemoryManager>) -> Box<dyn Tool> {
195 Box::new(SpawnAgentTool::from_config(config, memory))
196}
197
198pub fn create_spawn_agent_tool_at_depth(
202 config: Config,
203 memory: Arc<MemoryManager>,
204 depth: u8,
205) -> Option<Box<dyn Tool>> {
206 let max_depth = config.agent.max_spawn_depth.unwrap_or(1);
207
208 if depth >= max_depth {
209 return None;
211 }
212
213 let tool = SpawnAgentTool::new(SpawnContext {
214 depth,
215 config,
216 memory,
217 model: None,
218 max_depth,
219 });
220
221 Some(Box::new(tool))
222}
223
224pub struct MemorySearchTool {
226 workspace: PathBuf,
227}
228
229impl MemorySearchTool {
230 pub fn new(workspace: PathBuf) -> Self {
231 Self { workspace }
232 }
233}
234
235#[async_trait]
236impl Tool for MemorySearchTool {
237 fn name(&self) -> &str {
238 "memory_search"
239 }
240
241 fn schema(&self) -> ToolSchema {
242 ToolSchema {
243 name: "memory_search".to_string(),
244 description: "Search the memory index for relevant information".to_string(),
245 parameters: json!({
246 "type": "object",
247 "properties": {
248 "query": {
249 "type": "string",
250 "description": "The search query"
251 },
252 "limit": {
253 "type": "integer",
254 "description": "Maximum number of results (default: 5)"
255 }
256 },
257 "required": ["query"]
258 }),
259 }
260 }
261
262 async fn execute(&self, arguments: &str) -> Result<String> {
263 let args: Value = serde_json::from_str(arguments)?;
264 let query = args["query"]
265 .as_str()
266 .ok_or_else(|| anyhow::anyhow!("Missing query"))?;
267 let limit = args["limit"].as_u64().unwrap_or(5) as usize;
268
269 debug!("Memory search: {} (limit: {})", query, limit);
270
271 let mut results = Vec::new();
274
275 let memory_file = self.workspace.join("MEMORY.md");
276 if memory_file.exists()
277 && let Ok(content) = fs::read_to_string(&memory_file)
278 {
279 for (i, line) in content.lines().enumerate() {
280 if line.to_lowercase().contains(&query.to_lowercase()) {
281 results.push(format!("MEMORY.md:{}: {}", i + 1, line));
282 if results.len() >= limit {
283 break;
284 }
285 }
286 }
287 }
288
289 let memory_dir = self.workspace.join("memory");
291 if memory_dir.exists()
292 && let Ok(entries) = fs::read_dir(&memory_dir)
293 {
294 for entry in entries.filter_map(|e| e.ok()) {
295 if results.len() >= limit {
296 break;
297 }
298
299 let path = entry.path();
300 if path.extension().map(|e| e == "md").unwrap_or(false)
301 && let Ok(content) = fs::read_to_string(&path)
302 {
303 let filename = path.file_name().unwrap().to_string_lossy();
304 for (i, line) in content.lines().enumerate() {
305 if line.to_lowercase().contains(&query.to_lowercase()) {
306 results.push(format!("memory/{}:{}: {}", filename, i + 1, line));
307 if results.len() >= limit {
308 break;
309 }
310 }
311 }
312 }
313 }
314 }
315
316 if results.is_empty() {
317 Ok("No results found".to_string())
318 } else {
319 Ok(results.join("\n"))
320 }
321 }
322}
323
324pub struct MemorySearchToolWithIndex {
326 memory: Arc<MemoryManager>,
327}
328
329impl MemorySearchToolWithIndex {
330 pub fn new(memory: Arc<MemoryManager>) -> Self {
331 Self { memory }
332 }
333}
334
335#[async_trait]
336impl Tool for MemorySearchToolWithIndex {
337 fn name(&self) -> &str {
338 "memory_search"
339 }
340
341 fn schema(&self) -> ToolSchema {
342 let description = if self.memory.has_embeddings() {
343 "Search the memory index using hybrid semantic + keyword search for relevant information"
344 } else {
345 "Search the memory index for relevant information"
346 };
347
348 ToolSchema {
349 name: "memory_search".to_string(),
350 description: description.to_string(),
351 parameters: json!({
352 "type": "object",
353 "properties": {
354 "query": {
355 "type": "string",
356 "description": "The search query"
357 },
358 "limit": {
359 "type": "integer",
360 "description": "Maximum number of results (default: 5)"
361 }
362 },
363 "required": ["query"]
364 }),
365 }
366 }
367
368 async fn execute(&self, arguments: &str) -> Result<String> {
369 let args: Value = serde_json::from_str(arguments)?;
370 let query = args["query"]
371 .as_str()
372 .ok_or_else(|| anyhow::anyhow!("Missing query"))?;
373 let limit = args["limit"].as_u64().unwrap_or(5) as usize;
374
375 let search_type = if self.memory.has_embeddings() {
376 "hybrid"
377 } else {
378 "FTS"
379 };
380 debug!(
381 "Memory search ({}): {} (limit: {})",
382 search_type, query, limit
383 );
384
385 let results = self.memory.search(query, limit)?;
386
387 if results.is_empty() {
388 return Ok("No results found".to_string());
389 }
390
391 let formatted: Vec<String> = results
393 .iter()
394 .enumerate()
395 .map(|(i, chunk)| {
396 let preview: String = chunk.content.chars().take(200).collect();
397 let preview = preview.replace('\n', " ");
398 format!(
399 "{}. [{}:{}-{}] (score: {:.3})\n {}{}",
400 i + 1,
401 chunk.file,
402 chunk.line_start,
403 chunk.line_end,
404 chunk.score,
405 preview,
406 if chunk.content.len() > 200 { "..." } else { "" }
407 )
408 })
409 .collect();
410
411 Ok(formatted.join("\n\n"))
412 }
413}
414
415pub struct MemoryGetTool {
417 workspace: PathBuf,
418}
419
420impl MemoryGetTool {
421 pub fn new(workspace: PathBuf) -> Self {
422 Self { workspace }
423 }
424
425 fn resolve_path(&self, path: &str) -> PathBuf {
426 if path.starts_with("memory/") || path == "MEMORY.md" || path == "HEARTBEAT.md" {
428 self.workspace.join(path)
429 } else {
430 PathBuf::from(shellexpand::tilde(path).to_string())
431 }
432 }
433
434 fn is_within_workspace(&self, resolved: &std::path::Path) -> bool {
437 let workspace_canonical = match self.workspace.canonicalize() {
438 Ok(p) => p,
439 Err(_) => return false,
440 };
441 if let Ok(canonical) = resolved.canonicalize() {
443 return canonical.starts_with(&workspace_canonical);
444 }
445 if let Some(parent) = resolved.parent()
447 && let Ok(parent_canonical) = parent.canonicalize()
448 {
449 return parent_canonical.starts_with(&workspace_canonical);
450 }
451 false
452 }
453}
454
455#[async_trait]
456impl Tool for MemoryGetTool {
457 fn name(&self) -> &str {
458 "memory_get"
459 }
460
461 fn schema(&self) -> ToolSchema {
462 ToolSchema {
463 name: "memory_get".to_string(),
464 description: "Safe snippet read from MEMORY.md or memory/*.md with optional line range; use after memory_search to pull only the needed lines and keep context small.".to_string(),
465 parameters: json!({
466 "type": "object",
467 "properties": {
468 "path": {
469 "type": "string",
470 "description": "Path to the file (e.g., 'MEMORY.md' or 'memory/2024-01-15.md')"
471 },
472 "from": {
473 "type": "integer",
474 "description": "Starting line number (1-indexed, default: 1)"
475 },
476 "lines": {
477 "type": "integer",
478 "description": "Number of lines to read (default: 50)"
479 }
480 },
481 "required": ["path"]
482 }),
483 }
484 }
485
486 async fn execute(&self, arguments: &str) -> Result<String> {
487 let args: Value = serde_json::from_str(arguments)?;
488 let path = args["path"]
489 .as_str()
490 .ok_or_else(|| anyhow::anyhow!("Missing path"))?;
491
492 if path.contains('\0') {
494 anyhow::bail!("Invalid path: null bytes not allowed");
495 }
496
497 let from = args["from"].as_u64().unwrap_or(1).max(1) as usize;
498 let lines_count = (args["lines"].as_u64().unwrap_or(50) as usize).min(10_000);
499
500 let resolved_path = self.resolve_path(path);
501
502 if resolved_path
504 .components()
505 .any(|c| matches!(c, std::path::Component::ParentDir))
506 {
507 anyhow::bail!("Invalid path: path traversal not allowed");
508 }
509
510 if !self.is_within_workspace(&resolved_path) {
512 anyhow::bail!("Access denied: path is outside workspace");
513 }
514
515 debug!(
516 "Memory get: {} (from: {}, lines: {})",
517 resolved_path.display(),
518 from,
519 lines_count
520 );
521
522 if !resolved_path.exists() {
523 return Ok(format!("File not found: {}", path));
524 }
525
526 let content = fs::read_to_string(&resolved_path)?;
527 let lines: Vec<&str> = content.lines().collect();
528 let total_lines = lines.len();
529
530 let start = (from - 1).min(total_lines);
532 let end = (start + lines_count).min(total_lines);
533
534 if start >= total_lines {
535 return Ok(format!(
536 "Line {} is past end of file ({} lines)",
537 from, total_lines
538 ));
539 }
540
541 let selected: Vec<String> = lines[start..end]
542 .iter()
543 .enumerate()
544 .map(|(i, line)| format!("{:4}\t{}", start + i + 1, line))
545 .collect();
546
547 let header = format!(
548 "# {} (lines {}-{} of {})\n",
549 path,
550 start + 1,
551 end,
552 total_lines
553 );
554 Ok(header + &selected.join("\n"))
555 }
556}
557
558pub struct DocumentLoadTool {
560 loaders: crate::media::DocumentLoaders,
561 workspace: PathBuf,
562 max_bytes: usize,
563 output_max_chars: usize,
564 cache: Option<crate::media::cache::MediaCache>,
565}
566
567impl DocumentLoadTool {
568 pub fn new(workspace: PathBuf, config: &crate::config::ToolsConfig) -> Self {
569 let loaders = match config.document_loaders {
570 Some(ref custom) => crate::media::DocumentLoaders::with_custom(custom),
571 None => crate::media::DocumentLoaders::new(),
572 };
573 let cache = if config.media_cache_enabled {
574 Some(crate::media::cache::MediaCache::new(
575 workspace.join(".cache").join("media"),
576 config.media_cache_max_mb,
577 ))
578 } else {
579 None
580 };
581 Self {
582 loaders,
583 workspace,
584 max_bytes: config.document_max_bytes,
585 output_max_chars: config.tool_output_max_chars,
586 cache,
587 }
588 }
589
590 fn validate_path(&self, path_str: &str) -> Result<PathBuf> {
591 if path_str.contains('\0') {
592 anyhow::bail!("Invalid path: null bytes not allowed");
593 }
594 let expanded = shellexpand::tilde(path_str).to_string();
595 let resolved = if std::path::Path::new(&expanded).is_absolute() {
596 PathBuf::from(expanded)
597 } else {
598 self.workspace.join(expanded)
599 };
600 if resolved
601 .components()
602 .any(|c| matches!(c, std::path::Component::ParentDir))
603 {
604 anyhow::bail!("Invalid path: path traversal not allowed");
605 }
606 Ok(resolved)
607 }
608}
609
610#[async_trait]
611impl Tool for DocumentLoadTool {
612 fn name(&self) -> &str {
613 "document_load"
614 }
615
616 fn schema(&self) -> ToolSchema {
617 ToolSchema {
618 name: "document_load".to_string(),
619 description: "Extract text content from PDF, DOCX, EPUB, or HTML documents. Returns the document text.".to_string(),
620 parameters: json!({
621 "type": "object",
622 "properties": {
623 "path": {
624 "type": "string",
625 "description": "Path to the document file (relative to workspace or absolute)"
626 }
627 },
628 "required": ["path"]
629 }),
630 }
631 }
632
633 async fn execute(&self, arguments: &str) -> Result<String> {
634 let args: Value = serde_json::from_str(arguments)?;
635 let path_str = args["path"]
636 .as_str()
637 .ok_or_else(|| anyhow::anyhow!("Missing path"))?;
638
639 let resolved = self.validate_path(path_str)?;
640
641 if !resolved.exists() {
642 anyhow::bail!("File not found: {}", path_str);
643 }
644
645 let metadata = fs::metadata(&resolved)?;
646 if metadata.len() as usize > self.max_bytes {
647 anyhow::bail!(
648 "File too large: {} bytes (max: {} bytes / {}MB)",
649 metadata.len(),
650 self.max_bytes,
651 self.max_bytes / 1_048_576
652 );
653 }
654
655 let ext = resolved.extension().and_then(|e| e.to_str()).unwrap_or("");
656 if !self.loaders.has_loader(ext) {
657 let supported = self.loaders.supported_extensions().join(", ");
658 anyhow::bail!("Unsupported format: .{}. Supported: {}", ext, supported);
659 }
660
661 if let Some(ref cache) = self.cache
663 && let Some(cached) = cache.get(&resolved)
664 {
665 return Ok(cached);
666 }
667
668 debug!("Loading document: {} ({})", resolved.display(), ext);
669 let text = self.loaders.extract_text(&resolved)?;
670
671 if let Some(ref cache) = self.cache {
672 let _ = cache.put(&resolved, &text);
673 }
674
675 if self.output_max_chars > 0 && text.len() > self.output_max_chars {
676 let truncated = truncate_on_char_boundary(&text, self.output_max_chars);
677 Ok(format!(
678 "{}\n\n[Truncated, {} chars total]",
679 truncated,
680 text.len()
681 ))
682 } else {
683 Ok(text)
684 }
685 }
686}
687
688pub struct AudioTranscribeTool {
690 registry: Arc<crate::media::SttRegistry>,
691 workspace: PathBuf,
692 cache: Option<crate::media::cache::MediaCache>,
693}
694
695impl AudioTranscribeTool {
696 pub fn new(
697 registry: Arc<crate::media::SttRegistry>,
698 workspace: PathBuf,
699 cache: Option<crate::media::cache::MediaCache>,
700 ) -> Self {
701 Self {
702 registry,
703 workspace,
704 cache,
705 }
706 }
707
708 fn validate_path(&self, path_str: &str) -> Result<PathBuf> {
709 if path_str.contains('\0') {
710 anyhow::bail!("Invalid path: null bytes not allowed");
711 }
712 let expanded = shellexpand::tilde(path_str).to_string();
713 let resolved = if std::path::Path::new(&expanded).is_absolute() {
714 PathBuf::from(expanded)
715 } else {
716 self.workspace.join(expanded)
717 };
718 if resolved
719 .components()
720 .any(|c| matches!(c, std::path::Component::ParentDir))
721 {
722 anyhow::bail!("Invalid path: path traversal not allowed");
723 }
724 Ok(resolved)
725 }
726}
727
728#[async_trait]
729impl Tool for AudioTranscribeTool {
730 fn name(&self) -> &str {
731 "transcribe_audio"
732 }
733
734 fn schema(&self) -> ToolSchema {
735 ToolSchema {
736 name: "transcribe_audio".to_string(),
737 description: "Transcribe audio files (MP3, M4A, WAV, OGG, FLAC, WEBM) to text using speech-to-text.".to_string(),
738 parameters: json!({
739 "type": "object",
740 "properties": {
741 "path": {
742 "type": "string",
743 "description": "Path to the audio file"
744 },
745 "language": {
746 "type": "string",
747 "description": "Language hint (ISO 639-1, e.g., 'en', 'zh', 'ja'). Default: 'en'"
748 }
749 },
750 "required": ["path"]
751 }),
752 }
753 }
754
755 async fn execute(&self, arguments: &str) -> Result<String> {
756 let args: Value = serde_json::from_str(arguments)?;
757 let path_str = args["path"]
758 .as_str()
759 .ok_or_else(|| anyhow::anyhow!("Missing path"))?;
760
761 let resolved = self.validate_path(path_str)?;
762
763 if !resolved.exists() {
764 anyhow::bail!("File not found: {}", path_str);
765 }
766
767 let mime_type = crate::media::audio::mime_type_from_path(&resolved);
768 if mime_type == "audio/octet-stream" {
769 let ext = resolved.extension().and_then(|e| e.to_str()).unwrap_or("?");
770 anyhow::bail!(
771 "Unsupported audio format: .{}. Supported: ogg, opus, mp3, m4a, wav, webm, flac",
772 ext
773 );
774 }
775
776 if let Some(ref cache) = self.cache
778 && let Some(cached) = cache.get(&resolved)
779 {
780 return Ok(cached);
781 }
782
783 let audio_data = fs::read(&resolved)?;
784 debug!(
785 "Transcribing audio: {} ({} bytes, {})",
786 resolved.display(),
787 audio_data.len(),
788 mime_type
789 );
790
791 let text = self.registry.transcribe(&audio_data, mime_type).await?;
792
793 if let Some(ref cache) = self.cache {
794 let _ = cache.put(&resolved, &text);
795 }
796
797 Ok(text)
798 }
799}
800
801fn truncate_on_char_boundary(s: &str, max_bytes: usize) -> &str {
802 &s[..s.floor_char_boundary(max_bytes)]
803}
804
805async fn validate_web_fetch_url(url: &str) -> Result<reqwest::Url> {
808 ssrf::validate_url(url).await
809}
810
811const MAX_WEB_FETCH_REDIRECTS: usize = 10;
812
813fn should_follow_redirect(status: reqwest::StatusCode) -> bool {
814 matches!(
815 status,
816 reqwest::StatusCode::MOVED_PERMANENTLY
817 | reqwest::StatusCode::FOUND
818 | reqwest::StatusCode::SEE_OTHER
819 | reqwest::StatusCode::TEMPORARY_REDIRECT
820 | reqwest::StatusCode::PERMANENT_REDIRECT
821 )
822}
823
824async fn resolve_and_validate_redirect_target(
825 current: &reqwest::Url,
826 location: &str,
827) -> Result<reqwest::Url> {
828 let candidate = current
829 .join(location)
830 .map_err(|e| anyhow::anyhow!("Invalid redirect target '{}': {}", location, e))?;
831 validate_web_fetch_url(candidate.as_str()).await
832}
833
834fn extract_fallback_text(html: &str) -> String {
835 static SCRIPT_RE: Lazy<Regex> =
836 Lazy::new(|| Regex::new(r"(?is)<script[^>]*>.*?</script>").expect("valid script regex"));
837 static STYLE_RE: Lazy<Regex> =
838 Lazy::new(|| Regex::new(r"(?is)<style[^>]*>.*?</style>").expect("valid style regex"));
839 static TAG_RE: Lazy<Regex> =
840 Lazy::new(|| Regex::new(r"(?is)<[^>]+>").expect("valid tag regex"));
841 static WS_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"\s+").expect("valid whitespace regex"));
842
843 let no_scripts = SCRIPT_RE.replace_all(html, " ");
844 let no_styles = STYLE_RE.replace_all(&no_scripts, " ");
845 let no_tags = TAG_RE.replace_all(&no_styles, " ");
846 WS_RE.replace_all(no_tags.trim(), " ").to_string()
847}
848
849fn extract_readable_text(html: &str, url: &reqwest::Url) -> String {
850 let mut cursor = Cursor::new(html.as_bytes());
851 match extractor::extract(&mut cursor, url) {
852 Ok(product) => {
853 let text = product.text.trim();
854 if text.is_empty() {
855 return extract_fallback_text(html);
856 }
857
858 let title = product.title.trim();
859 if title.is_empty() {
860 text.to_string()
861 } else {
862 format!("# {}\n\n{}", title, text)
863 }
864 }
865 Err(e) => {
866 debug!("Readability extraction failed for {}: {}", url, e);
867 extract_fallback_text(html)
868 }
869 }
870}
871
872pub struct WebFetchTool {
874 client: reqwest::Client,
875 max_bytes: usize,
876 filter: super::tool_filters::CompiledToolFilter,
877}
878
879impl WebFetchTool {
880 pub fn new(max_bytes: usize, filter: super::tool_filters::CompiledToolFilter) -> Result<Self> {
881 let client = reqwest::Client::builder()
882 .redirect(reqwest::redirect::Policy::none())
883 .build()?;
884
885 Ok(Self {
886 client,
887 max_bytes,
888 filter,
889 })
890 }
891
892 async fn fetch_with_validated_redirects(
893 &self,
894 mut current_url: reqwest::Url,
895 ) -> Result<(reqwest::Response, reqwest::Url)> {
896 for redirect_count in 0..=MAX_WEB_FETCH_REDIRECTS {
897 let response = self
898 .client
899 .get(current_url.clone())
900 .header("User-Agent", "LocalGPT/0.1")
901 .send()
902 .await?;
903
904 if !should_follow_redirect(response.status()) {
905 return Ok((response, current_url));
906 }
907
908 if redirect_count == MAX_WEB_FETCH_REDIRECTS {
909 anyhow::bail!(
910 "Too many redirects (>{}) while fetching {}",
911 MAX_WEB_FETCH_REDIRECTS,
912 current_url
913 );
914 }
915
916 let location = response
917 .headers()
918 .get(reqwest::header::LOCATION)
919 .ok_or_else(|| {
920 anyhow::anyhow!(
921 "Redirect response {} missing Location header",
922 response.status()
923 )
924 })?
925 .to_str()
926 .map_err(|_| anyhow::anyhow!("Redirect Location header is not valid UTF-8"))?;
927
928 let next_url = resolve_and_validate_redirect_target(¤t_url, location).await?;
929 debug!(
930 "Following redirect {}: {} -> {}",
931 redirect_count + 1,
932 current_url,
933 next_url
934 );
935 current_url = next_url;
936 }
937
938 unreachable!("redirect loop should return or bail")
939 }
940}
941
942#[async_trait]
943impl Tool for WebFetchTool {
944 fn name(&self) -> &str {
945 "web_fetch"
946 }
947
948 fn schema(&self) -> ToolSchema {
949 ToolSchema {
950 name: "web_fetch".to_string(),
951 description: "Fetch content from a URL".to_string(),
952 parameters: json!({
953 "type": "object",
954 "properties": {
955 "url": {
956 "type": "string",
957 "description": "The URL to fetch"
958 }
959 },
960 "required": ["url"]
961 }),
962 }
963 }
964
965 async fn execute(&self, arguments: &str) -> Result<String> {
966 let args: Value = serde_json::from_str(arguments)?;
967 let url = args["url"]
968 .as_str()
969 .ok_or_else(|| anyhow::anyhow!("Missing url"))?;
970
971 self.filter.check(url, "web_fetch", "url")?;
973
974 let parsed_url = validate_web_fetch_url(url).await?;
975 debug!("Fetching URL: {}", parsed_url);
976
977 let (response, final_url) = self.fetch_with_validated_redirects(parsed_url).await?;
978
979 let status = response.status();
980 let content_type = response
981 .headers()
982 .get(reqwest::header::CONTENT_TYPE)
983 .and_then(|v| v.to_str().ok())
984 .unwrap_or("")
985 .to_string();
986
987 let download_limit = self.max_bytes * 2;
990
991 if let Some(content_length) = response.content_length()
993 && content_length as usize > download_limit
994 {
995 anyhow::bail!(
996 "Response too large ({} bytes, limit {})",
997 content_length,
998 download_limit
999 );
1000 }
1001
1002 let mut body_bytes = Vec::new();
1004 let mut stream = response.bytes_stream();
1005 use futures::StreamExt;
1006 while let Some(chunk) = stream.next().await {
1007 let chunk = chunk?;
1008 body_bytes.extend_from_slice(&chunk);
1009 if body_bytes.len() > download_limit {
1010 anyhow::bail!(
1011 "Response too large (>{} bytes), download aborted",
1012 download_limit
1013 );
1014 }
1015 }
1016 let body = String::from_utf8_lossy(&body_bytes).to_string();
1017 let extracted =
1018 if content_type.contains("text/html") || content_type.contains("application/xhtml") {
1019 extract_readable_text(&body, &final_url)
1020 } else {
1021 body
1022 };
1023
1024 let truncated = if extracted.len() > self.max_bytes {
1026 let prefix = truncate_on_char_boundary(&extracted, self.max_bytes);
1027 format!(
1028 "{}...\n\n[Truncated, {} bytes total]",
1029 prefix,
1030 extracted.len()
1031 )
1032 } else {
1033 extracted
1034 };
1035
1036 Ok(format!(
1037 "Status: {}\nURL: {}\nContent-Type: {}\n\n{}",
1038 status, final_url, content_type, truncated
1039 ))
1040 }
1041}
1042
1043pub fn extract_tool_detail(tool_name: &str, arguments: &str) -> Option<String> {
1046 let args: Value = serde_json::from_str(arguments).ok()?;
1047
1048 match tool_name {
1049 "edit_file" | "write_file" | "read_file" | "replace" => args
1050 .get("path")
1051 .or_else(|| args.get("file_path"))
1052 .and_then(|v| v.as_str())
1053 .map(|s| s.to_string()),
1054 "bash" | "run_shell_command" => args.get("command").and_then(|v| v.as_str()).map(|s| {
1055 if s.len() > 60 {
1056 format!("{}...", &s[..57])
1057 } else {
1058 s.to_string()
1059 }
1060 }),
1061 "memory_search" => args
1062 .get("query")
1063 .and_then(|v| v.as_str())
1064 .map(|s| format!("\"{}\"", s)),
1065 "web_fetch" => args
1066 .get("url")
1067 .or_else(|| args.get("prompt"))
1068 .and_then(|v| v.as_str())
1069 .map(|s| s.to_string()),
1070 "web_search" | "google_web_search" => args
1071 .get("query")
1072 .and_then(|v| v.as_str())
1073 .map(|s| format!("\"{}\"", s)),
1074 "grep_search" | "glob" => args
1075 .get("pattern")
1076 .and_then(|v| v.as_str())
1077 .map(|s| format!("\"{}\"", s)),
1078 "list_directory" => args
1079 .get("dir_path")
1080 .and_then(|v| v.as_str())
1081 .map(|s| s.to_string()),
1082 "codebase_investigator" => args
1083 .get("objective")
1084 .and_then(|v| v.as_str())
1085 .map(|s| s.to_string()),
1086 "document_load" => args
1087 .get("path")
1088 .and_then(|v| v.as_str())
1089 .map(|s| s.to_string()),
1090 "transcribe_audio" => args
1091 .get("path")
1092 .and_then(|v| v.as_str())
1093 .map(|s| s.to_string()),
1094
1095 "gen_spawn_primitive" => {
1097 let name = args.get("name").and_then(|v| v.as_str());
1098 let shape = args.get("shape").and_then(|v| v.as_str()).unwrap_or("?");
1099 name.map(|n| format!("{} ({})", n, shape))
1100 }
1101 "gen_spawn_batch" => args
1102 .get("entities")
1103 .and_then(|v| v.as_array())
1104 .map(|arr| format!("{} entities", arr.len())),
1105 "gen_modify_batch" => args
1106 .get("entities")
1107 .and_then(|v| v.as_array())
1108 .map(|arr| format!("{} entities", arr.len())),
1109 "gen_delete_batch" => args
1110 .get("names")
1111 .and_then(|v| v.as_array())
1112 .map(|arr| format!("{} entities", arr.len())),
1113 "gen_spawn_mesh" => args
1114 .get("name")
1115 .and_then(|v| v.as_str())
1116 .map(|s| s.to_string()),
1117 "gen_modify_entity" => args
1118 .get("name")
1119 .and_then(|v| v.as_str())
1120 .map(|s| s.to_string()),
1121 "gen_delete_entity" => args
1122 .get("name")
1123 .and_then(|v| v.as_str())
1124 .map(|s| s.to_string()),
1125 "gen_entity_info" => args
1126 .get("name")
1127 .and_then(|v| v.as_str())
1128 .map(|s| s.to_string()),
1129 "gen_set_light" => args
1130 .get("name")
1131 .and_then(|v| v.as_str())
1132 .map(|s| s.to_string()),
1133 "gen_load_gltf" => args
1134 .get("path")
1135 .and_then(|v| v.as_str())
1136 .map(|s| s.to_string()),
1137 "gen_export_screenshot" => args
1138 .get("path")
1139 .and_then(|v| v.as_str())
1140 .map(|s| s.to_string()),
1141 "gen_export_gltf" => args
1142 .get("path")
1143 .and_then(|v| v.as_str())
1144 .map(|s| s.to_string()),
1145 "gen_save_world" => args
1146 .get("name")
1147 .and_then(|v| v.as_str())
1148 .map(|s| format!("'{}'", s)),
1149 "gen_load_world" => args
1150 .get("path")
1151 .and_then(|v| v.as_str())
1152 .map(|s| s.to_string()),
1153 "gen_export_world" => args
1154 .get("format")
1155 .and_then(|v| v.as_str())
1156 .map(|f| format!("format: {}", f)),
1157
1158 "gen_audio_emitter" => args
1160 .get("name")
1161 .and_then(|v| v.as_str())
1162 .map(|s| s.to_string()),
1163 "gen_modify_audio" => args
1164 .get("name")
1165 .and_then(|v| v.as_str())
1166 .map(|s| s.to_string()),
1167
1168 "gen_add_behavior" => {
1170 let entity = args.get("entity").and_then(|v| v.as_str());
1171 let behavior_type = args
1172 .get("behavior")
1173 .and_then(|b| b.get("type"))
1174 .and_then(|v| v.as_str());
1175 match (entity, behavior_type) {
1176 (Some(e), Some(t)) => Some(format!("{} [{}]", e, t)),
1177 (Some(e), None) => Some(e.to_string()),
1178 _ => None,
1179 }
1180 }
1181 "gen_remove_behavior" => args
1182 .get("entity")
1183 .and_then(|v| v.as_str())
1184 .map(|s| s.to_string()),
1185 "gen_list_behaviors" => args
1186 .get("entity")
1187 .and_then(|v| v.as_str())
1188 .map(|s| s.to_string()),
1189
1190 "gen_scene_info"
1192 | "gen_screenshot"
1193 | "gen_set_camera"
1194 | "gen_set_environment"
1195 | "gen_set_ambience"
1196 | "gen_audio_info"
1197 | "gen_pause_behaviors"
1198 | "gen_clear_scene" => None,
1199
1200 _ => None,
1201 }
1202}
1203
1204pub struct WikiAddTool {
1208 store: Arc<crate::memory::wiki::WikiStore>,
1209}
1210
1211impl WikiAddTool {
1212 pub fn new(store: Arc<crate::memory::wiki::WikiStore>) -> Self {
1213 Self { store }
1214 }
1215}
1216
1217#[async_trait]
1218impl Tool for WikiAddTool {
1219 fn name(&self) -> &str {
1220 "wiki_add"
1221 }
1222
1223 fn schema(&self) -> ToolSchema {
1224 ToolSchema {
1225 name: "wiki_add".to_string(),
1226 description: "Add or update a structured knowledge claim with optional evidence. Deduplicates similar claims automatically.".to_string(),
1227 parameters: json!({
1228 "type": "object",
1229 "properties": {
1230 "text": {
1231 "type": "string",
1232 "description": "The claim text"
1233 },
1234 "category": {
1235 "type": "string",
1236 "enum": ["fact", "preference", "decision", "question"],
1237 "description": "Claim category (default: fact)"
1238 },
1239 "confidence": {
1240 "type": "number",
1241 "description": "Confidence score 0.0-1.0 (default: 0.8)"
1242 },
1243 "evidence_source": {
1244 "type": "string",
1245 "description": "Source of evidence (file path, URL, session ID)"
1246 },
1247 "evidence_excerpt": {
1248 "type": "string",
1249 "description": "Relevant text excerpt from the source"
1250 }
1251 },
1252 "required": ["text"]
1253 }),
1254 }
1255 }
1256
1257 async fn execute(&self, arguments: &str) -> Result<String> {
1258 let args: Value = serde_json::from_str(arguments)?;
1259 let text = args["text"]
1260 .as_str()
1261 .ok_or_else(|| anyhow::anyhow!("Missing text"))?;
1262
1263 let category = args["category"]
1264 .as_str()
1265 .map(crate::memory::wiki::ClaimCategory::parse)
1266 .transpose()?
1267 .unwrap_or(crate::memory::wiki::ClaimCategory::Fact);
1268
1269 let confidence = args["confidence"].as_f64().unwrap_or(0.8) as f32;
1270 let evidence_source = args["evidence_source"].as_str();
1271 let evidence_excerpt = args["evidence_excerpt"].as_str();
1272
1273 let id = self.store.add_claim(
1274 text,
1275 category,
1276 confidence,
1277 evidence_source,
1278 evidence_excerpt,
1279 )?;
1280
1281 Ok(format!("Claim stored (id: {}, category: {})", id, category))
1282 }
1283}
1284
1285pub struct WikiSearchTool {
1287 store: Arc<crate::memory::wiki::WikiStore>,
1288}
1289
1290impl WikiSearchTool {
1291 pub fn new(store: Arc<crate::memory::wiki::WikiStore>) -> Self {
1292 Self { store }
1293 }
1294}
1295
1296#[async_trait]
1297impl Tool for WikiSearchTool {
1298 fn name(&self) -> &str {
1299 "wiki_search"
1300 }
1301
1302 fn schema(&self) -> ToolSchema {
1303 ToolSchema {
1304 name: "wiki_search".to_string(),
1305 description: "Search structured knowledge claims by text, category, or freshness."
1306 .to_string(),
1307 parameters: json!({
1308 "type": "object",
1309 "properties": {
1310 "query": {
1311 "type": "string",
1312 "description": "Search query"
1313 },
1314 "category": {
1315 "type": "string",
1316 "enum": ["fact", "preference", "decision", "question"],
1317 "description": "Filter by category (optional)"
1318 },
1319 "include_stale": {
1320 "type": "boolean",
1321 "description": "Include stale claims (default: false)"
1322 },
1323 "limit": {
1324 "type": "integer",
1325 "description": "Maximum results (default: 10)"
1326 }
1327 },
1328 "required": ["query"]
1329 }),
1330 }
1331 }
1332
1333 async fn execute(&self, arguments: &str) -> Result<String> {
1334 let args: Value = serde_json::from_str(arguments)?;
1335 let query = args["query"]
1336 .as_str()
1337 .ok_or_else(|| anyhow::anyhow!("Missing query"))?;
1338
1339 let category = args["category"]
1340 .as_str()
1341 .map(crate::memory::wiki::ClaimCategory::parse)
1342 .transpose()?;
1343
1344 let include_stale = args["include_stale"].as_bool().unwrap_or(false);
1345 let limit = args["limit"].as_u64().unwrap_or(10) as usize;
1346
1347 let claims = self.store.search(query, category, include_stale, limit)?;
1348
1349 if claims.is_empty() {
1350 return Ok("No claims found".to_string());
1351 }
1352
1353 let formatted: Vec<String> = claims
1354 .iter()
1355 .enumerate()
1356 .map(|(i, c)| {
1357 let freshness = self.store.freshness(c.updated_at);
1358 let evidence_summary = if c.evidence.is_empty() {
1359 String::new()
1360 } else {
1361 format!(
1362 "\n Evidence ({}):\n{}",
1363 c.evidence.len(),
1364 c.evidence
1365 .iter()
1366 .take(3)
1367 .map(|e| format!(
1368 " - [{}] {}",
1369 e.source,
1370 e.excerpt.chars().take(80).collect::<String>()
1371 ))
1372 .collect::<Vec<_>>()
1373 .join("\n")
1374 )
1375 };
1376 format!(
1377 "{}. [{}] ({}, {}, conf: {:.1}) {freshness}\n {}{}",
1378 i + 1,
1379 c.id.chars().take(8).collect::<String>(),
1380 c.category,
1381 c.status,
1382 c.confidence,
1383 c.text,
1384 evidence_summary,
1385 freshness = freshness,
1386 )
1387 })
1388 .collect();
1389
1390 Ok(formatted.join("\n\n"))
1391 }
1392}
1393
1394pub struct WikiStatusTool {
1396 store: Arc<crate::memory::wiki::WikiStore>,
1397}
1398
1399impl WikiStatusTool {
1400 pub fn new(store: Arc<crate::memory::wiki::WikiStore>) -> Self {
1401 Self { store }
1402 }
1403}
1404
1405#[async_trait]
1406impl Tool for WikiStatusTool {
1407 fn name(&self) -> &str {
1408 "wiki_status"
1409 }
1410
1411 fn schema(&self) -> ToolSchema {
1412 ToolSchema {
1413 name: "wiki_status".to_string(),
1414 description: "Get knowledge base health overview: total claims, breakdown by category/status/freshness, top stale claims.".to_string(),
1415 parameters: json!({
1416 "type": "object",
1417 "properties": {},
1418 "required": []
1419 }),
1420 }
1421 }
1422
1423 async fn execute(&self, _arguments: &str) -> Result<String> {
1424 let status = self.store.status()?;
1425
1426 let mut out = format!(
1427 "## Knowledge Base Status\n\nTotal claims: {}\n",
1428 status.total_claims
1429 );
1430
1431 if !status.by_category.is_empty() {
1432 out.push_str("\n**By category:**\n");
1433 for (cat, count) in &status.by_category {
1434 out.push_str(&format!("- {}: {}\n", cat, count));
1435 }
1436 }
1437
1438 if !status.by_status.is_empty() {
1439 out.push_str("\n**By status:**\n");
1440 for (st, count) in &status.by_status {
1441 out.push_str(&format!("- {}: {}\n", st, count));
1442 }
1443 }
1444
1445 out.push_str("\n**By freshness:**\n");
1446 for (freshness, count) in &status.by_freshness {
1447 out.push_str(&format!("- {}: {}\n", freshness, count));
1448 }
1449
1450 if !status.top_stale.is_empty() {
1451 out.push_str("\n**Top stale claims:**\n");
1452 for c in &status.top_stale {
1453 out.push_str(&format!(
1454 "- [{}] {}\n",
1455 c.id.chars().take(8).collect::<String>(),
1456 c.text
1457 ));
1458 }
1459 }
1460
1461 Ok(out)
1462 }
1463}
1464
1465#[cfg(test)]
1466mod tests {
1467 use super::*;
1468
1469 #[test]
1473 fn test_extract_readable_text_removes_html() {
1474 let html = r#"
1475 <html><head><style>.x{display:none}</style></head>
1476 <body><script>alert(1)</script><h1>Title</h1><p>Hello <b>world</b>.</p></body></html>
1477 "#;
1478 let url = reqwest::Url::parse("https://example.com/test").unwrap();
1479 let text = extract_readable_text(html, &url);
1480 assert!(text.contains("Hello world"));
1481 assert!(!text.contains("alert(1)"));
1482 }
1483
1484 #[tokio::test]
1485 async fn test_redirect_target_validation_blocks_private_ip() {
1486 let current = reqwest::Url::parse("https://93.184.216.34/start").unwrap();
1487 let err = resolve_and_validate_redirect_target(¤t, "http://127.0.0.1/admin").await;
1488 assert!(err.is_err());
1489 let msg = err.unwrap_err().to_string();
1490 assert!(
1491 msg.contains("private/reserved IP"),
1492 "expected SSRF block message, got: {msg}"
1493 );
1494 }
1495
1496 #[tokio::test]
1497 async fn test_redirect_target_validation_allows_relative_public_ip_target() {
1498 let current = reqwest::Url::parse("https://93.184.216.34/start").unwrap();
1499 let next = resolve_and_validate_redirect_target(¤t, "/next")
1500 .await
1501 .unwrap();
1502 assert_eq!(next.as_str(), "https://93.184.216.34/next");
1503 }
1504
1505 #[tokio::test]
1506 async fn test_redirect_target_validation_blocks_non_http_scheme() {
1507 let current = reqwest::Url::parse("https://93.184.216.34/start").unwrap();
1508 let err = resolve_and_validate_redirect_target(¤t, "file:///etc/passwd").await;
1509 assert!(err.is_err());
1510 let msg = err.unwrap_err().to_string();
1511 assert!(msg.contains("Only http/https"));
1512 }
1513
1514 #[tokio::test]
1515 async fn test_memory_get_rejects_path_traversal() {
1516 let workspace = std::env::temp_dir().join("localgpt_test_workspace");
1517 let _ = std::fs::create_dir_all(&workspace);
1518 let tool = MemoryGetTool::new(workspace);
1519
1520 let args = r#"{"path": "memory/../../../etc/passwd"}"#;
1522 let result = tool.execute(args).await;
1523 assert!(result.is_err());
1524 let msg = result.unwrap_err().to_string();
1525 assert!(msg.contains("path traversal"));
1526 }
1527
1528 #[tokio::test]
1529 async fn test_memory_get_rejects_null_bytes() {
1530 let workspace = std::env::temp_dir().join("localgpt_test_workspace");
1531 let _ = std::fs::create_dir_all(&workspace);
1532 let tool = MemoryGetTool::new(workspace);
1533
1534 let args = r#"{"path": "memory/\u0000evil.md"}"#;
1535 let result = tool.execute(args).await;
1536 assert!(result.is_err());
1537 }
1538
1539 #[tokio::test]
1540 async fn test_memory_get_caps_lines_parameter() {
1541 let workspace = std::env::temp_dir().join("localgpt_test_mg_lines");
1542 let _ = std::fs::create_dir_all(workspace.join("memory"));
1543 std::fs::write(workspace.join("MEMORY.md"), "line1\nline2\nline3\n").unwrap();
1545 let tool = MemoryGetTool::new(workspace.clone());
1546
1547 let args = r#"{"path": "MEMORY.md", "lines": 999999999}"#;
1549 let result = tool.execute(args).await.unwrap();
1550 assert!(result.contains("line1"));
1551 let _ = std::fs::remove_dir_all(&workspace);
1553 }
1554
1555 fn test_tools_config() -> crate::config::ToolsConfig {
1558 crate::config::ToolsConfig::default()
1559 }
1560
1561 #[test]
1562 fn test_document_load_tool_schema() {
1563 let workspace = std::env::temp_dir().join("localgpt_test_doc_schema");
1564 let tool = DocumentLoadTool::new(workspace, &test_tools_config());
1565 assert_eq!(tool.name(), "document_load");
1566 let schema = tool.schema();
1567 assert_eq!(schema.name, "document_load");
1568 let params = &schema.parameters;
1569 assert!(params["properties"]["path"].is_object());
1570 assert_eq!(params["required"][0], "path");
1571 }
1572
1573 #[tokio::test]
1574 async fn test_document_load_rejects_path_traversal() {
1575 let workspace = std::env::temp_dir().join("localgpt_test_doc_traversal");
1576 let _ = std::fs::create_dir_all(&workspace);
1577 let tool = DocumentLoadTool::new(workspace, &test_tools_config());
1578
1579 let args = r#"{"path": "../../../etc/passwd"}"#;
1580 let result = tool.execute(args).await;
1581 assert!(result.is_err());
1582 assert!(result.unwrap_err().to_string().contains("path traversal"));
1583 }
1584
1585 #[tokio::test]
1586 async fn test_document_load_rejects_unsupported_format() {
1587 let workspace = std::env::temp_dir().join("localgpt_test_doc_format");
1588 let _ = std::fs::create_dir_all(&workspace);
1589 std::fs::write(workspace.join("test.xyz"), "content").unwrap();
1590 let tool = DocumentLoadTool::new(workspace.clone(), &test_tools_config());
1591
1592 let args = r#"{"path": "test.xyz"}"#;
1593 let result = tool.execute(args).await;
1594 assert!(result.is_err());
1595 let msg = result.unwrap_err().to_string();
1596 assert!(msg.contains("Unsupported format"));
1597 assert!(msg.contains("pdf"));
1598 let _ = std::fs::remove_dir_all(&workspace);
1599 }
1600
1601 #[tokio::test]
1602 async fn test_document_load_rejects_too_large() {
1603 let workspace = std::env::temp_dir().join("localgpt_test_doc_large");
1604 let _ = std::fs::create_dir_all(&workspace);
1605 std::fs::write(workspace.join("big.pdf"), vec![0u8; 100]).unwrap();
1606
1607 let mut config = test_tools_config();
1608 config.document_max_bytes = 50; let tool = DocumentLoadTool::new(workspace.clone(), &config);
1610
1611 let args = r#"{"path": "big.pdf"}"#;
1612 let result = tool.execute(args).await;
1613 assert!(result.is_err());
1614 assert!(result.unwrap_err().to_string().contains("too large"));
1615 let _ = std::fs::remove_dir_all(&workspace);
1616 }
1617
1618 #[tokio::test]
1619 async fn test_document_load_file_not_found() {
1620 let workspace = std::env::temp_dir().join("localgpt_test_doc_notfound");
1621 let _ = std::fs::create_dir_all(&workspace);
1622 let tool = DocumentLoadTool::new(workspace, &test_tools_config());
1623
1624 let args = r#"{"path": "nonexistent.pdf"}"#;
1625 let result = tool.execute(args).await;
1626 assert!(result.is_err());
1627 assert!(result.unwrap_err().to_string().contains("not found"));
1628 }
1629
1630 #[test]
1633 fn test_audio_transcribe_tool_schema() {
1634 let workspace = std::env::temp_dir().join("localgpt_test_audio_schema");
1635 let registry = Arc::new(crate::media::SttRegistry::new(
1636 crate::media::SttConfig::default(),
1637 ));
1638 let tool = AudioTranscribeTool::new(registry, workspace, None);
1639 assert_eq!(tool.name(), "transcribe_audio");
1640 let schema = tool.schema();
1641 assert_eq!(schema.name, "transcribe_audio");
1642 let params = &schema.parameters;
1643 assert!(params["properties"]["path"].is_object());
1644 assert!(params["properties"]["language"].is_object());
1645 assert_eq!(params["required"][0], "path");
1646 }
1647
1648 #[tokio::test]
1649 async fn test_audio_transcribe_rejects_path_traversal() {
1650 let workspace = std::env::temp_dir().join("localgpt_test_audio_traversal");
1651 let _ = std::fs::create_dir_all(&workspace);
1652 let registry = Arc::new(crate::media::SttRegistry::new(
1653 crate::media::SttConfig::default(),
1654 ));
1655 let tool = AudioTranscribeTool::new(registry, workspace, None);
1656
1657 let args = r#"{"path": "../../../etc/passwd.mp3"}"#;
1658 let result = tool.execute(args).await;
1659 assert!(result.is_err());
1660 assert!(result.unwrap_err().to_string().contains("path traversal"));
1661 }
1662
1663 #[tokio::test]
1664 async fn test_audio_transcribe_rejects_unsupported_format() {
1665 let workspace = std::env::temp_dir().join("localgpt_test_audio_format");
1666 let _ = std::fs::create_dir_all(&workspace);
1667 std::fs::write(workspace.join("test.txt"), "not audio").unwrap();
1668 let registry = Arc::new(crate::media::SttRegistry::new(
1669 crate::media::SttConfig::default(),
1670 ));
1671 let tool = AudioTranscribeTool::new(registry, workspace.clone(), None);
1672
1673 let args = r#"{"path": "test.txt"}"#;
1674 let result = tool.execute(args).await;
1675 assert!(result.is_err());
1676 assert!(
1677 result
1678 .unwrap_err()
1679 .to_string()
1680 .contains("Unsupported audio")
1681 );
1682 let _ = std::fs::remove_dir_all(&workspace);
1683 }
1684
1685 #[tokio::test]
1686 async fn test_audio_transcribe_file_not_found() {
1687 let workspace = std::env::temp_dir().join("localgpt_test_audio_notfound");
1688 let _ = std::fs::create_dir_all(&workspace);
1689 let registry = Arc::new(crate::media::SttRegistry::new(
1690 crate::media::SttConfig::default(),
1691 ));
1692 let tool = AudioTranscribeTool::new(registry, workspace, None);
1693
1694 let args = r#"{"path": "nonexistent.mp3"}"#;
1695 let result = tool.execute(args).await;
1696 assert!(result.is_err());
1697 assert!(result.unwrap_err().to_string().contains("not found"));
1698 }
1699}