claude_agent/session/state/
mod.rs

1//! Session state management.
2
3mod config;
4mod enums;
5mod ids;
6mod message;
7mod policy;
8
9pub use config::SessionConfig;
10pub use enums::{SessionMode, SessionState, SessionType};
11pub use ids::{MessageId, SessionId};
12pub use message::{MessageMetadata, SessionMessage, ThinkingMetadata, ToolResultMeta};
13pub use policy::{PermissionMode, PermissionPolicy, ToolLimits};
14
15use std::collections::HashMap;
16
17use chrono::{DateTime, Utc};
18use serde::{Deserialize, Serialize};
19
20use crate::session::types::{CompactRecord, Plan, TodoItem, TodoStatus};
21use crate::types::{CacheControl, CacheTtl, ContentBlock, Message, Role, TokenUsage, Usage};
22
23const MAX_COMPACT_HISTORY_SIZE: usize = 50;
24
25#[derive(Clone, Debug, Serialize, Deserialize)]
26pub struct Session {
27    pub id: SessionId,
28    pub parent_id: Option<SessionId>,
29    pub session_type: SessionType,
30    pub tenant_id: Option<String>,
31    pub mode: SessionMode,
32    pub state: SessionState,
33    pub config: SessionConfig,
34    pub permission_policy: PermissionPolicy,
35    pub messages: Vec<SessionMessage>,
36    pub current_leaf_id: Option<MessageId>,
37    pub summary: Option<String>,
38    pub total_usage: TokenUsage,
39    #[serde(default)]
40    pub current_input_tokens: u64,
41    pub total_cost_usd: f64,
42    pub static_context_hash: Option<String>,
43    pub created_at: DateTime<Utc>,
44    pub updated_at: DateTime<Utc>,
45    pub expires_at: Option<DateTime<Utc>>,
46    pub error: Option<String>,
47    #[serde(default)]
48    pub todos: Vec<TodoItem>,
49    #[serde(default)]
50    pub current_plan: Option<Plan>,
51    #[serde(default)]
52    pub compact_history: Vec<CompactRecord>,
53}
54
55impl Session {
56    pub fn new(config: SessionConfig) -> Self {
57        Self::with_id(SessionId::new(), config)
58    }
59
60    pub fn with_id(id: SessionId, config: SessionConfig) -> Self {
61        Self::init(id, None, SessionType::Main, config)
62    }
63
64    pub fn new_subagent(
65        parent_id: SessionId,
66        agent_type: impl Into<String>,
67        description: impl Into<String>,
68        config: SessionConfig,
69    ) -> Self {
70        let session_type = SessionType::Subagent {
71            agent_type: agent_type.into(),
72            description: description.into(),
73        };
74        Self::init(SessionId::new(), Some(parent_id), session_type, config)
75    }
76
77    fn init(
78        id: SessionId,
79        parent_id: Option<SessionId>,
80        session_type: SessionType,
81        config: SessionConfig,
82    ) -> Self {
83        let now = Utc::now();
84        let expires_at = config
85            .ttl_secs
86            .map(|ttl| now + chrono::Duration::seconds(ttl as i64));
87
88        Self {
89            id,
90            parent_id,
91            session_type,
92            tenant_id: None,
93            mode: config.mode.clone(),
94            state: SessionState::Created,
95            permission_policy: config.permission_policy.clone(),
96            config,
97            messages: Vec::with_capacity(32),
98            current_leaf_id: None,
99            summary: None,
100            total_usage: TokenUsage::default(),
101            current_input_tokens: 0,
102            total_cost_usd: 0.0,
103            static_context_hash: None,
104            created_at: now,
105            updated_at: now,
106            expires_at,
107            error: None,
108            todos: Vec::with_capacity(8),
109            current_plan: None,
110            compact_history: Vec::new(),
111        }
112    }
113
114    pub fn is_subagent(&self) -> bool {
115        matches!(self.session_type, SessionType::Subagent { .. })
116    }
117
118    pub fn is_running(&self) -> bool {
119        matches!(
120            self.state,
121            SessionState::Active | SessionState::WaitingForTools
122        )
123    }
124
125    pub fn is_terminal(&self) -> bool {
126        matches!(
127            self.state,
128            SessionState::Completed | SessionState::Failed | SessionState::Cancelled
129        )
130    }
131
132    pub fn is_expired(&self) -> bool {
133        self.expires_at.is_some_and(|expires| Utc::now() > expires)
134    }
135
136    pub fn add_message(&mut self, mut message: SessionMessage) {
137        if let Some(leaf) = &self.current_leaf_id {
138            message.parent_id = Some(leaf.clone());
139        }
140        self.current_leaf_id = Some(message.id.clone());
141        if let Some(usage) = &message.usage {
142            self.total_usage.add(usage);
143        }
144        self.messages.push(message);
145        self.updated_at = Utc::now();
146    }
147
148    pub fn get_current_branch(&self) -> Vec<&SessionMessage> {
149        let index: HashMap<&MessageId, &SessionMessage> =
150            self.messages.iter().map(|m| (&m.id, m)).collect();
151
152        let mut result = Vec::new();
153        let mut current_id = self.current_leaf_id.as_ref();
154
155        while let Some(id) = current_id {
156            if let Some(&msg) = index.get(id) {
157                result.push(msg);
158                current_id = msg.parent_id.as_ref();
159            } else {
160                break;
161            }
162        }
163
164        result.reverse();
165        result
166    }
167
168    /// Convert session messages to API format with default caching (5m TTL).
169    pub fn to_api_messages(&self) -> Vec<Message> {
170        self.to_api_messages_with_cache(Some(CacheTtl::FiveMinutes))
171    }
172
173    /// Convert session messages to API format with optional caching.
174    ///
175    /// Per Anthropic best practices, caches the last user message with the specified TTL.
176    /// Pass `None` to disable caching.
177    pub fn to_api_messages_with_cache(&self, ttl: Option<CacheTtl>) -> Vec<Message> {
178        let branch = self.get_current_branch();
179        if branch.is_empty() {
180            return Vec::new();
181        }
182
183        let mut messages: Vec<Message> = branch.iter().map(|m| m.to_api_message()).collect();
184
185        if let Some(ttl) = ttl {
186            self.apply_cache_breakpoint(&mut messages, ttl);
187        }
188
189        messages
190    }
191
192    /// Apply cache breakpoint to the last user message.
193    ///
194    /// Per Anthropic best practices for multi-turn conversations,
195    /// only the last user message needs cache_control to enable
196    /// caching of the entire conversation history before it.
197    fn apply_cache_breakpoint(&self, messages: &mut [Message], ttl: CacheTtl) {
198        let last_user_idx = messages
199            .iter()
200            .enumerate()
201            .rev()
202            .find(|(_, m)| m.role == Role::User)
203            .map(|(i, _)| i);
204
205        if let Some(idx) = last_user_idx {
206            messages[idx].set_cache_on_last_block(CacheControl::ephemeral().with_ttl(ttl));
207        }
208    }
209
210    pub fn branch_length(&self) -> usize {
211        self.get_current_branch().len()
212    }
213
214    pub fn set_state(&mut self, state: SessionState) {
215        self.state = state;
216        self.updated_at = Utc::now();
217    }
218
219    pub fn set_todos(&mut self, todos: Vec<TodoItem>) {
220        self.todos = todos;
221        self.updated_at = Utc::now();
222    }
223
224    pub fn todos_in_progress_count(&self) -> usize {
225        self.todos
226            .iter()
227            .filter(|t| t.status == TodoStatus::InProgress)
228            .count()
229    }
230
231    pub fn enter_plan_mode(&mut self, name: Option<String>) -> &Plan {
232        let mut plan = Plan::new(self.id);
233        if let Some(n) = name {
234            plan = plan.with_name(n);
235        }
236        self.current_plan = Some(plan);
237        self.updated_at = Utc::now();
238        self.current_plan.as_ref().expect("plan was just set")
239    }
240
241    pub fn update_plan_content(&mut self, content: String) {
242        if let Some(ref mut plan) = self.current_plan {
243            plan.content = content;
244            self.updated_at = Utc::now();
245        }
246    }
247
248    pub fn exit_plan_mode(&mut self) -> Option<Plan> {
249        if let Some(ref mut plan) = self.current_plan {
250            plan.approve();
251            self.updated_at = Utc::now();
252        }
253        self.current_plan.take()
254    }
255
256    pub fn cancel_plan(&mut self) -> Option<Plan> {
257        if let Some(ref mut plan) = self.current_plan {
258            plan.cancel();
259            self.updated_at = Utc::now();
260        }
261        self.current_plan.take()
262    }
263
264    pub fn is_in_plan_mode(&self) -> bool {
265        self.current_plan
266            .as_ref()
267            .is_some_and(|p| !p.status.is_terminal())
268    }
269
270    pub fn record_compact(&mut self, record: CompactRecord) {
271        if self.compact_history.len() >= MAX_COMPACT_HISTORY_SIZE {
272            self.compact_history.remove(0);
273        }
274        self.compact_history.push(record);
275        self.updated_at = Utc::now();
276    }
277
278    pub fn update_summary(&mut self, summary: impl Into<String>) {
279        self.summary = Some(summary.into());
280        self.updated_at = Utc::now();
281    }
282
283    pub fn add_user_message(&mut self, content: impl Into<String>) {
284        let msg = SessionMessage::user(vec![ContentBlock::text(content.into())]);
285        self.add_message(msg);
286    }
287
288    pub fn add_assistant_message(&mut self, content: Vec<ContentBlock>, usage: Option<Usage>) {
289        let mut msg = SessionMessage::assistant(content);
290        if let Some(u) = usage {
291            msg = msg.with_usage(TokenUsage {
292                input_tokens: u.input_tokens as u64,
293                output_tokens: u.output_tokens as u64,
294                cache_read_input_tokens: u.cache_read_input_tokens.unwrap_or(0) as u64,
295                cache_creation_input_tokens: u.cache_creation_input_tokens.unwrap_or(0) as u64,
296            });
297        }
298        self.add_message(msg);
299    }
300
301    pub fn add_tool_results(&mut self, results: Vec<crate::types::ToolResultBlock>) {
302        let content: Vec<ContentBlock> =
303            results.into_iter().map(ContentBlock::ToolResult).collect();
304        let msg = SessionMessage::user(content);
305        self.add_message(msg);
306    }
307
308    pub fn current_tokens(&self) -> u64 {
309        self.current_input_tokens
310    }
311
312    pub fn should_compact(&self, max_tokens: u64, threshold: f32, keep_messages: usize) -> bool {
313        self.messages.len() > keep_messages
314            && self.current_input_tokens as f32 > max_tokens as f32 * threshold
315    }
316
317    pub fn update_usage(&mut self, usage: &Usage) {
318        self.current_input_tokens = usage.input_tokens as u64;
319        self.total_usage.input_tokens += usage.input_tokens as u64;
320        self.total_usage.output_tokens += usage.output_tokens as u64;
321        if let Some(cache_read) = usage.cache_read_input_tokens {
322            self.total_usage.cache_read_input_tokens += cache_read as u64;
323        }
324        if let Some(cache_creation) = usage.cache_creation_input_tokens {
325            self.total_usage.cache_creation_input_tokens += cache_creation as u64;
326        }
327    }
328
329    pub async fn compact(
330        &mut self,
331        client: &crate::Client,
332        keep_messages: usize,
333    ) -> crate::Result<crate::types::CompactResult> {
334        use crate::client::ModelType;
335        use crate::client::messages::CreateMessageRequest;
336        use crate::types::CompactResult;
337
338        if self.messages.len() <= keep_messages {
339            return Ok(CompactResult::NotNeeded);
340        }
341
342        let tokens_before = self.current_input_tokens;
343        let split_point = self.messages.len() - keep_messages;
344        let to_summarize: Vec<_> = self.messages[..split_point].to_vec();
345        let to_keep: Vec<_> = self.messages[split_point..].to_vec();
346
347        let summary_prompt = Self::format_for_summary(&to_summarize);
348        let model = client.adapter().model(ModelType::Small).to_string();
349        let request = CreateMessageRequest::new(&model, vec![Message::user(&summary_prompt)])
350            .with_max_tokens(2000);
351        let response = client.send(request).await?;
352        let summary = response.text();
353
354        let original_count = self.messages.len();
355
356        self.messages.clear();
357        self.current_leaf_id = None;
358
359        let summary_msg = SessionMessage::user(vec![ContentBlock::text(format!(
360            "[Previous conversation summary]\n{}",
361            summary
362        ))])
363        .as_compact_summary();
364        self.add_message(summary_msg);
365
366        for mut msg in to_keep {
367            msg.parent_id = self.current_leaf_id.clone();
368            self.current_leaf_id = Some(msg.id.clone());
369            self.messages.push(msg);
370        }
371
372        // Reset to 0: actual value will be set by next API call's update_usage().
373        // This also prevents immediate re-compaction since should_compact() returns false when 0.
374        self.current_input_tokens = 0;
375        self.summary = Some(summary.clone());
376        self.updated_at = Utc::now();
377
378        let record = CompactRecord::new(self.id)
379            .with_counts(original_count, self.messages.len())
380            .with_summary(summary.clone())
381            .with_saved_tokens(tokens_before as usize);
382        self.record_compact(record);
383
384        Ok(CompactResult::Compacted {
385            original_count,
386            new_count: self.messages.len(),
387            saved_tokens: tokens_before as usize,
388            summary,
389        })
390    }
391
392    fn format_for_summary(messages: &[SessionMessage]) -> String {
393        let estimated_capacity = messages.len() * 500 + 200;
394        let mut formatted = String::with_capacity(estimated_capacity.min(32768));
395        formatted.push_str(
396            "Summarize this conversation concisely. \
397             Preserve key decisions, code changes, file paths, and important context:\n\n",
398        );
399
400        for msg in messages {
401            let role = match msg.role {
402                Role::User => "User",
403                Role::Assistant => "Assistant",
404            };
405            formatted.push_str(role);
406            formatted.push_str(":\n");
407
408            for block in &msg.content {
409                if let Some(text) = block.as_text() {
410                    if text.len() > 800 {
411                        formatted.push_str(&text[..800]);
412                        formatted.push_str("... [truncated]\n");
413                    } else {
414                        formatted.push_str(text);
415                        formatted.push('\n');
416                    }
417                }
418            }
419            formatted.push('\n');
420        }
421
422        formatted
423    }
424
425    pub fn clear_messages(&mut self) {
426        self.messages.clear();
427        self.current_leaf_id = None;
428        self.updated_at = Utc::now();
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use crate::types::{ContentBlock, Role};
436
437    #[test]
438    fn test_session_creation() {
439        let config = SessionConfig::default();
440        let session = Session::new(config);
441
442        assert_eq!(session.state, SessionState::Created);
443        assert!(session.messages.is_empty());
444        assert!(session.current_leaf_id.is_none());
445    }
446
447    #[test]
448    fn test_add_message() {
449        let mut session = Session::new(SessionConfig::default());
450
451        let msg1 = SessionMessage::user(vec![ContentBlock::text("Hello")]);
452        session.add_message(msg1);
453
454        assert_eq!(session.messages.len(), 1);
455        assert!(session.current_leaf_id.is_some());
456    }
457
458    #[test]
459    fn test_message_tree() {
460        let mut session = Session::new(SessionConfig::default());
461
462        let user_msg = SessionMessage::user(vec![ContentBlock::text("Hello")]);
463        session.add_message(user_msg);
464
465        let assistant_msg = SessionMessage::assistant(vec![ContentBlock::text("Hi there!")]);
466        session.add_message(assistant_msg);
467
468        let branch = session.get_current_branch();
469        assert_eq!(branch.len(), 2);
470        assert_eq!(branch[0].role, Role::User);
471        assert_eq!(branch[1].role, Role::Assistant);
472    }
473
474    #[test]
475    fn test_session_expiry() {
476        let config = SessionConfig {
477            ttl_secs: Some(0),
478            ..Default::default()
479        };
480        let session = Session::new(config);
481
482        std::thread::sleep(std::time::Duration::from_millis(10));
483        assert!(session.is_expired());
484    }
485
486    #[test]
487    fn test_token_usage_accumulation() {
488        let mut session = Session::new(SessionConfig::default());
489
490        let msg1 = SessionMessage::assistant(vec![ContentBlock::text("Response 1")]).with_usage(
491            TokenUsage {
492                input_tokens: 100,
493                output_tokens: 50,
494                ..Default::default()
495            },
496        );
497        session.add_message(msg1);
498
499        let msg2 = SessionMessage::assistant(vec![ContentBlock::text("Response 2")]).with_usage(
500            TokenUsage {
501                input_tokens: 150,
502                output_tokens: 75,
503                ..Default::default()
504            },
505        );
506        session.add_message(msg2);
507
508        assert_eq!(session.total_usage.input_tokens, 250);
509        assert_eq!(session.total_usage.output_tokens, 125);
510    }
511
512    #[test]
513    fn test_compact_history_limit() {
514        let mut session = Session::new(SessionConfig::default());
515
516        for i in 0..MAX_COMPACT_HISTORY_SIZE + 10 {
517            let record = CompactRecord::new(session.id).with_summary(format!("Summary {}", i));
518            session.record_compact(record);
519        }
520
521        assert_eq!(session.compact_history.len(), MAX_COMPACT_HISTORY_SIZE);
522        assert!(session.compact_history[0].summary.contains("10"));
523    }
524
525    #[test]
526    fn test_exit_plan_mode_takes_ownership() {
527        let mut session = Session::new(SessionConfig::default());
528        session.enter_plan_mode(Some("Test Plan".to_string()));
529
530        let plan = session.exit_plan_mode();
531        assert!(plan.is_some());
532        assert!(session.current_plan.is_none());
533    }
534
535    #[test]
536    fn test_message_caching_applies_to_last_user_turn() {
537        let mut session = Session::new(SessionConfig::default());
538
539        session.add_user_message("First question");
540        session.add_message(SessionMessage::assistant(vec![ContentBlock::text(
541            "First answer",
542        )]));
543        session.add_user_message("Second question");
544
545        let messages = session.to_api_messages();
546
547        assert_eq!(messages.len(), 3);
548        assert!(!messages[0].has_cache_control());
549        assert!(!messages[1].has_cache_control());
550        assert!(messages[2].has_cache_control());
551    }
552
553    #[test]
554    fn test_message_caching_disabled() {
555        let mut session = Session::new(SessionConfig::default());
556
557        session.add_user_message("Question");
558
559        // Pass None to disable caching
560        let messages = session.to_api_messages_with_cache(None);
561
562        assert_eq!(messages.len(), 1);
563        assert!(!messages[0].has_cache_control());
564    }
565
566    #[test]
567    fn test_message_caching_empty_session() {
568        let session = Session::new(SessionConfig::default());
569        let messages = session.to_api_messages();
570        assert!(messages.is_empty());
571    }
572
573    #[test]
574    fn test_message_caching_assistant_only() {
575        let mut session = Session::new(SessionConfig::default());
576        session.add_message(SessionMessage::assistant(vec![ContentBlock::text("Hi")]));
577
578        let messages = session.to_api_messages();
579
580        assert_eq!(messages.len(), 1);
581        assert!(!messages[0].has_cache_control());
582    }
583}