1use crate::budget::BudgetLimits;
6use crate::error::ToolError;
7use crate::session::DeferredToolLoadAuthority;
8use crate::types::{Message, ToolNameSet};
9use serde::{Deserialize, Serialize};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
14pub struct OperationId(pub Uuid);
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
21#[serde(rename_all = "snake_case")]
22pub enum WaitPolicy {
23 Barrier,
25 Detached,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
35pub struct AsyncOpRef {
36 pub operation_id: OperationId,
37 pub wait_policy: WaitPolicy,
38}
39
40impl WaitPolicy {
41 pub fn barrier() -> Self {
43 Self::Barrier
44 }
45
46 pub fn detached() -> Self {
48 Self::Detached
49 }
50}
51
52impl AsyncOpRef {
53 pub fn barrier(operation_id: OperationId) -> Self {
55 Self {
56 operation_id,
57 wait_policy: WaitPolicy::barrier(),
58 }
59 }
60
61 pub fn detached(operation_id: OperationId) -> Self {
63 Self {
64 operation_id,
65 wait_policy: WaitPolicy::detached(),
66 }
67 }
68}
69
70#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
83#[serde(tag = "effect_type", rename_all = "snake_case")]
84pub enum SessionEffect {
85 GrantManageMob { mob_id: String },
87 RequestDeferredTools {
89 authorities: Vec<DeferredToolLoadAuthority>,
90 },
91 AppendAssistantBlocks {
94 blocks: Vec<crate::types::AssistantBlock>,
95 },
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum ToolDispatchTerminalErrorKind {
100 NotFound,
101 Unavailable,
102 InvalidArguments,
103 ExecutionFailed,
104 Timeout,
105 AccessDenied,
106 Other,
107 CallbackPending,
108}
109
110impl From<&ToolError> for ToolDispatchTerminalErrorKind {
111 fn from(error: &ToolError) -> Self {
112 match error {
113 ToolError::NotFound { .. } => Self::NotFound,
114 ToolError::Unavailable { .. } => Self::Unavailable,
115 ToolError::InvalidArguments { .. } => Self::InvalidArguments,
116 ToolError::ExecutionFailed { .. } | ToolError::ExecutionFailedWithData { .. } => {
117 Self::ExecutionFailed
118 }
119 ToolError::Timeout { .. } => Self::Timeout,
120 ToolError::AccessDenied { .. } => Self::AccessDenied,
121 ToolError::Other(_) => Self::Other,
122 ToolError::CallbackPending { .. } => Self::CallbackPending,
123 }
124 }
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum ToolDispatchTerminalCause {
129 RuntimeToolError { kind: ToolDispatchTerminalErrorKind },
130}
131
132impl ToolDispatchTerminalCause {
133 #[must_use]
134 pub fn runtime_tool_error(error: &ToolError) -> Self {
135 Self::RuntimeToolError {
136 kind: ToolDispatchTerminalErrorKind::from(error),
137 }
138 }
139
140 #[must_use]
141 pub fn is_runtime_tool_timeout(self) -> bool {
142 matches!(
143 self,
144 Self::RuntimeToolError {
145 kind: ToolDispatchTerminalErrorKind::Timeout
146 }
147 )
148 }
149}
150
151#[derive(Debug, Clone)]
152pub struct ToolDispatchOutcome {
153 pub result: crate::types::ToolResult,
155 pub async_ops: Vec<AsyncOpRef>,
160 pub session_effects: Vec<SessionEffect>,
166 terminal_cause: Option<ToolDispatchTerminalCause>,
171}
172
173#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub enum ToolDispatchTimeoutPolicy {
176 Default { timeout: std::time::Duration },
178 Disabled,
181 Finite { timeout: std::time::Duration },
183}
184
185impl ToolDispatchTimeoutPolicy {
186 #[must_use]
187 pub fn timeout(self) -> Option<std::time::Duration> {
188 match self {
189 Self::Default { timeout } | Self::Finite { timeout } => Some(timeout),
190 Self::Disabled => None,
191 }
192 }
193
194 #[must_use]
195 pub fn timeout_ms(self) -> Option<u64> {
196 self.timeout()
197 .map(|timeout| u64::try_from(timeout.as_millis()).unwrap_or(u64::MAX))
198 }
199}
200
201impl ToolDispatchOutcome {
202 pub fn new(
204 result: crate::types::ToolResult,
205 async_ops: Vec<AsyncOpRef>,
206 session_effects: Vec<SessionEffect>,
207 ) -> Self {
208 Self {
209 result,
210 async_ops,
211 session_effects,
212 terminal_cause: None,
213 }
214 }
215
216 pub fn sync_result(result: crate::types::ToolResult) -> Self {
218 Self::new(result, Vec::new(), Vec::new())
219 }
220
221 #[must_use]
222 pub fn terminal_cause(&self) -> Option<ToolDispatchTerminalCause> {
223 self.terminal_cause
224 }
225
226 #[must_use]
227 pub fn is_runtime_tool_timeout(&self) -> bool {
228 self.terminal_cause
229 .is_some_and(ToolDispatchTerminalCause::is_runtime_tool_timeout)
230 }
231
232 pub(crate) fn clear_terminal_cause(&mut self) {
233 self.terminal_cause = None;
234 }
235}
236
237impl From<crate::types::ToolResult> for ToolDispatchOutcome {
238 fn from(result: crate::types::ToolResult) -> Self {
239 Self::sync_result(result)
240 }
241}
242
243pub fn terminal_tool_outcome_for_error(
246 tool_use_id: impl Into<String>,
247 error: ToolError,
248) -> ToolDispatchOutcome {
249 let terminal_cause = ToolDispatchTerminalCause::runtime_tool_error(&error);
250 let payload = error.to_error_payload();
251 let serialized = serde_json::to_string(&payload)
252 .unwrap_or_else(|_| "{\"error\":\"tool_error\",\"message\":\"tool error\"}".to_string());
253 let mut outcome = ToolDispatchOutcome::sync_result(crate::types::ToolResult::new(
254 tool_use_id.into(),
255 serialized,
256 true,
257 ));
258 outcome.terminal_cause = Some(terminal_cause);
259 outcome
260}
261
262impl OperationId {
263 pub fn new() -> Self {
265 Self(crate::time_compat::new_uuid_v7())
266 }
267}
268
269impl Default for OperationId {
270 fn default() -> Self {
271 Self::new()
272 }
273}
274
275impl std::fmt::Display for OperationId {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 write!(f, "{}", self.0)
278 }
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
283#[serde(rename_all = "snake_case")]
284pub enum WorkKind {
285 ToolCall,
287 ShellCommand,
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
293#[serde(rename_all = "snake_case")]
294pub enum ResultShape {
295 Single,
297 Stream,
299 Batch,
301}
302
303#[derive(Debug, Clone, Default, Serialize, Deserialize)]
305#[serde(tag = "type", content = "value", rename_all = "snake_case")]
306pub enum ContextStrategy {
307 #[default]
309 FullHistory,
310 LastTurns(u32),
312 Summary { max_tokens: u32 },
314 Custom { messages: Vec<Message> },
316}
317
318#[derive(Debug, Clone, Default, Serialize, Deserialize)]
320#[serde(tag = "type", content = "value", rename_all = "snake_case")]
321pub enum ForkBudgetPolicy {
322 #[default]
324 Equal,
325 Proportional,
327 Fixed(u64),
329 Remaining,
331}
332
333#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
335#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
336#[serde(tag = "type", content = "value", rename_all = "snake_case")]
337pub enum ToolAccessPolicy {
338 #[default]
340 Inherit,
341 AllowList(ToolNameSet),
343 DenyList(ToolNameSet),
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize, Default)]
349pub struct OperationPolicy {
350 pub timeout_ms: Option<u64>,
352 pub cancel_on_parent_cancel: bool,
354 pub checkpoint_results: bool,
356}
357
358#[derive(Debug, Clone, Serialize, Deserialize)]
360pub struct OperationSpec {
361 pub id: OperationId,
362 pub kind: WorkKind,
363 pub result_shape: ResultShape,
364 pub policy: OperationPolicy,
365 pub budget_reservation: BudgetLimits,
366 pub depth: u32,
367 pub depends_on: Vec<OperationId>,
368 pub context: Option<ContextStrategy>,
369 pub tool_access: Option<ToolAccessPolicy>,
370}
371
372#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
374pub struct OperationResult {
375 pub id: OperationId,
376 pub content: String,
377 pub is_error: bool,
378 pub duration_ms: u64,
379 pub tokens_used: u64,
380}
381
382#[derive(Debug, Clone, Serialize, Deserialize)]
384#[serde(tag = "type", rename_all = "snake_case")]
385pub enum OpEvent {
386 Started { id: OperationId, kind: WorkKind },
388
389 Progress {
391 id: OperationId,
392 message: String,
393 percent: Option<f32>,
394 },
395
396 Completed {
398 id: OperationId,
399 result: OperationResult,
400 },
401
402 Failed { id: OperationId, error: String },
404
405 Cancelled { id: OperationId },
407}
408
409#[derive(Debug, Clone, Serialize, Deserialize)]
411pub struct ConcurrencyLimits {
412 pub max_depth: u32,
414 pub max_concurrent_ops: usize,
416 pub max_concurrent_agents: usize,
418 pub max_children_per_agent: usize,
420}
421
422impl Default for ConcurrencyLimits {
423 fn default() -> Self {
424 Self {
425 max_depth: 3,
426 max_concurrent_ops: 32,
427 max_concurrent_agents: 8,
428 max_children_per_agent: 5,
429 }
430 }
431}
432
433#[derive(Debug, Clone, Serialize, Deserialize, Default)]
435pub struct SpawnSpec {
436 pub prompt: String,
438 pub context: ContextStrategy,
440 pub tool_access: ToolAccessPolicy,
442 pub budget: BudgetLimits,
444 pub allow_spawn: bool,
446 pub system_prompt: Option<String>,
448}
449
450#[derive(Debug, Clone, Serialize, Deserialize)]
452pub struct ForkBranch {
453 pub name: String,
455 pub prompt: String,
457 pub tool_access: Option<ToolAccessPolicy>,
459}
460
461#[cfg(test)]
462#[allow(clippy::unwrap_used, clippy::expect_used)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn barrier_constructor_produces_barrier_policy() {
468 assert_eq!(WaitPolicy::barrier(), WaitPolicy::Barrier);
469 let op_ref = AsyncOpRef::barrier(OperationId::new());
470 assert_eq!(op_ref.wait_policy, WaitPolicy::Barrier);
471 }
472
473 #[test]
474 fn detached_constructor_produces_detached_policy() {
475 assert_eq!(WaitPolicy::detached(), WaitPolicy::Detached);
476 let op_ref = AsyncOpRef::detached(OperationId::new());
477 assert_eq!(op_ref.wait_policy, WaitPolicy::Detached);
478 }
479
480 #[test]
481 fn test_operation_id_encoding() {
482 let id = OperationId::new();
483 let json = serde_json::to_string(&id).unwrap();
484
485 let parsed: OperationId = serde_json::from_str(&json).unwrap();
486 assert_eq!(id, parsed);
487 }
488
489 #[test]
490 fn test_work_kind_serialization() {
491 assert_eq!(
492 serde_json::to_value(WorkKind::ToolCall).unwrap(),
493 "tool_call"
494 );
495 assert_eq!(
496 serde_json::to_value(WorkKind::ShellCommand).unwrap(),
497 "shell_command"
498 );
499 }
500
501 #[test]
502 fn test_context_strategy_serialization() {
503 let full = ContextStrategy::FullHistory;
504 let json = serde_json::to_value(&full).unwrap();
505 assert_eq!(json["type"], "full_history");
506
507 let last = ContextStrategy::LastTurns(5);
508 let json = serde_json::to_value(&last).unwrap();
509 assert_eq!(json["type"], "last_turns");
510 assert_eq!(json["value"], 5);
512
513 let summary = ContextStrategy::Summary { max_tokens: 1000 };
514 let json = serde_json::to_value(&summary).unwrap();
515 assert_eq!(json["type"], "summary");
516 assert_eq!(json["value"]["max_tokens"], 1000);
518
519 let parsed: ContextStrategy = serde_json::from_value(json).unwrap();
521 match parsed {
522 ContextStrategy::Summary { max_tokens } => assert_eq!(max_tokens, 1000),
523 _ => unreachable!("Wrong variant"),
524 }
525 }
526
527 #[test]
528 fn test_fork_budget_policy_serialization() {
529 let policies = vec![
530 (ForkBudgetPolicy::Equal, "equal"),
531 (ForkBudgetPolicy::Proportional, "proportional"),
532 (ForkBudgetPolicy::Remaining, "remaining"),
533 ];
534
535 for (policy, expected_type) in policies {
536 let json = serde_json::to_value(&policy).unwrap();
537 assert_eq!(json["type"], expected_type);
538 }
539
540 let fixed = ForkBudgetPolicy::Fixed(5000);
541 let json = serde_json::to_value(&fixed).unwrap();
542 assert_eq!(json["type"], "fixed");
543 assert_eq!(json["value"], 5000);
545
546 let parsed: ForkBudgetPolicy = serde_json::from_value(json).unwrap();
548 match parsed {
549 ForkBudgetPolicy::Fixed(tokens) => assert_eq!(tokens, 5000),
550 _ => unreachable!("Wrong variant"),
551 }
552 }
553
554 #[test]
555 fn test_tool_access_policy_serialization() {
556 let inherit = ToolAccessPolicy::Inherit;
557 let json = serde_json::to_value(&inherit).unwrap();
558 assert_eq!(json["type"], "inherit");
559
560 let allow = ToolAccessPolicy::AllowList(["read_file", "write_file"].into_iter().collect());
561 let json = serde_json::to_value(&allow).unwrap();
562 assert_eq!(json["type"], "allow_list");
563 assert!(json["value"].is_array());
565
566 let deny = ToolAccessPolicy::DenyList(["dangerous_tool"].into_iter().collect());
567 let json = serde_json::to_value(&deny).unwrap();
568 assert_eq!(json["type"], "deny_list");
569 assert!(json["value"].is_array());
570
571 let parsed: ToolAccessPolicy = serde_json::from_value(json).unwrap();
573 match parsed {
574 ToolAccessPolicy::DenyList(tools) => {
575 assert_eq!(tools.len(), 1);
576 assert!(tools.contains("dangerous_tool"));
577 }
578 _ => unreachable!("Wrong variant"),
579 }
580 }
581
582 #[test]
583 fn test_op_event_serialization() {
584 let events = vec![
585 OpEvent::Started {
586 id: OperationId::new(),
587 kind: WorkKind::ToolCall,
588 },
589 OpEvent::Progress {
590 id: OperationId::new(),
591 message: "50% complete".to_string(),
592 percent: Some(0.5),
593 },
594 OpEvent::Completed {
595 id: OperationId::new(),
596 result: OperationResult {
597 id: OperationId::new(),
598 content: "result".to_string(),
599 is_error: false,
600 duration_ms: 100,
601 tokens_used: 50,
602 },
603 },
604 OpEvent::Failed {
605 id: OperationId::new(),
606 error: "timeout".to_string(),
607 },
608 OpEvent::Cancelled {
609 id: OperationId::new(),
610 },
611 ];
612
613 for event in events {
614 let json = serde_json::to_value(&event).unwrap();
615 assert!(json.get("type").is_some());
616
617 let _: OpEvent = serde_json::from_value(json).unwrap();
619 }
620 }
621
622 #[test]
623 fn test_concurrency_limits_default() {
624 let limits = ConcurrencyLimits::default();
625 assert_eq!(limits.max_depth, 3);
626 assert_eq!(limits.max_concurrent_ops, 32);
627 assert_eq!(limits.max_concurrent_agents, 8);
628 assert_eq!(limits.max_children_per_agent, 5);
629 }
630
631 #[test]
632 fn session_effect_grant_manage_mob_serde_round_trip() {
633 let effect = SessionEffect::GrantManageMob {
634 mob_id: "test-mob".into(),
635 };
636 let json = serde_json::to_value(&effect).unwrap();
637 let parsed: SessionEffect = serde_json::from_value(json).unwrap();
638 assert_eq!(effect, parsed);
639 }
640
641 #[test]
642 fn tool_dispatch_outcome_with_session_effects() {
643 let result = crate::types::ToolResult::new("t1".into(), "ok".into(), false);
644 let outcome = ToolDispatchOutcome::new(
645 result,
646 vec![],
647 vec![SessionEffect::GrantManageMob {
648 mob_id: "mob-1".into(),
649 }],
650 );
651 assert_eq!(outcome.session_effects.len(), 1);
652 assert_eq!(outcome.terminal_cause(), None);
653 }
654
655 #[test]
656 fn tool_dispatch_outcome_sync_result_has_empty_effects() {
657 let result = crate::types::ToolResult::new("t1".into(), "ok".into(), false);
658 let outcome = ToolDispatchOutcome::sync_result(result);
659 assert!(outcome.session_effects.is_empty());
660 assert_eq!(outcome.terminal_cause(), None);
661 }
662
663 #[test]
664 fn terminal_tool_outcome_carries_runtime_timeout_cause() {
665 let outcome = terminal_tool_outcome_for_error("t1", ToolError::timeout("slow_tool", 50));
666
667 assert!(outcome.result.is_error);
668 assert!(outcome.is_runtime_tool_timeout());
669 assert_eq!(
670 outcome.terminal_cause(),
671 Some(ToolDispatchTerminalCause::RuntimeToolError {
672 kind: ToolDispatchTerminalErrorKind::Timeout,
673 })
674 );
675 }
676
677 #[test]
678 fn tool_authored_error_result_has_no_runtime_terminal_cause() {
679 let result =
680 crate::types::ToolResult::new("t1".into(), "{\"error\":\"timeout\"}".into(), true);
681 let outcome = ToolDispatchOutcome::sync_result(result);
682
683 assert!(outcome.result.is_error);
684 assert!(!outcome.is_runtime_tool_timeout());
685 assert_eq!(outcome.terminal_cause(), None);
686 }
687}