use crate::budget::BudgetLimits;
use crate::error::ToolError;
use crate::session::DeferredToolLoadAuthority;
use crate::types::{Message, ToolNameSet};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct OperationId(pub Uuid);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum WaitPolicy {
Barrier,
Detached,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct AsyncOpRef {
pub operation_id: OperationId,
pub wait_policy: WaitPolicy,
}
impl WaitPolicy {
pub fn barrier() -> Self {
Self::Barrier
}
pub fn detached() -> Self {
Self::Detached
}
}
impl AsyncOpRef {
pub fn barrier(operation_id: OperationId) -> Self {
Self {
operation_id,
wait_policy: WaitPolicy::barrier(),
}
}
pub fn detached(operation_id: OperationId) -> Self {
Self {
operation_id,
wait_policy: WaitPolicy::detached(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "effect_type", rename_all = "snake_case")]
pub enum SessionEffect {
GrantManageMob { mob_id: String },
RequestDeferredTools {
authorities: Vec<DeferredToolLoadAuthority>,
},
AppendAssistantBlocks {
blocks: Vec<crate::types::AssistantBlock>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolDispatchTerminalErrorKind {
NotFound,
Unavailable,
InvalidArguments,
ExecutionFailed,
Timeout,
AccessDenied,
Other,
CallbackPending,
}
impl From<&ToolError> for ToolDispatchTerminalErrorKind {
fn from(error: &ToolError) -> Self {
match error {
ToolError::NotFound { .. } => Self::NotFound,
ToolError::Unavailable { .. } => Self::Unavailable,
ToolError::InvalidArguments { .. } => Self::InvalidArguments,
ToolError::ExecutionFailed { .. } | ToolError::ExecutionFailedWithData { .. } => {
Self::ExecutionFailed
}
ToolError::Timeout { .. } => Self::Timeout,
ToolError::AccessDenied { .. } => Self::AccessDenied,
ToolError::Other(_) => Self::Other,
ToolError::CallbackPending { .. } => Self::CallbackPending,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolDispatchTerminalCause {
RuntimeToolError { kind: ToolDispatchTerminalErrorKind },
}
impl ToolDispatchTerminalCause {
#[must_use]
pub fn runtime_tool_error(error: &ToolError) -> Self {
Self::RuntimeToolError {
kind: ToolDispatchTerminalErrorKind::from(error),
}
}
#[must_use]
pub fn is_runtime_tool_timeout(self) -> bool {
matches!(
self,
Self::RuntimeToolError {
kind: ToolDispatchTerminalErrorKind::Timeout
}
)
}
}
#[derive(Debug, Clone)]
pub struct ToolDispatchOutcome {
pub result: crate::types::ToolResult,
pub async_ops: Vec<AsyncOpRef>,
pub session_effects: Vec<SessionEffect>,
terminal_cause: Option<ToolDispatchTerminalCause>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolDispatchTimeoutPolicy {
Default { timeout: std::time::Duration },
Disabled,
Finite { timeout: std::time::Duration },
}
impl ToolDispatchTimeoutPolicy {
#[must_use]
pub fn timeout(self) -> Option<std::time::Duration> {
match self {
Self::Default { timeout } | Self::Finite { timeout } => Some(timeout),
Self::Disabled => None,
}
}
#[must_use]
pub fn timeout_ms(self) -> Option<u64> {
self.timeout()
.map(|timeout| u64::try_from(timeout.as_millis()).unwrap_or(u64::MAX))
}
}
impl ToolDispatchOutcome {
pub fn new(
result: crate::types::ToolResult,
async_ops: Vec<AsyncOpRef>,
session_effects: Vec<SessionEffect>,
) -> Self {
Self {
result,
async_ops,
session_effects,
terminal_cause: None,
}
}
pub fn sync_result(result: crate::types::ToolResult) -> Self {
Self::new(result, Vec::new(), Vec::new())
}
#[must_use]
pub fn terminal_cause(&self) -> Option<ToolDispatchTerminalCause> {
self.terminal_cause
}
#[must_use]
pub fn is_runtime_tool_timeout(&self) -> bool {
self.terminal_cause
.is_some_and(ToolDispatchTerminalCause::is_runtime_tool_timeout)
}
pub(crate) fn clear_terminal_cause(&mut self) {
self.terminal_cause = None;
}
}
impl From<crate::types::ToolResult> for ToolDispatchOutcome {
fn from(result: crate::types::ToolResult) -> Self {
Self::sync_result(result)
}
}
pub fn terminal_tool_outcome_for_error(
tool_use_id: impl Into<String>,
error: ToolError,
) -> ToolDispatchOutcome {
let terminal_cause = ToolDispatchTerminalCause::runtime_tool_error(&error);
let payload = error.to_error_payload();
let serialized = serde_json::to_string(&payload)
.unwrap_or_else(|_| "{\"error\":\"tool_error\",\"message\":\"tool error\"}".to_string());
let mut outcome = ToolDispatchOutcome::sync_result(crate::types::ToolResult::new(
tool_use_id.into(),
serialized,
true,
));
outcome.terminal_cause = Some(terminal_cause);
outcome
}
impl OperationId {
pub fn new() -> Self {
Self(crate::time_compat::new_uuid_v7())
}
}
impl Default for OperationId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for OperationId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum WorkKind {
ToolCall,
ShellCommand,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ResultShape {
Single,
Stream,
Batch,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(tag = "type", content = "value", rename_all = "snake_case")]
pub enum ContextStrategy {
#[default]
FullHistory,
LastTurns(u32),
Summary { max_tokens: u32 },
Custom { messages: Vec<Message> },
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(tag = "type", content = "value", rename_all = "snake_case")]
pub enum ForkBudgetPolicy {
#[default]
Equal,
Proportional,
Fixed(u64),
Remaining,
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", content = "value", rename_all = "snake_case")]
pub enum ToolAccessPolicy {
#[default]
Inherit,
AllowList(ToolNameSet),
DenyList(ToolNameSet),
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct OperationPolicy {
pub timeout_ms: Option<u64>,
pub cancel_on_parent_cancel: bool,
pub checkpoint_results: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OperationSpec {
pub id: OperationId,
pub kind: WorkKind,
pub result_shape: ResultShape,
pub policy: OperationPolicy,
pub budget_reservation: BudgetLimits,
pub depth: u32,
pub depends_on: Vec<OperationId>,
pub context: Option<ContextStrategy>,
pub tool_access: Option<ToolAccessPolicy>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OperationResult {
pub id: OperationId,
pub content: String,
pub is_error: bool,
pub duration_ms: u64,
pub tokens_used: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum OpEvent {
Started { id: OperationId, kind: WorkKind },
Progress {
id: OperationId,
message: String,
percent: Option<f32>,
},
Completed {
id: OperationId,
result: OperationResult,
},
Failed { id: OperationId, error: String },
Cancelled { id: OperationId },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConcurrencyLimits {
pub max_depth: u32,
pub max_concurrent_ops: usize,
pub max_concurrent_agents: usize,
pub max_children_per_agent: usize,
}
impl Default for ConcurrencyLimits {
fn default() -> Self {
Self {
max_depth: 3,
max_concurrent_ops: 32,
max_concurrent_agents: 8,
max_children_per_agent: 5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct SpawnSpec {
pub prompt: String,
pub context: ContextStrategy,
pub tool_access: ToolAccessPolicy,
pub budget: BudgetLimits,
pub allow_spawn: bool,
pub system_prompt: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ForkBranch {
pub name: String,
pub prompt: String,
pub tool_access: Option<ToolAccessPolicy>,
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn barrier_constructor_produces_barrier_policy() {
assert_eq!(WaitPolicy::barrier(), WaitPolicy::Barrier);
let op_ref = AsyncOpRef::barrier(OperationId::new());
assert_eq!(op_ref.wait_policy, WaitPolicy::Barrier);
}
#[test]
fn detached_constructor_produces_detached_policy() {
assert_eq!(WaitPolicy::detached(), WaitPolicy::Detached);
let op_ref = AsyncOpRef::detached(OperationId::new());
assert_eq!(op_ref.wait_policy, WaitPolicy::Detached);
}
#[test]
fn test_operation_id_encoding() {
let id = OperationId::new();
let json = serde_json::to_string(&id).unwrap();
let parsed: OperationId = serde_json::from_str(&json).unwrap();
assert_eq!(id, parsed);
}
#[test]
fn test_work_kind_serialization() {
assert_eq!(
serde_json::to_value(WorkKind::ToolCall).unwrap(),
"tool_call"
);
assert_eq!(
serde_json::to_value(WorkKind::ShellCommand).unwrap(),
"shell_command"
);
}
#[test]
fn test_context_strategy_serialization() {
let full = ContextStrategy::FullHistory;
let json = serde_json::to_value(&full).unwrap();
assert_eq!(json["type"], "full_history");
let last = ContextStrategy::LastTurns(5);
let json = serde_json::to_value(&last).unwrap();
assert_eq!(json["type"], "last_turns");
assert_eq!(json["value"], 5);
let summary = ContextStrategy::Summary { max_tokens: 1000 };
let json = serde_json::to_value(&summary).unwrap();
assert_eq!(json["type"], "summary");
assert_eq!(json["value"]["max_tokens"], 1000);
let parsed: ContextStrategy = serde_json::from_value(json).unwrap();
match parsed {
ContextStrategy::Summary { max_tokens } => assert_eq!(max_tokens, 1000),
_ => unreachable!("Wrong variant"),
}
}
#[test]
fn test_fork_budget_policy_serialization() {
let policies = vec![
(ForkBudgetPolicy::Equal, "equal"),
(ForkBudgetPolicy::Proportional, "proportional"),
(ForkBudgetPolicy::Remaining, "remaining"),
];
for (policy, expected_type) in policies {
let json = serde_json::to_value(&policy).unwrap();
assert_eq!(json["type"], expected_type);
}
let fixed = ForkBudgetPolicy::Fixed(5000);
let json = serde_json::to_value(&fixed).unwrap();
assert_eq!(json["type"], "fixed");
assert_eq!(json["value"], 5000);
let parsed: ForkBudgetPolicy = serde_json::from_value(json).unwrap();
match parsed {
ForkBudgetPolicy::Fixed(tokens) => assert_eq!(tokens, 5000),
_ => unreachable!("Wrong variant"),
}
}
#[test]
fn test_tool_access_policy_serialization() {
let inherit = ToolAccessPolicy::Inherit;
let json = serde_json::to_value(&inherit).unwrap();
assert_eq!(json["type"], "inherit");
let allow = ToolAccessPolicy::AllowList(["read_file", "write_file"].into_iter().collect());
let json = serde_json::to_value(&allow).unwrap();
assert_eq!(json["type"], "allow_list");
assert!(json["value"].is_array());
let deny = ToolAccessPolicy::DenyList(["dangerous_tool"].into_iter().collect());
let json = serde_json::to_value(&deny).unwrap();
assert_eq!(json["type"], "deny_list");
assert!(json["value"].is_array());
let parsed: ToolAccessPolicy = serde_json::from_value(json).unwrap();
match parsed {
ToolAccessPolicy::DenyList(tools) => {
assert_eq!(tools.len(), 1);
assert!(tools.contains("dangerous_tool"));
}
_ => unreachable!("Wrong variant"),
}
}
#[test]
fn test_op_event_serialization() {
let events = vec![
OpEvent::Started {
id: OperationId::new(),
kind: WorkKind::ToolCall,
},
OpEvent::Progress {
id: OperationId::new(),
message: "50% complete".to_string(),
percent: Some(0.5),
},
OpEvent::Completed {
id: OperationId::new(),
result: OperationResult {
id: OperationId::new(),
content: "result".to_string(),
is_error: false,
duration_ms: 100,
tokens_used: 50,
},
},
OpEvent::Failed {
id: OperationId::new(),
error: "timeout".to_string(),
},
OpEvent::Cancelled {
id: OperationId::new(),
},
];
for event in events {
let json = serde_json::to_value(&event).unwrap();
assert!(json.get("type").is_some());
let _: OpEvent = serde_json::from_value(json).unwrap();
}
}
#[test]
fn test_concurrency_limits_default() {
let limits = ConcurrencyLimits::default();
assert_eq!(limits.max_depth, 3);
assert_eq!(limits.max_concurrent_ops, 32);
assert_eq!(limits.max_concurrent_agents, 8);
assert_eq!(limits.max_children_per_agent, 5);
}
#[test]
fn session_effect_grant_manage_mob_serde_round_trip() {
let effect = SessionEffect::GrantManageMob {
mob_id: "test-mob".into(),
};
let json = serde_json::to_value(&effect).unwrap();
let parsed: SessionEffect = serde_json::from_value(json).unwrap();
assert_eq!(effect, parsed);
}
#[test]
fn tool_dispatch_outcome_with_session_effects() {
let result = crate::types::ToolResult::new("t1".into(), "ok".into(), false);
let outcome = ToolDispatchOutcome::new(
result,
vec![],
vec![SessionEffect::GrantManageMob {
mob_id: "mob-1".into(),
}],
);
assert_eq!(outcome.session_effects.len(), 1);
assert_eq!(outcome.terminal_cause(), None);
}
#[test]
fn tool_dispatch_outcome_sync_result_has_empty_effects() {
let result = crate::types::ToolResult::new("t1".into(), "ok".into(), false);
let outcome = ToolDispatchOutcome::sync_result(result);
assert!(outcome.session_effects.is_empty());
assert_eq!(outcome.terminal_cause(), None);
}
#[test]
fn terminal_tool_outcome_carries_runtime_timeout_cause() {
let outcome = terminal_tool_outcome_for_error("t1", ToolError::timeout("slow_tool", 50));
assert!(outcome.result.is_error);
assert!(outcome.is_runtime_tool_timeout());
assert_eq!(
outcome.terminal_cause(),
Some(ToolDispatchTerminalCause::RuntimeToolError {
kind: ToolDispatchTerminalErrorKind::Timeout,
})
);
}
#[test]
fn tool_authored_error_result_has_no_runtime_terminal_cause() {
let result =
crate::types::ToolResult::new("t1".into(), "{\"error\":\"timeout\"}".into(), true);
let outcome = ToolDispatchOutcome::sync_result(result);
assert!(outcome.result.is_error);
assert!(!outcome.is_runtime_tool_timeout());
assert_eq!(outcome.terminal_cause(), None);
}
}