Skip to main content

claude_agent/session/state/
mod.rs

1//! Session state management.
2
3mod config;
4mod enums;
5mod ids;
6mod message;
7mod policy;
8
9pub use config::SessionConfig;
10pub use enums::{SessionState, SessionType};
11pub use ids::{MessageId, SessionId};
12pub use message::{MessageMetadata, SessionMessage, ThinkingMetadata, ToolResultMeta};
13pub use policy::{PermissionMode, SessionPermissions, SessionToolLimits};
14
15use std::collections::{HashMap, VecDeque};
16
17use chrono::{DateTime, Utc};
18use rust_decimal::Decimal;
19use serde::{Deserialize, Serialize};
20
21use crate::session::types::{CompactRecord, Plan, TodoItem, TodoStatus};
22use crate::types::{CacheControl, CacheTtl, ContentBlock, Message, Role, TokenUsage, Usage};
23
24const MAX_COMPACT_HISTORY_SIZE: usize = 50;
25
26#[derive(Clone, Debug, Serialize, Deserialize)]
27pub struct Session {
28    pub id: SessionId,
29    pub parent_id: Option<SessionId>,
30    pub session_type: SessionType,
31    pub tenant_id: Option<String>,
32    pub state: SessionState,
33    pub config: SessionConfig,
34    pub permissions: SessionPermissions,
35    pub messages: Vec<SessionMessage>,
36    pub current_leaf_id: Option<MessageId>,
37    pub summary: Option<String>,
38    pub total_usage: TokenUsage,
39    #[serde(default)]
40    pub current_input_tokens: u64,
41    pub total_cost_usd: Decimal,
42    pub static_context_hash: Option<String>,
43    pub created_at: DateTime<Utc>,
44    pub updated_at: DateTime<Utc>,
45    pub expires_at: Option<DateTime<Utc>>,
46    pub error: Option<String>,
47    #[serde(default)]
48    pub todos: Vec<TodoItem>,
49    #[serde(default)]
50    pub current_plan: Option<Plan>,
51    #[serde(default)]
52    pub compact_history: VecDeque<CompactRecord>,
53}
54
55impl Session {
56    pub fn new(config: SessionConfig) -> Self {
57        Self::from_id(SessionId::new(), config)
58    }
59
60    pub fn from_id(id: SessionId, config: SessionConfig) -> Self {
61        Self::init(id, None, SessionType::Main, config)
62    }
63
64    pub fn new_subagent(
65        parent_id: SessionId,
66        agent_type: impl Into<String>,
67        description: impl Into<String>,
68        config: SessionConfig,
69    ) -> Self {
70        let session_type = SessionType::Subagent {
71            agent_type: agent_type.into(),
72            description: description.into(),
73        };
74        Self::init(SessionId::new(), Some(parent_id), session_type, config)
75    }
76
77    fn init(
78        id: SessionId,
79        parent_id: Option<SessionId>,
80        session_type: SessionType,
81        config: SessionConfig,
82    ) -> Self {
83        let now = Utc::now();
84        let expires_at = config
85            .ttl_secs
86            .map(|ttl| now + chrono::Duration::seconds(ttl as i64));
87
88        Self {
89            id,
90            parent_id,
91            session_type,
92            tenant_id: None,
93            state: SessionState::Created,
94            permissions: config.permissions.clone(),
95            config,
96            messages: Vec::with_capacity(32),
97            current_leaf_id: None,
98            summary: None,
99            total_usage: TokenUsage::default(),
100            current_input_tokens: 0,
101            total_cost_usd: Decimal::ZERO,
102            static_context_hash: None,
103            created_at: now,
104            updated_at: now,
105            expires_at,
106            error: None,
107            todos: Vec::with_capacity(8),
108            current_plan: None,
109            compact_history: VecDeque::new(),
110        }
111    }
112
113    pub fn is_subagent(&self) -> bool {
114        matches!(self.session_type, SessionType::Subagent { .. })
115    }
116
117    pub fn is_running(&self) -> bool {
118        matches!(
119            self.state,
120            SessionState::Active | SessionState::WaitingForTools
121        )
122    }
123
124    pub fn is_terminal(&self) -> bool {
125        matches!(
126            self.state,
127            SessionState::Completed | SessionState::Failed | SessionState::Cancelled
128        )
129    }
130
131    pub fn is_expired(&self) -> bool {
132        self.expires_at.is_some_and(|expires| Utc::now() > expires)
133    }
134
135    pub fn add_message(&mut self, mut message: SessionMessage) {
136        if let Some(leaf) = &self.current_leaf_id {
137            message.parent_id = Some(leaf.clone());
138        }
139        self.current_leaf_id = Some(message.id.clone());
140        if let Some(usage) = &message.usage {
141            self.total_usage.add(usage);
142        }
143        self.messages.push(message);
144        self.updated_at = Utc::now();
145    }
146
147    pub fn current_branch(&self) -> Vec<&SessionMessage> {
148        let index: HashMap<&MessageId, &SessionMessage> =
149            self.messages.iter().map(|m| (&m.id, m)).collect();
150
151        let mut result = Vec::new();
152        let mut current_id = self.current_leaf_id.as_ref();
153
154        while let Some(id) = current_id {
155            if let Some(&msg) = index.get(id) {
156                result.push(msg);
157                current_id = msg.parent_id.as_ref();
158            } else {
159                break;
160            }
161        }
162
163        result.reverse();
164        result
165    }
166
167    /// Convert session messages to API format with default caching (5m TTL).
168    pub fn to_api_messages(&self) -> Vec<Message> {
169        self.to_api_messages_with_cache(Some(CacheTtl::FiveMinutes))
170    }
171
172    /// Convert session messages to API format with optional caching.
173    ///
174    /// Per Anthropic best practices, caches the last user message with the specified TTL.
175    /// Pass `None` to disable caching.
176    pub fn to_api_messages_with_cache(&self, ttl: Option<CacheTtl>) -> Vec<Message> {
177        let branch = self.current_branch();
178        if branch.is_empty() {
179            return Vec::new();
180        }
181
182        let mut messages: Vec<Message> = branch.iter().map(|m| m.to_api_message()).collect();
183
184        if let Some(ttl) = ttl {
185            self.apply_cache_breakpoint(&mut messages, ttl);
186        }
187
188        messages
189    }
190
191    /// Apply cache breakpoint to the last user message.
192    ///
193    /// Per Anthropic best practices for multi-turn conversations,
194    /// only the last user message needs cache_control to enable
195    /// caching of the entire conversation history before it.
196    fn apply_cache_breakpoint(&self, messages: &mut [Message], ttl: CacheTtl) {
197        let last_user_idx = messages
198            .iter()
199            .enumerate()
200            .rev()
201            .find(|(_, m)| m.role == Role::User)
202            .map(|(i, _)| i);
203
204        if let Some(idx) = last_user_idx {
205            messages[idx].set_cache_on_last_block(CacheControl::ephemeral().ttl(ttl));
206        }
207    }
208
209    pub fn set_state(&mut self, state: SessionState) {
210        self.state = state;
211        self.updated_at = Utc::now();
212    }
213
214    pub fn set_todos(&mut self, todos: Vec<TodoItem>) {
215        self.todos = todos;
216        self.updated_at = Utc::now();
217    }
218
219    pub fn todos_in_progress_count(&self) -> usize {
220        self.todos
221            .iter()
222            .filter(|t| t.status == TodoStatus::InProgress)
223            .count()
224    }
225
226    pub fn enter_plan_mode(&mut self, name: Option<String>) -> &Plan {
227        let mut plan = Plan::new(self.id);
228        if let Some(n) = name {
229            plan = plan.name(n);
230        }
231        self.updated_at = Utc::now();
232        self.current_plan.insert(plan)
233    }
234
235    pub fn update_plan_content(&mut self, content: String) {
236        if let Some(ref mut plan) = self.current_plan {
237            plan.content = content;
238            self.updated_at = Utc::now();
239        }
240    }
241
242    pub fn exit_plan_mode(&mut self) -> Option<Plan> {
243        if let Some(ref mut plan) = self.current_plan {
244            plan.approve();
245            self.updated_at = Utc::now();
246        }
247        self.current_plan.take()
248    }
249
250    pub fn cancel_plan(&mut self) -> Option<Plan> {
251        if let Some(ref mut plan) = self.current_plan {
252            plan.cancel();
253            self.updated_at = Utc::now();
254        }
255        self.current_plan.take()
256    }
257
258    pub fn is_in_plan_mode(&self) -> bool {
259        self.current_plan
260            .as_ref()
261            .is_some_and(|p| !p.status.is_terminal())
262    }
263
264    pub fn record_compact(&mut self, record: CompactRecord) {
265        if self.compact_history.len() >= MAX_COMPACT_HISTORY_SIZE {
266            self.compact_history.pop_front();
267        }
268        self.compact_history.push_back(record);
269        self.updated_at = Utc::now();
270    }
271
272    pub fn update_summary(&mut self, summary: impl Into<String>) {
273        self.summary = Some(summary.into());
274        self.updated_at = Utc::now();
275    }
276
277    pub fn add_user_message(&mut self, content: impl Into<String>) {
278        let msg = SessionMessage::user(vec![ContentBlock::text(content.into())]);
279        self.add_message(msg);
280    }
281
282    pub fn add_assistant_message(&mut self, content: Vec<ContentBlock>, usage: Option<Usage>) {
283        let mut msg = SessionMessage::assistant(content);
284        if let Some(u) = usage {
285            msg = msg.usage(TokenUsage {
286                input_tokens: u.input_tokens as u64,
287                output_tokens: u.output_tokens as u64,
288                cache_read_input_tokens: u.cache_read_input_tokens.unwrap_or(0) as u64,
289                cache_creation_input_tokens: u.cache_creation_input_tokens.unwrap_or(0) as u64,
290            });
291        }
292        self.add_message(msg);
293    }
294
295    pub fn add_tool_results(&mut self, results: Vec<crate::types::ToolResultBlock>) {
296        let content: Vec<ContentBlock> =
297            results.into_iter().map(ContentBlock::ToolResult).collect();
298        let msg = SessionMessage::user(content);
299        self.add_message(msg);
300    }
301
302    pub fn should_compact(&self, max_tokens: u64, threshold: f32, keep_messages: usize) -> bool {
303        self.messages.len() > keep_messages
304            && self.current_input_tokens as f32 > max_tokens as f32 * threshold
305    }
306
307    pub fn update_usage(&mut self, usage: &Usage) {
308        self.current_input_tokens = usage.context_usage() as u64;
309        self.total_usage.add_usage(usage);
310    }
311
312    pub async fn compact(
313        &mut self,
314        client: &crate::Client,
315        keep_messages: usize,
316    ) -> crate::Result<crate::types::CompactResult> {
317        use crate::client::ModelType;
318        use crate::client::messages::CreateMessageRequest;
319        use crate::types::CompactResult;
320
321        if self.messages.len() <= keep_messages {
322            return Ok(CompactResult::NotNeeded);
323        }
324
325        let tokens_before = self.current_input_tokens;
326        let original_count = self.messages.len();
327        let split_point = original_count - keep_messages;
328        let to_summarize: Vec<_> = self.messages[..split_point].to_vec();
329        let to_keep: Vec<_> = self.messages[split_point..].to_vec();
330
331        let summary_prompt = Self::format_for_summary(&to_summarize);
332        let model = client.adapter().model(ModelType::Small).to_string();
333        let request = CreateMessageRequest::new(&model, vec![Message::user(&summary_prompt)])
334            .max_tokens(2000);
335        let response = client.send(request).await?;
336        let summary = response.text();
337
338        // Build new message list before modifying self (swap pattern for data safety)
339        let mut new_messages = Vec::with_capacity(1 + to_keep.len());
340        let summary_msg = SessionMessage::user(vec![ContentBlock::text(format!(
341            "[Previous conversation summary]\n{}",
342            summary
343        ))])
344        .as_compact_summary();
345        new_messages.push(summary_msg);
346
347        let mut new_leaf_id = Some(new_messages[0].id.clone());
348        for mut msg in to_keep {
349            msg.parent_id = new_leaf_id.clone();
350            new_leaf_id = Some(msg.id.clone());
351            new_messages.push(msg);
352        }
353
354        // Atomic swap: replace old state only after new state is fully constructed
355        self.messages = new_messages;
356        self.current_leaf_id = new_leaf_id;
357        // Reset to 0: actual value will be set by next API call's update_usage().
358        // This also prevents immediate re-compaction since should_compact() returns false when 0.
359        self.current_input_tokens = 0;
360        self.summary = Some(summary.clone());
361        self.updated_at = Utc::now();
362
363        let record = CompactRecord::new(self.id)
364            .counts(original_count, self.messages.len())
365            .summary(summary.clone())
366            .saved_tokens(tokens_before as usize);
367        self.record_compact(record);
368
369        Ok(CompactResult::Compacted {
370            original_count,
371            new_count: self.messages.len(),
372            saved_tokens: tokens_before as usize,
373            summary,
374        })
375    }
376
377    fn format_for_summary(messages: &[SessionMessage]) -> String {
378        let estimated_capacity = messages.len() * 500 + 200;
379        let mut formatted = String::with_capacity(estimated_capacity.min(32768));
380        formatted.push_str(
381            "Summarize this conversation concisely. \
382             Preserve key decisions, code changes, file paths, and important context:\n\n",
383        );
384
385        for msg in messages {
386            let role = match msg.role {
387                Role::User => "User",
388                Role::Assistant => "Assistant",
389            };
390            formatted.push_str(role);
391            formatted.push_str(":\n");
392
393            for block in &msg.content {
394                if let Some(text) = block.as_text() {
395                    if text.len() > 800 {
396                        let mut end = 800;
397                        while !text.is_char_boundary(end) {
398                            end -= 1;
399                        }
400                        formatted.push_str(&text[..end]);
401                        formatted.push_str("... [truncated]\n");
402                    } else {
403                        formatted.push_str(text);
404                        formatted.push('\n');
405                    }
406                }
407            }
408            formatted.push('\n');
409        }
410
411        formatted
412    }
413
414    pub fn clear_messages(&mut self) {
415        self.messages.clear();
416        self.current_leaf_id = None;
417        self.updated_at = Utc::now();
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::types::{ContentBlock, Role};
425
426    #[test]
427    fn test_session_creation() {
428        let config = SessionConfig::default();
429        let session = Session::new(config);
430
431        assert_eq!(session.state, SessionState::Created);
432        assert!(session.messages.is_empty());
433        assert!(session.current_leaf_id.is_none());
434    }
435
436    #[test]
437    fn test_add_message() {
438        let mut session = Session::new(SessionConfig::default());
439
440        let msg1 = SessionMessage::user(vec![ContentBlock::text("Hello")]);
441        session.add_message(msg1);
442
443        assert_eq!(session.messages.len(), 1);
444        assert!(session.current_leaf_id.is_some());
445    }
446
447    #[test]
448    fn test_message_tree() {
449        let mut session = Session::new(SessionConfig::default());
450
451        let user_msg = SessionMessage::user(vec![ContentBlock::text("Hello")]);
452        session.add_message(user_msg);
453
454        let assistant_msg = SessionMessage::assistant(vec![ContentBlock::text("Hi there!")]);
455        session.add_message(assistant_msg);
456
457        let branch = session.current_branch();
458        assert_eq!(branch.len(), 2);
459        assert_eq!(branch[0].role, Role::User);
460        assert_eq!(branch[1].role, Role::Assistant);
461    }
462
463    #[test]
464    fn test_session_expiry() {
465        let config = SessionConfig {
466            ttl_secs: Some(0),
467            ..Default::default()
468        };
469        let session = Session::new(config);
470
471        std::thread::sleep(std::time::Duration::from_millis(10));
472        assert!(session.is_expired());
473    }
474
475    #[test]
476    fn test_token_usage_accumulation() {
477        let mut session = Session::new(SessionConfig::default());
478
479        let msg1 =
480            SessionMessage::assistant(vec![ContentBlock::text("Response 1")]).usage(TokenUsage {
481                input_tokens: 100,
482                output_tokens: 50,
483                ..Default::default()
484            });
485        session.add_message(msg1);
486
487        let msg2 =
488            SessionMessage::assistant(vec![ContentBlock::text("Response 2")]).usage(TokenUsage {
489                input_tokens: 150,
490                output_tokens: 75,
491                ..Default::default()
492            });
493        session.add_message(msg2);
494
495        assert_eq!(session.total_usage.input_tokens, 250);
496        assert_eq!(session.total_usage.output_tokens, 125);
497    }
498
499    #[test]
500    fn test_compact_history_limit() {
501        let mut session = Session::new(SessionConfig::default());
502
503        for i in 0..MAX_COMPACT_HISTORY_SIZE + 10 {
504            let record = CompactRecord::new(session.id).summary(format!("Summary {}", i));
505            session.record_compact(record);
506        }
507
508        assert_eq!(session.compact_history.len(), MAX_COMPACT_HISTORY_SIZE);
509        assert!(session.compact_history[0].summary.contains("10"));
510    }
511
512    #[test]
513    fn test_exit_plan_mode_takes_ownership() {
514        let mut session = Session::new(SessionConfig::default());
515        session.enter_plan_mode(Some("Test Plan".to_string()));
516
517        let plan = session.exit_plan_mode();
518        assert!(plan.is_some());
519        assert!(session.current_plan.is_none());
520    }
521
522    #[test]
523    fn test_message_caching_applies_to_last_user_turn() {
524        let mut session = Session::new(SessionConfig::default());
525
526        session.add_user_message("First question");
527        session.add_message(SessionMessage::assistant(vec![ContentBlock::text(
528            "First answer",
529        )]));
530        session.add_user_message("Second question");
531
532        let messages = session.to_api_messages();
533
534        assert_eq!(messages.len(), 3);
535        assert!(!messages[0].has_cache_control());
536        assert!(!messages[1].has_cache_control());
537        assert!(messages[2].has_cache_control());
538    }
539
540    #[test]
541    fn test_message_caching_disabled() {
542        let mut session = Session::new(SessionConfig::default());
543
544        session.add_user_message("Question");
545
546        // Pass None to disable caching
547        let messages = session.to_api_messages_with_cache(None);
548
549        assert_eq!(messages.len(), 1);
550        assert!(!messages[0].has_cache_control());
551    }
552
553    #[test]
554    fn test_message_caching_empty_session() {
555        let session = Session::new(SessionConfig::default());
556        let messages = session.to_api_messages();
557        assert!(messages.is_empty());
558    }
559
560    #[test]
561    fn test_message_caching_assistant_only() {
562        let mut session = Session::new(SessionConfig::default());
563        session.add_message(SessionMessage::assistant(vec![ContentBlock::text("Hi")]));
564
565        let messages = session.to_api_messages();
566
567        assert_eq!(messages.len(), 1);
568        assert!(!messages[0].has_cache_control());
569    }
570}