Skip to main content

ai_agents_runtime/optimization/
turn.rs

1use std::collections::HashMap;
2
3use ai_agents_state::Transition;
4use serde_json::Value;
5use uuid::Uuid;
6
7use super::branch::RuntimeOptimizationKind;
8
9/// Tracks lifecycle and staged writes for one optimized runtime turn.
10#[allow(dead_code)]
11#[derive(Debug, Clone)]
12pub struct TurnOptimizationContext {
13    /// Stable ID shared by branch and maintenance work for the current turn.
14    pub turn_id: Uuid,
15    /// Processed user input after the input pipeline.
16    pub processed_input: String,
17    /// Context values produced by the input process pipeline.
18    pub input_context: HashMap<String, Value>,
19    /// Context writes that must only commit if the selected path wins.
20    pub staged_context_writes: HashMap<String, Value>,
21    /// Whether actor memory and relationship loading ran for this turn.
22    pub pre_turn_lifecycle_completed: bool,
23    /// Whether the user message has been written to memory.
24    pub user_message_committed: bool,
25    /// Whether post-turn maintenance has been scheduled or completed.
26    pub post_turn_lifecycle_completed: bool,
27    /// Redispatch depth used to avoid repeated lifecycle work.
28    pub redispatch_depth: u32,
29    /// Number of speculative LLM calls used in this turn.
30    pub speculative_llm_calls_used: u32,
31    /// Maximum speculative LLM calls allowed in this turn.
32    pub max_speculative_llm_calls: u32,
33}
34
35#[allow(dead_code)]
36impl TurnOptimizationContext {
37    /// Creates a new turn context with no staged writes.
38    pub fn new(
39        processed_input: impl Into<String>,
40        input_context: HashMap<String, Value>,
41        max_speculative_llm_calls: u32,
42    ) -> Self {
43        Self {
44            turn_id: Uuid::new_v4(),
45            processed_input: processed_input.into(),
46            input_context,
47            staged_context_writes: HashMap::new(),
48            pre_turn_lifecycle_completed: false,
49            user_message_committed: false,
50            post_turn_lifecycle_completed: false,
51            redispatch_depth: 0,
52            speculative_llm_calls_used: 0,
53            max_speculative_llm_calls,
54        }
55    }
56
57    /// Returns true when another speculative LLM call can be scheduled.
58    pub fn reserve_speculative_llm_call(&mut self) -> bool {
59        if self.speculative_llm_calls_used >= self.max_speculative_llm_calls {
60            return false;
61        }
62        self.speculative_llm_calls_used += 1;
63        true
64    }
65
66    pub fn reserve_speculative_llm_call_for(&mut self, _kind: RuntimeOptimizationKind) -> bool {
67        self.reserve_speculative_llm_call()
68    }
69
70    pub fn release_or_mark_failed_reservation(&mut self, _kind: RuntimeOptimizationKind) {}
71
72    pub fn can_schedule_branch(&self, active_tasks: usize, max_parallel_tasks: usize) -> bool {
73        active_tasks < max_parallel_tasks
74    }
75
76    pub fn stage_context_write(&mut self, key: impl Into<String>, value: Value) {
77        self.staged_context_writes.insert(key.into(), value);
78    }
79
80    pub fn take_staged_context_writes(&mut self) -> HashMap<String, Value> {
81        std::mem::take(&mut self.staged_context_writes)
82    }
83
84    /// Returns true when the turn can schedule the requested number of branch calls.
85    pub fn reserve_speculative_llm_calls(&mut self, count: u32) -> bool {
86        if self.speculative_llm_calls_used + count > self.max_speculative_llm_calls {
87            return false;
88        }
89        self.speculative_llm_calls_used += count;
90        true
91    }
92
93    /// Marks the root user message as committed.
94    pub fn mark_user_message_committed(&mut self) {
95        self.user_message_committed = true;
96    }
97
98    /// Marks post-turn lifecycle work as completed.
99    pub fn mark_post_turn_lifecycle_completed(&mut self) {
100        self.post_turn_lifecycle_completed = true;
101    }
102
103    /// Enters a redispatch scope.
104    pub fn enter_redispatch(&mut self) {
105        self.redispatch_depth += 1;
106    }
107
108    /// Leaves a redispatch scope.
109    pub fn exit_redispatch(&mut self) {
110        self.redispatch_depth = self.redispatch_depth.saturating_sub(1);
111    }
112}
113
114/// Selected transition with enough data to commit side effects later.
115#[derive(Debug, Clone)]
116pub struct TransitionCandidate {
117    /// State path where the transition was selected.
118    pub from_state: String,
119    /// Selected transition definition.
120    pub transition: Transition,
121    /// Reason recorded in state history and hooks.
122    pub reason: String,
123}
124
125impl TransitionCandidate {
126    /// Creates a candidate from the current state and transition.
127    pub fn new(
128        from_state: impl Into<String>,
129        transition: Transition,
130        reason: impl Into<String>,
131    ) -> Self {
132        Self {
133            from_state: from_state.into(),
134            transition,
135            reason: reason.into(),
136        }
137    }
138
139    /// Returns the transition target string from the YAML definition.
140    pub fn target(&self) -> &str {
141        &self.transition.to
142    }
143}