1use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::Arc;
22use thiserror::Error;
23use tracing::{debug, info, instrument, warn};
24
25use crate::sandbox::Sandbox;
27
28#[derive(Error, Debug)]
35#[non_exhaustive]
36pub enum ToolError {
37 #[error("Tool '{0}' not found")]
38 NotFound(String),
39
40 #[error("Tool '{0}' execution failed: {1}")]
41 ExecutionFailed(String, String),
42
43 #[error("Invalid arguments for tool '{0}': {1}")]
44 InvalidArguments(String, String),
45
46 #[allow(dead_code)]
47 #[error("Policy denied: {0}")]
48 PolicyDenied(String),
49
50 #[allow(dead_code)]
51 #[error("Sandbox violation: {0}")]
52 SandboxViolation(String),
53
54 #[error("IO error: {0}")]
55 Io(#[from] std::io::Error),
56}
57
58pub type ToolResultValue<T> = std::result::Result<T, ToolError>;
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct JsonSchema {
65 #[serde(rename = "type")]
66 pub schema_type: String,
67 #[serde(default, skip_serializing_if = "Option::is_none")]
68 pub description: Option<String>,
69 #[serde(default, skip_serializing_if = "Option::is_none")]
70 pub properties: Option<HashMap<String, JsonSchema>>,
71 #[serde(default, skip_serializing_if = "Option::is_none")]
72 pub required: Option<Vec<String>>,
73 #[serde(default, skip_serializing_if = "Option::is_none")]
74 pub items: Option<Box<JsonSchema>>,
75 #[serde(default, skip_serializing_if = "Option::is_none")]
76 pub enum_values: Option<Vec<String>>,
77}
78
79impl JsonSchema {
80 pub fn string(description: &str) -> Self {
82 Self {
83 schema_type: "string".to_string(),
84 description: Some(description.to_string()),
85 properties: None,
86 required: None,
87 items: None,
88 enum_values: None,
89 }
90 }
91
92 pub fn object(properties: HashMap<String, JsonSchema>, required: Vec<String>) -> Self {
94 Self {
95 schema_type: "object".to_string(),
96 description: None,
97 properties: Some(properties),
98 required: Some(required),
99 items: None,
100 enum_values: None,
101 }
102 }
103
104 #[allow(dead_code)]
106 pub fn array(items: JsonSchema, description: &str) -> Self {
107 Self {
108 schema_type: "array".to_string(),
109 description: Some(description.to_string()),
110 properties: None,
111 required: None,
112 items: Some(Box::new(items)),
113 enum_values: None,
114 }
115 }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
120pub struct ToolDefinition {
121 pub name: String,
123 pub description: String,
125 pub parameters: JsonSchema,
127 #[serde(default)]
129 pub requires_approval: bool,
130 #[serde(default)]
132 pub category: ToolCategory,
133}
134
135impl ToolDefinition {
136 #[allow(dead_code)]
139 pub fn to_openai_tool(&self) -> serde_json::Value {
140 serde_json::json!({
141 "type": "function",
142 "function": {
143 "name": self.name,
144 "description": self.description,
145 "parameters": self.parameters
146 }
147 })
148 }
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
156#[non_exhaustive]
157pub enum ToolCategory {
158 #[default]
159 General,
160 Shell,
161 FileSystem,
162 Network,
163 CodeAnalysis,
164 WebSearch,
165 Mcp,
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub struct ToolCall {
171 pub name: String,
173 pub arguments: serde_json::Value,
175 #[serde(default)]
177 pub id: Option<String>,
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct ToolResult {
183 pub tool_name: String,
185 pub success: bool,
187 pub output: String,
189 #[serde(default, skip_serializing_if = "Option::is_none")]
191 pub error: Option<String>,
192 #[serde(default, skip_serializing_if = "Option::is_none")]
194 pub exit_code: Option<i32>,
195 #[serde(default, skip_serializing_if = "Option::is_none")]
197 pub duration_ms: Option<u64>,
198}
199
200#[async_trait::async_trait]
204pub trait ToolImpl: Send + Sync {
205 async fn execute(&self, args: serde_json::Value) -> ToolResultValue<ToolResult>;
207
208 fn definition(&self) -> &ToolDefinition;
210
211 fn name(&self) -> &str {
213 &self.definition().name
214 }
215}
216
217#[derive(Clone)]
221pub struct ToolRegistry {
222 tools: HashMap<String, Arc<dyn ToolImpl>>,
223}
224
225impl ToolRegistry {
226 pub fn new() -> Self {
228 Self {
229 tools: HashMap::new(),
230 }
231 }
232
233 pub fn register(&mut self, tool: Arc<dyn ToolImpl>) {
235 let name = tool.name().to_string();
236 info!(tool = %name, category = ?tool.definition().category, "Tool registered");
237 self.tools.insert(name, tool);
238 }
239
240 pub fn get(&self, name: &str) -> Option<&Arc<dyn ToolImpl>> {
242 self.tools.get(name)
243 }
244
245 #[allow(dead_code)]
247 pub fn has(&self, name: &str) -> bool {
248 self.tools.contains_key(name)
249 }
250
251 #[allow(dead_code)]
253 pub fn definitions(&self) -> Vec<ToolDefinition> {
254 self.tools
255 .values()
256 .map(|t| t.definition().clone())
257 .collect()
258 }
259
260 #[allow(dead_code)]
262 pub fn to_openai_tools(&self) -> Vec<serde_json::Value> {
263 self.tools
264 .values()
265 .map(|t| t.definition().to_openai_tool())
266 .collect()
267 }
268
269 #[allow(dead_code)]
271 pub fn len(&self) -> usize {
272 self.tools.len()
273 }
274
275 #[allow(dead_code)]
277 pub fn is_empty(&self) -> bool {
278 self.tools.is_empty()
279 }
280
281 #[instrument(skip(self), fields(tool = %call.name))]
283 pub async fn execute(&self, call: ToolCall) -> ToolResultValue<ToolResult> {
284 let start = std::time::Instant::now();
285
286 let tool = self
287 .get(&call.name)
288 .ok_or_else(|| ToolError::NotFound(call.name.clone()))?;
289
290 info!(tool = %call.name, "Executing tool call");
291 debug!(
292 tool = %call.name,
293 args = %call.arguments,
294 "Tool call arguments"
295 );
296
297 let mut result = tool.execute(call.arguments).await?;
298 result.duration_ms = Some(start.elapsed().as_millis() as u64);
299
300 if result.success {
301 info!(
302 tool = %call.name,
303 duration_ms = result.duration_ms.unwrap_or(0),
304 "Tool executed successfully"
305 );
306 debug!(
307 tool = %call.name,
308 output_len = result.output.len(),
309 "Tool result output"
310 );
311 } else {
312 warn!(
313 tool = %call.name,
314 error = %result.error.as_deref().unwrap_or("unknown"),
315 "Tool execution failed"
316 );
317 }
318
319 Ok(result)
320 }
321
322 pub fn with_default_tools() -> Self {
324 let mut registry = Self::new();
325 registry.register(Arc::new(ShellTool::new()));
326 registry.register(Arc::new(ReadFileTool::new()));
327 registry.register(Arc::new(WriteFileTool::new()));
328 registry.register(Arc::new(WebFetchTool::new()));
329 registry.register(Arc::new(WebSearchTool::new()));
330 registry
331 }
332
333 #[allow(dead_code)]
335 pub fn with_web_search_config(
336 endpoint: &str,
337 engine: &str,
338 max_results: usize,
339 fetch_content: bool,
340 ) -> Self {
341 let mut registry = Self::new();
342 registry.register(Arc::new(ShellTool::new()));
343 registry.register(Arc::new(ReadFileTool::new()));
344 registry.register(Arc::new(WriteFileTool::new()));
345 registry.register(Arc::new(WebFetchTool::new()));
346 registry.register(Arc::new(WebSearchTool::with_config(
347 endpoint.to_string(),
348 engine.to_string(),
349 max_results,
350 fetch_content,
351 )));
352 registry
353 }
354
355 pub fn with_config(config: &crate::config::Config) -> Self {
357 let mut registry = Self::new();
358 registry.register(Arc::new(ShellTool::new()));
359 registry.register(Arc::new(ReadFileTool::new()));
360 registry.register(Arc::new(WriteFileTool::new()));
361 registry.register(Arc::new(WebFetchTool::new()));
362 registry.register(Arc::new(WebSearchTool::with_config(
363 config.web_search.endpoint.clone(),
364 config.web_search.engine.clone(),
365 config.web_search.max_results,
366 config.web_search.fetch_content,
367 )));
368 registry
369 }
370}
371
372impl Default for ToolRegistry {
373 fn default() -> Self {
374 Self::with_default_tools()
375 }
376}
377
378pub struct ShellTool {
382 definition: ToolDefinition,
383 sandbox: Option<Sandbox>,
384}
385
386impl ShellTool {
387 pub fn new() -> Self {
388 Self::default()
389 }
390
391 #[allow(dead_code)]
392 pub fn new_with_sandbox(sandbox: Sandbox) -> Self {
393 Self {
394 sandbox: Some(sandbox),
395 ..Self::default()
396 }
397 }
398}
399
400impl Default for ShellTool {
401 fn default() -> Self {
402 let mut properties = HashMap::new();
403 properties.insert(
404 "command".to_string(),
405 JsonSchema::string("The shell command to execute"),
406 );
407 properties.insert(
408 "timeout_secs".to_string(),
409 JsonSchema {
410 schema_type: "integer".to_string(),
411 description: Some("Timeout in seconds (default: 30)".to_string()),
412 properties: None,
413 required: None,
414 items: None,
415 enum_values: None,
416 },
417 );
418 properties.insert(
419 "workdir".to_string(),
420 JsonSchema::string("Working directory (default: current)"),
421 );
422
423 Self {
424 definition: ToolDefinition {
425 name: "shell_exec".to_string(),
426 description: "Execute a shell command and return its output. Use for running scripts, compiling code, or any command-line operation. Runs in a sandboxed environment.".to_string(),
427 parameters: JsonSchema::object(
428 properties,
429 vec!["command".to_string()],
430 ),
431 requires_approval: true,
432 category: ToolCategory::Shell,
433 },
434 sandbox: None,
435 }
436 }
437}
438
439#[async_trait::async_trait]
440impl ToolImpl for ShellTool {
441 fn definition(&self) -> &ToolDefinition {
442 &self.definition
443 }
444
445 async fn execute(&self, args: serde_json::Value) -> ToolResultValue<ToolResult> {
446 let command = args
447 .get("command")
448 .and_then(|v| v.as_str())
449 .ok_or_else(|| {
450 ToolError::InvalidArguments(
451 "shell_exec".to_string(),
452 "missing 'command' argument".to_string(),
453 )
454 })?;
455
456 let timeout_secs = args
457 .get("timeout_secs")
458 .and_then(|v| v.as_u64())
459 .unwrap_or(30);
460
461 let workdir = if let Some(sandbox) = &self.sandbox {
463 sandbox.workdir().to_string_lossy().to_string()
464 } else {
465 args.get("workdir")
466 .and_then(|v| v.as_str())
467 .map(|s| s.to_string())
468 .unwrap_or_else(|| {
469 std::env::current_dir()
470 .unwrap_or_default()
471 .to_string_lossy()
472 .to_string()
473 })
474 };
475
476 let result = run_shell_command(command, timeout_secs, Some(workdir)).await?;
478
479 Ok(result)
480 }
481}
482
483pub struct ReadFileTool {
485 definition: ToolDefinition,
486}
487
488impl ReadFileTool {
489 pub fn new() -> Self {
490 Self::default()
491 }
492}
493
494impl Default for ReadFileTool {
495 fn default() -> Self {
496 let mut properties = HashMap::new();
497 properties.insert(
498 "path".to_string(),
499 JsonSchema::string("Absolute path to the file to read"),
500 );
501 properties.insert(
502 "max_bytes".to_string(),
503 JsonSchema {
504 schema_type: "integer".to_string(),
505 description: Some("Maximum bytes to read (default: 65536)".to_string()),
506 properties: None,
507 required: None,
508 items: None,
509 enum_values: None,
510 },
511 );
512
513 Self {
514 definition: ToolDefinition {
515 name: "read_file".to_string(),
516 description: "Read the contents of a file from the filesystem. Returns the file content as text.".to_string(),
517 parameters: JsonSchema::object(
518 properties,
519 vec!["path".to_string()],
520 ),
521 requires_approval: false,
522 category: ToolCategory::FileSystem,
523 },
524 }
525 }
526}
527
528#[async_trait::async_trait]
529impl ToolImpl for ReadFileTool {
530 fn definition(&self) -> &ToolDefinition {
531 &self.definition
532 }
533
534 async fn execute(&self, args: serde_json::Value) -> ToolResultValue<ToolResult> {
535 let path = args.get("path").and_then(|v| v.as_str()).ok_or_else(|| {
536 ToolError::InvalidArguments(
537 "read_file".to_string(),
538 "missing 'path' argument".to_string(),
539 )
540 })?;
541
542 let max_bytes = args
543 .get("max_bytes")
544 .and_then(|v| v.as_u64())
545 .unwrap_or(65536) as usize;
546
547 let content = tokio::fs::read_to_string(path).await.map_err(|e| {
548 ToolError::ExecutionFailed("read_file".to_string(), format!("Cannot read file: {}", e))
549 })?;
550
551 let truncated = if content.len() > max_bytes {
552 format!(
553 "{}...\n[truncated at {} bytes]",
554 &content[..max_bytes],
555 max_bytes
556 )
557 } else {
558 content
559 };
560
561 Ok(ToolResult {
562 tool_name: "read_file".to_string(),
563 success: true,
564 output: truncated,
565 error: None,
566 exit_code: None,
567 duration_ms: None,
568 })
569 }
570}
571
572pub struct WriteFileTool {
574 definition: ToolDefinition,
575}
576
577impl WriteFileTool {
578 pub fn new() -> Self {
579 Self::default()
580 }
581}
582
583impl Default for WriteFileTool {
584 fn default() -> Self {
585 let mut properties = HashMap::new();
586 properties.insert(
587 "path".to_string(),
588 JsonSchema::string("Absolute path to the file to write"),
589 );
590 properties.insert(
591 "content".to_string(),
592 JsonSchema::string("The content to write to the file"),
593 );
594 properties.insert(
595 "append".to_string(),
596 JsonSchema {
597 schema_type: "boolean".to_string(),
598 description: Some(
599 "If true, append instead of overwrite (default: false)".to_string(),
600 ),
601 properties: None,
602 required: None,
603 items: None,
604 enum_values: None,
605 },
606 );
607
608 Self {
609 definition: ToolDefinition {
610 name: "write_file".to_string(),
611 description: "Write content to a file. Creates parent directories if they don't exist. Can append to existing files.".to_string(),
612 parameters: JsonSchema::object(
613 properties,
614 vec!["path".to_string(), "content".to_string()],
615 ),
616 requires_approval: true,
617 category: ToolCategory::FileSystem,
618 },
619 }
620 }
621}
622
623#[async_trait::async_trait]
624impl ToolImpl for WriteFileTool {
625 fn definition(&self) -> &ToolDefinition {
626 &self.definition
627 }
628
629 async fn execute(&self, args: serde_json::Value) -> ToolResultValue<ToolResult> {
630 let path = args.get("path").and_then(|v| v.as_str()).ok_or_else(|| {
631 ToolError::InvalidArguments(
632 "write_file".to_string(),
633 "missing 'path' argument".to_string(),
634 )
635 })?;
636
637 let content = args
638 .get("content")
639 .and_then(|v| v.as_str())
640 .ok_or_else(|| {
641 ToolError::InvalidArguments(
642 "write_file".to_string(),
643 "missing 'content' argument".to_string(),
644 )
645 })?;
646
647 let append = args
648 .get("append")
649 .and_then(|v| v.as_bool())
650 .unwrap_or(false);
651
652 if let Some(parent) = std::path::Path::new(path).parent() {
654 tokio::fs::create_dir_all(parent).await.map_err(|e| {
655 ToolError::ExecutionFailed(
656 "write_file".to_string(),
657 format!("Cannot create directories: {}", e),
658 )
659 })?;
660 }
661
662 if append {
663 let mut file = tokio::fs::OpenOptions::new()
664 .append(true)
665 .create(true)
666 .open(path)
667 .await
668 .map_err(|e| {
669 ToolError::ExecutionFailed(
670 "write_file".to_string(),
671 format!("Cannot open file for append: {}", e),
672 )
673 })?;
674 tokio::io::AsyncWriteExt::write_all(&mut file, content.as_bytes())
675 .await
676 .map_err(|e| {
677 ToolError::ExecutionFailed(
678 "write_file".to_string(),
679 format!("Cannot write to file: {}", e),
680 )
681 })?;
682 } else {
683 tokio::fs::write(path, content).await.map_err(|e| {
684 ToolError::ExecutionFailed(
685 "write_file".to_string(),
686 format!("Cannot write file: {}", e),
687 )
688 })?;
689 }
690
691 Ok(ToolResult {
692 tool_name: "write_file".to_string(),
693 success: true,
694 output: format!("Successfully wrote {} bytes to {}", content.len(), path),
695 error: None,
696 exit_code: None,
697 duration_ms: None,
698 })
699 }
700}
701
702pub struct WebFetchTool {
704 definition: ToolDefinition,
705}
706
707impl WebFetchTool {
708 pub fn new() -> Self {
709 Self::default()
710 }
711}
712
713impl Default for WebFetchTool {
714 fn default() -> Self {
715 let mut properties = HashMap::new();
716 properties.insert("url".to_string(), JsonSchema::string("The URL to fetch"));
717 properties.insert(
718 "max_bytes".to_string(),
719 JsonSchema {
720 schema_type: "integer".to_string(),
721 description: Some("Maximum bytes to read (default: 131072)".to_string()),
722 properties: None,
723 required: None,
724 items: None,
725 enum_values: None,
726 },
727 );
728
729 Self {
730 definition: ToolDefinition {
731 name: "web_fetch".to_string(),
732 description: "Fetch a URL and return its content as text. Use for reading web pages, APIs, or documentation.".to_string(),
733 parameters: JsonSchema::object(
734 properties,
735 vec!["url".to_string()],
736 ),
737 requires_approval: false,
738 category: ToolCategory::Network,
739 },
740 }
741 }
742}
743
744#[async_trait::async_trait]
745impl ToolImpl for WebFetchTool {
746 fn definition(&self) -> &ToolDefinition {
747 &self.definition
748 }
749
750 async fn execute(&self, args: serde_json::Value) -> ToolResultValue<ToolResult> {
751 let url = args.get("url").and_then(|v| v.as_str()).ok_or_else(|| {
752 ToolError::InvalidArguments(
753 "web_fetch".to_string(),
754 "missing 'url' argument".to_string(),
755 )
756 })?;
757
758 let max_bytes = args
759 .get("max_bytes")
760 .and_then(|v| v.as_u64())
761 .unwrap_or(131072) as usize;
762
763 let client = reqwest::Client::builder()
764 .timeout(std::time::Duration::from_secs(30))
765 .user_agent("RavenClaws/0.9.2")
766 .build()
767 .map_err(|e| {
768 ToolError::ExecutionFailed("web_fetch".to_string(), format!("HTTP client: {}", e))
769 })?;
770
771 let response = client.get(url).send().await.map_err(|e| {
772 ToolError::ExecutionFailed("web_fetch".to_string(), format!("Request failed: {}", e))
773 })?;
774
775 let status = response.status();
776 let content_type = response
777 .headers()
778 .get(reqwest::header::CONTENT_TYPE)
779 .and_then(|v| v.to_str().ok())
780 .unwrap_or("unknown")
781 .to_string();
782
783 let body = response.text().await.map_err(|e| {
784 ToolError::ExecutionFailed(
785 "web_fetch".to_string(),
786 format!("Failed to read response body: {}", e),
787 )
788 })?;
789
790 let truncated = if body.len() > max_bytes {
791 format!(
792 "{}...\n[truncated at {} bytes]",
793 &body[..max_bytes],
794 max_bytes
795 )
796 } else {
797 body
798 };
799
800 Ok(ToolResult {
801 tool_name: "web_fetch".to_string(),
802 success: status.is_success(),
803 output: format!(
804 "Status: {}\nContent-Type: {}\n\n{}",
805 status.as_u16(),
806 content_type,
807 truncated
808 ),
809 error: if status.is_success() {
810 None
811 } else {
812 Some(format!("HTTP {}", status.as_u16()))
813 },
814 exit_code: Some(status.as_u16() as i32),
815 duration_ms: None,
816 })
817 }
818}
819
820pub struct WebSearchTool {
822 definition: ToolDefinition,
823 search_endpoint: String,
824 search_engine: String,
825 max_results: usize,
826 fetch_content: bool,
827}
828
829impl WebSearchTool {
830 pub fn new() -> Self {
831 Self::default()
832 }
833
834 pub fn with_config(
835 endpoint: String,
836 engine: String,
837 max_results: usize,
838 fetch_content: bool,
839 ) -> Self {
840 let mut properties = HashMap::new();
841 properties.insert("query".to_string(), JsonSchema::string("The search query"));
842 properties.insert(
843 "max_results".to_string(),
844 JsonSchema {
845 schema_type: "integer".to_string(),
846 description: Some(
847 "Maximum number of search results to return (default: 5)".to_string(),
848 ),
849 properties: None,
850 required: None,
851 items: None,
852 enum_values: None,
853 },
854 );
855 properties.insert(
856 "fetch_content".to_string(),
857 JsonSchema {
858 schema_type: "boolean".to_string(),
859 description: Some(
860 "Whether to fetch and extract content from each result (default: true)"
861 .to_string(),
862 ),
863 properties: None,
864 required: None,
865 items: None,
866 enum_values: None,
867 },
868 );
869
870 Self {
871 definition: ToolDefinition {
872 name: "web_search".to_string(),
873 description: "Search the web for information. Returns a list of results with titles, URLs, and snippets. Can optionally fetch and extract readable content from each result.".to_string(),
874 parameters: JsonSchema::object(
875 properties,
876 vec!["query".to_string()],
877 ),
878 requires_approval: false,
879 category: ToolCategory::WebSearch,
880 },
881 search_endpoint: endpoint,
882 search_engine: engine,
883 max_results,
884 fetch_content,
885 }
886 }
887}
888
889impl Default for WebSearchTool {
890 fn default() -> Self {
891 Self::with_config(
892 "https://searx.be".to_string(),
893 "duckduckgo".to_string(),
894 5,
895 true,
896 )
897 }
898}
899
900impl WebSearchTool {
901 async fn search_searxng(
903 &self,
904 query: &str,
905 max_results: usize,
906 ) -> ToolResultValue<Vec<SearchResult>> {
907 let client = reqwest::Client::builder()
908 .timeout(std::time::Duration::from_secs(15))
909 .user_agent("RavenClaws/0.9.2")
910 .build()
911 .map_err(|e| {
912 ToolError::ExecutionFailed("web_search".to_string(), format!("HTTP client: {}", e))
913 })?;
914
915 let url = format!(
916 "{}/search?q={}&format=json&language=en&pageno=1",
917 self.search_endpoint.trim_end_matches('/'),
918 urlencoding(query)
919 );
920
921 let response = client.get(&url).send().await.map_err(|e| {
922 ToolError::ExecutionFailed(
923 "web_search".to_string(),
924 format!("Search request failed: {}", e),
925 )
926 })?;
927
928 if !response.status().is_success() {
929 return Err(ToolError::ExecutionFailed(
930 "web_search".to_string(),
931 format!("Search API returned HTTP {}", response.status().as_u16()),
932 ));
933 }
934
935 let body: serde_json::Value = response.json().await.map_err(|e| {
936 ToolError::ExecutionFailed(
937 "web_search".to_string(),
938 format!("Failed to parse search results: {}", e),
939 )
940 })?;
941
942 let results = body["results"]
943 .as_array()
944 .map(|arr| {
945 arr.iter()
946 .take(max_results)
947 .filter_map(|r| {
948 let title = r["title"].as_str().unwrap_or("").to_string();
949 let url = r["url"].as_str().unwrap_or("").to_string();
950 let snippet = r["content"].as_str().unwrap_or("").to_string();
951 if title.is_empty() && url.is_empty() {
952 None
953 } else {
954 Some(SearchResult {
955 title,
956 url,
957 snippet,
958 })
959 }
960 })
961 .collect::<Vec<_>>()
962 })
963 .unwrap_or_default();
964
965 Ok(results)
966 }
967
968 async fn search_duckduckgo(
970 &self,
971 query: &str,
972 max_results: usize,
973 ) -> ToolResultValue<Vec<SearchResult>> {
974 let client = reqwest::Client::builder()
975 .timeout(std::time::Duration::from_secs(15))
976 .user_agent("Mozilla/5.0 (compatible; RavenClaws/0.9.2)")
977 .build()
978 .map_err(|e| {
979 ToolError::ExecutionFailed("web_search".to_string(), format!("HTTP client: {}", e))
980 })?;
981
982 let url = format!("https://html.duckduckgo.com/html/?q={}", urlencoding(query));
983
984 let response = client.get(&url).send().await.map_err(|e| {
985 ToolError::ExecutionFailed(
986 "web_search".to_string(),
987 format!("Search request failed: {}", e),
988 )
989 })?;
990
991 let body = response.text().await.map_err(|e| {
992 ToolError::ExecutionFailed(
993 "web_search".to_string(),
994 format!("Failed to read search results: {}", e),
995 )
996 })?;
997
998 let mut results = Vec::new();
1000 let mut pos = 0;
1001 let result_class = "result__a";
1002
1003 while results.len() < max_results {
1004 let link_start = match body[pos..].find(result_class) {
1006 Some(i) => pos + i,
1007 None => break,
1008 };
1009
1010 let a_start = match body[link_start..].find("<a ") {
1012 Some(i) => link_start + i,
1013 None => break,
1014 };
1015 let a_end = match body[a_start..].find("</a>") {
1016 Some(i) => a_start + i,
1017 None => break,
1018 };
1019
1020 let a_tag = &body[a_start..a_end];
1021
1022 let url = extract_href(a_tag).unwrap_or_default();
1024 let title = a_tag.rsplit('>').next().unwrap_or("").trim().to_string();
1026
1027 let snippet_start = match body[a_end..].find("result__snippet") {
1029 Some(i) => a_end + i,
1030 None => {
1031 results.push(SearchResult {
1032 title,
1033 url,
1034 snippet: String::new(),
1035 });
1036 pos = a_end + 1;
1037 continue;
1038 }
1039 };
1040 let snippet_close = match body[snippet_start..].find("</a>") {
1041 Some(i) => snippet_start + i,
1042 None => {
1043 results.push(SearchResult {
1044 title,
1045 url,
1046 snippet: String::new(),
1047 });
1048 pos = a_end + 1;
1049 continue;
1050 }
1051 };
1052 let snippet_html = &body[snippet_start..snippet_close];
1053 let snippet = strip_html_tags(snippet_html).trim().to_string();
1054
1055 if !url.is_empty() || !title.is_empty() {
1056 results.push(SearchResult {
1057 title,
1058 url,
1059 snippet,
1060 });
1061 }
1062
1063 pos = a_end + 1;
1064 }
1065
1066 Ok(results)
1067 }
1068}
1069
1070#[allow(dead_code)]
1072struct SearchResult {
1073 title: String,
1074 url: String,
1075 snippet: String,
1076}
1077
1078#[async_trait::async_trait]
1079impl ToolImpl for WebSearchTool {
1080 fn definition(&self) -> &ToolDefinition {
1081 &self.definition
1082 }
1083
1084 async fn execute(&self, args: serde_json::Value) -> ToolResultValue<ToolResult> {
1085 let query = args.get("query").and_then(|v| v.as_str()).ok_or_else(|| {
1086 ToolError::InvalidArguments(
1087 "web_search".to_string(),
1088 "missing 'query' argument".to_string(),
1089 )
1090 })?;
1091
1092 let max_results = args
1093 .get("max_results")
1094 .and_then(|v| v.as_u64())
1095 .unwrap_or(self.max_results as u64) as usize;
1096
1097 let fetch_content = args
1098 .get("fetch_content")
1099 .and_then(|v| v.as_bool())
1100 .unwrap_or(self.fetch_content);
1101
1102 let results = match self.search_engine.as_str() {
1104 "searxng" => self.search_searxng(query, max_results).await?,
1105 _ => self.search_duckduckgo(query, max_results).await?,
1106 };
1107
1108 if results.is_empty() {
1109 return Ok(ToolResult {
1110 tool_name: "web_search".to_string(),
1111 success: true,
1112 output: "No search results found.".to_string(),
1113 error: None,
1114 exit_code: None,
1115 duration_ms: None,
1116 });
1117 }
1118
1119 let mut output = String::new();
1121 for (i, result) in results.iter().enumerate() {
1122 output.push_str(&format!(
1123 "[{}] **{}**\n URL: {}\n Snippet: {}\n",
1124 i + 1,
1125 result.title,
1126 result.url,
1127 result.snippet
1128 ));
1129
1130 if fetch_content && !result.url.is_empty() {
1131 match fetch_and_extract_content(&result.url, 8192).await {
1132 Ok(content) => {
1133 output.push_str(&format!(" Content: {}\n", content));
1134 }
1135 Err(e) => {
1136 output.push_str(&format!(" Content: (unavailable: {})\n", e));
1137 }
1138 }
1139 }
1140 }
1141
1142 Ok(ToolResult {
1143 tool_name: "web_search".to_string(),
1144 success: true,
1145 output,
1146 error: None,
1147 exit_code: None,
1148 duration_ms: None,
1149 })
1150 }
1151}
1152
1153async fn fetch_and_extract_content(url: &str, max_bytes: usize) -> ToolResultValue<String> {
1157 let client = reqwest::Client::builder()
1158 .timeout(std::time::Duration::from_secs(15))
1159 .user_agent("Mozilla/5.0 (compatible; RavenClaws/0.9.2)")
1160 .build()
1161 .map_err(|e| {
1162 ToolError::ExecutionFailed("web_fetch".to_string(), format!("HTTP client: {}", e))
1163 })?;
1164
1165 let response = client.get(url).send().await.map_err(|e| {
1166 ToolError::ExecutionFailed("web_fetch".to_string(), format!("Request failed: {}", e))
1167 })?;
1168
1169 if !response.status().is_success() {
1170 return Err(ToolError::ExecutionFailed(
1171 "web_fetch".to_string(),
1172 format!("HTTP {}", response.status().as_u16()),
1173 ));
1174 }
1175
1176 let body = response.text().await.map_err(|e| {
1177 ToolError::ExecutionFailed(
1178 "web_fetch".to_string(),
1179 format!("Failed to read response: {}", e),
1180 )
1181 })?;
1182
1183 Ok(html_to_text(&body, max_bytes))
1184}
1185
1186fn html_to_text(html: &str, max_chars: usize) -> String {
1188 let mut text = String::new();
1189 let bytes = html.as_bytes();
1190 let len = bytes.len();
1191 let mut i = 0;
1192 let mut in_tag = false;
1193 let mut in_script = false;
1194 let mut in_style = false;
1195 let mut in_title = false;
1196 let mut title_text = String::new();
1197 let mut last_char_was_space = true;
1198
1199 while i < len {
1200 if in_script {
1201 if i + 8 < len && bytes[i..i + 9].eq_ignore_ascii_case(b"</script>") {
1203 in_script = false;
1204 i += 9;
1205 continue;
1206 }
1207 i += 1;
1208 continue;
1209 }
1210
1211 if in_style {
1212 if i + 7 < len && bytes[i..i + 8].eq_ignore_ascii_case(b"</style>") {
1214 in_style = false;
1215 i += 8;
1216 continue;
1217 }
1218 i += 1;
1219 continue;
1220 }
1221
1222 if in_title {
1223 if i + 7 < len && bytes[i..i + 8].eq_ignore_ascii_case(b"</title>") {
1225 in_title = false;
1226 i += 8;
1227 continue;
1228 }
1229 title_text.push(bytes[i] as char);
1230 i += 1;
1231 continue;
1232 }
1233
1234 if in_tag {
1235 if bytes[i] == b'>' {
1236 in_tag = false;
1237 if i >= 2 {
1239 let tag_start = (0..i).rev().find(|&j| bytes[j] == b'<').unwrap_or(0);
1240 let tag_content = &html[tag_start..i].to_lowercase();
1241 if (tag_content.starts_with("<br")
1242 || tag_content.starts_with("<p")
1243 || tag_content.starts_with("<tr")
1244 || tag_content.starts_with("<div")
1245 || tag_content.starts_with("<li")
1246 || tag_content.starts_with("<h1")
1247 || tag_content.starts_with("<h2")
1248 || tag_content.starts_with("<h3")
1249 || tag_content.starts_with("<h4")
1250 || tag_content.starts_with("<h5")
1251 || tag_content.starts_with("<h6"))
1252 && !last_char_was_space
1253 {
1254 text.push('\n');
1255 last_char_was_space = true;
1256 }
1257 }
1258 } else {
1259 if bytes[i] == b's' || bytes[i] == b'S' {
1261 if i + 5 < len && bytes[i..i + 6].eq_ignore_ascii_case(b"script") {
1262 in_script = true;
1263 } else if i + 4 < len && bytes[i..i + 5].eq_ignore_ascii_case(b"style") {
1264 in_style = true;
1265 } else if i + 4 < len && bytes[i..i + 5].eq_ignore_ascii_case(b"title") {
1266 in_title = true;
1267 }
1268 }
1269 }
1270 i += 1;
1271 continue;
1272 }
1273
1274 if bytes[i] == b'<' {
1275 in_tag = true;
1276 i += 1;
1277 continue;
1278 }
1279
1280 if bytes[i] == b'&' {
1282 let remaining = len - i;
1283 let entity = if remaining > 5 && &html[i..i + 6] == " " {
1284 i += 6;
1285 " "
1286 } else if remaining > 3 && &html[i..i + 4] == "<" {
1287 i += 4;
1288 "<"
1289 } else if remaining > 3 && &html[i..i + 4] == ">" {
1290 i += 4;
1291 ">"
1292 } else if remaining > 4 && &html[i..i + 5] == "&" {
1293 i += 5;
1294 "&"
1295 } else if remaining > 5 && &html[i..i + 6] == """ {
1296 i += 6;
1297 "\""
1298 } else if remaining > 3 && &html[i..i + 4] == "'" {
1299 i += 4;
1300 "'"
1301 } else {
1302 i += 1;
1303 continue;
1304 };
1305
1306 if text.len() >= max_chars {
1307 break;
1308 }
1309 text.push_str(entity);
1310 last_char_was_space = entity == " ";
1311 continue;
1312 }
1313
1314 if bytes[i].is_ascii_whitespace() {
1316 if !last_char_was_space {
1317 text.push(' ');
1318 last_char_was_space = true;
1319 }
1320 i += 1;
1321 continue;
1322 }
1323
1324 if text.len() >= max_chars {
1325 break;
1326 }
1327 text.push(bytes[i] as char);
1328 last_char_was_space = false;
1329 i += 1;
1330 }
1331
1332 let title_text = title_text.trim();
1334 let text = text.trim();
1335
1336 if !title_text.is_empty() {
1337 format!("Title: {}\n\n{}", title_text, text)
1338 } else {
1339 text.to_string()
1340 }
1341}
1342
1343fn strip_html_tags(input: &str) -> String {
1345 let mut output = String::new();
1346 let mut in_tag = false;
1347 for c in input.chars() {
1348 match c {
1349 '<' => in_tag = true,
1350 '>' => in_tag = false,
1351 _ => {
1352 if !in_tag {
1353 output.push(c);
1354 }
1355 }
1356 }
1357 }
1358 output
1360 .replace("&", "&")
1361 .replace("<", "<")
1362 .replace(">", ">")
1363 .replace(""", "\"")
1364 .replace("'", "'")
1365 .replace(" ", " ")
1366}
1367
1368fn extract_href(a_tag: &str) -> Option<String> {
1370 let href_start = a_tag.find("href=\"")?;
1371 let value_start = href_start + 6;
1372 let value_end = a_tag[value_start..].find('"')?;
1373 let href = &a_tag[value_start..value_start + value_end];
1374
1375 if href.starts_with("//") {
1377 return Some(format!("https:{}", href));
1378 }
1379 if href.starts_with("/") {
1380 return None; }
1382
1383 Some(href.to_string())
1384}
1385
1386fn urlencoding(input: &str) -> String {
1388 let mut result = String::with_capacity(input.len() * 3);
1389 for byte in input.bytes() {
1390 match byte {
1391 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
1392 result.push(byte as char);
1393 }
1394 b' ' => result.push_str("%20"),
1395 _ => {
1396 result.push_str(&format!("%{:02X}", byte));
1397 }
1398 }
1399 }
1400 result
1401}
1402
1403#[allow(dead_code)]
1427pub struct ToolCallDetector {
1428 patterns: Vec<DetectorPattern>,
1429}
1430
1431#[allow(dead_code)]
1433struct DetectorPattern {
1434 regex: regex_lite::Regex,
1436 tool_name: Option<String>,
1438 arg_key: Option<String>,
1440 arg_group: usize,
1442}
1443
1444#[allow(dead_code)]
1445impl ToolCallDetector {
1446 pub fn new() -> Self {
1448 let patterns = vec![
1450 DetectorPattern {
1453 regex: regex_lite::Regex::new(
1454 r"(?i)(?:^|[.!?]\s+)(?:use|run|call|invoke)\s+(?:the\s+)?(\w+)\s+(?:tool|command|function)(?:\s+with\s+(?:args|arguments|parameters))?\s*:?\s*(.+?)(?:\.|$|\n)"
1455 ).expect("valid regex"),
1456 tool_name: None, arg_key: None,
1458 arg_group: 2,
1459 },
1460 DetectorPattern {
1462 regex: regex_lite::Regex::new(
1463 r"(?i)(?:I'?ll|I\s+will|let\s+me)\s+use\s+(?:the\s+)?(\w+)\s+(?:tool|command|function)\s+to\s+(?:run|execute|do)\s*:?\s*(.+?)(?:\.|$|\n)"
1464 ).expect("valid regex"),
1465 tool_name: None,
1466 arg_key: Some("command".to_string()),
1467 arg_group: 2,
1468 },
1469 DetectorPattern {
1471 regex: regex_lite::Regex::new(
1472 r"(?i)(?:let\s+me|I'?ll|I\s+will)\s+(?:read|open|check)\s+(?:the\s+)?file\s+(.+?)(?:\.|$|\n)"
1473 ).expect("valid regex"),
1474 tool_name: Some("read_file".to_string()),
1475 arg_key: Some("path".to_string()),
1476 arg_group: 1,
1477 },
1478 DetectorPattern {
1480 regex: regex_lite::Regex::new(
1481 r"(?i)(?:let\s+me|I'?ll|I\s+will)\s+(?:search|look\s+up|find|google)\s+(?:for\s+)?(.+?)(?:\.|$|\n)"
1482 ).expect("valid regex"),
1483 tool_name: Some("web_search".to_string()),
1484 arg_key: Some("query".to_string()),
1485 arg_group: 1,
1486 },
1487 DetectorPattern {
1489 regex: regex_lite::Regex::new(
1490 r"(?i)(?:let\s+me|I'?ll|I\s+will)\s+(?:fetch|get|download)\s+(https?://\S+)(?:\.|$|\n|\s)"
1491 ).expect("valid regex"),
1492 tool_name: Some("web_fetch".to_string()),
1493 arg_key: Some("url".to_string()),
1494 arg_group: 1,
1495 },
1496 ];
1497
1498 Self { patterns }
1499 }
1500
1501 pub fn detect(&self, text: &str) -> Vec<ToolCall> {
1505 let mut seen = std::collections::HashSet::new();
1506 let mut calls = Vec::new();
1507
1508 for pattern in &self.patterns {
1509 for cap in pattern.regex.captures_iter(text) {
1510 let tool_name = match &pattern.tool_name {
1511 Some(name) => name.clone(),
1512 None => cap
1513 .get(1)
1514 .map(|m| m.as_str().to_string())
1515 .unwrap_or_default(),
1516 };
1517
1518 if !Self::is_known_tool(&tool_name) {
1520 continue;
1521 }
1522
1523 let arg_value = cap
1524 .get(pattern.arg_group)
1525 .map(|m| m.as_str().trim().to_string())
1526 .unwrap_or_default();
1527
1528 if arg_value.is_empty() {
1529 continue;
1530 }
1531
1532 let arguments = match &pattern.arg_key {
1534 Some(key) => {
1535 serde_json::json!({ key: arg_value })
1536 }
1537 None => {
1538 serde_json::from_str(&arg_value).unwrap_or_else(
1540 |_| serde_json::json!({ "command": arg_value, "input": arg_value }),
1541 )
1542 }
1543 };
1544
1545 let key = format!("{}:{:?}", tool_name, arguments);
1547 if seen.contains(&key) {
1548 continue;
1549 }
1550 seen.insert(key);
1551
1552 calls.push(ToolCall {
1553 name: tool_name,
1554 arguments,
1555 id: None,
1556 });
1557 }
1558 }
1559
1560 calls
1561 }
1562
1563 fn is_known_tool(name: &str) -> bool {
1565 matches!(
1566 name,
1567 "shell_exec" | "read_file" | "write_file" | "web_fetch" | "web_search"
1568 )
1569 }
1570}
1571
1572impl Default for ToolCallDetector {
1573 fn default() -> Self {
1574 Self::new()
1575 }
1576}
1577
1578async fn run_shell_command(
1582 command: &str,
1583 timeout_secs: u64,
1584 workdir: Option<String>,
1585) -> ToolResultValue<ToolResult> {
1586 use tokio::process::Command;
1587
1588 let shell = if cfg!(target_os = "windows") {
1589 "cmd.exe"
1590 } else {
1591 "sh"
1592 };
1593 let flag = if cfg!(target_os = "windows") {
1594 "/C"
1595 } else {
1596 "-c"
1597 };
1598
1599 let mut cmd = Command::new(shell);
1600 cmd.arg(flag).arg(command);
1601
1602 if let Some(dir) = &workdir {
1603 cmd.current_dir(dir);
1604 }
1605
1606 let output = tokio::time::timeout(std::time::Duration::from_secs(timeout_secs), cmd.output())
1607 .await
1608 .map_err(|_| {
1609 ToolError::ExecutionFailed(
1610 "shell_exec".to_string(),
1611 format!("Command timed out after {} seconds", timeout_secs),
1612 )
1613 })?
1614 .map_err(|e| {
1615 ToolError::ExecutionFailed(
1616 "shell_exec".to_string(),
1617 format!("Failed to execute: {}", e),
1618 )
1619 })?;
1620
1621 let stdout = String::from_utf8_lossy(&output.stdout).to_string();
1622 let stderr = String::from_utf8_lossy(&output.stderr).to_string();
1623 let exit_code = output.status.code().unwrap_or(-1);
1624
1625 let mut output_text = String::new();
1626 if !stdout.is_empty() {
1627 output_text.push_str(&stdout);
1628 }
1629 if !stderr.is_empty() {
1630 if !output_text.is_empty() {
1631 output_text.push_str("\n--- stderr ---\n");
1632 }
1633 output_text.push_str(&stderr);
1634 }
1635
1636 const MAX_OUTPUT: usize = 65536;
1638 if output_text.len() > MAX_OUTPUT {
1639 output_text = format!(
1640 "{}...\n[truncated at {} bytes]",
1641 &output_text[..MAX_OUTPUT],
1642 MAX_OUTPUT
1643 );
1644 }
1645
1646 Ok(ToolResult {
1647 tool_name: "shell_exec".to_string(),
1648 success: exit_code == 0,
1649 output: output_text,
1650 error: if exit_code != 0 {
1651 Some(format!("Exit code: {}", exit_code))
1652 } else {
1653 None
1654 },
1655 exit_code: Some(exit_code),
1656 duration_ms: None,
1657 })
1658}
1659
1660#[cfg(test)]
1663mod tests {
1664 use super::*;
1665
1666 #[test]
1667 fn test_tool_registry_empty() {
1668 let registry = ToolRegistry::new();
1669 assert!(registry.is_empty());
1670 assert_eq!(registry.len(), 0);
1671 }
1672
1673 #[test]
1674 fn test_tool_registry_register() {
1675 let mut registry = ToolRegistry::new();
1676 registry.register(Arc::new(ShellTool::new()));
1677 assert!(!registry.is_empty());
1678 assert_eq!(registry.len(), 1);
1679 assert!(registry.has("shell_exec"));
1680 }
1681
1682 #[test]
1683 fn test_tool_registry_default_tools() {
1684 let registry = ToolRegistry::with_default_tools();
1685 assert_eq!(registry.len(), 5);
1686 assert!(registry.has("shell_exec"));
1687 assert!(registry.has("read_file"));
1688 assert!(registry.has("write_file"));
1689 assert!(registry.has("web_fetch"));
1690 assert!(registry.has("web_search"));
1691 }
1692
1693 #[test]
1694 fn test_tool_definitions() {
1695 let registry = ToolRegistry::with_default_tools();
1696 let defs = registry.definitions();
1697 assert_eq!(defs.len(), 5);
1698
1699 let shell_def = defs.iter().find(|d| d.name == "shell_exec").unwrap();
1700 assert!(shell_def.description.contains("shell command"));
1701 assert!(shell_def.requires_approval);
1702 assert_eq!(shell_def.category, ToolCategory::Shell);
1703 }
1704
1705 #[test]
1706 fn test_tool_not_found() {
1707 let registry = ToolRegistry::new();
1708 let result = registry.get("nonexistent");
1709 assert!(result.is_none());
1710 }
1711
1712 #[test]
1713 fn test_shell_tool_definition() {
1714 let tool = ShellTool::new();
1715 let def = tool.definition();
1716 assert_eq!(def.name, "shell_exec");
1717 assert!(def.requires_approval);
1718 }
1719
1720 #[test]
1721 fn test_read_file_tool_definition() {
1722 let tool = ReadFileTool::new();
1723 let def = tool.definition();
1724 assert_eq!(def.name, "read_file");
1725 assert!(!def.requires_approval);
1726 }
1727
1728 #[test]
1729 fn test_write_file_tool_definition() {
1730 let tool = WriteFileTool::new();
1731 let def = tool.definition();
1732 assert_eq!(def.name, "write_file");
1733 assert!(def.requires_approval);
1734 }
1735
1736 #[test]
1737 fn test_web_fetch_tool_definition() {
1738 let tool = WebFetchTool::new();
1739 let def = tool.definition();
1740 assert_eq!(def.name, "web_fetch");
1741 assert!(!def.requires_approval);
1742 }
1743
1744 #[test]
1745 fn test_tool_call_serialization() {
1746 let call = ToolCall {
1747 name: "shell_exec".to_string(),
1748 arguments: serde_json::json!({"command": "echo hello"}),
1749 id: Some("call_123".to_string()),
1750 };
1751
1752 let json = serde_json::to_string(&call).unwrap();
1753 assert!(json.contains("shell_exec"));
1754 assert!(json.contains("echo hello"));
1755 assert!(json.contains("call_123"));
1756 }
1757
1758 #[test]
1759 fn test_tool_result_serialization() {
1760 let result = ToolResult {
1761 tool_name: "shell_exec".to_string(),
1762 success: true,
1763 output: "hello\n".to_string(),
1764 error: None,
1765 exit_code: Some(0),
1766 duration_ms: Some(42),
1767 };
1768
1769 let json = serde_json::to_string(&result).unwrap();
1770 assert!(json.contains("shell_exec"));
1771 assert!(json.contains("hello"));
1772 assert!(json.contains("42"));
1773 }
1774
1775 #[test]
1776 fn test_tool_result_failure() {
1777 let result = ToolResult {
1778 tool_name: "shell_exec".to_string(),
1779 success: false,
1780 output: String::new(),
1781 error: Some("Exit code: 1".to_string()),
1782 exit_code: Some(1),
1783 duration_ms: Some(10),
1784 };
1785
1786 assert!(!result.success);
1787 assert_eq!(result.exit_code, Some(1));
1788 }
1789
1790 #[test]
1791 fn test_json_schema_string() {
1792 let schema = JsonSchema::string("A test string");
1793 assert_eq!(schema.schema_type, "string");
1794 assert_eq!(schema.description.unwrap(), "A test string");
1795 }
1796
1797 #[test]
1798 fn test_json_schema_object() {
1799 let mut props = HashMap::new();
1800 props.insert("name".to_string(), JsonSchema::string("The name"));
1801 let schema = JsonSchema::object(props, vec!["name".to_string()]);
1802 assert_eq!(schema.schema_type, "object");
1803 assert!(schema.properties.unwrap().contains_key("name"));
1804 }
1805
1806 #[test]
1807 fn test_tool_error_not_found() {
1808 let err = ToolError::NotFound("test_tool".to_string());
1809 assert_eq!(format!("{}", err), "Tool 'test_tool' not found");
1810 }
1811
1812 #[test]
1813 fn test_tool_error_execution_failed() {
1814 let err = ToolError::ExecutionFailed("test".to_string(), "oops".to_string());
1815 assert_eq!(format!("{}", err), "Tool 'test' execution failed: oops");
1816 }
1817
1818 #[test]
1819 fn test_tool_error_invalid_arguments() {
1820 let err = ToolError::InvalidArguments("test".to_string(), "bad arg".to_string());
1821 assert_eq!(
1822 format!("{}", err),
1823 "Invalid arguments for tool 'test': bad arg"
1824 );
1825 }
1826
1827 #[test]
1828 fn test_tool_error_policy_denied() {
1829 let err = ToolError::PolicyDenied("not allowed".to_string());
1830 assert_eq!(format!("{}", err), "Policy denied: not allowed");
1831 }
1832
1833 #[test]
1834 fn test_tool_error_sandbox_violation() {
1835 let err = ToolError::SandboxViolation("escape attempt".to_string());
1836 assert_eq!(format!("{}", err), "Sandbox violation: escape attempt");
1837 }
1838
1839 #[test]
1840 fn test_tool_category_default() {
1841 let cat = ToolCategory::default();
1842 assert_eq!(cat, ToolCategory::General);
1843 }
1844
1845 #[test]
1846 fn test_tool_category_serialization() {
1847 let cat = ToolCategory::Shell;
1848 let json = serde_json::to_string(&cat).unwrap();
1849 assert_eq!(json, "\"Shell\"");
1850 }
1851
1852 #[test]
1853 fn test_tool_definition_requires_approval_default() {
1854 let def = ToolDefinition {
1855 name: "test".to_string(),
1856 description: "test".to_string(),
1857 parameters: JsonSchema::string("test"),
1858 requires_approval: false,
1859 category: ToolCategory::General,
1860 };
1861 assert!(!def.requires_approval);
1862 }
1863
1864 #[tokio::test]
1865 async fn test_shell_exec_success() {
1866 let tool = ShellTool::new();
1867 let args = serde_json::json!({"command": "echo hello"});
1868 let result = tool.execute(args).await.unwrap();
1869 assert!(result.success);
1870 assert!(result.output.contains("hello"));
1871 assert_eq!(result.exit_code, Some(0));
1872 }
1873
1874 #[tokio::test]
1875 async fn test_shell_exec_failure() {
1876 let tool = ShellTool::new();
1877 let args = serde_json::json!({"command": "exit 42"});
1878 let result = tool.execute(args).await.unwrap();
1879 assert!(!result.success);
1880 assert_eq!(result.exit_code, Some(42));
1881 }
1882
1883 #[tokio::test]
1884 async fn test_shell_exec_missing_command() {
1885 let tool = ShellTool::new();
1886 let args = serde_json::json!({});
1887 let err = tool.execute(args).await.unwrap_err();
1888 assert!(matches!(err, ToolError::InvalidArguments(_, _)));
1889 }
1890
1891 #[tokio::test]
1892 async fn test_read_file_not_found() {
1893 let tool = ReadFileTool::new();
1894 let args = serde_json::json!({"path": "/tmp/nonexistent_file_ravenclaws_test"});
1895 let result = tool.execute(args).await;
1896 assert!(result.is_err());
1897 assert!(matches!(
1898 result.unwrap_err(),
1899 ToolError::ExecutionFailed(_, _)
1900 ));
1901 }
1902
1903 #[tokio::test]
1904 async fn test_read_file_missing_path() {
1905 let tool = ReadFileTool::new();
1906 let args = serde_json::json!({});
1907 let err = tool.execute(args).await.unwrap_err();
1908 assert!(matches!(err, ToolError::InvalidArguments(_, _)));
1909 }
1910
1911 #[tokio::test]
1912 async fn test_write_file_missing_args() {
1913 let tool = WriteFileTool::new();
1914 let args = serde_json::json!({});
1915 let err = tool.execute(args).await.unwrap_err();
1916 assert!(matches!(err, ToolError::InvalidArguments(_, _)));
1917 }
1918
1919 #[tokio::test]
1920 async fn test_web_fetch_missing_url() {
1921 let tool = WebFetchTool::new();
1922 let args = serde_json::json!({});
1923 let err = tool.execute(args).await.unwrap_err();
1924 assert!(matches!(err, ToolError::InvalidArguments(_, _)));
1925 }
1926
1927 #[tokio::test]
1928 async fn test_write_and_read_file() {
1929 let dir = std::env::temp_dir().join(format!("ravenclaws_test_{}", std::process::id()));
1930 let path = dir.join("test_write.txt");
1931 let path_str = path.to_string_lossy().to_string();
1932
1933 let write_tool = WriteFileTool::new();
1935 let args = serde_json::json!({"path": path_str, "content": "Hello, RavenClaws!"});
1936 let result = write_tool.execute(args).await.unwrap();
1937 assert!(result.success);
1938 assert!(result.output.contains("18 bytes"));
1939
1940 let read_tool = ReadFileTool::new();
1942 let args = serde_json::json!({"path": path_str});
1943 let result = read_tool.execute(args).await.unwrap();
1944 assert!(result.success);
1945 assert_eq!(result.output.trim(), "Hello, RavenClaws!");
1946
1947 let _ = tokio::fs::remove_file(&path).await;
1949 let _ = tokio::fs::remove_dir(dir).await;
1950 }
1951
1952 #[tokio::test]
1953 async fn test_write_file_append() {
1954 let dir = std::env::temp_dir().join(format!("ravenclaws_test_{}", std::process::id()));
1955 let path = dir.join("test_append.txt");
1956 let path_str = path.to_string_lossy().to_string();
1957
1958 let write_tool = WriteFileTool::new();
1960 let args = serde_json::json!({"path": path_str, "content": "line1\n"});
1961 write_tool.execute(args).await.unwrap();
1962
1963 let args = serde_json::json!({"path": path_str, "content": "line2\n", "append": true});
1965 let result = write_tool.execute(args).await.unwrap();
1966 assert!(result.success);
1967
1968 let read_tool = ReadFileTool::new();
1970 let args = serde_json::json!({"path": path_str});
1971 let result = read_tool.execute(args).await.unwrap();
1972 assert!(result.success);
1973 assert!(result.output.contains("line1"));
1974 assert!(result.output.contains("line2"));
1975
1976 let _ = tokio::fs::remove_file(&path).await;
1978 let _ = tokio::fs::remove_dir(dir).await;
1979 }
1980
1981 #[tokio::test]
1982 async fn test_tool_registry_execute() {
1983 let registry = ToolRegistry::with_default_tools();
1984 let call = ToolCall {
1985 name: "shell_exec".to_string(),
1986 arguments: serde_json::json!({"command": "echo hello"}),
1987 id: None,
1988 };
1989 let result = registry.execute(call).await.unwrap();
1990 assert!(result.success);
1991 assert!(result.output.contains("hello"));
1992 }
1993
1994 #[tokio::test]
1995 async fn test_tool_registry_execute_not_found() {
1996 let registry = ToolRegistry::new();
1997 let call = ToolCall {
1998 name: "nonexistent".to_string(),
1999 arguments: serde_json::json!({}),
2000 id: None,
2001 };
2002 let err = registry.execute(call).await.unwrap_err();
2003 assert!(matches!(err, ToolError::NotFound(_)));
2004 }
2005
2006 #[test]
2009 fn test_web_search_tool_definition() {
2010 let tool = WebSearchTool::new();
2011 let def = tool.definition();
2012 assert_eq!(def.name, "web_search");
2013 assert!(!def.requires_approval);
2014 assert_eq!(def.category, ToolCategory::WebSearch);
2015 assert!(def.description.contains("Search the web"));
2016 }
2017
2018 #[test]
2019 fn test_web_search_tool_with_config() {
2020 let tool = WebSearchTool::with_config(
2021 "http://localhost:8888".to_string(),
2022 "searxng".to_string(),
2023 10,
2024 false,
2025 );
2026 let def = tool.definition();
2027 assert_eq!(def.name, "web_search");
2028 assert_eq!(tool.search_endpoint, "http://localhost:8888");
2029 assert_eq!(tool.search_engine, "searxng");
2030 assert_eq!(tool.max_results, 10);
2031 assert!(!tool.fetch_content);
2032 }
2033
2034 #[tokio::test]
2035 async fn test_web_search_missing_query() {
2036 let tool = WebSearchTool::new();
2037 let args = serde_json::json!({});
2038 let err = tool.execute(args).await.unwrap_err();
2039 assert!(matches!(err, ToolError::InvalidArguments(_, _)));
2040 }
2041
2042 #[test]
2043 fn test_web_search_tool_registry() {
2044 let registry = ToolRegistry::with_default_tools();
2045 assert!(registry.has("web_search"));
2046 let defs = registry.definitions();
2047 let search_def = defs.iter().find(|d| d.name == "web_search").unwrap();
2048 assert_eq!(search_def.category, ToolCategory::WebSearch);
2049 }
2050
2051 #[test]
2052 fn test_web_search_tool_with_config_registry() {
2053 let registry =
2054 ToolRegistry::with_web_search_config("http://localhost:8888", "searxng", 10, false);
2055 assert!(registry.has("web_search"));
2056 assert!(registry.has("shell_exec"));
2057 assert!(registry.has("read_file"));
2058 assert!(registry.has("write_file"));
2059 assert!(registry.has("web_fetch"));
2060 assert_eq!(registry.len(), 5);
2061 }
2062
2063 #[test]
2066 fn test_html_to_text_strips_tags() {
2067 let html = "<html><body><p>Hello, world!</p></body></html>";
2068 let text = html_to_text(html, 1000);
2069 assert!(text.contains("Hello, world!"));
2070 assert!(!text.contains("<p>"));
2071 assert!(!text.contains("</p>"));
2072 }
2073
2074 #[test]
2075 fn test_html_to_text_extracts_title() {
2076 let html = "<html><head><title>Test Page</title></head><body><p>Content</p></body></html>";
2077 let text = html_to_text(html, 1000);
2078 assert!(text.contains("Test Page"));
2079 assert!(text.contains("Content"));
2080 }
2081
2082 #[test]
2083 fn test_html_to_text_strips_script_and_style() {
2084 let html = "<html><head><script>alert('xss');</script><style>.cls{}</style></head><body><p>Visible</p></body></html>";
2085 let text = html_to_text(html, 1000);
2086 assert!(text.contains("Visible"));
2087 assert!(!text.contains("alert"));
2088 assert!(!text.contains(".cls"));
2089 }
2090
2091 #[test]
2092 fn test_html_to_text_handles_entities() {
2093 let html = "<p>foo & bar < baz > qux</p>";
2094 let text = html_to_text(html, 1000);
2095 assert!(text.contains("foo & bar < baz > qux") || text.contains("foo & bar"));
2096 }
2097
2098 #[test]
2099 fn test_html_to_text_respects_max_chars() {
2100 let html = "<p>Hello World This Is A Test</p>";
2101 let text = html_to_text(html, 5);
2102 assert!(text.len() <= 5);
2103 }
2104
2105 #[test]
2106 fn test_html_to_text_empty_input() {
2107 assert_eq!(html_to_text("", 1000), "");
2108 }
2109
2110 #[test]
2111 fn test_html_to_text_no_html() {
2112 let text = html_to_text("Just plain text", 1000);
2113 assert_eq!(text, "Just plain text");
2114 }
2115
2116 #[test]
2117 fn test_strip_html_tags_basic() {
2118 let result = strip_html_tags("<b>bold</b> and <i>italic</i>");
2119 assert_eq!(result, "bold and italic");
2120 }
2121
2122 #[test]
2123 fn test_strip_html_tags_with_entities() {
2124 let result = strip_html_tags("foo & bar < baz");
2125 assert_eq!(result, "foo & bar < baz");
2126 }
2127
2128 #[test]
2129 fn test_extract_href_basic() {
2130 let result = extract_href(r#"<a href="https://example.com">link</a>"#);
2131 assert_eq!(result, Some("https://example.com".to_string()));
2132 }
2133
2134 #[test]
2135 fn test_extract_href_protocol_relative() {
2136 let result = extract_href(r#"<a href="//example.com/path">link</a>"#);
2137 assert_eq!(result, Some("https://example.com/path".to_string()));
2138 }
2139
2140 #[test]
2141 fn test_extract_href_relative() {
2142 let result = extract_href(r#"<a href="/relative/path">link</a>"#);
2143 assert_eq!(result, None);
2144 }
2145
2146 #[test]
2147 fn test_extract_href_no_match() {
2148 let result = extract_href("<span>no link here</span>");
2149 assert_eq!(result, None);
2150 }
2151
2152 #[test]
2153 fn test_urlencoding_basic() {
2154 assert_eq!(urlencoding("hello world"), "hello%20world");
2155 assert_eq!(urlencoding("foo/bar"), "foo%2Fbar");
2156 assert_eq!(urlencoding("simple"), "simple");
2157 }
2158
2159 #[test]
2160 fn test_fetch_and_extract_content_invalid_url() {
2161 let result = tokio_test::block_on(fetch_and_extract_content("http://0.0.0.0:1", 1000));
2162 assert!(result.is_err());
2163 }
2164
2165 #[test]
2168 fn test_tool_call_detector_shell_exec() {
2169 let detector = ToolCallDetector::new();
2170 let text = "I'll use the shell_exec tool to run: ls -la";
2171 let calls = detector.detect(text);
2172 assert_eq!(calls.len(), 1, "Should detect one tool call");
2173 assert_eq!(calls[0].name, "shell_exec");
2174 assert_eq!(calls[0].arguments["command"], "ls -la");
2175 }
2176
2177 #[test]
2178 fn test_tool_call_detector_read_file() {
2179 let detector = ToolCallDetector::new();
2180 let text = "Let me read the file /etc/hostname";
2181 let calls = detector.detect(text);
2182 assert_eq!(calls.len(), 1, "Should detect one tool call");
2183 assert_eq!(calls[0].name, "read_file");
2184 assert_eq!(calls[0].arguments["path"], "/etc/hostname");
2185 }
2186
2187 #[test]
2188 fn test_tool_call_detector_web_search() {
2189 let detector = ToolCallDetector::new();
2190 let text = "I'll search for Rust programming language";
2191 let calls = detector.detect(text);
2192 assert_eq!(calls.len(), 1, "Should detect one tool call");
2193 assert_eq!(calls[0].name, "web_search");
2194 assert!(calls[0].arguments["query"]
2195 .as_str()
2196 .unwrap()
2197 .contains("Rust"));
2198 }
2199
2200 #[test]
2201 fn test_tool_call_detector_web_fetch() {
2202 let detector = ToolCallDetector::new();
2203 let text = "I'll fetch https://example.com/api";
2204 let calls = detector.detect(text);
2205 assert_eq!(calls.len(), 1, "Should detect one tool call");
2206 assert_eq!(calls[0].name, "web_fetch");
2207 assert_eq!(calls[0].arguments["url"], "https://example.com/api");
2208 }
2209
2210 #[test]
2211 fn test_tool_call_detector_use_tool_syntax() {
2212 let detector = ToolCallDetector::new();
2213 let text = "Use the shell_exec tool with args: echo hello world";
2214 let calls = detector.detect(text);
2215 assert_eq!(calls.len(), 1, "Should detect one tool call");
2216 assert_eq!(calls[0].name, "shell_exec");
2217 }
2218
2219 #[test]
2220 fn test_tool_call_detector_no_false_positives() {
2221 let detector = ToolCallDetector::new();
2222 let text = "I think we should consider using a different approach here.";
2223 let calls = detector.detect(text);
2224 assert_eq!(calls.len(), 0, "Should not detect any tool calls");
2225 }
2226
2227 #[test]
2228 fn test_tool_call_detector_empty_text() {
2229 let detector = ToolCallDetector::new();
2230 let calls = detector.detect("");
2231 assert_eq!(calls.len(), 0);
2232 }
2233
2234 #[test]
2235 fn test_tool_call_detector_multiple_calls() {
2236 let detector = ToolCallDetector::new();
2237 let text = "Let me read the file /etc/hosts. Then I'll search for DNS configuration.";
2238 let calls = detector.detect(text);
2239 assert_eq!(calls.len(), 2, "Should detect two tool calls");
2240 assert_eq!(calls[0].name, "read_file");
2241 assert_eq!(calls[1].name, "web_search");
2242 }
2243
2244 #[test]
2245 fn test_tool_call_detector_unknown_tool_skipped() {
2246 let detector = ToolCallDetector::new();
2247 let text = "Use the nonexistent_tool tool with args: something";
2248 let calls = detector.detect(text);
2249 assert_eq!(calls.len(), 0, "Should skip unknown tools");
2250 }
2251
2252 #[test]
2253 fn test_tool_call_detector_is_known_tool() {
2254 assert!(ToolCallDetector::is_known_tool("shell_exec"));
2255 assert!(ToolCallDetector::is_known_tool("read_file"));
2256 assert!(ToolCallDetector::is_known_tool("write_file"));
2257 assert!(ToolCallDetector::is_known_tool("web_fetch"));
2258 assert!(ToolCallDetector::is_known_tool("web_search"));
2259 assert!(!ToolCallDetector::is_known_tool("unknown_tool"));
2260 }
2261
2262 #[test]
2263 fn test_tool_call_detector_default() {
2264 let detector = ToolCallDetector::default();
2265 let calls = detector.detect("I'll use the shell_exec tool to run: echo test");
2266 assert_eq!(calls.len(), 1);
2267 }
2268}