1use super::engine::{RhaiScriptEngine, ScriptContext, ScriptEngineConfig};
10use anyhow::{Result, anyhow};
11#[allow(unused_imports)]
12use rhai::{Dynamic, Engine, Map, Scope};
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17use tracing::info;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
25#[serde(rename_all = "lowercase")]
26#[derive(Default)]
27pub enum ParameterType {
28 #[default]
29 String,
30 Integer,
31 Float,
32 Boolean,
33 Array,
34 Object,
35 Any,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct ToolParameter {
41 pub name: String,
43 #[serde(default)]
45 pub param_type: ParameterType,
46 #[serde(default)]
48 pub description: String,
49 #[serde(default)]
51 pub required: bool,
52 pub default: Option<serde_json::Value>,
54 pub enum_values: Option<Vec<serde_json::Value>>,
56 pub minimum: Option<f64>,
58 pub maximum: Option<f64>,
60 pub min_length: Option<usize>,
62 pub max_length: Option<usize>,
64 pub pattern: Option<String>,
66}
67
68impl ToolParameter {
69 pub fn new(name: &str, param_type: ParameterType) -> Self {
70 Self {
71 name: name.to_string(),
72 param_type,
73 description: String::new(),
74 required: false,
75 default: None,
76 enum_values: None,
77 minimum: None,
78 maximum: None,
79 min_length: None,
80 max_length: None,
81 pattern: None,
82 }
83 }
84
85 pub fn required(mut self) -> Self {
86 self.required = true;
87 self
88 }
89
90 pub fn with_description(mut self, desc: &str) -> Self {
91 self.description = desc.to_string();
92 self
93 }
94
95 pub fn with_default<T: Serialize>(mut self, value: T) -> Self {
96 self.default = serde_json::to_value(value).ok();
97 self
98 }
99
100 pub fn with_enum(mut self, values: Vec<serde_json::Value>) -> Self {
101 self.enum_values = Some(values);
102 self
103 }
104
105 pub fn with_range(mut self, min: f64, max: f64) -> Self {
106 self.minimum = Some(min);
107 self.maximum = Some(max);
108 self
109 }
110
111 pub fn validate(&self, value: &serde_json::Value) -> Result<()> {
113 match (&self.param_type, value) {
115 (ParameterType::String, serde_json::Value::String(_)) => {}
116 (ParameterType::Integer, serde_json::Value::Number(n)) if n.is_i64() => {}
117 (ParameterType::Float, serde_json::Value::Number(_)) => {}
118 (ParameterType::Boolean, serde_json::Value::Bool(_)) => {}
119 (ParameterType::Array, serde_json::Value::Array(_)) => {}
120 (ParameterType::Object, serde_json::Value::Object(_)) => {}
121 (ParameterType::Any, _) => {}
122 (ParameterType::String, serde_json::Value::Null) if !self.required => {}
123 _ => {
124 return Err(anyhow!(
125 "Parameter '{}' has invalid type, expected {:?}",
126 self.name,
127 self.param_type
128 ));
129 }
130 }
131
132 if let Some(ref enum_values) = self.enum_values
134 && !enum_values.contains(value)
135 {
136 return Err(anyhow!(
137 "Parameter '{}' value must be one of {:?}",
138 self.name,
139 enum_values
140 ));
141 }
142
143 if let serde_json::Value::Number(n) = value
145 && let Some(f) = n.as_f64()
146 {
147 if let Some(min) = self.minimum
148 && f < min
149 {
150 return Err(anyhow!("Parameter '{}' must be >= {}", self.name, min));
151 }
152 if let Some(max) = self.maximum
153 && f > max
154 {
155 return Err(anyhow!("Parameter '{}' must be <= {}", self.name, max));
156 }
157 }
158
159 if let serde_json::Value::String(s) = value {
161 if let Some(min) = self.min_length
162 && s.len() < min
163 {
164 return Err(anyhow!(
165 "Parameter '{}' length must be >= {}",
166 self.name,
167 min
168 ));
169 }
170 if let Some(max) = self.max_length
171 && s.len() > max
172 {
173 return Err(anyhow!(
174 "Parameter '{}' length must be <= {}",
175 self.name,
176 max
177 ));
178 }
179 if let Some(ref pattern) = self.pattern {
181 let re = regex::Regex::new(pattern)
182 .map_err(|e| anyhow!("Invalid regex pattern: {}", e))?;
183 if !re.is_match(s) {
184 return Err(anyhow!(
185 "Parameter '{}' does not match pattern: {}",
186 self.name,
187 pattern
188 ));
189 }
190 }
191 }
192
193 if let serde_json::Value::Array(arr) = value {
195 if let Some(min) = self.min_length
196 && arr.len() < min
197 {
198 return Err(anyhow!(
199 "Parameter '{}' array length must be >= {}",
200 self.name,
201 min
202 ));
203 }
204 if let Some(max) = self.max_length
205 && arr.len() > max
206 {
207 return Err(anyhow!(
208 "Parameter '{}' array length must be <= {}",
209 self.name,
210 max
211 ));
212 }
213 }
214
215 Ok(())
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct ScriptToolDefinition {
226 pub id: String,
228 pub name: String,
230 pub description: String,
232 pub parameters: Vec<ToolParameter>,
234 pub script: String,
236 #[serde(default = "default_entry_function")]
238 pub entry_function: String,
239 #[serde(default = "default_true")]
241 pub enable_cache: bool,
242 #[serde(default = "default_timeout")]
244 pub timeout_ms: u64,
245 #[serde(default)]
247 pub tags: Vec<String>,
248 #[serde(default)]
250 pub metadata: HashMap<String, String>,
251}
252
253fn default_entry_function() -> String {
254 "execute".to_string()
255}
256
257fn default_true() -> bool {
258 true
259}
260
261fn default_timeout() -> u64 {
262 30000
263}
264
265impl ScriptToolDefinition {
266 pub fn new(id: &str, name: &str, script: &str) -> Self {
267 Self {
268 id: id.to_string(),
269 name: name.to_string(),
270 description: String::new(),
271 parameters: Vec::new(),
272 script: script.to_string(),
273 entry_function: "execute".to_string(),
274 enable_cache: true,
275 timeout_ms: 30000,
276 tags: Vec::new(),
277 metadata: HashMap::new(),
278 }
279 }
280
281 pub fn with_description(mut self, desc: &str) -> Self {
282 self.description = desc.to_string();
283 self
284 }
285
286 pub fn with_parameter(mut self, param: ToolParameter) -> Self {
287 self.parameters.push(param);
288 self
289 }
290
291 pub fn with_entry(mut self, function: &str) -> Self {
292 self.entry_function = function.to_string();
293 self
294 }
295
296 pub fn with_tag(mut self, tag: &str) -> Self {
297 self.tags.push(tag.to_string());
298 self
299 }
300
301 pub fn validate_input(&self, input: &HashMap<String, serde_json::Value>) -> Result<()> {
303 for param in &self.parameters {
304 if let Some(value) = input.get(¶m.name) {
305 param.validate(value)?;
306 } else if param.required && param.default.is_none() {
307 return Err(anyhow!("Required parameter '{}' is missing", param.name));
308 }
309 }
310 Ok(())
311 }
312
313 pub fn apply_defaults(&self, input: &mut HashMap<String, serde_json::Value>) {
315 for param in &self.parameters {
316 if !input.contains_key(¶m.name)
317 && let Some(ref default) = param.default
318 {
319 input.insert(param.name.clone(), default.clone());
320 }
321 }
322 }
323
324 pub fn to_json_schema(&self) -> serde_json::Value {
326 let mut properties = serde_json::Map::new();
327 let mut required = Vec::new();
328
329 for param in &self.parameters {
330 let mut prop = serde_json::Map::new();
331
332 let type_str = match param.param_type {
333 ParameterType::String => "string",
334 ParameterType::Integer => "integer",
335 ParameterType::Float => "number",
336 ParameterType::Boolean => "boolean",
337 ParameterType::Array => "array",
338 ParameterType::Object => "object",
339 ParameterType::Any => "any",
340 };
341
342 prop.insert("type".to_string(), serde_json::json!(type_str));
343
344 if !param.description.is_empty() {
345 prop.insert(
346 "description".to_string(),
347 serde_json::json!(param.description),
348 );
349 }
350
351 if let Some(ref enum_values) = param.enum_values {
352 prop.insert("enum".to_string(), serde_json::json!(enum_values));
353 }
354
355 if let Some(min) = param.minimum {
356 prop.insert("minimum".to_string(), serde_json::json!(min));
357 }
358
359 if let Some(max) = param.maximum {
360 prop.insert("maximum".to_string(), serde_json::json!(max));
361 }
362
363 properties.insert(param.name.clone(), serde_json::Value::Object(prop));
364
365 if param.required {
366 required.push(param.name.clone());
367 }
368 }
369
370 serde_json::json!({
371 "type": "object",
372 "properties": properties,
373 "required": required
374 })
375 }
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize)]
384pub struct ToolExecutionResult {
385 pub tool_id: String,
387 pub success: bool,
389 pub result: serde_json::Value,
391 pub error: Option<String>,
393 pub execution_time_ms: u64,
395 pub logs: Vec<String>,
397}
398
399pub struct ScriptToolRegistry {
405 engine: Arc<RhaiScriptEngine>,
407 tools: Arc<RwLock<HashMap<String, ScriptToolDefinition>>>,
409}
410
411impl ScriptToolRegistry {
412 pub fn new(engine_config: ScriptEngineConfig) -> Result<Self> {
414 let engine = Arc::new(RhaiScriptEngine::new(engine_config)?);
415 Ok(Self {
416 engine,
417 tools: Arc::new(RwLock::new(HashMap::new())),
418 })
419 }
420
421 pub fn with_engine(engine: Arc<RhaiScriptEngine>) -> Self {
423 Self {
424 engine,
425 tools: Arc::new(RwLock::new(HashMap::new())),
426 }
427 }
428
429 pub async fn register(&self, tool: ScriptToolDefinition) -> Result<()> {
431 if tool.enable_cache {
433 let script_id = format!("tool_{}", tool.id);
434 self.engine
435 .compile_and_cache(&script_id, &tool.name, &tool.script)
436 .await?;
437 }
438
439 let mut tools = self.tools.write().await;
441 info!("Registered script tool: {} ({})", tool.name, tool.id);
442 tools.insert(tool.id.clone(), tool);
443
444 Ok(())
445 }
446
447 pub async fn register_batch(&self, tools: Vec<ScriptToolDefinition>) -> Result<Vec<String>> {
449 let mut registered = Vec::new();
450 for tool in tools {
451 let id = tool.id.clone();
452 self.register(tool).await?;
453 registered.push(id);
454 }
455 Ok(registered)
456 }
457
458 pub async fn load_from_yaml(&self, path: &str) -> Result<String> {
460 let content = tokio::fs::read_to_string(path).await?;
461 let tool: ScriptToolDefinition = serde_yaml::from_str(&content)?;
462 let id = tool.id.clone();
463 self.register(tool).await?;
464 Ok(id)
465 }
466
467 pub async fn load_from_json(&self, path: &str) -> Result<String> {
469 let content = tokio::fs::read_to_string(path).await?;
470 let tool: ScriptToolDefinition = serde_json::from_str(&content)?;
471 let id = tool.id.clone();
472 self.register(tool).await?;
473 Ok(id)
474 }
475
476 pub async fn load_from_directory(&self, dir_path: &str) -> Result<Vec<String>> {
478 let mut loaded = Vec::new();
479 let mut entries = tokio::fs::read_dir(dir_path).await?;
480
481 while let Some(entry) = entries.next_entry().await? {
482 let path = entry.path();
483 if let Some(ext) = path.extension() {
484 let id = match ext.to_str() {
485 Some("yaml") | Some("yml") => {
486 self.load_from_yaml(path.to_str().unwrap()).await.ok()
487 }
488 Some("json") => self.load_from_json(path.to_str().unwrap()).await.ok(),
489 _ => None,
490 };
491 if let Some(id) = id {
492 loaded.push(id);
493 }
494 }
495 }
496
497 info!("Loaded {} tools from directory: {}", loaded.len(), dir_path);
498 Ok(loaded)
499 }
500
501 pub async fn execute(
503 &self,
504 tool_id: &str,
505 input: HashMap<String, serde_json::Value>,
506 ) -> Result<ToolExecutionResult> {
507 let start_time = std::time::Instant::now();
508
509 let tools = self.tools.read().await;
511 let tool = tools
512 .get(tool_id)
513 .ok_or_else(|| anyhow!("Tool not found: {}", tool_id))?
514 .clone();
515 drop(tools);
516
517 let mut params = input;
519 tool.apply_defaults(&mut params);
520
521 tool.validate_input(¶ms)?;
523
524 let mut context = ScriptContext::new();
526 for (key, value) in ¶ms {
527 context.set_variable(key, value.clone())?;
528 }
529
530 context.set_variable("params", serde_json::json!(params))?;
532
533 let script_id = format!("tool_{}", tool_id);
535
536 if tool.enable_cache {
537 let input_value = serde_json::json!(params);
539 match self
540 .engine
541 .call_function::<serde_json::Value>(
542 &script_id,
543 &tool.entry_function,
544 vec![input_value],
545 &context,
546 )
547 .await
548 {
549 Ok(value) => Ok(ToolExecutionResult {
550 tool_id: tool_id.to_string(),
551 success: true,
552 result: value,
553 error: None,
554 execution_time_ms: start_time.elapsed().as_millis() as u64,
555 logs: Vec::new(),
556 }),
557 Err(_e) => {
558 let script_result = self.engine.execute_compiled(&script_id, &context).await?;
560 if script_result.success {
561 Ok(ToolExecutionResult {
562 tool_id: tool_id.to_string(),
563 success: true,
564 result: script_result.value,
565 error: None,
566 execution_time_ms: start_time.elapsed().as_millis() as u64,
567 logs: script_result.logs,
568 })
569 } else {
570 Ok(ToolExecutionResult {
571 tool_id: tool_id.to_string(),
572 success: false,
573 result: serde_json::Value::Null,
574 error: script_result.error,
575 execution_time_ms: start_time.elapsed().as_millis() as u64,
576 logs: script_result.logs,
577 })
578 }
579 }
580 }
581 } else {
582 let script_result = self.engine.execute(&tool.script, &context).await?;
583 Ok(ToolExecutionResult {
584 tool_id: tool_id.to_string(),
585 success: script_result.success,
586 result: script_result.value,
587 error: script_result.error,
588 execution_time_ms: start_time.elapsed().as_millis() as u64,
589 logs: script_result.logs,
590 })
591 }
592 }
593
594 pub async fn get_tool(&self, tool_id: &str) -> Option<ScriptToolDefinition> {
596 let tools = self.tools.read().await;
597 tools.get(tool_id).cloned()
598 }
599
600 pub async fn list_tools(&self) -> Vec<ScriptToolDefinition> {
602 let tools = self.tools.read().await;
603 tools.values().cloned().collect()
604 }
605
606 pub async fn list_tools_by_tag(&self, tag: &str) -> Vec<ScriptToolDefinition> {
608 let tools = self.tools.read().await;
609 tools
610 .values()
611 .filter(|t| t.tags.contains(&tag.to_string()))
612 .cloned()
613 .collect()
614 }
615
616 pub async fn unregister(&self, tool_id: &str) -> bool {
618 let mut tools = self.tools.write().await;
619 let removed = tools.remove(tool_id).is_some();
620
621 if removed {
622 let script_id = format!("tool_{}", tool_id);
624 self.engine.remove_cached(&script_id).await;
625 info!("Unregistered script tool: {}", tool_id);
626 }
627
628 removed
629 }
630
631 pub async fn clear(&self) {
633 let mut tools = self.tools.write().await;
634 tools.clear();
635 self.engine.clear_cache().await;
636 }
637
638 pub async fn tool_count(&self) -> usize {
640 let tools = self.tools.read().await;
641 tools.len()
642 }
643
644 pub async fn generate_tool_schemas(&self) -> Vec<serde_json::Value> {
646 let tools = self.tools.read().await;
647 tools
648 .values()
649 .map(|tool| {
650 serde_json::json!({
651 "name": tool.name,
652 "description": tool.description,
653 "parameters": tool.to_json_schema()
654 })
655 })
656 .collect()
657 }
658}
659
660pub struct ToolBuilder {
666 definition: ScriptToolDefinition,
667}
668
669impl ToolBuilder {
670 pub fn new(id: &str, name: &str) -> Self {
671 Self {
672 definition: ScriptToolDefinition::new(id, name, ""),
673 }
674 }
675
676 pub fn description(mut self, desc: &str) -> Self {
677 self.definition.description = desc.to_string();
678 self
679 }
680
681 pub fn script(mut self, script: &str) -> Self {
682 self.definition.script = script.to_string();
683 self
684 }
685
686 pub fn entry(mut self, function: &str) -> Self {
687 self.definition.entry_function = function.to_string();
688 self
689 }
690
691 pub fn param(mut self, param: ToolParameter) -> Self {
692 self.definition.parameters.push(param);
693 self
694 }
695
696 pub fn string_param(self, name: &str, required: bool) -> Self {
697 let mut param = ToolParameter::new(name, ParameterType::String);
698 if required {
699 param = param.required();
700 }
701 self.param(param)
702 }
703
704 pub fn int_param(self, name: &str, required: bool) -> Self {
705 let mut param = ToolParameter::new(name, ParameterType::Integer);
706 if required {
707 param = param.required();
708 }
709 self.param(param)
710 }
711
712 pub fn bool_param(self, name: &str, required: bool) -> Self {
713 let mut param = ToolParameter::new(name, ParameterType::Boolean);
714 if required {
715 param = param.required();
716 }
717 self.param(param)
718 }
719
720 pub fn tag(mut self, tag: &str) -> Self {
721 self.definition.tags.push(tag.to_string());
722 self
723 }
724
725 pub fn timeout(mut self, timeout_ms: u64) -> Self {
726 self.definition.timeout_ms = timeout_ms;
727 self
728 }
729
730 pub fn build(self) -> ScriptToolDefinition {
731 self.definition
732 }
733}
734
735#[cfg(test)]
740mod tests {
741 use super::*;
742
743 #[tokio::test]
744 async fn test_tool_registration() {
745 let registry = ScriptToolRegistry::new(ScriptEngineConfig::default()).unwrap();
746
747 let tool = ToolBuilder::new("add", "Add Numbers")
748 .description("Adds two numbers together")
749 .string_param("a", true)
750 .string_param("b", true)
751 .script(
752 r#"
753 fn execute(params) {
754 let a = params.a.parse_int();
755 let b = params.b.parse_int();
756 #{
757 result: a + b,
758 operation: "addition"
759 }
760 }
761 "#,
762 )
763 .build();
764
765 registry.register(tool).await.unwrap();
766
767 assert_eq!(registry.tool_count().await, 1);
768 }
769
770 #[tokio::test]
771 async fn test_tool_execution() {
772 let registry = ScriptToolRegistry::new(ScriptEngineConfig::default()).unwrap();
773
774 let tool = ScriptToolDefinition::new(
775 "multiply",
776 "Multiply",
777 r#"
778 let result = params.x * params.y;
779 result
780 "#,
781 )
782 .with_parameter(ToolParameter::new("x", ParameterType::Integer).required())
783 .with_parameter(ToolParameter::new("y", ParameterType::Integer).required());
784
785 registry.register(tool).await.unwrap();
786
787 let mut input = HashMap::new();
788 input.insert("x".to_string(), serde_json::json!(6));
789 input.insert("y".to_string(), serde_json::json!(7));
790
791 let result = registry.execute("multiply", input).await.unwrap();
792
793 assert!(result.success);
794 assert_eq!(result.result, serde_json::json!(42));
795 }
796
797 #[tokio::test]
798 async fn test_parameter_validation() {
799 let param = ToolParameter::new("age", ParameterType::Integer)
800 .required()
801 .with_range(0.0, 150.0);
802
803 assert!(param.validate(&serde_json::json!(25)).is_ok());
805
806 assert!(param.validate(&serde_json::json!(200)).is_err());
808
809 assert!(param.validate(&serde_json::json!("not a number")).is_err());
811 }
812
813 #[tokio::test]
814 async fn test_tool_with_defaults() {
815 let registry = ScriptToolRegistry::new(ScriptEngineConfig::default()).unwrap();
816
817 let tool = ScriptToolDefinition::new(
818 "greet",
819 "Greeting",
820 r#"
821 let name = params.name;
822 let greeting = params.greeting;
823 greeting + ", " + name + "!"
824 "#,
825 )
826 .with_parameter(ToolParameter::new("name", ParameterType::String).required())
827 .with_parameter(
828 ToolParameter::new("greeting", ParameterType::String).with_default("Hello"),
829 );
830
831 registry.register(tool).await.unwrap();
832
833 let mut input = HashMap::new();
835 input.insert("name".to_string(), serde_json::json!("World"));
836
837 let result = registry.execute("greet", input).await.unwrap();
838
839 assert!(result.success);
840 assert_eq!(result.result, serde_json::json!("Hello, World!"));
841 }
842
843 #[tokio::test]
844 async fn test_tool_json_schema() {
845 let tool = ToolBuilder::new("search", "Search")
846 .description("Search for items")
847 .param(
848 ToolParameter::new("query", ParameterType::String)
849 .required()
850 .with_description("Search query"),
851 )
852 .param(
853 ToolParameter::new("limit", ParameterType::Integer)
854 .with_default(10)
855 .with_range(1.0, 100.0),
856 )
857 .param(
858 ToolParameter::new("sort", ParameterType::String).with_enum(vec![
859 serde_json::json!("relevance"),
860 serde_json::json!("date"),
861 serde_json::json!("name"),
862 ]),
863 )
864 .script("")
865 .build();
866
867 let schema = tool.to_json_schema();
868
869 assert_eq!(schema["type"], "object");
870 assert!(schema["properties"]["query"].is_object());
871 assert_eq!(schema["required"], serde_json::json!(["query"]));
872 }
873
874 #[test]
875 fn test_tool_builder() {
876 let tool = ToolBuilder::new("test", "Test Tool")
877 .description("A test tool")
878 .string_param("input", true)
879 .int_param("count", false)
880 .bool_param("verbose", false)
881 .tag("test")
882 .tag("example")
883 .timeout(5000)
884 .script("input")
885 .build();
886
887 assert_eq!(tool.id, "test");
888 assert_eq!(tool.parameters.len(), 3);
889 assert_eq!(tool.tags.len(), 2);
890 assert_eq!(tool.timeout_ms, 5000);
891 }
892}