Skip to main content

oxi_agent/
state.rs

1/// Agent state management
2use crate::types::{StopReason, ToolResult};
3use oxi_ai::{ContentBlock, Message, TextContent};
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8/// Agent execution state
9///
10/// Tracks the full lifecycle of an agent conversation including messages,
11/// token usage, tool results, and iteration progress.
12///
13/// Derives `Serialize`/`Deserialize` for session persistence and
14/// cross-process state transfer (e.g. oxios supervisor serialization).
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct AgentState {
17    /// Conversation message history (user, assistant, and tool-result messages).
18    pub messages: Vec<Message>,
19    /// Current agent loop iteration (incremented after each assistant turn).
20    pub iteration: usize,
21    /// The reason the last turn stopped, if any.
22    pub stop_reason: Option<StopReason>,
23    /// Accumulated results from tool executions in the current conversation.
24    pub tool_results: Vec<ToolResult>,
25    /// Cumulative token count (input + output) across all turns.
26    pub total_tokens: usize,
27    /// Cumulative prompt / input tokens across all turns.
28    pub input_tokens: usize,
29    /// Cumulative completion / output tokens across all turns.
30    pub output_tokens: usize,
31    /// **Most-recent** reported input-token count from a single LLM response.
32    ///
33    /// Unlike [`Self::input_tokens`] (which is cumulative across all turns),
34    /// this is overwritten on every `ProviderEvent::Done` with the count
35    /// that turn actually sent to the model. It is the **ground-truth** signal
36    /// used by compaction when available — see [`Self::current_token_source`].
37    ///
38    /// `None` until the first `Done` event has been observed.
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub last_input_tokens: Option<usize>,
41    /// Heuristic token estimate (`bytes/4`) that was current when
42    /// `last_input_tokens` was last set. Used to detect the bytes/4 drift
43    /// described in #28 and surface it as a warning.
44    #[serde(default, skip_serializing_if = "Option::is_none")]
45    pub last_estimate_at_report: Option<usize>,
46    /// Divergence factor (reported / estimate) at the time of the last
47    /// `Done`. Surfaced in logs so the operator can see how badly the
48    /// `bytes/4` heuristic is undercounting on token-dense workloads.
49    /// `None` until first divergence observation.
50    #[serde(default, skip_serializing_if = "Option::is_none")]
51    pub last_estimate_divergence: Option<f64>,
52}
53
54/// Source of the value driving compaction / context-size decisions.
55///
56/// `Real` is the provider-reported input-token count from the most recent
57/// `Done` event — ground truth. `Heuristic` is the legacy
58/// `serialized_json.len() / 4` estimate. `None` means no messages have
59/// been sent yet (the loop has not observed any size at all).
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum TokenSource {
62    /// No size observation available yet (cold start).
63    None,
64    /// `bytes/4` estimate; reliable only for token-sparse prose.
65    Heuristic(usize),
66    /// Provider-reported input token count from the most recent `Done`.
67    Real(usize),
68}
69
70impl AgentState {
71    /// Create a new, default-initialized agent state.
72    pub fn new() -> Self {
73        Self::default()
74    }
75
76    /// Add a user message
77    pub fn add_user_message(&mut self, content: String) {
78        self.messages
79            .push(Message::User(oxi_ai::UserMessage::new(content)));
80    }
81
82    /// Add an assistant message
83    pub fn add_assistant_message(&mut self, content: String) {
84        let mut assistant =
85            oxi_ai::AssistantMessage::new(oxi_ai::Api::AnthropicMessages, "agent", "agent-model");
86        assistant.content = vec![ContentBlock::Text(TextContent::new(content))];
87        self.messages.push(Message::Assistant(assistant));
88    }
89
90    /// Add a tool result message to both the message history and the tool results list.
91    pub fn add_tool_result(&mut self, tool_call_id: String, content: String) {
92        let content_for_result = content.clone();
93        let tool_result_msg = oxi_ai::ToolResultMessage::new(
94            tool_call_id.clone(),
95            "tool",
96            vec![ContentBlock::Text(TextContent::new(content))],
97        );
98        self.messages
99            .push(oxi_ai::Message::ToolResult(tool_result_msg));
100        self.tool_results
101            .push(ToolResult::success(tool_call_id, content_for_result));
102    }
103
104    /// Increment the iteration counter after an assistant turn completes.
105    pub fn increment_iteration(&mut self) {
106        self.iteration += 1;
107    }
108
109    /// Record the reason the last turn stopped.
110    pub fn set_stop_reason(&mut self, reason: StopReason) {
111        self.stop_reason = Some(reason);
112    }
113
114    /// Accumulate token usage from a completed LLM call.
115    ///
116    /// `input` is **the input-token count for the just-completed turn** —
117    /// NOT a per-turn delta. The provider reports a fresh per-turn
118    /// `usage.input_tokens` on every `Done`; we accumulate it into
119    /// [`Self::input_tokens`] for lifetime accounting and **also** cache
120    /// it as the most-recent observation via [`Self::record_provider_turn`].
121    pub fn record_usage(&mut self, input: usize, output: usize) {
122        self.input_tokens += input;
123        self.output_tokens += output;
124        self.total_tokens += input + output;
125    }
126
127    /// Record the most recent provider-reported input-token count and the
128    /// heuristic estimate that was current at the time of the report, so
129    /// the loop can use the real count for compaction decisions and
130    /// surface drift (issue #28 gap 2).
131    ///
132    /// `input_tokens` is the value from `ProviderEvent::Done.message.usage.input`.
133    /// `estimate_at_report` is the bytes/4 estimate taken at the same moment
134    /// (i.e. the value the legacy path *would* have used for compaction).
135    pub fn record_provider_turn(&mut self, input_tokens: usize, estimate_at_report: usize) {
136        self.last_input_tokens = Some(input_tokens);
137        self.last_estimate_at_report = Some(estimate_at_report);
138        // Divergence = reported / estimate. A value > 1.0 means the
139        // heuristic under-counted; the documented failure in #28 saw
140        // ~3.5×. Guard against the zero-estimate case to avoid Inf/NaN.
141        self.last_estimate_divergence = if estimate_at_report > 0 {
142            Some(input_tokens as f64 / estimate_at_report as f64)
143        } else if input_tokens > 0 {
144            // An estimate of 0 against a non-zero report is the worst-case
145            // divergence: the heuristic was essentially blind to the
146            // context. Surface this as a high multiplier.
147            Some(f64::INFINITY)
148        } else {
149            Some(1.0)
150        };
151    }
152
153    /// Current best estimate of context size, tagged with its source.
154    ///
155    /// - `Real(n)` if the last completed turn reported `usage.input_tokens`.
156    /// - `Heuristic(n)` only before the first `Done` is observed (cold start).
157    /// - `None` if there are no messages yet.
158    ///
159    /// Callers (notably `maybe_compact`) should **prefer** `Real` and **only**
160    /// fall back to `Heuristic` on cold start. See issue #28.
161    pub fn current_token_source(&self) -> TokenSource {
162        if let Some(real) = self.last_input_tokens {
163            TokenSource::Real(real)
164        } else if !self.messages.is_empty() {
165            TokenSource::Heuristic(self.estimate_tokens())
166        } else {
167            TokenSource::None
168        }
169    }
170
171    /// Clear all state, resetting for a new conversation.
172    pub fn clear(&mut self) {
173        self.messages.clear();
174        self.iteration = 0;
175        self.stop_reason = None;
176        self.tool_results.clear();
177        self.total_tokens = 0;
178        self.input_tokens = 0;
179        self.output_tokens = 0;
180        self.last_input_tokens = None;
181        self.last_estimate_at_report = None;
182        self.last_estimate_divergence = None;
183    }
184
185    /// Replace the entire message history (used after context compaction).
186    pub fn replace_messages(&mut self, messages: Vec<Message>) {
187        self.messages = messages;
188    }
189
190    /// Rough token-count estimate based on the serialized message JSON length.
191    pub fn estimate_tokens(&self) -> usize {
192        let json = serde_json::to_string(&self.messages).unwrap_or_default();
193        json.len() / 4 // Rough approximation
194    }
195
196    /// Returns `true` if the agent has signaled a stop reason.
197    pub fn is_complete(&self) -> bool {
198        self.stop_reason.is_some()
199    }
200}
201
202/// Thread-safe agent state wrapper.
203#[derive(Default, Clone)]
204pub struct SharedState {
205    state: Arc<RwLock<AgentState>>,
206}
207
208impl SharedState {
209    /// Create a new SharedState with default (empty) agent state.
210    pub fn new() -> Self {
211        Self::default()
212    }
213
214    /// Obtain a snapshot of the current agent state.
215    pub fn get_state(&self) -> AgentState {
216        self.state.read().clone()
217    }
218
219    /// Mutably update the agent state under a write lock.
220    pub fn update<F>(&self, f: F)
221    where
222        F: FnOnce(&mut AgentState),
223    {
224        let mut state = self.state.write();
225        f(&mut state);
226    }
227
228    /// Reset the state for a new conversation (delegates to [`AgentState::clear`]).
229    pub fn reset(&self) {
230        let mut state = self.state.write();
231        state.clear();
232    }
233}