1use crate::types::{StopReason, ToolResult};
3use oxi_ai::{ContentBlock, Message, TextContent};
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7
8#[derive(Debug, Clone, Serialize, Deserialize, Default)]
16pub struct AgentState {
17 pub messages: Vec<Message>,
19 pub iteration: usize,
21 pub stop_reason: Option<StopReason>,
23 pub tool_results: Vec<ToolResult>,
25 pub total_tokens: usize,
27 pub input_tokens: usize,
29 pub output_tokens: usize,
31}
32
33impl AgentState {
34 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn add_user_message(&mut self, content: String) {
41 self.messages
42 .push(Message::User(oxi_ai::UserMessage::new(content)));
43 }
44
45 pub fn add_assistant_message(&mut self, content: String) {
47 let mut assistant =
48 oxi_ai::AssistantMessage::new(oxi_ai::Api::AnthropicMessages, "agent", "agent-model");
49 assistant.content = vec![ContentBlock::Text(TextContent::new(content))];
50 self.messages.push(Message::Assistant(assistant));
51 }
52
53 pub fn add_tool_result(&mut self, tool_call_id: String, content: String) {
55 let content_for_result = content.clone();
56 let tool_result_msg = oxi_ai::ToolResultMessage::new(
57 tool_call_id.clone(),
58 "tool",
59 vec![ContentBlock::Text(TextContent::new(content))],
60 );
61 self.messages
62 .push(oxi_ai::Message::ToolResult(tool_result_msg));
63 self.tool_results
64 .push(ToolResult::success(tool_call_id, content_for_result));
65 }
66
67 pub fn increment_iteration(&mut self) {
69 self.iteration += 1;
70 }
71
72 pub fn set_stop_reason(&mut self, reason: StopReason) {
74 self.stop_reason = Some(reason);
75 }
76
77 pub fn record_usage(&mut self, input: usize, output: usize) {
79 self.input_tokens += input;
80 self.output_tokens += output;
81 self.total_tokens += input + output;
82 }
83
84 pub fn clear(&mut self) {
86 self.messages.clear();
87 self.iteration = 0;
88 self.stop_reason = None;
89 self.tool_results.clear();
90 self.total_tokens = 0;
91 self.input_tokens = 0;
92 self.output_tokens = 0;
93 }
94
95 pub fn replace_messages(&mut self, messages: Vec<Message>) {
97 self.messages = messages;
98 }
99
100 pub fn estimate_tokens(&self) -> usize {
102 let json = serde_json::to_string(&self.messages).unwrap_or_default();
103 json.len() / 4 }
105
106 pub fn is_complete(&self) -> bool {
108 self.stop_reason.is_some()
109 }
110}
111
112#[derive(Default, Clone)]
114pub struct SharedState {
115 state: Arc<RwLock<AgentState>>,
116}
117
118impl SharedState {
119 pub fn new() -> Self {
121 Self::default()
122 }
123
124 pub fn get_state(&self) -> AgentState {
126 self.state.read().clone()
127 }
128
129 pub fn update<F>(&self, f: F)
131 where
132 F: FnOnce(&mut AgentState),
133 {
134 let mut state = self.state.write();
135 f(&mut state);
136 }
137
138 pub fn reset(&self) {
140 let mut state = self.state.write();
141 state.clear();
142 }
143}