claude_agent/session/state/
mod.rs

1//! Session state management.
2
3mod config;
4mod enums;
5mod ids;
6mod message;
7mod policy;
8
9pub use config::SessionConfig;
10pub use enums::{SessionMode, SessionState, SessionType};
11pub use ids::{MessageId, SessionId};
12pub use message::{MessageMetadata, SessionMessage, ThinkingMetadata, ToolResultMeta};
13pub use policy::{PermissionMode, PermissionPolicy, ToolLimits};
14
15use std::collections::HashMap;
16
17use chrono::{DateTime, Utc};
18use serde::{Deserialize, Serialize};
19
20use crate::session::types::{CompactRecord, Plan, TodoItem, TodoStatus};
21use crate::types::{Message, TokenUsage};
22
23#[derive(Clone, Debug, Serialize, Deserialize)]
24pub struct Session {
25    pub id: SessionId,
26    pub parent_id: Option<SessionId>,
27    pub session_type: SessionType,
28    pub tenant_id: Option<String>,
29    pub mode: SessionMode,
30    pub state: SessionState,
31    pub config: SessionConfig,
32    pub permission_policy: PermissionPolicy,
33    pub messages: Vec<SessionMessage>,
34    pub current_leaf_id: Option<MessageId>,
35    pub summary: Option<String>,
36    pub total_usage: TokenUsage,
37    pub total_cost_usd: f64,
38    pub static_context_hash: Option<String>,
39    pub created_at: DateTime<Utc>,
40    pub updated_at: DateTime<Utc>,
41    pub expires_at: Option<DateTime<Utc>>,
42    pub error: Option<String>,
43    #[serde(default)]
44    pub todos: Vec<TodoItem>,
45    #[serde(default)]
46    pub current_plan: Option<Plan>,
47    #[serde(default)]
48    pub compact_history: Vec<CompactRecord>,
49}
50
51impl Session {
52    pub fn new(config: SessionConfig) -> Self {
53        Self::with_id(SessionId::new(), config)
54    }
55
56    pub fn with_id(id: SessionId, config: SessionConfig) -> Self {
57        let now = Utc::now();
58        let expires_at = config
59            .ttl_secs
60            .map(|ttl| now + chrono::Duration::seconds(ttl as i64));
61
62        Self {
63            id,
64            parent_id: None,
65            session_type: SessionType::Main,
66            tenant_id: None,
67            mode: config.mode.clone(),
68            state: SessionState::Created,
69            config: config.clone(),
70            permission_policy: config.permission_policy.clone(),
71            messages: Vec::new(),
72            current_leaf_id: None,
73            summary: None,
74            total_usage: TokenUsage::default(),
75            total_cost_usd: 0.0,
76            static_context_hash: None,
77            created_at: now,
78            updated_at: now,
79            expires_at,
80            error: None,
81            todos: Vec::new(),
82            current_plan: None,
83            compact_history: Vec::new(),
84        }
85    }
86
87    pub fn new_subagent(
88        parent_id: SessionId,
89        agent_type: impl Into<String>,
90        description: impl Into<String>,
91        config: SessionConfig,
92    ) -> Self {
93        let now = Utc::now();
94        let expires_at = config
95            .ttl_secs
96            .map(|ttl| now + chrono::Duration::seconds(ttl as i64));
97
98        Self {
99            id: SessionId::new(),
100            parent_id: Some(parent_id),
101            session_type: SessionType::Subagent {
102                agent_type: agent_type.into(),
103                description: description.into(),
104            },
105            tenant_id: None,
106            mode: config.mode.clone(),
107            state: SessionState::Created,
108            config: config.clone(),
109            permission_policy: config.permission_policy.clone(),
110            messages: Vec::new(),
111            current_leaf_id: None,
112            summary: None,
113            total_usage: TokenUsage::default(),
114            total_cost_usd: 0.0,
115            static_context_hash: None,
116            created_at: now,
117            updated_at: now,
118            expires_at,
119            error: None,
120            todos: Vec::new(),
121            current_plan: None,
122            compact_history: Vec::new(),
123        }
124    }
125
126    pub fn is_subagent(&self) -> bool {
127        matches!(self.session_type, SessionType::Subagent { .. })
128    }
129
130    pub fn is_running(&self) -> bool {
131        matches!(
132            self.state,
133            SessionState::Active | SessionState::WaitingForTools
134        )
135    }
136
137    pub fn is_terminal(&self) -> bool {
138        matches!(
139            self.state,
140            SessionState::Completed | SessionState::Failed | SessionState::Cancelled
141        )
142    }
143
144    pub fn is_expired(&self) -> bool {
145        self.expires_at.is_some_and(|expires| Utc::now() > expires)
146    }
147
148    pub fn add_message(&mut self, mut message: SessionMessage) {
149        if let Some(leaf) = &self.current_leaf_id {
150            message.parent_id = Some(leaf.clone());
151        }
152        self.current_leaf_id = Some(message.id.clone());
153        if let Some(usage) = &message.usage {
154            self.total_usage.add(usage);
155        }
156        self.messages.push(message);
157        self.updated_at = Utc::now();
158    }
159
160    pub fn get_current_branch(&self) -> Vec<&SessionMessage> {
161        let index: HashMap<&MessageId, &SessionMessage> =
162            self.messages.iter().map(|m| (&m.id, m)).collect();
163
164        let mut result = Vec::new();
165        let mut current_id = self.current_leaf_id.as_ref();
166
167        while let Some(id) = current_id {
168            if let Some(&msg) = index.get(id) {
169                result.push(msg);
170                current_id = msg.parent_id.as_ref();
171            } else {
172                break;
173            }
174        }
175
176        result.reverse();
177        result
178    }
179
180    pub fn to_api_messages(&self) -> Vec<Message> {
181        self.get_current_branch()
182            .into_iter()
183            .map(|m| m.to_api_message())
184            .collect()
185    }
186
187    pub fn branch_length(&self) -> usize {
188        self.get_current_branch().len()
189    }
190
191    pub fn set_state(&mut self, state: SessionState) {
192        self.state = state;
193        self.updated_at = Utc::now();
194    }
195
196    pub fn set_todos(&mut self, todos: Vec<TodoItem>) {
197        self.todos = todos;
198        self.updated_at = Utc::now();
199    }
200
201    pub fn todos_in_progress_count(&self) -> usize {
202        self.todos
203            .iter()
204            .filter(|t| t.status == TodoStatus::InProgress)
205            .count()
206    }
207
208    pub fn enter_plan_mode(&mut self, name: Option<String>) -> &Plan {
209        let mut plan = Plan::new(self.id);
210        if let Some(n) = name {
211            plan = plan.with_name(n);
212        }
213        self.current_plan = Some(plan);
214        self.updated_at = Utc::now();
215        self.current_plan.as_ref().expect("plan was just set")
216    }
217
218    pub fn update_plan_content(&mut self, content: String) {
219        if let Some(ref mut plan) = self.current_plan {
220            plan.content = content;
221            self.updated_at = Utc::now();
222        }
223    }
224
225    pub fn exit_plan_mode(&mut self) -> Option<Plan> {
226        if let Some(ref mut plan) = self.current_plan {
227            plan.approve();
228            self.updated_at = Utc::now();
229        }
230        self.current_plan.clone()
231    }
232
233    pub fn cancel_plan(&mut self) -> Option<Plan> {
234        if let Some(ref mut plan) = self.current_plan {
235            plan.cancel();
236            self.updated_at = Utc::now();
237        }
238        self.current_plan.take()
239    }
240
241    pub fn is_in_plan_mode(&self) -> bool {
242        self.current_plan
243            .as_ref()
244            .is_some_and(|p| !p.status.is_terminal())
245    }
246
247    pub fn record_compact(&mut self, record: CompactRecord) {
248        self.compact_history.push(record);
249        self.updated_at = Utc::now();
250    }
251
252    pub fn update_summary(&mut self, summary: impl Into<String>) {
253        self.summary = Some(summary.into());
254        self.updated_at = Utc::now();
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use crate::types::{ContentBlock, Role};
262
263    #[test]
264    fn test_session_creation() {
265        let config = SessionConfig::default();
266        let session = Session::new(config);
267
268        assert_eq!(session.state, SessionState::Created);
269        assert!(session.messages.is_empty());
270        assert!(session.current_leaf_id.is_none());
271    }
272
273    #[test]
274    fn test_add_message() {
275        let mut session = Session::new(SessionConfig::default());
276
277        let msg1 = SessionMessage::user(vec![ContentBlock::text("Hello")]);
278        session.add_message(msg1);
279
280        assert_eq!(session.messages.len(), 1);
281        assert!(session.current_leaf_id.is_some());
282    }
283
284    #[test]
285    fn test_message_tree() {
286        let mut session = Session::new(SessionConfig::default());
287
288        let user_msg = SessionMessage::user(vec![ContentBlock::text("Hello")]);
289        session.add_message(user_msg);
290
291        let assistant_msg = SessionMessage::assistant(vec![ContentBlock::text("Hi there!")]);
292        session.add_message(assistant_msg);
293
294        let branch = session.get_current_branch();
295        assert_eq!(branch.len(), 2);
296        assert_eq!(branch[0].role, Role::User);
297        assert_eq!(branch[1].role, Role::Assistant);
298    }
299
300    #[test]
301    fn test_session_expiry() {
302        let config = SessionConfig {
303            ttl_secs: Some(0),
304            ..Default::default()
305        };
306        let session = Session::new(config);
307
308        std::thread::sleep(std::time::Duration::from_millis(10));
309        assert!(session.is_expired());
310    }
311
312    #[test]
313    fn test_token_usage_accumulation() {
314        let mut session = Session::new(SessionConfig::default());
315
316        let msg1 = SessionMessage::assistant(vec![ContentBlock::text("Response 1")]).with_usage(
317            TokenUsage {
318                input_tokens: 100,
319                output_tokens: 50,
320                ..Default::default()
321            },
322        );
323        session.add_message(msg1);
324
325        let msg2 = SessionMessage::assistant(vec![ContentBlock::text("Response 2")]).with_usage(
326            TokenUsage {
327                input_tokens: 150,
328                output_tokens: 75,
329                ..Default::default()
330            },
331        );
332        session.add_message(msg2);
333
334        assert_eq!(session.total_usage.input_tokens, 250);
335        assert_eq!(session.total_usage.output_tokens, 125);
336    }
337}