1use crate::budget::BudgetLimits;
6use crate::session::ToolVisibilityWitness;
7use crate::types::Message;
8use serde::{Deserialize, Serialize};
9use std::collections::{BTreeMap, BTreeSet};
10use uuid::Uuid;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct OperationId(pub Uuid);
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
21#[serde(rename_all = "snake_case")]
22pub enum WaitPolicy {
23 Barrier,
25 Detached,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
35pub struct AsyncOpRef {
36 pub operation_id: OperationId,
37 pub wait_policy: WaitPolicy,
38}
39
40impl WaitPolicy {
41 pub fn barrier() -> Self {
43 Self::Barrier
44 }
45
46 pub fn detached() -> Self {
48 Self::Detached
49 }
50}
51
52impl AsyncOpRef {
53 pub fn barrier(operation_id: OperationId) -> Self {
55 Self {
56 operation_id,
57 wait_policy: WaitPolicy::barrier(),
58 }
59 }
60
61 pub fn detached(operation_id: OperationId) -> Self {
63 Self {
64 operation_id,
65 wait_policy: WaitPolicy::detached(),
66 }
67 }
68}
69
70#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
83#[serde(tag = "effect_type", rename_all = "snake_case")]
84pub enum SessionEffect {
85 GrantManageMob { mob_id: String },
87 RequestDeferredTools {
89 names: BTreeSet<String>,
90 witnesses: BTreeMap<String, ToolVisibilityWitness>,
91 },
92}
93
94#[derive(Debug, Clone)]
95pub struct ToolDispatchOutcome {
96 pub result: crate::types::ToolResult,
98 pub async_ops: Vec<AsyncOpRef>,
103 pub session_effects: Vec<SessionEffect>,
109}
110
111impl ToolDispatchOutcome {
112 pub fn sync_result(result: crate::types::ToolResult) -> Self {
114 Self {
115 result,
116 async_ops: Vec::new(),
117 session_effects: Vec::new(),
118 }
119 }
120}
121
122impl From<crate::types::ToolResult> for ToolDispatchOutcome {
123 fn from(result: crate::types::ToolResult) -> Self {
124 Self::sync_result(result)
125 }
126}
127
128impl OperationId {
129 pub fn new() -> Self {
131 Self(Uuid::now_v7())
132 }
133}
134
135impl Default for OperationId {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141impl std::fmt::Display for OperationId {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 write!(f, "{}", self.0)
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
149#[serde(rename_all = "snake_case")]
150pub enum WorkKind {
151 ToolCall,
153 ShellCommand,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
159#[serde(rename_all = "snake_case")]
160pub enum ResultShape {
161 Single,
163 Stream,
165 Batch,
167}
168
169#[derive(Debug, Clone, Default, Serialize, Deserialize)]
171#[serde(tag = "type", content = "value", rename_all = "snake_case")]
172pub enum ContextStrategy {
173 #[default]
175 FullHistory,
176 LastTurns(u32),
178 Summary { max_tokens: u32 },
180 Custom { messages: Vec<Message> },
182}
183
184#[derive(Debug, Clone, Default, Serialize, Deserialize)]
186#[serde(tag = "type", content = "value", rename_all = "snake_case")]
187pub enum ForkBudgetPolicy {
188 #[default]
190 Equal,
191 Proportional,
193 Fixed(u64),
195 Remaining,
197}
198
199#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
201#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
202#[serde(tag = "type", content = "value", rename_all = "snake_case")]
203pub enum ToolAccessPolicy {
204 #[default]
206 Inherit,
207 AllowList(Vec<String>),
209 DenyList(Vec<String>),
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize, Default)]
215pub struct OperationPolicy {
216 pub timeout_ms: Option<u64>,
218 pub cancel_on_parent_cancel: bool,
220 pub checkpoint_results: bool,
222}
223
224#[derive(Debug, Clone, Serialize, Deserialize)]
226pub struct OperationSpec {
227 pub id: OperationId,
228 pub kind: WorkKind,
229 pub result_shape: ResultShape,
230 pub policy: OperationPolicy,
231 pub budget_reservation: BudgetLimits,
232 pub depth: u32,
233 pub depends_on: Vec<OperationId>,
234 pub context: Option<ContextStrategy>,
235 pub tool_access: Option<ToolAccessPolicy>,
236}
237
238#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
240pub struct OperationResult {
241 pub id: OperationId,
242 pub content: String,
243 pub is_error: bool,
244 pub duration_ms: u64,
245 pub tokens_used: u64,
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
250#[serde(tag = "type", rename_all = "snake_case")]
251pub enum OpEvent {
252 Started { id: OperationId, kind: WorkKind },
254
255 Progress {
257 id: OperationId,
258 message: String,
259 percent: Option<f32>,
260 },
261
262 Completed {
264 id: OperationId,
265 result: OperationResult,
266 },
267
268 Failed { id: OperationId, error: String },
270
271 Cancelled { id: OperationId },
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct ConcurrencyLimits {
278 pub max_depth: u32,
280 pub max_concurrent_ops: usize,
282 pub max_concurrent_agents: usize,
284 pub max_children_per_agent: usize,
286}
287
288impl Default for ConcurrencyLimits {
289 fn default() -> Self {
290 Self {
291 max_depth: 3,
292 max_concurrent_ops: 32,
293 max_concurrent_agents: 8,
294 max_children_per_agent: 5,
295 }
296 }
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize, Default)]
301pub struct SpawnSpec {
302 pub prompt: String,
304 pub context: ContextStrategy,
306 pub tool_access: ToolAccessPolicy,
308 pub budget: BudgetLimits,
310 pub allow_spawn: bool,
312 pub system_prompt: Option<String>,
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
318pub struct ForkBranch {
319 pub name: String,
321 pub prompt: String,
323 pub tool_access: Option<ToolAccessPolicy>,
325}
326
327#[cfg(test)]
328#[allow(clippy::unwrap_used, clippy::expect_used)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn barrier_constructor_produces_barrier_policy() {
334 assert_eq!(WaitPolicy::barrier(), WaitPolicy::Barrier);
335 let op_ref = AsyncOpRef::barrier(OperationId::new());
336 assert_eq!(op_ref.wait_policy, WaitPolicy::Barrier);
337 }
338
339 #[test]
340 fn detached_constructor_produces_detached_policy() {
341 assert_eq!(WaitPolicy::detached(), WaitPolicy::Detached);
342 let op_ref = AsyncOpRef::detached(OperationId::new());
343 assert_eq!(op_ref.wait_policy, WaitPolicy::Detached);
344 }
345
346 #[test]
347 fn test_operation_id_encoding() {
348 let id = OperationId::new();
349 let json = serde_json::to_string(&id).unwrap();
350
351 let parsed: OperationId = serde_json::from_str(&json).unwrap();
352 assert_eq!(id, parsed);
353 }
354
355 #[test]
356 fn test_work_kind_serialization() {
357 assert_eq!(
358 serde_json::to_value(WorkKind::ToolCall).unwrap(),
359 "tool_call"
360 );
361 assert_eq!(
362 serde_json::to_value(WorkKind::ShellCommand).unwrap(),
363 "shell_command"
364 );
365 }
366
367 #[test]
368 fn test_context_strategy_serialization() {
369 let full = ContextStrategy::FullHistory;
370 let json = serde_json::to_value(&full).unwrap();
371 assert_eq!(json["type"], "full_history");
372
373 let last = ContextStrategy::LastTurns(5);
374 let json = serde_json::to_value(&last).unwrap();
375 assert_eq!(json["type"], "last_turns");
376 assert_eq!(json["value"], 5);
378
379 let summary = ContextStrategy::Summary { max_tokens: 1000 };
380 let json = serde_json::to_value(&summary).unwrap();
381 assert_eq!(json["type"], "summary");
382 assert_eq!(json["value"]["max_tokens"], 1000);
384
385 let parsed: ContextStrategy = serde_json::from_value(json).unwrap();
387 match parsed {
388 ContextStrategy::Summary { max_tokens } => assert_eq!(max_tokens, 1000),
389 _ => unreachable!("Wrong variant"),
390 }
391 }
392
393 #[test]
394 fn test_fork_budget_policy_serialization() {
395 let policies = vec![
396 (ForkBudgetPolicy::Equal, "equal"),
397 (ForkBudgetPolicy::Proportional, "proportional"),
398 (ForkBudgetPolicy::Remaining, "remaining"),
399 ];
400
401 for (policy, expected_type) in policies {
402 let json = serde_json::to_value(&policy).unwrap();
403 assert_eq!(json["type"], expected_type);
404 }
405
406 let fixed = ForkBudgetPolicy::Fixed(5000);
407 let json = serde_json::to_value(&fixed).unwrap();
408 assert_eq!(json["type"], "fixed");
409 assert_eq!(json["value"], 5000);
411
412 let parsed: ForkBudgetPolicy = serde_json::from_value(json).unwrap();
414 match parsed {
415 ForkBudgetPolicy::Fixed(tokens) => assert_eq!(tokens, 5000),
416 _ => unreachable!("Wrong variant"),
417 }
418 }
419
420 #[test]
421 fn test_tool_access_policy_serialization() {
422 let inherit = ToolAccessPolicy::Inherit;
423 let json = serde_json::to_value(&inherit).unwrap();
424 assert_eq!(json["type"], "inherit");
425
426 let allow =
427 ToolAccessPolicy::AllowList(vec!["read_file".to_string(), "write_file".to_string()]);
428 let json = serde_json::to_value(&allow).unwrap();
429 assert_eq!(json["type"], "allow_list");
430 assert!(json["value"].is_array());
432
433 let deny = ToolAccessPolicy::DenyList(vec!["dangerous_tool".to_string()]);
434 let json = serde_json::to_value(&deny).unwrap();
435 assert_eq!(json["type"], "deny_list");
436 assert!(json["value"].is_array());
437
438 let parsed: ToolAccessPolicy = serde_json::from_value(json).unwrap();
440 match parsed {
441 ToolAccessPolicy::DenyList(tools) => {
442 assert_eq!(tools.len(), 1);
443 assert_eq!(tools[0], "dangerous_tool");
444 }
445 _ => unreachable!("Wrong variant"),
446 }
447 }
448
449 #[test]
450 fn test_op_event_serialization() {
451 let events = vec![
452 OpEvent::Started {
453 id: OperationId::new(),
454 kind: WorkKind::ToolCall,
455 },
456 OpEvent::Progress {
457 id: OperationId::new(),
458 message: "50% complete".to_string(),
459 percent: Some(0.5),
460 },
461 OpEvent::Completed {
462 id: OperationId::new(),
463 result: OperationResult {
464 id: OperationId::new(),
465 content: "result".to_string(),
466 is_error: false,
467 duration_ms: 100,
468 tokens_used: 50,
469 },
470 },
471 OpEvent::Failed {
472 id: OperationId::new(),
473 error: "timeout".to_string(),
474 },
475 OpEvent::Cancelled {
476 id: OperationId::new(),
477 },
478 ];
479
480 for event in events {
481 let json = serde_json::to_value(&event).unwrap();
482 assert!(json.get("type").is_some());
483
484 let _: OpEvent = serde_json::from_value(json).unwrap();
486 }
487 }
488
489 #[test]
490 fn test_concurrency_limits_default() {
491 let limits = ConcurrencyLimits::default();
492 assert_eq!(limits.max_depth, 3);
493 assert_eq!(limits.max_concurrent_ops, 32);
494 assert_eq!(limits.max_concurrent_agents, 8);
495 assert_eq!(limits.max_children_per_agent, 5);
496 }
497
498 #[test]
499 fn session_effect_grant_manage_mob_serde_round_trip() {
500 let effect = SessionEffect::GrantManageMob {
501 mob_id: "test-mob".into(),
502 };
503 let json = serde_json::to_value(&effect).unwrap();
504 let parsed: SessionEffect = serde_json::from_value(json).unwrap();
505 assert_eq!(effect, parsed);
506 }
507
508 #[test]
509 fn tool_dispatch_outcome_with_session_effects() {
510 let result = crate::types::ToolResult::new("t1".into(), "ok".into(), false);
511 let outcome = ToolDispatchOutcome {
512 result,
513 async_ops: vec![],
514 session_effects: vec![SessionEffect::GrantManageMob {
515 mob_id: "mob-1".into(),
516 }],
517 };
518 assert_eq!(outcome.session_effects.len(), 1);
519 }
520
521 #[test]
522 fn tool_dispatch_outcome_sync_result_has_empty_effects() {
523 let result = crate::types::ToolResult::new("t1".into(), "ok".into(), false);
524 let outcome = ToolDispatchOutcome::sync_result(result);
525 assert!(outcome.session_effects.is_empty());
526 }
527}