use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Instant;
use super::{Callable, CallableRegistry, DynCallable};
use crate::kernel::ids::{CallableType, ExecutionId, SpawnMode};
use crate::kernel::TokenUsage;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum CostTier {
Free,
Low,
#[default]
Medium,
High,
Premium,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CallableDescriptor {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub callable_type: CallableType,
#[serde(skip_serializing_if = "Option::is_none")]
pub input_schema: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_schema: Option<serde_json::Value>,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub can_spawn_children: bool,
#[serde(default)]
pub cost_tier: CostTier,
#[serde(skip_serializing_if = "Option::is_none")]
pub avg_latency_ms: Option<u64>,
}
impl CallableDescriptor {
pub fn from_callable(callable: &dyn Callable, callable_type: CallableType) -> Self {
Self {
name: callable.name().to_string(),
description: callable.description().map(String::from),
callable_type,
input_schema: None,
output_schema: None,
tags: Vec::new(),
can_spawn_children: false,
cost_tier: CostTier::Medium,
avg_latency_ms: None,
}
}
pub fn with_tags(mut self, tags: Vec<String>) -> Self {
self.tags = tags;
self
}
pub fn with_cost_tier(mut self, tier: CostTier) -> Self {
self.cost_tier = tier;
self
}
pub fn with_spawn_capability(mut self, can_spawn: bool) -> Self {
self.can_spawn_children = can_spawn;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CallableInvocation {
pub callable_name: String,
pub input: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub context: Option<HashMap<String, String>>,
#[serde(default)]
pub spawn_mode: SpawnMode,
#[serde(default = "default_priority")]
pub priority: u8,
#[serde(skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
}
fn default_priority() -> u8 {
50
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CallableInvocationResult {
pub success: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub output: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub child_execution_id: Option<ExecutionId>,
pub duration_ms: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_usage: Option<TokenUsage>,
}
impl CallableInvocationResult {
pub fn success(output: String, duration_ms: u64) -> Self {
Self {
success: true,
output: Some(output),
error: None,
child_execution_id: None,
duration_ms,
token_usage: None,
}
}
pub fn failure(error: impl Into<String>, duration_ms: u64) -> Self {
Self {
success: false,
output: None,
error: Some(error.into()),
child_execution_id: None,
duration_ms,
token_usage: None,
}
}
pub fn child_spawned(execution_id: ExecutionId, duration_ms: u64) -> Self {
Self {
success: true,
output: None,
error: None,
child_execution_id: Some(execution_id),
duration_ms,
token_usage: None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ResourceAllocationStrategy {
EqualSplit,
#[default]
SharedPool,
Priority,
Proportional,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceBudget {
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_time_ms: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_cost_cents: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_children: Option<u32>,
#[serde(default = "default_max_depth")]
pub max_discovery_depth: u32,
}
fn default_max_depth() -> u32 {
3
}
impl Default for ResourceBudget {
fn default() -> Self {
Self {
max_tokens: None,
max_time_ms: None,
max_cost_cents: None,
max_children: None,
max_discovery_depth: default_max_depth(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ResourceAllocation {
pub strategy: ResourceAllocationStrategy,
pub budget: ResourceBudget,
#[serde(default)]
pub used_tokens: u64,
#[serde(default)]
pub used_time_ms: u64,
#[serde(default)]
pub used_cost_cents: f64,
#[serde(default)]
pub children_spawned: u32,
#[serde(default)]
pub current_depth: u32,
}
impl ResourceAllocation {
pub fn new(strategy: ResourceAllocationStrategy, budget: ResourceBudget) -> Self {
Self {
strategy,
budget,
used_tokens: 0,
used_time_ms: 0,
used_cost_cents: 0.0,
children_spawned: 0,
current_depth: 0,
}
}
pub fn can_spawn_child(&self) -> bool {
match self.budget.max_children {
Some(max) => self.children_spawned < max,
None => true,
}
}
pub fn can_discover_deeper(&self) -> bool {
self.current_depth < self.budget.max_discovery_depth
}
pub fn has_token_budget(&self, tokens: u64) -> bool {
match self.budget.max_tokens {
Some(max) => self.used_tokens + tokens <= max,
None => true,
}
}
pub fn has_time_budget(&self, time_ms: u64) -> bool {
match self.budget.max_time_ms {
Some(max) => self.used_time_ms + time_ms <= max,
None => true,
}
}
pub fn record_tokens(&mut self, tokens: u64) {
self.used_tokens += tokens;
}
pub fn record_time(&mut self, time_ms: u64) {
self.used_time_ms += time_ms;
}
pub fn record_child_spawn(&mut self) {
self.children_spawned += 1;
}
pub fn increment_depth(&mut self) {
self.current_depth += 1;
}
pub fn child_allocation(&self) -> Self {
let mut child = self.clone();
child.increment_depth();
match self.strategy {
ResourceAllocationStrategy::EqualSplit => {
if let Some(max) = child.budget.max_tokens {
let remaining = max.saturating_sub(self.used_tokens);
child.budget.max_tokens = Some(remaining / 2);
}
if let Some(max) = child.budget.max_time_ms {
let remaining = max.saturating_sub(self.used_time_ms);
child.budget.max_time_ms = Some(remaining / 2);
}
}
ResourceAllocationStrategy::SharedPool => {
}
ResourceAllocationStrategy::Priority => {
if let Some(max) = child.budget.max_tokens {
let remaining = max.saturating_sub(self.used_tokens);
child.budget.max_tokens = Some((remaining * 80) / 100);
}
}
ResourceAllocationStrategy::Proportional => {
if let Some(max) = child.budget.max_tokens {
let remaining = max.saturating_sub(self.used_tokens);
child.budget.max_tokens = Some(remaining / 2);
}
}
}
child
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DiscoveryQuery {
#[serde(skip_serializing_if = "Option::is_none")]
pub callable_type: Option<CallableType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name_pattern: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_cost_tier: Option<CostTier>,
#[serde(default = "default_limit")]
pub limit: usize,
}
fn default_limit() -> usize {
10
}
impl Default for DiscoveryQuery {
fn default() -> Self {
Self {
callable_type: None,
tags: None,
name_pattern: None,
max_cost_tier: None,
limit: default_limit(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DiscoveryResult {
pub callables: Vec<CallableDescriptor>,
pub total_count: usize,
pub query: DiscoveryQuery,
}
pub struct CallableInvoker {
registry: CallableRegistry,
descriptors: HashMap<String, CallableDescriptor>,
}
impl CallableInvoker {
pub fn new(registry: CallableRegistry) -> Self {
Self {
registry,
descriptors: HashMap::new(),
}
}
pub fn register_descriptor(&mut self, descriptor: CallableDescriptor) {
self.descriptors.insert(descriptor.name.clone(), descriptor);
}
pub fn get(&self, name: &str) -> Option<DynCallable> {
self.registry.get(name)
}
pub async fn invoke(&self, invocation: CallableInvocation) -> CallableInvocationResult {
let start = Instant::now();
let callable = match self.registry.get(&invocation.callable_name) {
Some(c) => c,
None => {
return CallableInvocationResult::failure(
format!("Callable '{}' not found", invocation.callable_name),
start.elapsed().as_millis() as u64,
);
}
};
match invocation.spawn_mode {
SpawnMode::Inline => {
match callable.run(&invocation.input).await {
Ok(output) => CallableInvocationResult::success(
output,
start.elapsed().as_millis() as u64,
),
Err(e) => CallableInvocationResult::failure(
e.to_string(),
start.elapsed().as_millis() as u64,
),
}
}
SpawnMode::Child { background, .. } => {
if background {
let execution_id = ExecutionId::new();
CallableInvocationResult::child_spawned(
execution_id,
start.elapsed().as_millis() as u64,
)
} else {
match callable.run(&invocation.input).await {
Ok(output) => CallableInvocationResult::success(
output,
start.elapsed().as_millis() as u64,
),
Err(e) => CallableInvocationResult::failure(
e.to_string(),
start.elapsed().as_millis() as u64,
),
}
}
}
}
}
pub fn discover(&self, query: DiscoveryQuery) -> DiscoveryResult {
let mut matches: Vec<CallableDescriptor> = self
.descriptors
.values()
.filter(|desc| {
if let Some(ref t) = query.callable_type {
if &desc.callable_type != t {
return false;
}
}
if let Some(ref tags) = query.tags {
if !tags.iter().any(|t| desc.tags.contains(t)) {
return false;
}
}
if let Some(ref pattern) = query.name_pattern {
if !matches_glob(&desc.name, pattern) {
return false;
}
}
if let Some(ref max_tier) = query.max_cost_tier {
if !is_cost_tier_within(&desc.cost_tier, max_tier) {
return false;
}
}
true
})
.cloned()
.collect();
let total_count = matches.len();
matches.truncate(query.limit);
DiscoveryResult {
callables: matches,
total_count,
query,
}
}
pub fn list(&self) -> Vec<String> {
self.registry.list()
}
}
fn matches_glob(name: &str, pattern: &str) -> bool {
let mut name_chars = name.chars().peekable();
let mut pattern_chars = pattern.chars().peekable();
while let Some(p) = pattern_chars.next() {
match p {
'*' => {
if pattern_chars.peek().is_none() {
return true; }
let remaining_pattern: String = pattern_chars.collect();
let mut remaining_name = String::new();
loop {
if matches_glob(&remaining_name, &remaining_pattern) {
return true;
}
match name_chars.next() {
Some(c) => remaining_name.push(c),
None => return matches_glob("", &remaining_pattern),
}
}
}
'?' => {
if name_chars.next().is_none() {
return false;
}
}
c => {
match name_chars.next() {
Some(nc) if nc == c => continue,
_ => return false,
}
}
}
}
name_chars.next().is_none()
}
fn is_cost_tier_within(tier: &CostTier, max_tier: &CostTier) -> bool {
let tier_value = match tier {
CostTier::Free => 0,
CostTier::Low => 1,
CostTier::Medium => 2,
CostTier::High => 3,
CostTier::Premium => 4,
};
let max_value = match max_tier {
CostTier::Free => 0,
CostTier::Low => 1,
CostTier::Medium => 2,
CostTier::High => 3,
CostTier::Premium => 4,
};
tier_value <= max_value
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::Arc;
struct TestCallable {
name: String,
output: String,
}
#[async_trait]
impl Callable for TestCallable {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, _input: &str) -> anyhow::Result<String> {
Ok(self.output.clone())
}
}
#[test]
fn test_callable_descriptor() {
let callable = TestCallable {
name: "test".to_string(),
output: "output".to_string(),
};
let desc = CallableDescriptor::from_callable(&callable, CallableType::Agent)
.with_tags(vec!["research".to_string(), "analysis".to_string()])
.with_cost_tier(CostTier::High)
.with_spawn_capability(true);
assert_eq!(desc.name, "test");
assert_eq!(desc.callable_type, CallableType::Agent);
assert_eq!(desc.tags.len(), 2);
assert_eq!(desc.cost_tier, CostTier::High);
assert!(desc.can_spawn_children);
}
#[test]
fn test_resource_allocation() {
let budget = ResourceBudget {
max_tokens: Some(1000),
max_time_ms: Some(5000),
max_children: Some(3),
..Default::default()
};
let mut allocation =
ResourceAllocation::new(ResourceAllocationStrategy::EqualSplit, budget);
assert!(allocation.can_spawn_child());
assert!(allocation.has_token_budget(500));
allocation.record_tokens(400);
allocation.record_child_spawn();
assert!(allocation.has_token_budget(500));
assert!(!allocation.has_token_budget(700));
assert!(allocation.can_spawn_child());
allocation.record_child_spawn();
allocation.record_child_spawn();
assert!(!allocation.can_spawn_child());
}
#[test]
fn test_child_allocation() {
let budget = ResourceBudget {
max_tokens: Some(1000),
..Default::default()
};
let allocation = ResourceAllocation::new(ResourceAllocationStrategy::EqualSplit, budget);
let child = allocation.child_allocation();
assert_eq!(child.current_depth, 1);
assert_eq!(child.budget.max_tokens, Some(500)); }
#[tokio::test]
async fn test_callable_invoker() {
let registry = CallableRegistry::new();
let callable = Arc::new(TestCallable {
name: "test".to_string(),
output: "test output".to_string(),
});
registry.register("test".to_string(), callable);
let invoker = CallableInvoker::new(registry);
let invocation = CallableInvocation {
callable_name: "test".to_string(),
input: "input".to_string(),
context: None,
spawn_mode: SpawnMode::Inline,
priority: 50,
timeout_ms: None,
};
let result = invoker.invoke(invocation).await;
assert!(result.success);
assert_eq!(result.output, Some("test output".to_string()));
}
#[test]
fn test_discovery() {
let registry = CallableRegistry::new();
let mut invoker = CallableInvoker::new(registry);
invoker.register_descriptor(
CallableDescriptor::from_callable(
&TestCallable {
name: "research-agent".to_string(),
output: "".to_string(),
},
CallableType::Agent,
)
.with_tags(vec!["research".to_string()])
.with_cost_tier(CostTier::Medium),
);
invoker.register_descriptor(
CallableDescriptor::from_callable(
&TestCallable {
name: "analysis-agent".to_string(),
output: "".to_string(),
},
CallableType::Agent,
)
.with_tags(vec!["analysis".to_string()])
.with_cost_tier(CostTier::High),
);
let result = invoker.discover(DiscoveryQuery {
tags: Some(vec!["research".to_string()]),
..Default::default()
});
assert_eq!(result.callables.len(), 1);
assert_eq!(result.callables[0].name, "research-agent");
let result = invoker.discover(DiscoveryQuery {
max_cost_tier: Some(CostTier::Medium),
..Default::default()
});
assert_eq!(result.callables.len(), 1);
let result = invoker.discover(DiscoveryQuery::default());
assert_eq!(result.total_count, 2);
}
#[test]
fn test_cost_tier_comparison() {
assert!(is_cost_tier_within(&CostTier::Free, &CostTier::High));
assert!(is_cost_tier_within(&CostTier::Medium, &CostTier::Medium));
assert!(!is_cost_tier_within(&CostTier::High, &CostTier::Low));
}
}