Skip to main content

oxi_agent/
state.rs

1/// Agent state management
2use crate::types::{StopReason, ToolResult};
3use oxi_ai::{ContentBlock, Message, TextContent};
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8/// Agent execution state
9///
10/// Tracks the full lifecycle of an agent conversation including messages,
11/// token usage, tool results, and iteration progress.
12///
13/// Derives `Serialize`/`Deserialize` for session persistence and
14/// cross-process state transfer (e.g. oxios supervisor serialization).
15#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct AgentState {
17    /// Conversation message history (user, assistant, and tool-result messages).
18    pub messages: Vec<Message>,
19    /// Current agent loop iteration (incremented after each assistant turn).
20    pub iteration: usize,
21    /// The reason the last turn stopped, if any.
22    pub stop_reason: Option<StopReason>,
23    /// Accumulated results from tool executions in the current conversation.
24    pub tool_results: Vec<ToolResult>,
25    /// Cumulative token count (input + output) across all turns.
26    pub total_tokens: usize,
27    /// Cumulative prompt / input tokens across all turns.
28    pub input_tokens: usize,
29    /// Cumulative completion / output tokens across all turns.
30    pub output_tokens: usize,
31}
32
33impl AgentState {
34    /// Create a new, default-initialized agent state.
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    /// Add a user message
40    pub fn add_user_message(&mut self, content: String) {
41        self.messages
42            .push(Message::User(oxi_ai::UserMessage::new(content)));
43    }
44
45    /// Add an assistant message
46    pub fn add_assistant_message(&mut self, content: String) {
47        let mut assistant =
48            oxi_ai::AssistantMessage::new(oxi_ai::Api::AnthropicMessages, "agent", "agent-model");
49        assistant.content = vec![ContentBlock::Text(TextContent::new(content))];
50        self.messages.push(Message::Assistant(assistant));
51    }
52
53    /// Add a tool result message to both the message history and the tool results list.
54    pub fn add_tool_result(&mut self, tool_call_id: String, content: String) {
55        let content_for_result = content.clone();
56        let tool_result_msg = oxi_ai::ToolResultMessage::new(
57            tool_call_id.clone(),
58            "tool",
59            vec![ContentBlock::Text(TextContent::new(content))],
60        );
61        self.messages
62            .push(oxi_ai::Message::ToolResult(tool_result_msg));
63        self.tool_results
64            .push(ToolResult::success(tool_call_id, content_for_result));
65    }
66
67    /// Increment the iteration counter after an assistant turn completes.
68    pub fn increment_iteration(&mut self) {
69        self.iteration += 1;
70    }
71
72    /// Record the reason the last turn stopped.
73    pub fn set_stop_reason(&mut self, reason: StopReason) {
74        self.stop_reason = Some(reason);
75    }
76
77    /// Accumulate token usage from a completed LLM call.
78    pub fn record_usage(&mut self, input: usize, output: usize) {
79        self.input_tokens += input;
80        self.output_tokens += output;
81        self.total_tokens += input + output;
82    }
83
84    /// Clear all state, resetting for a new conversation.
85    pub fn clear(&mut self) {
86        self.messages.clear();
87        self.iteration = 0;
88        self.stop_reason = None;
89        self.tool_results.clear();
90        self.total_tokens = 0;
91        self.input_tokens = 0;
92        self.output_tokens = 0;
93    }
94
95    /// Replace the entire message history (used after context compaction).
96    pub fn replace_messages(&mut self, messages: Vec<Message>) {
97        self.messages = messages;
98    }
99
100    /// Rough token-count estimate based on the serialized message JSON length.
101    pub fn estimate_tokens(&self) -> usize {
102        let json = serde_json::to_string(&self.messages).unwrap_or_default();
103        json.len() / 4 // Rough approximation
104    }
105
106    /// Returns `true` if the agent has signaled a stop reason.
107    pub fn is_complete(&self) -> bool {
108        self.stop_reason.is_some()
109    }
110}
111
112/// Thread-safe agent state wrapper.
113#[derive(Default, Clone)]
114pub struct SharedState {
115    state: Arc<RwLock<AgentState>>,
116}
117
118impl SharedState {
119    /// Create a new SharedState with default (empty) agent state.
120    pub fn new() -> Self {
121        Self::default()
122    }
123
124    /// Obtain a snapshot of the current agent state.
125    pub fn get_state(&self) -> AgentState {
126        self.state.read().clone()
127    }
128
129    /// Mutably update the agent state under a write lock.
130    pub fn update<F>(&self, f: F)
131    where
132        F: FnOnce(&mut AgentState),
133    {
134        let mut state = self.state.write();
135        f(&mut state);
136    }
137
138    /// Reset the state for a new conversation (delegates to [`AgentState::clear`]).
139    pub fn reset(&self) {
140        let mut state = self.state.write();
141        state.clear();
142    }
143}