brainwires_tool_runtime/orchestrator/
engine.rs1use std::collections::HashMap;
31
32#[cfg(feature = "orchestrator")]
33use std::sync::{Arc, Mutex};
34#[cfg(feature = "orchestrator")]
35use std::time::Instant;
36
37#[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
38use std::cell::RefCell;
39#[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
40use std::rc::Rc;
41#[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
42use web_time::Instant;
43
44use rhai::{Engine, EvalAltResult, Scope};
45
46use super::sandbox::ExecutionLimits;
47use super::types::{OrchestratorError, OrchestratorResult, ToolCall};
48
49const MAX_EXPR_DEPTH: usize = 64;
55
56const MAX_CALL_DEPTH: usize = 64;
58
59#[cfg(feature = "orchestrator")]
65pub type SharedVec<T> = Arc<Mutex<Vec<T>>>;
66
67#[cfg(feature = "orchestrator")]
69pub type SharedCounter = Arc<Mutex<usize>>;
70
71#[cfg(feature = "orchestrator")]
75pub type ToolExecutor = Arc<dyn Fn(serde_json::Value) -> Result<String, String> + Send + Sync>;
76
77#[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
79pub type SharedVec<T> = Rc<RefCell<Vec<T>>>;
80
81#[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
83pub type SharedCounter = Rc<RefCell<usize>>;
84
85#[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
89pub type ToolExecutor = Rc<dyn Fn(serde_json::Value) -> Result<String, String>>;
90
91#[cfg(feature = "orchestrator")]
96fn new_shared_vec<T>() -> SharedVec<T> {
97 Arc::new(Mutex::new(Vec::new()))
98}
99
100#[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
101fn new_shared_vec<T>() -> SharedVec<T> {
102 Rc::new(RefCell::new(Vec::new()))
103}
104
105#[cfg(feature = "orchestrator")]
106fn new_shared_counter() -> SharedCounter {
107 Arc::new(Mutex::new(0))
108}
109
110#[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
111fn new_shared_counter() -> SharedCounter {
112 Rc::new(RefCell::new(0))
113}
114
115#[cfg(feature = "orchestrator")]
116fn clone_shared<T: ?Sized>(shared: &Arc<T>) -> Arc<T> {
117 Arc::clone(shared)
118}
119
120#[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
121fn clone_shared<T: ?Sized>(shared: &Rc<T>) -> Rc<T> {
122 Rc::clone(shared)
123}
124
125#[cfg(feature = "orchestrator")]
126fn lock_vec<T: Clone>(shared: &SharedVec<T>) -> Vec<T> {
127 shared
128 .lock()
129 .expect("orchestrator results lock poisoned")
130 .clone()
131}
132
133#[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
134fn lock_vec<T: Clone>(shared: &SharedVec<T>) -> Vec<T> {
135 shared.borrow().clone()
136}
137
138#[cfg(feature = "orchestrator")]
139fn push_to_vec<T>(shared: &SharedVec<T>, item: T) {
140 shared
141 .lock()
142 .expect("orchestrator results lock poisoned")
143 .push(item);
144}
145
146#[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
147fn push_to_vec<T>(shared: &SharedVec<T>, item: T) {
148 shared.borrow_mut().push(item);
149}
150
151#[cfg(feature = "orchestrator")]
152fn increment_counter(shared: &SharedCounter, max: usize) -> Result<(), ()> {
153 let mut c = shared
154 .lock()
155 .expect("orchestrator step counter lock poisoned");
156 if *c >= max {
157 return Err(());
158 }
159 *c += 1;
160 drop(c); Ok(())
162}
163
164#[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
165fn increment_counter(shared: &SharedCounter, max: usize) -> Result<(), ()> {
166 let mut c = shared.borrow_mut();
167 if *c >= max {
168 return Err(());
169 }
170 *c += 1;
171 Ok(())
172}
173
174pub struct ToolOrchestrator {
196 #[allow(dead_code)]
197 engine: Engine,
198 executors: HashMap<String, ToolExecutor>,
199}
200
201impl ToolOrchestrator {
202 #[must_use]
204 pub fn new() -> Self {
205 let mut engine = Engine::new();
206
207 engine.set_max_expr_depths(MAX_EXPR_DEPTH, MAX_CALL_DEPTH);
209
210 Self {
211 engine,
212 executors: HashMap::new(),
213 }
214 }
215
216 #[cfg(feature = "orchestrator")]
218 pub fn register_executor<F>(&mut self, name: impl Into<String>, executor: F)
219 where
220 F: Fn(serde_json::Value) -> Result<String, String> + Send + Sync + 'static,
221 {
222 self.executors.insert(name.into(), Arc::new(executor));
223 }
224
225 #[cfg(all(feature = "orchestrator-wasm", not(feature = "orchestrator")))]
227 pub fn register_executor<F>(&mut self, name: impl Into<String>, executor: F)
228 where
229 F: Fn(serde_json::Value) -> Result<String, String> + 'static,
230 {
231 self.executors.insert(name.into(), Rc::new(executor));
232 }
233
234 pub fn execute(
240 &self,
241 script: &str,
242 limits: ExecutionLimits,
243 ) -> Result<OrchestratorResult, OrchestratorError> {
244 let start_time = Instant::now();
245 let tool_calls: SharedVec<ToolCall> = new_shared_vec();
246 let call_count: SharedCounter = new_shared_counter();
247
248 let mut engine = Engine::new();
250
251 engine.set_max_operations(limits.max_operations);
253 engine.set_max_string_size(limits.max_string_size);
254 engine.set_max_array_size(limits.max_array_size);
255 engine.set_max_map_size(limits.max_map_size);
256 engine.set_max_expr_depths(MAX_EXPR_DEPTH, MAX_CALL_DEPTH);
257
258 let timeout_ms = limits.timeout_ms;
260 let progress_start = Instant::now();
261 engine.on_progress(move |_ops| {
262 let elapsed = u64::try_from(progress_start.elapsed().as_millis()).unwrap_or(u64::MAX);
264 if elapsed > timeout_ms {
265 Some(rhai::Dynamic::from("timeout"))
266 } else {
267 None
268 }
269 });
270
271 for (name, executor) in &self.executors {
273 let exec = clone_shared(executor);
274 let calls = clone_shared(&tool_calls);
275 let count = clone_shared(&call_count);
276 let max_calls = limits.max_tool_calls;
277 let tool_name = name.clone();
278
279 engine.register_fn(name.as_str(), move |input: rhai::Dynamic| -> String {
281 let call_start = Instant::now();
282
283 if increment_counter(&count, max_calls).is_err() {
285 return format!("ERROR: Maximum tool calls ({max_calls}) exceeded");
286 }
287
288 let json_input = dynamic_to_json(&input);
290
291 let (output, success) = match exec(json_input.clone()) {
293 Ok(result) => (result, true),
294 Err(e) => (format!("Tool error: {e}"), false),
295 };
296
297 let duration_ms =
299 u64::try_from(call_start.elapsed().as_millis()).unwrap_or(u64::MAX);
300 let call = ToolCall::new(
301 tool_name.clone(),
302 json_input,
303 output.clone(),
304 success,
305 duration_ms,
306 );
307 push_to_vec(&calls, call);
308
309 output
310 });
311 }
312
313 let ast = engine
315 .compile(script)
316 .map_err(|e| OrchestratorError::CompilationError(e.to_string()))?;
317
318 let mut scope = Scope::new();
320 let result = engine
321 .eval_ast_with_scope::<rhai::Dynamic>(&mut scope, &ast)
322 .map_err(|e| match *e {
323 EvalAltResult::ErrorTooManyOperations(_) => {
324 OrchestratorError::MaxOperationsExceeded(limits.max_operations)
325 }
326 EvalAltResult::ErrorTerminated(_, _) => {
327 OrchestratorError::Timeout(limits.timeout_ms)
328 }
329 _ => OrchestratorError::ExecutionError(e.to_string()),
330 })?;
331
332 let execution_time_ms = u64::try_from(start_time.elapsed().as_millis()).unwrap_or(u64::MAX);
333
334 let output = if result.is_string() {
336 result.into_string().unwrap_or_default()
337 } else if result.is_unit() {
338 String::new()
339 } else {
340 format!("{result:?}")
341 };
342
343 let calls = lock_vec(&tool_calls);
344 Ok(OrchestratorResult::success(
345 output,
346 calls,
347 execution_time_ms,
348 ))
349 }
350
351 #[must_use]
353 pub fn registered_tools(&self) -> Vec<&str> {
354 self.executors.keys().map(String::as_str).collect()
355 }
356}
357
358impl Default for ToolOrchestrator {
359 fn default() -> Self {
360 Self::new()
361 }
362}
363
364pub fn dynamic_to_json(value: &rhai::Dynamic) -> serde_json::Value {
375 if value.is_string() {
376 serde_json::Value::String(value.clone().into_string().unwrap_or_default())
377 } else if value.is_int() {
378 serde_json::Value::Number(serde_json::Number::from(
379 value.clone().as_int().unwrap_or(0),
380 ))
381 } else if value.is_float() {
382 serde_json::json!(value.clone().as_float().unwrap_or(0.0))
383 } else if value.is_bool() {
384 serde_json::Value::Bool(value.clone().as_bool().unwrap_or(false))
385 } else if value.is_array() {
386 let arr: Vec<rhai::Dynamic> = value.clone().into_array().unwrap_or_default();
387 serde_json::Value::Array(arr.iter().map(dynamic_to_json).collect())
388 } else if value.is_map() {
389 let map: rhai::Map = value.clone().cast();
390 let mut json_map = serde_json::Map::new();
391 for (k, v) in &map {
392 json_map.insert(k.to_string(), dynamic_to_json(v));
393 }
394 serde_json::Value::Object(json_map)
395 } else if value.is_unit() {
396 serde_json::Value::Null
397 } else {
398 serde_json::Value::String(format!("{value:?}"))
399 }
400}
401
402#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_orchestrator_creation() {
412 let orchestrator = ToolOrchestrator::new();
413 assert!(orchestrator.registered_tools().is_empty());
414 }
415
416 #[test]
417 fn test_register_executor() {
418 let mut orchestrator = ToolOrchestrator::new();
419 orchestrator.register_executor("test_tool", |_| Ok("success".to_string()));
420 assert!(orchestrator.registered_tools().contains(&"test_tool"));
421 }
422
423 #[test]
424 fn test_simple_script() {
425 let orchestrator = ToolOrchestrator::new();
426 let result = orchestrator
427 .execute("let x = 1 + 2; x", ExecutionLimits::default())
428 .unwrap();
429 assert!(result.success);
430 assert_eq!(result.output, "3");
431 }
432
433 #[test]
434 fn test_string_interpolation() {
435 let orchestrator = ToolOrchestrator::new();
436 let result = orchestrator
437 .execute(
438 r#"let name = "world"; `Hello, ${name}!`"#,
439 ExecutionLimits::default(),
440 )
441 .unwrap();
442 assert!(result.success);
443 assert_eq!(result.output, "Hello, world!");
444 }
445
446 #[test]
447 fn test_tool_execution() {
448 let mut orchestrator = ToolOrchestrator::new();
449 orchestrator.register_executor("greet", |input| {
450 let name = input.as_str().unwrap_or("stranger");
451 Ok(format!("Hello, {}!", name))
452 });
453
454 let result = orchestrator
455 .execute(r#"greet("Claude")"#, ExecutionLimits::default())
456 .unwrap();
457
458 assert!(result.success);
459 assert_eq!(result.output, "Hello, Claude!");
460 assert_eq!(result.tool_calls.len(), 1);
461 assert_eq!(result.tool_calls[0].tool_name, "greet");
462 }
463
464 #[test]
465 fn test_max_operations_limit() {
466 let orchestrator = ToolOrchestrator::new();
467 let limits = ExecutionLimits::default().with_max_operations(10);
468
469 let result =
470 orchestrator.execute("let sum = 0; for i in 0..1000 { sum += i; } sum", limits);
471
472 assert!(matches!(
473 result,
474 Err(OrchestratorError::MaxOperationsExceeded(_))
475 ));
476 }
477
478 #[test]
479 fn test_compilation_error() {
480 let orchestrator = ToolOrchestrator::new();
481 let result = orchestrator.execute(
482 "this is not valid rhai syntax {{{{",
483 ExecutionLimits::default(),
484 );
485
486 assert!(matches!(
487 result,
488 Err(OrchestratorError::CompilationError(_))
489 ));
490 }
491
492 #[test]
493 fn test_multiple_tool_calls() {
494 let mut orchestrator = ToolOrchestrator::new();
495
496 orchestrator.register_executor("add", |input| {
497 if let Some(arr) = input.as_array() {
498 let sum: i64 = arr.iter().filter_map(|v| v.as_i64()).sum();
499 Ok(sum.to_string())
500 } else {
501 Err("Expected array".to_string())
502 }
503 });
504
505 let script = r#"
506 let a = add([1, 2, 3]);
507 let b = add([4, 5, 6]);
508 `Sum1: ${a}, Sum2: ${b}`
509 "#;
510
511 let result = orchestrator
512 .execute(script, ExecutionLimits::default())
513 .unwrap();
514
515 assert!(result.success);
516 assert_eq!(result.tool_calls.len(), 2);
517 assert!(result.output.contains("Sum1: 6"));
518 assert!(result.output.contains("Sum2: 15"));
519 }
520
521 #[test]
522 fn test_tool_error_handling() {
523 let mut orchestrator = ToolOrchestrator::new();
524 orchestrator.register_executor("fail_tool", |_| Err("Intentional failure".to_string()));
525
526 let result = orchestrator
527 .execute(r#"fail_tool("test")"#, ExecutionLimits::default())
528 .unwrap();
529
530 assert!(result.success); assert!(result.output.contains("Tool error"));
532 assert_eq!(result.tool_calls.len(), 1);
533 assert!(!result.tool_calls[0].success);
534 }
535
536 #[test]
537 fn test_max_tool_calls_limit() {
538 let mut orchestrator = ToolOrchestrator::new();
539 orchestrator.register_executor("count", |_| Ok("1".to_string()));
540
541 let limits = ExecutionLimits::default().with_max_tool_calls(3);
542 let script = r#"
543 let a = count("1");
544 let b = count("2");
545 let c = count("3");
546 count("4")
547 "#;
548
549 let result = orchestrator.execute(script, limits).unwrap();
550
551 assert!(
552 result.output.contains("Maximum tool calls"),
553 "Expected error message about max tool calls, got: {}",
554 result.output
555 );
556 assert_eq!(result.tool_calls.len(), 3);
557 }
558
559 #[test]
560 fn test_tool_with_map_input() {
561 let mut orchestrator = ToolOrchestrator::new();
562 orchestrator.register_executor("get_value", |input| {
563 if let Some(obj) = input.as_object() {
564 if let Some(key) = obj.get("key").and_then(|v| v.as_str()) {
565 Ok(format!("Got key: {}", key))
566 } else {
567 Err("Missing key field".to_string())
568 }
569 } else {
570 Err("Expected object".to_string())
571 }
572 });
573
574 let result = orchestrator
575 .execute(
576 r#"get_value(#{ key: "test_key" })"#,
577 ExecutionLimits::default(),
578 )
579 .unwrap();
580
581 assert!(result.success);
582 assert_eq!(result.output, "Got key: test_key");
583 }
584
585 #[test]
586 fn test_loop_with_tool_calls() {
587 let mut orchestrator = ToolOrchestrator::new();
588 orchestrator.register_executor("double", |input| {
589 let n = input.as_i64().unwrap_or(0);
590 Ok((n * 2).to_string())
591 });
592
593 let script = r#"
594 let results = [];
595 for i in 1..4 {
596 results.push(double(i));
597 }
598 results
599 "#;
600
601 let result = orchestrator
602 .execute(script, ExecutionLimits::default())
603 .unwrap();
604
605 assert!(result.success);
606 assert_eq!(result.tool_calls.len(), 3);
607 }
608
609 #[test]
610 fn test_conditional_tool_calls() {
611 let mut orchestrator = ToolOrchestrator::new();
612 orchestrator.register_executor("check", |input| {
613 let n = input.as_i64().unwrap_or(0);
614 Ok(if n > 5 { "big" } else { "small" }.to_string())
615 });
616
617 let script = r#"
618 let x = 10;
619 if x > 5 {
620 check(x)
621 } else {
622 "skipped"
623 }
624 "#;
625
626 let result = orchestrator
627 .execute(script, ExecutionLimits::default())
628 .unwrap();
629
630 assert!(result.success);
631 assert_eq!(result.output, "big");
632 assert_eq!(result.tool_calls.len(), 1);
633 }
634
635 #[test]
636 fn test_empty_script() {
637 let orchestrator = ToolOrchestrator::new();
638 let result = orchestrator
639 .execute("", ExecutionLimits::default())
640 .unwrap();
641
642 assert!(result.success);
643 assert!(result.output.is_empty());
644 }
645
646 #[test]
647 fn test_unit_return() {
648 let orchestrator = ToolOrchestrator::new();
649 let result = orchestrator
650 .execute("let x = 5;", ExecutionLimits::default())
651 .unwrap();
652
653 assert!(result.success);
654 assert!(result.output.is_empty());
655 }
656
657 #[test]
658 fn test_dynamic_to_json_types() {
659 use rhai::Dynamic;
660
661 let d = Dynamic::from("hello".to_string());
662 let j = dynamic_to_json(&d);
663 assert_eq!(j, serde_json::json!("hello"));
664
665 let d = Dynamic::from(42_i64);
666 let j = dynamic_to_json(&d);
667 assert_eq!(j, serde_json::json!(42));
668
669 let d = Dynamic::from(2.5_f64);
670 let j = dynamic_to_json(&d);
671 assert!((j.as_f64().unwrap() - 2.5).abs() < 0.001);
672
673 let d = Dynamic::from(true);
674 let j = dynamic_to_json(&d);
675 assert_eq!(j, serde_json::json!(true));
676
677 let d = Dynamic::UNIT;
678 let j = dynamic_to_json(&d);
679 assert_eq!(j, serde_json::Value::Null);
680 }
681
682 #[test]
683 fn test_execution_time_recorded() {
684 let orchestrator = ToolOrchestrator::new();
685 let result = orchestrator
686 .execute(
687 "let sum = 0; for i in 0..100 { sum += i; } sum",
688 ExecutionLimits::default(),
689 )
690 .unwrap();
691
692 assert!(result.success);
693 assert!(result.execution_time_ms < 10000);
694 }
695
696 #[test]
697 fn test_tool_call_duration_recorded() {
698 let mut orchestrator = ToolOrchestrator::new();
699 orchestrator.register_executor("slow_tool", |_| {
700 std::thread::sleep(std::time::Duration::from_millis(10));
701 Ok("done".to_string())
702 });
703
704 let result = orchestrator
705 .execute(r#"slow_tool("test")"#, ExecutionLimits::default())
706 .unwrap();
707
708 assert!(result.success);
709 assert_eq!(result.tool_calls.len(), 1);
710 assert!(result.tool_calls[0].duration_ms >= 10);
711 }
712
713 #[test]
714 fn test_default_impl() {
715 let orchestrator = ToolOrchestrator::default();
716 assert!(orchestrator.registered_tools().is_empty());
717
718 let result = orchestrator
719 .execute("1 + 1", ExecutionLimits::default())
720 .unwrap();
721 assert!(result.success);
722 assert_eq!(result.output, "2");
723 }
724
725 #[test]
726 fn test_timeout_error() {
727 let orchestrator = ToolOrchestrator::new();
728
729 let limits = ExecutionLimits::default()
730 .with_timeout_ms(1)
731 .with_max_operations(1_000_000);
732
733 let result = orchestrator.execute(
734 r#"
735 let sum = 0;
736 for i in 0..1000000 {
737 sum += i;
738 }
739 sum
740 "#,
741 limits,
742 );
743
744 assert!(result.is_err());
745 match result {
746 Err(OrchestratorError::Timeout(ms)) => assert_eq!(ms, 1),
747 _ => panic!("Expected Timeout error, got: {:?}", result),
748 }
749 }
750
751 #[test]
752 fn test_runtime_error() {
753 let orchestrator = ToolOrchestrator::new();
754
755 let result = orchestrator.execute("undefined_variable", ExecutionLimits::default());
756
757 assert!(result.is_err());
758 match result {
759 Err(OrchestratorError::ExecutionError(msg)) => {
760 assert!(msg.contains("undefined_variable") || msg.contains("not found"));
761 }
762 _ => panic!("Expected ExecutionError"),
763 }
764 }
765
766 #[test]
767 fn test_registered_tools() {
768 let mut orchestrator = ToolOrchestrator::new();
769 assert!(orchestrator.registered_tools().is_empty());
770
771 orchestrator.register_executor("tool_a", |_| Ok("a".to_string()));
772 orchestrator.register_executor("tool_b", |_| Ok("b".to_string()));
773
774 let tools = orchestrator.registered_tools();
775 assert_eq!(tools.len(), 2);
776 assert!(tools.contains(&"tool_a"));
777 assert!(tools.contains(&"tool_b"));
778 }
779
780 #[test]
781 fn test_dynamic_to_json_array() {
782 use rhai::Dynamic;
783
784 let arr: Vec<Dynamic> = vec![
785 Dynamic::from(1_i64),
786 Dynamic::from(2_i64),
787 Dynamic::from(3_i64),
788 ];
789 let d = Dynamic::from(arr);
790 let j = dynamic_to_json(&d);
791
792 assert_eq!(j, serde_json::json!([1, 2, 3]));
793 }
794
795 #[test]
796 fn test_dynamic_to_json_map() {
797 use rhai::{Dynamic, Map};
798
799 let mut map = Map::new();
800 map.insert("key".into(), Dynamic::from("value".to_string()));
801 map.insert("num".into(), Dynamic::from(42_i64));
802 let d = Dynamic::from(map);
803 let j = dynamic_to_json(&d);
804
805 assert!(j.is_object());
806 let obj = j.as_object().unwrap();
807 assert_eq!(obj.get("key").unwrap(), &serde_json::json!("value"));
808 assert_eq!(obj.get("num").unwrap(), &serde_json::json!(42));
809 }
810
811 #[test]
812 fn test_non_string_result() {
813 let orchestrator = ToolOrchestrator::new();
814
815 let result = orchestrator
816 .execute("42", ExecutionLimits::default())
817 .unwrap();
818
819 assert!(result.success);
820 assert_eq!(result.output, "42");
821 }
822
823 #[test]
824 fn test_array_result() {
825 let orchestrator = ToolOrchestrator::new();
826
827 let result = orchestrator
828 .execute("[1, 2, 3]", ExecutionLimits::default())
829 .unwrap();
830
831 assert!(result.success);
832 assert!(result.output.contains("1"));
833 assert!(result.output.contains("2"));
834 assert!(result.output.contains("3"));
835 }
836
837 #[test]
838 fn test_dynamic_to_json_fallback() {
839 use rhai::Dynamic;
840
841 #[derive(Clone)]
842 struct CustomType {
843 #[allow(dead_code)]
844 value: i32,
845 }
846
847 let custom = CustomType { value: 42 };
848 let d = Dynamic::from(custom);
849 let j = dynamic_to_json(&d);
850
851 assert!(j.is_string());
852 let s = j.as_str().unwrap();
853 assert!(!s.is_empty());
854 }
855}