oxi_agent/state.rs
1/// Agent state management
2use crate::types::{StopReason, ToolResult};
3use oxi_ai::{ContentBlock, Message, TextContent};
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8/// Agent execution state
9///
10/// Tracks the full lifecycle of an agent conversation including messages,
11/// token usage, tool results, and iteration progress.
12///
13/// Derives `Serialize`/`Deserialize` for session persistence and
14/// cross-process state transfer (e.g. oxios supervisor serialization).
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct AgentState {
17 /// Conversation message history (user, assistant, and tool-result messages).
18 pub messages: Vec<Message>,
19 /// Current agent loop iteration (incremented after each assistant turn).
20 pub iteration: usize,
21 /// The reason the last turn stopped, if any.
22 pub stop_reason: Option<StopReason>,
23 /// Accumulated results from tool executions in the current conversation.
24 pub tool_results: Vec<ToolResult>,
25 /// Cumulative token count (input + output) across all turns.
26 pub total_tokens: usize,
27 /// Cumulative prompt / input tokens across all turns.
28 pub input_tokens: usize,
29 /// Cumulative completion / output tokens across all turns.
30 pub output_tokens: usize,
31 /// **Most-recent** reported input-token count from a single LLM response.
32 ///
33 /// Unlike [`Self::input_tokens`] (which is cumulative across all turns),
34 /// this is overwritten on every `ProviderEvent::Done` with the count
35 /// that turn actually sent to the model. It is the **ground-truth** signal
36 /// used by compaction when available — see [`Self::current_token_source`].
37 ///
38 /// `None` until the first `Done` event has been observed.
39 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub last_input_tokens: Option<usize>,
41 /// Heuristic token estimate (`bytes/4`) that was current when
42 /// `last_input_tokens` was last set. Used to detect the bytes/4 drift
43 /// described in #28 and surface it as a warning.
44 #[serde(default, skip_serializing_if = "Option::is_none")]
45 pub last_estimate_at_report: Option<usize>,
46 /// Divergence factor (reported / estimate) at the time of the last
47 /// `Done`. Surfaced in logs so the operator can see how badly the
48 /// `bytes/4` heuristic is undercounting on token-dense workloads.
49 /// `None` until first divergence observation.
50 #[serde(default, skip_serializing_if = "Option::is_none")]
51 pub last_estimate_divergence: Option<f64>,
52}
53
54/// Source of the value driving compaction / context-size decisions.
55///
56/// `Real` is the provider-reported input-token count from the most recent
57/// `Done` event — ground truth. `Heuristic` is the legacy
58/// `serialized_json.len() / 4` estimate. `None` means no messages have
59/// been sent yet (the loop has not observed any size at all).
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum TokenSource {
62 /// No size observation available yet (cold start).
63 None,
64 /// `bytes/4` estimate; reliable only for token-sparse prose.
65 Heuristic(usize),
66 /// Provider-reported input token count from the most recent `Done`.
67 Real(usize),
68}
69
70impl AgentState {
71 /// Create a new, default-initialized agent state.
72 pub fn new() -> Self {
73 Self::default()
74 }
75
76 /// Add a user message
77 pub fn add_user_message(&mut self, content: String) {
78 self.messages
79 .push(Message::User(oxi_ai::UserMessage::new(content)));
80 }
81
82 /// Add an assistant message
83 pub fn add_assistant_message(&mut self, content: String) {
84 let mut assistant =
85 oxi_ai::AssistantMessage::new(oxi_ai::Api::AnthropicMessages, "agent", "agent-model");
86 assistant.content = vec![ContentBlock::Text(TextContent::new(content))];
87 self.messages.push(Message::Assistant(assistant));
88 }
89
90 /// Add a tool result message to both the message history and the tool results list.
91 pub fn add_tool_result(&mut self, tool_call_id: String, content: String) {
92 let content_for_result = content.clone();
93 let tool_result_msg = oxi_ai::ToolResultMessage::new(
94 tool_call_id.clone(),
95 "tool",
96 vec![ContentBlock::Text(TextContent::new(content))],
97 );
98 self.messages
99 .push(oxi_ai::Message::ToolResult(tool_result_msg));
100 self.tool_results
101 .push(ToolResult::success(tool_call_id, content_for_result));
102 }
103
104 /// Increment the iteration counter after an assistant turn completes.
105 pub fn increment_iteration(&mut self) {
106 self.iteration += 1;
107 }
108
109 /// Record the reason the last turn stopped.
110 pub fn set_stop_reason(&mut self, reason: StopReason) {
111 self.stop_reason = Some(reason);
112 }
113
114 /// Accumulate token usage from a completed LLM call.
115 ///
116 /// `input` is **the input-token count for the just-completed turn** —
117 /// NOT a per-turn delta. The provider reports a fresh per-turn
118 /// `usage.input_tokens` on every `Done`; we accumulate it into
119 /// [`Self::input_tokens`] for lifetime accounting and **also** cache
120 /// it as the most-recent observation via [`Self::record_provider_turn`].
121 pub fn record_usage(&mut self, input: usize, output: usize) {
122 self.input_tokens += input;
123 self.output_tokens += output;
124 self.total_tokens += input + output;
125 }
126
127 /// Record the most recent provider-reported input-token count and the
128 /// heuristic estimate that was current at the time of the report, so
129 /// the loop can use the real count for compaction decisions and
130 /// surface drift (issue #28 gap 2).
131 ///
132 /// `input_tokens` is the value from `ProviderEvent::Done.message.usage.input`.
133 /// `estimate_at_report` is the bytes/4 estimate taken at the same moment
134 /// (i.e. the value the legacy path *would* have used for compaction).
135 pub fn record_provider_turn(&mut self, input_tokens: usize, estimate_at_report: usize) {
136 self.last_input_tokens = Some(input_tokens);
137 self.last_estimate_at_report = Some(estimate_at_report);
138 // Divergence = reported / estimate. A value > 1.0 means the
139 // heuristic under-counted; the documented failure in #28 saw
140 // ~3.5×. Guard against the zero-estimate case to avoid Inf/NaN.
141 self.last_estimate_divergence = if estimate_at_report > 0 {
142 Some(input_tokens as f64 / estimate_at_report as f64)
143 } else if input_tokens > 0 {
144 // An estimate of 0 against a non-zero report is the worst-case
145 // divergence: the heuristic was essentially blind to the
146 // context. Surface this as a high multiplier.
147 Some(f64::INFINITY)
148 } else {
149 Some(1.0)
150 };
151 }
152
153 /// Current best estimate of context size, tagged with its source.
154 ///
155 /// - `Real(n)` if the last completed turn reported `usage.input_tokens`.
156 /// - `Heuristic(n)` only before the first `Done` is observed (cold start).
157 /// - `None` if there are no messages yet.
158 ///
159 /// Callers (notably `maybe_compact`) should **prefer** `Real` and **only**
160 /// fall back to `Heuristic` on cold start. See issue #28.
161 pub fn current_token_source(&self) -> TokenSource {
162 if let Some(real) = self.last_input_tokens {
163 TokenSource::Real(real)
164 } else if !self.messages.is_empty() {
165 TokenSource::Heuristic(self.estimate_tokens())
166 } else {
167 TokenSource::None
168 }
169 }
170
171 /// Clear all state, resetting for a new conversation.
172 pub fn clear(&mut self) {
173 self.messages.clear();
174 self.iteration = 0;
175 self.stop_reason = None;
176 self.tool_results.clear();
177 self.total_tokens = 0;
178 self.input_tokens = 0;
179 self.output_tokens = 0;
180 self.last_input_tokens = None;
181 self.last_estimate_at_report = None;
182 self.last_estimate_divergence = None;
183 }
184
185 /// Replace the entire message history (used after context compaction).
186 pub fn replace_messages(&mut self, messages: Vec<Message>) {
187 self.messages = messages;
188 }
189
190 /// Rough token-count estimate based on the serialized message JSON length.
191 pub fn estimate_tokens(&self) -> usize {
192 let json = serde_json::to_string(&self.messages).unwrap_or_default();
193 json.len() / 4 // Rough approximation
194 }
195
196 /// Returns `true` if the agent has signaled a stop reason.
197 pub fn is_complete(&self) -> bool {
198 self.stop_reason.is_some()
199 }
200}
201
202/// Thread-safe agent state wrapper.
203#[derive(Default, Clone)]
204pub struct SharedState {
205 state: Arc<RwLock<AgentState>>,
206}
207
208impl SharedState {
209 /// Create a new SharedState with default (empty) agent state.
210 pub fn new() -> Self {
211 Self::default()
212 }
213
214 /// Obtain a snapshot of the current agent state.
215 pub fn get_state(&self) -> AgentState {
216 self.state.read().clone()
217 }
218
219 /// Mutably update the agent state under a write lock.
220 pub fn update<F>(&self, f: F)
221 where
222 F: FnOnce(&mut AgentState),
223 {
224 let mut state = self.state.write();
225 f(&mut state);
226 }
227
228 /// Reset the state for a new conversation (delegates to [`AgentState::clear`]).
229 pub fn reset(&self) {
230 let mut state = self.state.write();
231 state.clear();
232 }
233}