Skip to main content

agent_conversation_state/
lib.rs

1/*!
2agent-conversation-state: track state and phase transitions for LLM agents.
3
4```rust
5use agent_conversation_state::{ConversationState, Phase};
6
7let mut state = ConversationState::new();
8assert_eq!(state.phase(), Phase::Idle);
9state.transition(Phase::GatheringInfo);
10assert_eq!(state.phase(), Phase::GatheringInfo);
11```
12*/
13
14use serde_json::Value;
15use std::fmt;
16
17/// Phase in the agent conversation lifecycle.
18#[derive(Debug, Clone, PartialEq)]
19pub enum Phase {
20    Idle,
21    GatheringInfo,
22    Thinking,
23    CallingTools,
24    Responding,
25    Finished,
26    Error(String),
27}
28
29impl fmt::Display for Phase {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            Phase::Idle => write!(f, "idle"),
33            Phase::GatheringInfo => write!(f, "gathering_info"),
34            Phase::Thinking => write!(f, "thinking"),
35            Phase::CallingTools => write!(f, "calling_tools"),
36            Phase::Responding => write!(f, "responding"),
37            Phase::Finished => write!(f, "finished"),
38            Phase::Error(msg) => write!(f, "error: {}", msg),
39        }
40    }
41}
42
43impl Phase {
44    pub fn is_terminal(&self) -> bool { matches!(self, Phase::Finished | Phase::Error(_)) }
45}
46
47/// A state transition record.
48#[derive(Debug, Clone)]
49pub struct Transition {
50    pub from: Phase,
51    pub to: Phase,
52    pub metadata: Option<Value>,
53}
54
55/// Tracks conversation phase and context data.
56pub struct ConversationState {
57    phase: Phase,
58    history: Vec<Transition>,
59    context: std::collections::HashMap<String, Value>,
60}
61
62impl ConversationState {
63    pub fn new() -> Self {
64        Self { phase: Phase::Idle, history: Vec::new(), context: std::collections::HashMap::new() }
65    }
66
67    pub fn phase(&self) -> &Phase { &self.phase }
68
69    /// Transition to a new phase.
70    pub fn transition(&mut self, to: Phase) {
71        let from = self.phase.clone();
72        self.history.push(Transition { from, to: to.clone(), metadata: None });
73        self.phase = to;
74    }
75
76    /// Transition with metadata.
77    pub fn transition_with(&mut self, to: Phase, metadata: Value) {
78        let from = self.phase.clone();
79        self.history.push(Transition { from, to: to.clone(), metadata: Some(metadata) });
80        self.phase = to;
81    }
82
83    /// Set a context value.
84    pub fn set_ctx<V: Into<Value>>(&mut self, key: &str, value: V) {
85        self.context.insert(key.to_string(), value.into());
86    }
87
88    pub fn get_ctx(&self, key: &str) -> Option<&Value> { self.context.get(key) }
89
90    pub fn transition_count(&self) -> usize { self.history.len() }
91    pub fn history(&self) -> &[Transition] { &self.history }
92
93    pub fn is_terminal(&self) -> bool { self.phase.is_terminal() }
94
95    /// Previous phase (before last transition).
96    pub fn previous_phase(&self) -> Option<&Phase> {
97        self.history.last().map(|t| &t.from)
98    }
99
100    /// Reset to Idle.
101    pub fn reset(&mut self) {
102        self.phase = Phase::Idle;
103        self.history.clear();
104        self.context.clear();
105    }
106}
107
108impl Default for ConversationState {
109    fn default() -> Self { Self::new() }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use serde_json::json;
116
117    #[test]
118    fn starts_idle() {
119        let s = ConversationState::new();
120        assert_eq!(s.phase(), &Phase::Idle);
121    }
122
123    #[test]
124    fn transition_changes_phase() {
125        let mut s = ConversationState::new();
126        s.transition(Phase::Thinking);
127        assert_eq!(s.phase(), &Phase::Thinking);
128    }
129
130    #[test]
131    fn history_records_transitions() {
132        let mut s = ConversationState::new();
133        s.transition(Phase::GatheringInfo);
134        s.transition(Phase::Thinking);
135        assert_eq!(s.transition_count(), 2);
136        assert_eq!(s.history()[0].from, Phase::Idle);
137        assert_eq!(s.history()[0].to, Phase::GatheringInfo);
138    }
139
140    #[test]
141    fn previous_phase() {
142        let mut s = ConversationState::new();
143        s.transition(Phase::Thinking);
144        assert_eq!(s.previous_phase(), Some(&Phase::Idle));
145    }
146
147    #[test]
148    fn terminal_phases() {
149        assert!(Phase::Finished.is_terminal());
150        assert!(Phase::Error("oops".into()).is_terminal());
151        assert!(!Phase::Thinking.is_terminal());
152    }
153
154    #[test]
155    fn is_terminal_on_state() {
156        let mut s = ConversationState::new();
157        s.transition(Phase::Finished);
158        assert!(s.is_terminal());
159    }
160
161    #[test]
162    fn context_store() {
163        let mut s = ConversationState::new();
164        s.set_ctx("user_id", json!("u123"));
165        assert_eq!(s.get_ctx("user_id").unwrap(), "u123");
166    }
167
168    #[test]
169    fn context_missing_key() {
170        let s = ConversationState::new();
171        assert!(s.get_ctx("nope").is_none());
172    }
173
174    #[test]
175    fn transition_with_metadata() {
176        let mut s = ConversationState::new();
177        s.transition_with(Phase::CallingTools, json!({"tool": "search"}));
178        assert!(s.history()[0].metadata.is_some());
179    }
180
181    #[test]
182    fn reset() {
183        let mut s = ConversationState::new();
184        s.transition(Phase::Thinking);
185        s.set_ctx("key", json!(1));
186        s.reset();
187        assert_eq!(s.phase(), &Phase::Idle);
188        assert_eq!(s.transition_count(), 0);
189        assert!(s.get_ctx("key").is_none());
190    }
191
192    #[test]
193    fn phase_display() {
194        assert_eq!(Phase::Idle.to_string(), "idle");
195        assert_eq!(Phase::GatheringInfo.to_string(), "gathering_info");
196        assert!(Phase::Error("bad".into()).to_string().contains("bad"));
197    }
198
199    #[test]
200    fn multiple_transitions() {
201        let mut s = ConversationState::new();
202        s.transition(Phase::GatheringInfo);
203        s.transition(Phase::Thinking);
204        s.transition(Phase::CallingTools);
205        s.transition(Phase::Responding);
206        s.transition(Phase::Finished);
207        assert_eq!(s.transition_count(), 5);
208        assert!(s.is_terminal());
209    }
210}