agent_conversation_state/
lib.rs1use serde_json::Value;
15use std::fmt;
16
17#[derive(Debug, Clone, PartialEq)]
19pub enum Phase {
20 Idle,
21 GatheringInfo,
22 Thinking,
23 CallingTools,
24 Responding,
25 Finished,
26 Error(String),
27}
28
29impl fmt::Display for Phase {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 Phase::Idle => write!(f, "idle"),
33 Phase::GatheringInfo => write!(f, "gathering_info"),
34 Phase::Thinking => write!(f, "thinking"),
35 Phase::CallingTools => write!(f, "calling_tools"),
36 Phase::Responding => write!(f, "responding"),
37 Phase::Finished => write!(f, "finished"),
38 Phase::Error(msg) => write!(f, "error: {}", msg),
39 }
40 }
41}
42
43impl Phase {
44 pub fn is_terminal(&self) -> bool { matches!(self, Phase::Finished | Phase::Error(_)) }
45}
46
47#[derive(Debug, Clone)]
49pub struct Transition {
50 pub from: Phase,
51 pub to: Phase,
52 pub metadata: Option<Value>,
53}
54
55pub struct ConversationState {
57 phase: Phase,
58 history: Vec<Transition>,
59 context: std::collections::HashMap<String, Value>,
60}
61
62impl ConversationState {
63 pub fn new() -> Self {
64 Self { phase: Phase::Idle, history: Vec::new(), context: std::collections::HashMap::new() }
65 }
66
67 pub fn phase(&self) -> &Phase { &self.phase }
68
69 pub fn transition(&mut self, to: Phase) {
71 let from = self.phase.clone();
72 self.history.push(Transition { from, to: to.clone(), metadata: None });
73 self.phase = to;
74 }
75
76 pub fn transition_with(&mut self, to: Phase, metadata: Value) {
78 let from = self.phase.clone();
79 self.history.push(Transition { from, to: to.clone(), metadata: Some(metadata) });
80 self.phase = to;
81 }
82
83 pub fn set_ctx<V: Into<Value>>(&mut self, key: &str, value: V) {
85 self.context.insert(key.to_string(), value.into());
86 }
87
88 pub fn get_ctx(&self, key: &str) -> Option<&Value> { self.context.get(key) }
89
90 pub fn transition_count(&self) -> usize { self.history.len() }
91 pub fn history(&self) -> &[Transition] { &self.history }
92
93 pub fn is_terminal(&self) -> bool { self.phase.is_terminal() }
94
95 pub fn previous_phase(&self) -> Option<&Phase> {
97 self.history.last().map(|t| &t.from)
98 }
99
100 pub fn reset(&mut self) {
102 self.phase = Phase::Idle;
103 self.history.clear();
104 self.context.clear();
105 }
106}
107
108impl Default for ConversationState {
109 fn default() -> Self { Self::new() }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use serde_json::json;
116
117 #[test]
118 fn starts_idle() {
119 let s = ConversationState::new();
120 assert_eq!(s.phase(), &Phase::Idle);
121 }
122
123 #[test]
124 fn transition_changes_phase() {
125 let mut s = ConversationState::new();
126 s.transition(Phase::Thinking);
127 assert_eq!(s.phase(), &Phase::Thinking);
128 }
129
130 #[test]
131 fn history_records_transitions() {
132 let mut s = ConversationState::new();
133 s.transition(Phase::GatheringInfo);
134 s.transition(Phase::Thinking);
135 assert_eq!(s.transition_count(), 2);
136 assert_eq!(s.history()[0].from, Phase::Idle);
137 assert_eq!(s.history()[0].to, Phase::GatheringInfo);
138 }
139
140 #[test]
141 fn previous_phase() {
142 let mut s = ConversationState::new();
143 s.transition(Phase::Thinking);
144 assert_eq!(s.previous_phase(), Some(&Phase::Idle));
145 }
146
147 #[test]
148 fn terminal_phases() {
149 assert!(Phase::Finished.is_terminal());
150 assert!(Phase::Error("oops".into()).is_terminal());
151 assert!(!Phase::Thinking.is_terminal());
152 }
153
154 #[test]
155 fn is_terminal_on_state() {
156 let mut s = ConversationState::new();
157 s.transition(Phase::Finished);
158 assert!(s.is_terminal());
159 }
160
161 #[test]
162 fn context_store() {
163 let mut s = ConversationState::new();
164 s.set_ctx("user_id", json!("u123"));
165 assert_eq!(s.get_ctx("user_id").unwrap(), "u123");
166 }
167
168 #[test]
169 fn context_missing_key() {
170 let s = ConversationState::new();
171 assert!(s.get_ctx("nope").is_none());
172 }
173
174 #[test]
175 fn transition_with_metadata() {
176 let mut s = ConversationState::new();
177 s.transition_with(Phase::CallingTools, json!({"tool": "search"}));
178 assert!(s.history()[0].metadata.is_some());
179 }
180
181 #[test]
182 fn reset() {
183 let mut s = ConversationState::new();
184 s.transition(Phase::Thinking);
185 s.set_ctx("key", json!(1));
186 s.reset();
187 assert_eq!(s.phase(), &Phase::Idle);
188 assert_eq!(s.transition_count(), 0);
189 assert!(s.get_ctx("key").is_none());
190 }
191
192 #[test]
193 fn phase_display() {
194 assert_eq!(Phase::Idle.to_string(), "idle");
195 assert_eq!(Phase::GatheringInfo.to_string(), "gathering_info");
196 assert!(Phase::Error("bad".into()).to_string().contains("bad"));
197 }
198
199 #[test]
200 fn multiple_transitions() {
201 let mut s = ConversationState::new();
202 s.transition(Phase::GatheringInfo);
203 s.transition(Phase::Thinking);
204 s.transition(Phase::CallingTools);
205 s.transition(Phase::Responding);
206 s.transition(Phase::Finished);
207 assert_eq!(s.transition_count(), 5);
208 assert!(s.is_terminal());
209 }
210}