Skip to main content

mofa_extra/rhai/
tools.rs

1//! Rhai 动态工具系统
2//!
3//! 允许通过 Rhai 脚本动态定义和执行工具,实现:
4//! - 脚本化的工具定义
5//! - 运行时工具注册
6//! - 工具参数验证
7//! - 工具执行沙箱
8
9use super::engine::{RhaiScriptEngine, ScriptContext, ScriptEngineConfig};
10use anyhow::{Result, anyhow};
11#[allow(unused_imports)]
12use rhai::{Dynamic, Engine, Map, Scope};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17use tracing::info;
18
19// ============================================================================
20// 工具参数定义
21// ============================================================================
22
23/// 参数类型
24#[derive(Debug, Clone, Serialize, Deserialize)]
25#[serde(rename_all = "lowercase")]
26#[derive(Default)]
27pub enum ParameterType {
28    #[default]
29    String,
30    Integer,
31    Float,
32    Boolean,
33    Array,
34    Object,
35    Any,
36}
37
38/// 工具参数定义
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ToolParameter {
41    /// 参数名称
42    pub name: String,
43    /// 参数类型
44    #[serde(default)]
45    pub param_type: ParameterType,
46    /// 参数描述
47    #[serde(default)]
48    pub description: String,
49    /// 是否必需
50    #[serde(default)]
51    pub required: bool,
52    /// 默认值
53    pub default: Option<serde_json::Value>,
54    /// 枚举值(如果有)
55    pub enum_values: Option<Vec<serde_json::Value>>,
56    /// 最小值(数字类型)
57    pub minimum: Option<f64>,
58    /// 最大值(数字类型)
59    pub maximum: Option<f64>,
60    /// 最小长度(字符串/数组)
61    pub min_length: Option<usize>,
62    /// 最大长度(字符串/数组)
63    pub max_length: Option<usize>,
64    /// 正则表达式模式(字符串)
65    pub pattern: Option<String>,
66}
67
68impl ToolParameter {
69    pub fn new(name: &str, param_type: ParameterType) -> Self {
70        Self {
71            name: name.to_string(),
72            param_type,
73            description: String::new(),
74            required: false,
75            default: None,
76            enum_values: None,
77            minimum: None,
78            maximum: None,
79            min_length: None,
80            max_length: None,
81            pattern: None,
82        }
83    }
84
85    pub fn required(mut self) -> Self {
86        self.required = true;
87        self
88    }
89
90    pub fn with_description(mut self, desc: &str) -> Self {
91        self.description = desc.to_string();
92        self
93    }
94
95    pub fn with_default<T: Serialize>(mut self, value: T) -> Self {
96        self.default = serde_json::to_value(value).ok();
97        self
98    }
99
100    pub fn with_enum(mut self, values: Vec<serde_json::Value>) -> Self {
101        self.enum_values = Some(values);
102        self
103    }
104
105    pub fn with_range(mut self, min: f64, max: f64) -> Self {
106        self.minimum = Some(min);
107        self.maximum = Some(max);
108        self
109    }
110
111    /// 验证参数值
112    pub fn validate(&self, value: &serde_json::Value) -> Result<()> {
113        // 检查类型
114        match (&self.param_type, value) {
115            (ParameterType::String, serde_json::Value::String(_)) => {}
116            (ParameterType::Integer, serde_json::Value::Number(n)) if n.is_i64() => {}
117            (ParameterType::Float, serde_json::Value::Number(_)) => {}
118            (ParameterType::Boolean, serde_json::Value::Bool(_)) => {}
119            (ParameterType::Array, serde_json::Value::Array(_)) => {}
120            (ParameterType::Object, serde_json::Value::Object(_)) => {}
121            (ParameterType::Any, _) => {}
122            (ParameterType::String, serde_json::Value::Null) if !self.required => {}
123            _ => {
124                return Err(anyhow!(
125                    "Parameter '{}' has invalid type, expected {:?}",
126                    self.name,
127                    self.param_type
128                ));
129            }
130        }
131
132        // 检查枚举值
133        if let Some(ref enum_values) = self.enum_values
134            && !enum_values.contains(value)
135        {
136            return Err(anyhow!(
137                "Parameter '{}' value must be one of {:?}",
138                self.name,
139                enum_values
140            ));
141        }
142
143        // 检查数值范围
144        if let serde_json::Value::Number(n) = value
145            && let Some(f) = n.as_f64()
146        {
147            if let Some(min) = self.minimum
148                && f < min
149            {
150                return Err(anyhow!("Parameter '{}' must be >= {}", self.name, min));
151            }
152            if let Some(max) = self.maximum
153                && f > max
154            {
155                return Err(anyhow!("Parameter '{}' must be <= {}", self.name, max));
156            }
157        }
158
159        // 检查字符串长度
160        if let serde_json::Value::String(s) = value {
161            if let Some(min) = self.min_length
162                && s.len() < min
163            {
164                return Err(anyhow!(
165                    "Parameter '{}' length must be >= {}",
166                    self.name,
167                    min
168                ));
169            }
170            if let Some(max) = self.max_length
171                && s.len() > max
172            {
173                return Err(anyhow!(
174                    "Parameter '{}' length must be <= {}",
175                    self.name,
176                    max
177                ));
178            }
179            // 检查正则表达式
180            if let Some(ref pattern) = self.pattern {
181                let re = regex::Regex::new(pattern)
182                    .map_err(|e| anyhow!("Invalid regex pattern: {}", e))?;
183                if !re.is_match(s) {
184                    return Err(anyhow!(
185                        "Parameter '{}' does not match pattern: {}",
186                        self.name,
187                        pattern
188                    ));
189                }
190            }
191        }
192
193        // 检查数组长度
194        if let serde_json::Value::Array(arr) = value {
195            if let Some(min) = self.min_length
196                && arr.len() < min
197            {
198                return Err(anyhow!(
199                    "Parameter '{}' array length must be >= {}",
200                    self.name,
201                    min
202                ));
203            }
204            if let Some(max) = self.max_length
205                && arr.len() > max
206            {
207                return Err(anyhow!(
208                    "Parameter '{}' array length must be <= {}",
209                    self.name,
210                    max
211                ));
212            }
213        }
214
215        Ok(())
216    }
217}
218
219// ============================================================================
220// 脚本工具定义
221// ============================================================================
222
223/// 脚本工具定义
224#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct ScriptToolDefinition {
226    /// 工具 ID
227    pub id: String,
228    /// 工具名称
229    pub name: String,
230    /// 工具描述
231    pub description: String,
232    /// 参数定义
233    pub parameters: Vec<ToolParameter>,
234    /// 脚本源代码
235    pub script: String,
236    /// 入口函数名(默认 "execute")
237    #[serde(default = "default_entry_function")]
238    pub entry_function: String,
239    /// 是否启用缓存
240    #[serde(default = "default_true")]
241    pub enable_cache: bool,
242    /// 超时时间(毫秒)
243    #[serde(default = "default_timeout")]
244    pub timeout_ms: u64,
245    /// 工具标签
246    #[serde(default)]
247    pub tags: Vec<String>,
248    /// 元数据
249    #[serde(default)]
250    pub metadata: HashMap<String, String>,
251}
252
253fn default_entry_function() -> String {
254    "execute".to_string()
255}
256
257fn default_true() -> bool {
258    true
259}
260
261fn default_timeout() -> u64 {
262    30000
263}
264
265impl ScriptToolDefinition {
266    pub fn new(id: &str, name: &str, script: &str) -> Self {
267        Self {
268            id: id.to_string(),
269            name: name.to_string(),
270            description: String::new(),
271            parameters: Vec::new(),
272            script: script.to_string(),
273            entry_function: "execute".to_string(),
274            enable_cache: true,
275            timeout_ms: 30000,
276            tags: Vec::new(),
277            metadata: HashMap::new(),
278        }
279    }
280
281    pub fn with_description(mut self, desc: &str) -> Self {
282        self.description = desc.to_string();
283        self
284    }
285
286    pub fn with_parameter(mut self, param: ToolParameter) -> Self {
287        self.parameters.push(param);
288        self
289    }
290
291    pub fn with_entry(mut self, function: &str) -> Self {
292        self.entry_function = function.to_string();
293        self
294    }
295
296    pub fn with_tag(mut self, tag: &str) -> Self {
297        self.tags.push(tag.to_string());
298        self
299    }
300
301    /// 验证输入参数
302    pub fn validate_input(&self, input: &HashMap<String, serde_json::Value>) -> Result<()> {
303        for param in &self.parameters {
304            if let Some(value) = input.get(&param.name) {
305                param.validate(value)?;
306            } else if param.required && param.default.is_none() {
307                return Err(anyhow!("Required parameter '{}' is missing", param.name));
308            }
309        }
310        Ok(())
311    }
312
313    /// 获取带默认值的输入
314    pub fn apply_defaults(&self, input: &mut HashMap<String, serde_json::Value>) {
315        for param in &self.parameters {
316            if !input.contains_key(&param.name)
317                && let Some(ref default) = param.default
318            {
319                input.insert(param.name.clone(), default.clone());
320            }
321        }
322    }
323
324    /// 生成 JSON Schema 格式的参数描述
325    pub fn to_json_schema(&self) -> serde_json::Value {
326        let mut properties = serde_json::Map::new();
327        let mut required = Vec::new();
328
329        for param in &self.parameters {
330            let mut prop = serde_json::Map::new();
331
332            let type_str = match param.param_type {
333                ParameterType::String => "string",
334                ParameterType::Integer => "integer",
335                ParameterType::Float => "number",
336                ParameterType::Boolean => "boolean",
337                ParameterType::Array => "array",
338                ParameterType::Object => "object",
339                ParameterType::Any => "any",
340            };
341
342            prop.insert("type".to_string(), serde_json::json!(type_str));
343
344            if !param.description.is_empty() {
345                prop.insert(
346                    "description".to_string(),
347                    serde_json::json!(param.description),
348                );
349            }
350
351            if let Some(ref enum_values) = param.enum_values {
352                prop.insert("enum".to_string(), serde_json::json!(enum_values));
353            }
354
355            if let Some(min) = param.minimum {
356                prop.insert("minimum".to_string(), serde_json::json!(min));
357            }
358
359            if let Some(max) = param.maximum {
360                prop.insert("maximum".to_string(), serde_json::json!(max));
361            }
362
363            properties.insert(param.name.clone(), serde_json::Value::Object(prop));
364
365            if param.required {
366                required.push(param.name.clone());
367            }
368        }
369
370        serde_json::json!({
371            "type": "object",
372            "properties": properties,
373            "required": required
374        })
375    }
376}
377
378// ============================================================================
379// 工具执行结果
380// ============================================================================
381
382/// 工具执行结果
383#[derive(Debug, Clone, Serialize, Deserialize)]
384pub struct ToolExecutionResult {
385    /// 工具 ID
386    pub tool_id: String,
387    /// 是否成功
388    pub success: bool,
389    /// 返回值
390    pub result: serde_json::Value,
391    /// 错误信息
392    pub error: Option<String>,
393    /// 执行时间(毫秒)
394    pub execution_time_ms: u64,
395    /// 执行日志
396    pub logs: Vec<String>,
397}
398
399// ============================================================================
400// 脚本工具注册表
401// ============================================================================
402
403/// 脚本工具注册表
404pub struct ScriptToolRegistry {
405    /// 脚本引擎
406    engine: Arc<RhaiScriptEngine>,
407    /// 已注册的工具
408    tools: Arc<RwLock<HashMap<String, ScriptToolDefinition>>>,
409}
410
411impl ScriptToolRegistry {
412    /// 创建工具注册表
413    pub fn new(engine_config: ScriptEngineConfig) -> Result<Self> {
414        let engine = Arc::new(RhaiScriptEngine::new(engine_config)?);
415        Ok(Self {
416            engine,
417            tools: Arc::new(RwLock::new(HashMap::new())),
418        })
419    }
420
421    /// 使用已有引擎创建注册表
422    pub fn with_engine(engine: Arc<RhaiScriptEngine>) -> Self {
423        Self {
424            engine,
425            tools: Arc::new(RwLock::new(HashMap::new())),
426        }
427    }
428
429    /// 注册工具
430    pub async fn register(&self, tool: ScriptToolDefinition) -> Result<()> {
431        // 预编译脚本(如果启用缓存)
432        if tool.enable_cache {
433            let script_id = format!("tool_{}", tool.id);
434            self.engine
435                .compile_and_cache(&script_id, &tool.name, &tool.script)
436                .await?;
437        }
438
439        // 注册到工具表
440        let mut tools = self.tools.write().await;
441        info!("Registered script tool: {} ({})", tool.name, tool.id);
442        tools.insert(tool.id.clone(), tool);
443
444        Ok(())
445    }
446
447    /// 批量注册工具
448    pub async fn register_batch(&self, tools: Vec<ScriptToolDefinition>) -> Result<Vec<String>> {
449        let mut registered = Vec::new();
450        for tool in tools {
451            let id = tool.id.clone();
452            self.register(tool).await?;
453            registered.push(id);
454        }
455        Ok(registered)
456    }
457
458    /// 从 YAML 文件加载工具
459    pub async fn load_from_yaml(&self, path: &str) -> Result<String> {
460        let content = tokio::fs::read_to_string(path).await?;
461        let tool: ScriptToolDefinition = serde_yaml::from_str(&content)?;
462        let id = tool.id.clone();
463        self.register(tool).await?;
464        Ok(id)
465    }
466
467    /// 从 JSON 文件加载工具
468    pub async fn load_from_json(&self, path: &str) -> Result<String> {
469        let content = tokio::fs::read_to_string(path).await?;
470        let tool: ScriptToolDefinition = serde_json::from_str(&content)?;
471        let id = tool.id.clone();
472        self.register(tool).await?;
473        Ok(id)
474    }
475
476    /// 从目录加载所有工具
477    pub async fn load_from_directory(&self, dir_path: &str) -> Result<Vec<String>> {
478        let mut loaded = Vec::new();
479        let mut entries = tokio::fs::read_dir(dir_path).await?;
480
481        while let Some(entry) = entries.next_entry().await? {
482            let path = entry.path();
483            if let Some(ext) = path.extension() {
484                let id = match ext.to_str() {
485                    Some("yaml") | Some("yml") => {
486                        self.load_from_yaml(path.to_str().unwrap()).await.ok()
487                    }
488                    Some("json") => self.load_from_json(path.to_str().unwrap()).await.ok(),
489                    _ => None,
490                };
491                if let Some(id) = id {
492                    loaded.push(id);
493                }
494            }
495        }
496
497        info!("Loaded {} tools from directory: {}", loaded.len(), dir_path);
498        Ok(loaded)
499    }
500
501    /// 执行工具
502    pub async fn execute(
503        &self,
504        tool_id: &str,
505        input: HashMap<String, serde_json::Value>,
506    ) -> Result<ToolExecutionResult> {
507        let start_time = std::time::Instant::now();
508
509        // 获取工具定义
510        let tools = self.tools.read().await;
511        let tool = tools
512            .get(tool_id)
513            .ok_or_else(|| anyhow!("Tool not found: {}", tool_id))?
514            .clone();
515        drop(tools);
516
517        // 准备输入
518        let mut params = input;
519        tool.apply_defaults(&mut params);
520
521        // 验证输入
522        tool.validate_input(&params)?;
523
524        // 准备上下文
525        let mut context = ScriptContext::new();
526        for (key, value) in &params {
527            context.set_variable(key, value.clone())?;
528        }
529
530        // 将所有参数作为一个 object 传入
531        context.set_variable("params", serde_json::json!(params))?;
532
533        // 执行脚本
534        let script_id = format!("tool_{}", tool_id);
535
536        if tool.enable_cache {
537            // 尝试调用入口函数
538            let input_value = serde_json::json!(params);
539            match self
540                .engine
541                .call_function::<serde_json::Value>(
542                    &script_id,
543                    &tool.entry_function,
544                    vec![input_value],
545                    &context,
546                )
547                .await
548            {
549                Ok(value) => Ok(ToolExecutionResult {
550                    tool_id: tool_id.to_string(),
551                    success: true,
552                    result: value,
553                    error: None,
554                    execution_time_ms: start_time.elapsed().as_millis() as u64,
555                    logs: Vec::new(),
556                }),
557                Err(_e) => {
558                    // 如果函数调用失败,尝试直接执行
559                    let script_result = self.engine.execute_compiled(&script_id, &context).await?;
560                    if script_result.success {
561                        Ok(ToolExecutionResult {
562                            tool_id: tool_id.to_string(),
563                            success: true,
564                            result: script_result.value,
565                            error: None,
566                            execution_time_ms: start_time.elapsed().as_millis() as u64,
567                            logs: script_result.logs,
568                        })
569                    } else {
570                        Ok(ToolExecutionResult {
571                            tool_id: tool_id.to_string(),
572                            success: false,
573                            result: serde_json::Value::Null,
574                            error: script_result.error,
575                            execution_time_ms: start_time.elapsed().as_millis() as u64,
576                            logs: script_result.logs,
577                        })
578                    }
579                }
580            }
581        } else {
582            let script_result = self.engine.execute(&tool.script, &context).await?;
583            Ok(ToolExecutionResult {
584                tool_id: tool_id.to_string(),
585                success: script_result.success,
586                result: script_result.value,
587                error: script_result.error,
588                execution_time_ms: start_time.elapsed().as_millis() as u64,
589                logs: script_result.logs,
590            })
591        }
592    }
593
594    /// 获取工具定义
595    pub async fn get_tool(&self, tool_id: &str) -> Option<ScriptToolDefinition> {
596        let tools = self.tools.read().await;
597        tools.get(tool_id).cloned()
598    }
599
600    /// 列出所有工具
601    pub async fn list_tools(&self) -> Vec<ScriptToolDefinition> {
602        let tools = self.tools.read().await;
603        tools.values().cloned().collect()
604    }
605
606    /// 按标签过滤工具
607    pub async fn list_tools_by_tag(&self, tag: &str) -> Vec<ScriptToolDefinition> {
608        let tools = self.tools.read().await;
609        tools
610            .values()
611            .filter(|t| t.tags.contains(&tag.to_string()))
612            .cloned()
613            .collect()
614    }
615
616    /// 移除工具
617    pub async fn unregister(&self, tool_id: &str) -> bool {
618        let mut tools = self.tools.write().await;
619        let removed = tools.remove(tool_id).is_some();
620
621        if removed {
622            // 清除缓存的脚本
623            let script_id = format!("tool_{}", tool_id);
624            self.engine.remove_cached(&script_id).await;
625            info!("Unregistered script tool: {}", tool_id);
626        }
627
628        removed
629    }
630
631    /// 清空所有工具
632    pub async fn clear(&self) {
633        let mut tools = self.tools.write().await;
634        tools.clear();
635        self.engine.clear_cache().await;
636    }
637
638    /// 获取工具数量
639    pub async fn tool_count(&self) -> usize {
640        let tools = self.tools.read().await;
641        tools.len()
642    }
643
644    /// 生成所有工具的 JSON Schema 描述(用于 LLM function calling)
645    pub async fn generate_tool_schemas(&self) -> Vec<serde_json::Value> {
646        let tools = self.tools.read().await;
647        tools
648            .values()
649            .map(|tool| {
650                serde_json::json!({
651                    "name": tool.name,
652                    "description": tool.description,
653                    "parameters": tool.to_json_schema()
654                })
655            })
656            .collect()
657    }
658}
659
660// ============================================================================
661// 便捷构建器
662// ============================================================================
663
664/// 工具定义构建器
665pub struct ToolBuilder {
666    definition: ScriptToolDefinition,
667}
668
669impl ToolBuilder {
670    pub fn new(id: &str, name: &str) -> Self {
671        Self {
672            definition: ScriptToolDefinition::new(id, name, ""),
673        }
674    }
675
676    pub fn description(mut self, desc: &str) -> Self {
677        self.definition.description = desc.to_string();
678        self
679    }
680
681    pub fn script(mut self, script: &str) -> Self {
682        self.definition.script = script.to_string();
683        self
684    }
685
686    pub fn entry(mut self, function: &str) -> Self {
687        self.definition.entry_function = function.to_string();
688        self
689    }
690
691    pub fn param(mut self, param: ToolParameter) -> Self {
692        self.definition.parameters.push(param);
693        self
694    }
695
696    pub fn string_param(self, name: &str, required: bool) -> Self {
697        let mut param = ToolParameter::new(name, ParameterType::String);
698        if required {
699            param = param.required();
700        }
701        self.param(param)
702    }
703
704    pub fn int_param(self, name: &str, required: bool) -> Self {
705        let mut param = ToolParameter::new(name, ParameterType::Integer);
706        if required {
707            param = param.required();
708        }
709        self.param(param)
710    }
711
712    pub fn bool_param(self, name: &str, required: bool) -> Self {
713        let mut param = ToolParameter::new(name, ParameterType::Boolean);
714        if required {
715            param = param.required();
716        }
717        self.param(param)
718    }
719
720    pub fn tag(mut self, tag: &str) -> Self {
721        self.definition.tags.push(tag.to_string());
722        self
723    }
724
725    pub fn timeout(mut self, timeout_ms: u64) -> Self {
726        self.definition.timeout_ms = timeout_ms;
727        self
728    }
729
730    pub fn build(self) -> ScriptToolDefinition {
731        self.definition
732    }
733}
734
735// ============================================================================
736// 测试
737// ============================================================================
738
739#[cfg(test)]
740mod tests {
741    use super::*;
742
743    #[tokio::test]
744    async fn test_tool_registration() {
745        let registry = ScriptToolRegistry::new(ScriptEngineConfig::default()).unwrap();
746
747        let tool = ToolBuilder::new("add", "Add Numbers")
748            .description("Adds two numbers together")
749            .string_param("a", true)
750            .string_param("b", true)
751            .script(
752                r#"
753                fn execute(params) {
754                    let a = params.a.parse_int();
755                    let b = params.b.parse_int();
756                    #{
757                        result: a + b,
758                        operation: "addition"
759                    }
760                }
761            "#,
762            )
763            .build();
764
765        registry.register(tool).await.unwrap();
766
767        assert_eq!(registry.tool_count().await, 1);
768    }
769
770    #[tokio::test]
771    async fn test_tool_execution() {
772        let registry = ScriptToolRegistry::new(ScriptEngineConfig::default()).unwrap();
773
774        let tool = ScriptToolDefinition::new(
775            "multiply",
776            "Multiply",
777            r#"
778                let result = params.x * params.y;
779                result
780            "#,
781        )
782        .with_parameter(ToolParameter::new("x", ParameterType::Integer).required())
783        .with_parameter(ToolParameter::new("y", ParameterType::Integer).required());
784
785        registry.register(tool).await.unwrap();
786
787        let mut input = HashMap::new();
788        input.insert("x".to_string(), serde_json::json!(6));
789        input.insert("y".to_string(), serde_json::json!(7));
790
791        let result = registry.execute("multiply", input).await.unwrap();
792
793        assert!(result.success);
794        assert_eq!(result.result, serde_json::json!(42));
795    }
796
797    #[tokio::test]
798    async fn test_parameter_validation() {
799        let param = ToolParameter::new("age", ParameterType::Integer)
800            .required()
801            .with_range(0.0, 150.0);
802
803        // 有效值
804        assert!(param.validate(&serde_json::json!(25)).is_ok());
805
806        // 超出范围
807        assert!(param.validate(&serde_json::json!(200)).is_err());
808
809        // 错误类型
810        assert!(param.validate(&serde_json::json!("not a number")).is_err());
811    }
812
813    #[tokio::test]
814    async fn test_tool_with_defaults() {
815        let registry = ScriptToolRegistry::new(ScriptEngineConfig::default()).unwrap();
816
817        let tool = ScriptToolDefinition::new(
818            "greet",
819            "Greeting",
820            r#"
821                let name = params.name;
822                let greeting = params.greeting;
823                greeting + ", " + name + "!"
824            "#,
825        )
826        .with_parameter(ToolParameter::new("name", ParameterType::String).required())
827        .with_parameter(
828            ToolParameter::new("greeting", ParameterType::String).with_default("Hello"),
829        );
830
831        registry.register(tool).await.unwrap();
832
833        // 不提供 greeting 参数,使用默认值
834        let mut input = HashMap::new();
835        input.insert("name".to_string(), serde_json::json!("World"));
836
837        let result = registry.execute("greet", input).await.unwrap();
838
839        assert!(result.success);
840        assert_eq!(result.result, serde_json::json!("Hello, World!"));
841    }
842
843    #[tokio::test]
844    async fn test_tool_json_schema() {
845        let tool = ToolBuilder::new("search", "Search")
846            .description("Search for items")
847            .param(
848                ToolParameter::new("query", ParameterType::String)
849                    .required()
850                    .with_description("Search query"),
851            )
852            .param(
853                ToolParameter::new("limit", ParameterType::Integer)
854                    .with_default(10)
855                    .with_range(1.0, 100.0),
856            )
857            .param(
858                ToolParameter::new("sort", ParameterType::String).with_enum(vec![
859                    serde_json::json!("relevance"),
860                    serde_json::json!("date"),
861                    serde_json::json!("name"),
862                ]),
863            )
864            .script("")
865            .build();
866
867        let schema = tool.to_json_schema();
868
869        assert_eq!(schema["type"], "object");
870        assert!(schema["properties"]["query"].is_object());
871        assert_eq!(schema["required"], serde_json::json!(["query"]));
872    }
873
874    #[test]
875    fn test_tool_builder() {
876        let tool = ToolBuilder::new("test", "Test Tool")
877            .description("A test tool")
878            .string_param("input", true)
879            .int_param("count", false)
880            .bool_param("verbose", false)
881            .tag("test")
882            .tag("example")
883            .timeout(5000)
884            .script("input")
885            .build();
886
887        assert_eq!(tool.id, "test");
888        assert_eq!(tool.parameters.len(), 3);
889        assert_eq!(tool.tags.len(), 2);
890        assert_eq!(tool.timeout_ms, 5000);
891    }
892}