1use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::time::Instant;
13
14use super::{Callable, CallableRegistry, DynCallable};
15use crate::kernel::ids::{CallableType, ExecutionId, SpawnMode};
16use crate::kernel::TokenUsage;
17
18#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
21#[serde(rename_all = "snake_case")]
22pub enum CostTier {
23 Free,
24 Low,
25 #[default]
26 Medium,
27 High,
28 Premium,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
34#[serde(rename_all = "camelCase")]
35pub struct CallableDescriptor {
36 pub name: String,
38
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub description: Option<String>,
42
43 pub callable_type: CallableType,
45
46 #[serde(skip_serializing_if = "Option::is_none")]
48 pub input_schema: Option<serde_json::Value>,
49
50 #[serde(skip_serializing_if = "Option::is_none")]
52 pub output_schema: Option<serde_json::Value>,
53
54 #[serde(default)]
56 pub tags: Vec<String>,
57
58 #[serde(default)]
60 pub can_spawn_children: bool,
61
62 #[serde(default)]
64 pub cost_tier: CostTier,
65
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub avg_latency_ms: Option<u64>,
69}
70
71impl CallableDescriptor {
72 pub fn from_callable(callable: &dyn Callable, callable_type: CallableType) -> Self {
74 Self {
75 name: callable.name().to_string(),
76 description: callable.description().map(String::from),
77 callable_type,
78 input_schema: None,
79 output_schema: None,
80 tags: Vec::new(),
81 can_spawn_children: false,
82 cost_tier: CostTier::Medium,
83 avg_latency_ms: None,
84 }
85 }
86
87 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
89 self.tags = tags;
90 self
91 }
92
93 pub fn with_cost_tier(mut self, tier: CostTier) -> Self {
95 self.cost_tier = tier;
96 self
97 }
98
99 pub fn with_spawn_capability(mut self, can_spawn: bool) -> Self {
101 self.can_spawn_children = can_spawn;
102 self
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
109#[serde(rename_all = "camelCase")]
110pub struct CallableInvocation {
111 pub callable_name: String,
113
114 pub input: String,
116
117 #[serde(skip_serializing_if = "Option::is_none")]
119 pub context: Option<HashMap<String, String>>,
120
121 #[serde(default)]
123 pub spawn_mode: SpawnMode,
124
125 #[serde(default = "default_priority")]
127 pub priority: u8,
128
129 #[serde(skip_serializing_if = "Option::is_none")]
131 pub timeout_ms: Option<u64>,
132}
133
134fn default_priority() -> u8 {
135 50
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
141#[serde(rename_all = "camelCase")]
142pub struct CallableInvocationResult {
143 pub success: bool,
145
146 #[serde(skip_serializing_if = "Option::is_none")]
148 pub output: Option<String>,
149
150 #[serde(skip_serializing_if = "Option::is_none")]
152 pub error: Option<String>,
153
154 #[serde(skip_serializing_if = "Option::is_none")]
156 pub child_execution_id: Option<ExecutionId>,
157
158 pub duration_ms: u64,
160
161 #[serde(skip_serializing_if = "Option::is_none")]
163 pub token_usage: Option<TokenUsage>,
164}
165
166impl CallableInvocationResult {
167 pub fn success(output: String, duration_ms: u64) -> Self {
169 Self {
170 success: true,
171 output: Some(output),
172 error: None,
173 child_execution_id: None,
174 duration_ms,
175 token_usage: None,
176 }
177 }
178
179 pub fn failure(error: impl Into<String>, duration_ms: u64) -> Self {
181 Self {
182 success: false,
183 output: None,
184 error: Some(error.into()),
185 child_execution_id: None,
186 duration_ms,
187 token_usage: None,
188 }
189 }
190
191 pub fn child_spawned(execution_id: ExecutionId, duration_ms: u64) -> Self {
193 Self {
194 success: true,
195 output: None,
196 error: None,
197 child_execution_id: Some(execution_id),
198 duration_ms,
199 token_usage: None,
200 }
201 }
202}
203
204#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
207#[serde(rename_all = "snake_case")]
208pub enum ResourceAllocationStrategy {
209 EqualSplit,
211 #[default]
213 SharedPool,
214 Priority,
216 Proportional,
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
223#[serde(rename_all = "camelCase")]
224pub struct ResourceBudget {
225 #[serde(skip_serializing_if = "Option::is_none")]
227 pub max_tokens: Option<u64>,
228
229 #[serde(skip_serializing_if = "Option::is_none")]
231 pub max_time_ms: Option<u64>,
232
233 #[serde(skip_serializing_if = "Option::is_none")]
235 pub max_cost_cents: Option<f64>,
236
237 #[serde(skip_serializing_if = "Option::is_none")]
239 pub max_children: Option<u32>,
240
241 #[serde(default = "default_max_depth")]
243 pub max_discovery_depth: u32,
244}
245
246fn default_max_depth() -> u32 {
247 3
248}
249
250impl Default for ResourceBudget {
251 fn default() -> Self {
252 Self {
253 max_tokens: None,
254 max_time_ms: None,
255 max_cost_cents: None,
256 max_children: None,
257 max_discovery_depth: default_max_depth(),
258 }
259 }
260}
261
262#[derive(Debug, Clone, Serialize, Deserialize)]
265#[serde(rename_all = "camelCase")]
266pub struct ResourceAllocation {
267 pub strategy: ResourceAllocationStrategy,
269
270 pub budget: ResourceBudget,
272
273 #[serde(default)]
275 pub used_tokens: u64,
276
277 #[serde(default)]
279 pub used_time_ms: u64,
280
281 #[serde(default)]
283 pub used_cost_cents: f64,
284
285 #[serde(default)]
287 pub children_spawned: u32,
288
289 #[serde(default)]
291 pub current_depth: u32,
292}
293
294impl ResourceAllocation {
295 pub fn new(strategy: ResourceAllocationStrategy, budget: ResourceBudget) -> Self {
297 Self {
298 strategy,
299 budget,
300 used_tokens: 0,
301 used_time_ms: 0,
302 used_cost_cents: 0.0,
303 children_spawned: 0,
304 current_depth: 0,
305 }
306 }
307
308 pub fn can_spawn_child(&self) -> bool {
310 match self.budget.max_children {
311 Some(max) => self.children_spawned < max,
312 None => true,
313 }
314 }
315
316 pub fn can_discover_deeper(&self) -> bool {
318 self.current_depth < self.budget.max_discovery_depth
319 }
320
321 pub fn has_token_budget(&self, tokens: u64) -> bool {
323 match self.budget.max_tokens {
324 Some(max) => self.used_tokens + tokens <= max,
325 None => true,
326 }
327 }
328
329 pub fn has_time_budget(&self, time_ms: u64) -> bool {
331 match self.budget.max_time_ms {
332 Some(max) => self.used_time_ms + time_ms <= max,
333 None => true,
334 }
335 }
336
337 pub fn record_tokens(&mut self, tokens: u64) {
339 self.used_tokens += tokens;
340 }
341
342 pub fn record_time(&mut self, time_ms: u64) {
344 self.used_time_ms += time_ms;
345 }
346
347 pub fn record_child_spawn(&mut self) {
349 self.children_spawned += 1;
350 }
351
352 pub fn increment_depth(&mut self) {
354 self.current_depth += 1;
355 }
356
357 pub fn child_allocation(&self) -> Self {
359 let mut child = self.clone();
360 child.increment_depth();
361
362 match self.strategy {
364 ResourceAllocationStrategy::EqualSplit => {
365 if let Some(max) = child.budget.max_tokens {
367 let remaining = max.saturating_sub(self.used_tokens);
368 child.budget.max_tokens = Some(remaining / 2);
369 }
370 if let Some(max) = child.budget.max_time_ms {
371 let remaining = max.saturating_sub(self.used_time_ms);
372 child.budget.max_time_ms = Some(remaining / 2);
373 }
374 }
375 ResourceAllocationStrategy::SharedPool => {
376 }
378 ResourceAllocationStrategy::Priority => {
379 if let Some(max) = child.budget.max_tokens {
381 let remaining = max.saturating_sub(self.used_tokens);
382 child.budget.max_tokens = Some((remaining * 80) / 100);
383 }
384 }
385 ResourceAllocationStrategy::Proportional => {
386 if let Some(max) = child.budget.max_tokens {
389 let remaining = max.saturating_sub(self.used_tokens);
390 child.budget.max_tokens = Some(remaining / 2);
391 }
392 }
393 }
394
395 child
396 }
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
402#[serde(rename_all = "camelCase")]
403pub struct DiscoveryQuery {
404 #[serde(skip_serializing_if = "Option::is_none")]
406 pub callable_type: Option<CallableType>,
407
408 #[serde(skip_serializing_if = "Option::is_none")]
410 pub tags: Option<Vec<String>>,
411
412 #[serde(skip_serializing_if = "Option::is_none")]
414 pub name_pattern: Option<String>,
415
416 #[serde(skip_serializing_if = "Option::is_none")]
418 pub max_cost_tier: Option<CostTier>,
419
420 #[serde(default = "default_limit")]
422 pub limit: usize,
423}
424
425fn default_limit() -> usize {
426 10
427}
428
429impl Default for DiscoveryQuery {
430 fn default() -> Self {
431 Self {
432 callable_type: None,
433 tags: None,
434 name_pattern: None,
435 max_cost_tier: None,
436 limit: default_limit(),
437 }
438 }
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
444#[serde(rename_all = "camelCase")]
445pub struct DiscoveryResult {
446 pub callables: Vec<CallableDescriptor>,
448
449 pub total_count: usize,
451
452 pub query: DiscoveryQuery,
454}
455
456pub struct CallableInvoker {
460 registry: CallableRegistry,
461 descriptors: HashMap<String, CallableDescriptor>,
462}
463
464impl CallableInvoker {
465 pub fn new(registry: CallableRegistry) -> Self {
467 Self {
468 registry,
469 descriptors: HashMap::new(),
470 }
471 }
472
473 pub fn register_descriptor(&mut self, descriptor: CallableDescriptor) {
475 self.descriptors.insert(descriptor.name.clone(), descriptor);
476 }
477
478 pub fn get(&self, name: &str) -> Option<DynCallable> {
480 self.registry.get(name)
481 }
482
483 pub async fn invoke(&self, invocation: CallableInvocation) -> CallableInvocationResult {
485 let start = Instant::now();
486
487 let callable = match self.registry.get(&invocation.callable_name) {
488 Some(c) => c,
489 None => {
490 return CallableInvocationResult::failure(
491 format!("Callable '{}' not found", invocation.callable_name),
492 start.elapsed().as_millis() as u64,
493 );
494 }
495 };
496
497 match invocation.spawn_mode {
498 SpawnMode::Inline => {
499 match callable.run(&invocation.input).await {
501 Ok(output) => CallableInvocationResult::success(
502 output,
503 start.elapsed().as_millis() as u64,
504 ),
505 Err(e) => CallableInvocationResult::failure(
506 e.to_string(),
507 start.elapsed().as_millis() as u64,
508 ),
509 }
510 }
511 SpawnMode::Child { background, .. } => {
512 if background {
513 let execution_id = ExecutionId::new();
515 CallableInvocationResult::child_spawned(
517 execution_id,
518 start.elapsed().as_millis() as u64,
519 )
520 } else {
521 match callable.run(&invocation.input).await {
523 Ok(output) => CallableInvocationResult::success(
524 output,
525 start.elapsed().as_millis() as u64,
526 ),
527 Err(e) => CallableInvocationResult::failure(
528 e.to_string(),
529 start.elapsed().as_millis() as u64,
530 ),
531 }
532 }
533 }
534 }
535 }
536
537 pub fn discover(&self, query: DiscoveryQuery) -> DiscoveryResult {
539 let mut matches: Vec<CallableDescriptor> = self
540 .descriptors
541 .values()
542 .filter(|desc| {
543 if let Some(ref t) = query.callable_type {
545 if &desc.callable_type != t {
546 return false;
547 }
548 }
549
550 if let Some(ref tags) = query.tags {
552 if !tags.iter().any(|t| desc.tags.contains(t)) {
553 return false;
554 }
555 }
556
557 if let Some(ref pattern) = query.name_pattern {
559 if !matches_glob(&desc.name, pattern) {
560 return false;
561 }
562 }
563
564 if let Some(ref max_tier) = query.max_cost_tier {
566 if !is_cost_tier_within(&desc.cost_tier, max_tier) {
567 return false;
568 }
569 }
570
571 true
572 })
573 .cloned()
574 .collect();
575
576 let total_count = matches.len();
577 matches.truncate(query.limit);
578
579 DiscoveryResult {
580 callables: matches,
581 total_count,
582 query,
583 }
584 }
585
586 pub fn list(&self) -> Vec<String> {
588 self.registry.list()
589 }
590}
591
592fn matches_glob(name: &str, pattern: &str) -> bool {
594 let mut name_chars = name.chars().peekable();
596 let mut pattern_chars = pattern.chars().peekable();
597
598 while let Some(p) = pattern_chars.next() {
599 match p {
600 '*' => {
601 if pattern_chars.peek().is_none() {
603 return true; }
605 let remaining_pattern: String = pattern_chars.collect();
607 let mut remaining_name = String::new();
608 loop {
609 if matches_glob(&remaining_name, &remaining_pattern) {
610 return true;
611 }
612 match name_chars.next() {
613 Some(c) => remaining_name.push(c),
614 None => return matches_glob("", &remaining_pattern),
615 }
616 }
617 }
618 '?' => {
619 if name_chars.next().is_none() {
621 return false;
622 }
623 }
624 c => {
625 match name_chars.next() {
627 Some(nc) if nc == c => continue,
628 _ => return false,
629 }
630 }
631 }
632 }
633
634 name_chars.next().is_none()
636}
637
638fn is_cost_tier_within(tier: &CostTier, max_tier: &CostTier) -> bool {
640 let tier_value = match tier {
641 CostTier::Free => 0,
642 CostTier::Low => 1,
643 CostTier::Medium => 2,
644 CostTier::High => 3,
645 CostTier::Premium => 4,
646 };
647 let max_value = match max_tier {
648 CostTier::Free => 0,
649 CostTier::Low => 1,
650 CostTier::Medium => 2,
651 CostTier::High => 3,
652 CostTier::Premium => 4,
653 };
654 tier_value <= max_value
655}
656
657#[cfg(test)]
658mod tests {
659 use super::*;
660 use async_trait::async_trait;
661 use std::sync::Arc;
662
663 struct TestCallable {
664 name: String,
665 output: String,
666 }
667
668 #[async_trait]
669 impl Callable for TestCallable {
670 fn name(&self) -> &str {
671 &self.name
672 }
673
674 async fn run(&self, _input: &str) -> anyhow::Result<String> {
675 Ok(self.output.clone())
676 }
677 }
678
679 #[test]
680 fn test_callable_descriptor() {
681 let callable = TestCallable {
682 name: "test".to_string(),
683 output: "output".to_string(),
684 };
685
686 let desc = CallableDescriptor::from_callable(&callable, CallableType::Agent)
687 .with_tags(vec!["research".to_string(), "analysis".to_string()])
688 .with_cost_tier(CostTier::High)
689 .with_spawn_capability(true);
690
691 assert_eq!(desc.name, "test");
692 assert_eq!(desc.callable_type, CallableType::Agent);
693 assert_eq!(desc.tags.len(), 2);
694 assert_eq!(desc.cost_tier, CostTier::High);
695 assert!(desc.can_spawn_children);
696 }
697
698 #[test]
699 fn test_resource_allocation() {
700 let budget = ResourceBudget {
701 max_tokens: Some(1000),
702 max_time_ms: Some(5000),
703 max_children: Some(3),
704 ..Default::default()
705 };
706
707 let mut allocation =
708 ResourceAllocation::new(ResourceAllocationStrategy::EqualSplit, budget);
709
710 assert!(allocation.can_spawn_child());
711 assert!(allocation.has_token_budget(500));
712
713 allocation.record_tokens(400);
714 allocation.record_child_spawn();
715
716 assert!(allocation.has_token_budget(500));
717 assert!(!allocation.has_token_budget(700));
718 assert!(allocation.can_spawn_child());
719
720 allocation.record_child_spawn();
721 allocation.record_child_spawn();
722 assert!(!allocation.can_spawn_child());
723 }
724
725 #[test]
726 fn test_child_allocation() {
727 let budget = ResourceBudget {
728 max_tokens: Some(1000),
729 ..Default::default()
730 };
731
732 let allocation = ResourceAllocation::new(ResourceAllocationStrategy::EqualSplit, budget);
733 let child = allocation.child_allocation();
734
735 assert_eq!(child.current_depth, 1);
736 assert_eq!(child.budget.max_tokens, Some(500)); }
738
739 #[tokio::test]
740 async fn test_callable_invoker() {
741 let registry = CallableRegistry::new();
742 let callable = Arc::new(TestCallable {
743 name: "test".to_string(),
744 output: "test output".to_string(),
745 });
746 registry.register("test".to_string(), callable);
747
748 let invoker = CallableInvoker::new(registry);
749
750 let invocation = CallableInvocation {
751 callable_name: "test".to_string(),
752 input: "input".to_string(),
753 context: None,
754 spawn_mode: SpawnMode::Inline,
755 priority: 50,
756 timeout_ms: None,
757 };
758
759 let result = invoker.invoke(invocation).await;
760 assert!(result.success);
761 assert_eq!(result.output, Some("test output".to_string()));
762 }
763
764 #[test]
765 fn test_discovery() {
766 let registry = CallableRegistry::new();
767 let mut invoker = CallableInvoker::new(registry);
768
769 invoker.register_descriptor(
770 CallableDescriptor::from_callable(
771 &TestCallable {
772 name: "research-agent".to_string(),
773 output: "".to_string(),
774 },
775 CallableType::Agent,
776 )
777 .with_tags(vec!["research".to_string()])
778 .with_cost_tier(CostTier::Medium),
779 );
780
781 invoker.register_descriptor(
782 CallableDescriptor::from_callable(
783 &TestCallable {
784 name: "analysis-agent".to_string(),
785 output: "".to_string(),
786 },
787 CallableType::Agent,
788 )
789 .with_tags(vec!["analysis".to_string()])
790 .with_cost_tier(CostTier::High),
791 );
792
793 let result = invoker.discover(DiscoveryQuery {
795 tags: Some(vec!["research".to_string()]),
796 ..Default::default()
797 });
798 assert_eq!(result.callables.len(), 1);
799 assert_eq!(result.callables[0].name, "research-agent");
800
801 let result = invoker.discover(DiscoveryQuery {
803 max_cost_tier: Some(CostTier::Medium),
804 ..Default::default()
805 });
806 assert_eq!(result.callables.len(), 1);
807
808 let result = invoker.discover(DiscoveryQuery::default());
810 assert_eq!(result.total_count, 2);
811 }
812
813 #[test]
814 fn test_cost_tier_comparison() {
815 assert!(is_cost_tier_within(&CostTier::Free, &CostTier::High));
816 assert!(is_cost_tier_within(&CostTier::Medium, &CostTier::Medium));
817 assert!(!is_cost_tier_within(&CostTier::High, &CostTier::Low));
818 }
819}