claude_agent/session/state/
mod.rs1mod config;
4mod enums;
5mod ids;
6mod message;
7mod policy;
8
9pub use config::SessionConfig;
10pub use enums::{SessionMode, SessionState, SessionType};
11pub use ids::{MessageId, SessionId};
12pub use message::{MessageMetadata, SessionMessage, ThinkingMetadata, ToolResultMeta};
13pub use policy::{PermissionMode, PermissionPolicy, ToolLimits};
14
15use std::collections::HashMap;
16
17use chrono::{DateTime, Utc};
18use serde::{Deserialize, Serialize};
19
20use crate::session::types::{CompactRecord, Plan, TodoItem, TodoStatus};
21use crate::types::{CacheControl, CacheTtl, ContentBlock, Message, Role, TokenUsage, Usage};
22
23const MAX_COMPACT_HISTORY_SIZE: usize = 50;
24
25#[derive(Clone, Debug, Serialize, Deserialize)]
26pub struct Session {
27 pub id: SessionId,
28 pub parent_id: Option<SessionId>,
29 pub session_type: SessionType,
30 pub tenant_id: Option<String>,
31 pub mode: SessionMode,
32 pub state: SessionState,
33 pub config: SessionConfig,
34 pub permission_policy: PermissionPolicy,
35 pub messages: Vec<SessionMessage>,
36 pub current_leaf_id: Option<MessageId>,
37 pub summary: Option<String>,
38 pub total_usage: TokenUsage,
39 #[serde(default)]
40 pub current_input_tokens: u64,
41 pub total_cost_usd: f64,
42 pub static_context_hash: Option<String>,
43 pub created_at: DateTime<Utc>,
44 pub updated_at: DateTime<Utc>,
45 pub expires_at: Option<DateTime<Utc>>,
46 pub error: Option<String>,
47 #[serde(default)]
48 pub todos: Vec<TodoItem>,
49 #[serde(default)]
50 pub current_plan: Option<Plan>,
51 #[serde(default)]
52 pub compact_history: Vec<CompactRecord>,
53}
54
55impl Session {
56 pub fn new(config: SessionConfig) -> Self {
57 Self::with_id(SessionId::new(), config)
58 }
59
60 pub fn with_id(id: SessionId, config: SessionConfig) -> Self {
61 Self::init(id, None, SessionType::Main, config)
62 }
63
64 pub fn new_subagent(
65 parent_id: SessionId,
66 agent_type: impl Into<String>,
67 description: impl Into<String>,
68 config: SessionConfig,
69 ) -> Self {
70 let session_type = SessionType::Subagent {
71 agent_type: agent_type.into(),
72 description: description.into(),
73 };
74 Self::init(SessionId::new(), Some(parent_id), session_type, config)
75 }
76
77 fn init(
78 id: SessionId,
79 parent_id: Option<SessionId>,
80 session_type: SessionType,
81 config: SessionConfig,
82 ) -> Self {
83 let now = Utc::now();
84 let expires_at = config
85 .ttl_secs
86 .map(|ttl| now + chrono::Duration::seconds(ttl as i64));
87
88 Self {
89 id,
90 parent_id,
91 session_type,
92 tenant_id: None,
93 mode: config.mode.clone(),
94 state: SessionState::Created,
95 permission_policy: config.permission_policy.clone(),
96 config,
97 messages: Vec::with_capacity(32),
98 current_leaf_id: None,
99 summary: None,
100 total_usage: TokenUsage::default(),
101 current_input_tokens: 0,
102 total_cost_usd: 0.0,
103 static_context_hash: None,
104 created_at: now,
105 updated_at: now,
106 expires_at,
107 error: None,
108 todos: Vec::with_capacity(8),
109 current_plan: None,
110 compact_history: Vec::new(),
111 }
112 }
113
114 pub fn is_subagent(&self) -> bool {
115 matches!(self.session_type, SessionType::Subagent { .. })
116 }
117
118 pub fn is_running(&self) -> bool {
119 matches!(
120 self.state,
121 SessionState::Active | SessionState::WaitingForTools
122 )
123 }
124
125 pub fn is_terminal(&self) -> bool {
126 matches!(
127 self.state,
128 SessionState::Completed | SessionState::Failed | SessionState::Cancelled
129 )
130 }
131
132 pub fn is_expired(&self) -> bool {
133 self.expires_at.is_some_and(|expires| Utc::now() > expires)
134 }
135
136 pub fn add_message(&mut self, mut message: SessionMessage) {
137 if let Some(leaf) = &self.current_leaf_id {
138 message.parent_id = Some(leaf.clone());
139 }
140 self.current_leaf_id = Some(message.id.clone());
141 if let Some(usage) = &message.usage {
142 self.total_usage.add(usage);
143 }
144 self.messages.push(message);
145 self.updated_at = Utc::now();
146 }
147
148 pub fn get_current_branch(&self) -> Vec<&SessionMessage> {
149 let index: HashMap<&MessageId, &SessionMessage> =
150 self.messages.iter().map(|m| (&m.id, m)).collect();
151
152 let mut result = Vec::new();
153 let mut current_id = self.current_leaf_id.as_ref();
154
155 while let Some(id) = current_id {
156 if let Some(&msg) = index.get(id) {
157 result.push(msg);
158 current_id = msg.parent_id.as_ref();
159 } else {
160 break;
161 }
162 }
163
164 result.reverse();
165 result
166 }
167
168 pub fn to_api_messages(&self) -> Vec<Message> {
170 self.to_api_messages_with_cache(Some(CacheTtl::FiveMinutes))
171 }
172
173 pub fn to_api_messages_with_cache(&self, ttl: Option<CacheTtl>) -> Vec<Message> {
178 let branch = self.get_current_branch();
179 if branch.is_empty() {
180 return Vec::new();
181 }
182
183 let mut messages: Vec<Message> = branch.iter().map(|m| m.to_api_message()).collect();
184
185 if let Some(ttl) = ttl {
186 self.apply_cache_breakpoint(&mut messages, ttl);
187 }
188
189 messages
190 }
191
192 fn apply_cache_breakpoint(&self, messages: &mut [Message], ttl: CacheTtl) {
198 let last_user_idx = messages
199 .iter()
200 .enumerate()
201 .rev()
202 .find(|(_, m)| m.role == Role::User)
203 .map(|(i, _)| i);
204
205 if let Some(idx) = last_user_idx {
206 messages[idx].set_cache_on_last_block(CacheControl::ephemeral().with_ttl(ttl));
207 }
208 }
209
210 pub fn branch_length(&self) -> usize {
211 self.get_current_branch().len()
212 }
213
214 pub fn set_state(&mut self, state: SessionState) {
215 self.state = state;
216 self.updated_at = Utc::now();
217 }
218
219 pub fn set_todos(&mut self, todos: Vec<TodoItem>) {
220 self.todos = todos;
221 self.updated_at = Utc::now();
222 }
223
224 pub fn todos_in_progress_count(&self) -> usize {
225 self.todos
226 .iter()
227 .filter(|t| t.status == TodoStatus::InProgress)
228 .count()
229 }
230
231 pub fn enter_plan_mode(&mut self, name: Option<String>) -> &Plan {
232 let mut plan = Plan::new(self.id);
233 if let Some(n) = name {
234 plan = plan.with_name(n);
235 }
236 self.current_plan = Some(plan);
237 self.updated_at = Utc::now();
238 self.current_plan.as_ref().expect("plan was just set")
239 }
240
241 pub fn update_plan_content(&mut self, content: String) {
242 if let Some(ref mut plan) = self.current_plan {
243 plan.content = content;
244 self.updated_at = Utc::now();
245 }
246 }
247
248 pub fn exit_plan_mode(&mut self) -> Option<Plan> {
249 if let Some(ref mut plan) = self.current_plan {
250 plan.approve();
251 self.updated_at = Utc::now();
252 }
253 self.current_plan.take()
254 }
255
256 pub fn cancel_plan(&mut self) -> Option<Plan> {
257 if let Some(ref mut plan) = self.current_plan {
258 plan.cancel();
259 self.updated_at = Utc::now();
260 }
261 self.current_plan.take()
262 }
263
264 pub fn is_in_plan_mode(&self) -> bool {
265 self.current_plan
266 .as_ref()
267 .is_some_and(|p| !p.status.is_terminal())
268 }
269
270 pub fn record_compact(&mut self, record: CompactRecord) {
271 if self.compact_history.len() >= MAX_COMPACT_HISTORY_SIZE {
272 self.compact_history.remove(0);
273 }
274 self.compact_history.push(record);
275 self.updated_at = Utc::now();
276 }
277
278 pub fn update_summary(&mut self, summary: impl Into<String>) {
279 self.summary = Some(summary.into());
280 self.updated_at = Utc::now();
281 }
282
283 pub fn add_user_message(&mut self, content: impl Into<String>) {
284 let msg = SessionMessage::user(vec![ContentBlock::text(content.into())]);
285 self.add_message(msg);
286 }
287
288 pub fn add_assistant_message(&mut self, content: Vec<ContentBlock>, usage: Option<Usage>) {
289 let mut msg = SessionMessage::assistant(content);
290 if let Some(u) = usage {
291 msg = msg.with_usage(TokenUsage {
292 input_tokens: u.input_tokens as u64,
293 output_tokens: u.output_tokens as u64,
294 cache_read_input_tokens: u.cache_read_input_tokens.unwrap_or(0) as u64,
295 cache_creation_input_tokens: u.cache_creation_input_tokens.unwrap_or(0) as u64,
296 });
297 }
298 self.add_message(msg);
299 }
300
301 pub fn add_tool_results(&mut self, results: Vec<crate::types::ToolResultBlock>) {
302 let content: Vec<ContentBlock> =
303 results.into_iter().map(ContentBlock::ToolResult).collect();
304 let msg = SessionMessage::user(content);
305 self.add_message(msg);
306 }
307
308 pub fn current_tokens(&self) -> u64 {
309 self.current_input_tokens
310 }
311
312 pub fn should_compact(&self, max_tokens: u64, threshold: f32, keep_messages: usize) -> bool {
313 self.messages.len() > keep_messages
314 && self.current_input_tokens as f32 > max_tokens as f32 * threshold
315 }
316
317 pub fn update_usage(&mut self, usage: &Usage) {
318 self.current_input_tokens = usage.input_tokens as u64;
319 self.total_usage.input_tokens += usage.input_tokens as u64;
320 self.total_usage.output_tokens += usage.output_tokens as u64;
321 if let Some(cache_read) = usage.cache_read_input_tokens {
322 self.total_usage.cache_read_input_tokens += cache_read as u64;
323 }
324 if let Some(cache_creation) = usage.cache_creation_input_tokens {
325 self.total_usage.cache_creation_input_tokens += cache_creation as u64;
326 }
327 }
328
329 pub async fn compact(
330 &mut self,
331 client: &crate::Client,
332 keep_messages: usize,
333 ) -> crate::Result<crate::types::CompactResult> {
334 use crate::client::ModelType;
335 use crate::client::messages::CreateMessageRequest;
336 use crate::types::CompactResult;
337
338 if self.messages.len() <= keep_messages {
339 return Ok(CompactResult::NotNeeded);
340 }
341
342 let tokens_before = self.current_input_tokens;
343 let split_point = self.messages.len() - keep_messages;
344 let to_summarize: Vec<_> = self.messages[..split_point].to_vec();
345 let to_keep: Vec<_> = self.messages[split_point..].to_vec();
346
347 let summary_prompt = Self::format_for_summary(&to_summarize);
348 let model = client.adapter().model(ModelType::Small).to_string();
349 let request = CreateMessageRequest::new(&model, vec![Message::user(&summary_prompt)])
350 .with_max_tokens(2000);
351 let response = client.send(request).await?;
352 let summary = response.text();
353
354 let original_count = self.messages.len();
355
356 self.messages.clear();
357 self.current_leaf_id = None;
358
359 let summary_msg = SessionMessage::user(vec![ContentBlock::text(format!(
360 "[Previous conversation summary]\n{}",
361 summary
362 ))])
363 .as_compact_summary();
364 self.add_message(summary_msg);
365
366 for mut msg in to_keep {
367 msg.parent_id = self.current_leaf_id.clone();
368 self.current_leaf_id = Some(msg.id.clone());
369 self.messages.push(msg);
370 }
371
372 self.current_input_tokens = 0;
375 self.summary = Some(summary.clone());
376 self.updated_at = Utc::now();
377
378 let record = CompactRecord::new(self.id)
379 .with_counts(original_count, self.messages.len())
380 .with_summary(summary.clone())
381 .with_saved_tokens(tokens_before as usize);
382 self.record_compact(record);
383
384 Ok(CompactResult::Compacted {
385 original_count,
386 new_count: self.messages.len(),
387 saved_tokens: tokens_before as usize,
388 summary,
389 })
390 }
391
392 fn format_for_summary(messages: &[SessionMessage]) -> String {
393 let estimated_capacity = messages.len() * 500 + 200;
394 let mut formatted = String::with_capacity(estimated_capacity.min(32768));
395 formatted.push_str(
396 "Summarize this conversation concisely. \
397 Preserve key decisions, code changes, file paths, and important context:\n\n",
398 );
399
400 for msg in messages {
401 let role = match msg.role {
402 Role::User => "User",
403 Role::Assistant => "Assistant",
404 };
405 formatted.push_str(role);
406 formatted.push_str(":\n");
407
408 for block in &msg.content {
409 if let Some(text) = block.as_text() {
410 if text.len() > 800 {
411 formatted.push_str(&text[..800]);
412 formatted.push_str("... [truncated]\n");
413 } else {
414 formatted.push_str(text);
415 formatted.push('\n');
416 }
417 }
418 }
419 formatted.push('\n');
420 }
421
422 formatted
423 }
424
425 pub fn clear_messages(&mut self) {
426 self.messages.clear();
427 self.current_leaf_id = None;
428 self.updated_at = Utc::now();
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use crate::types::{ContentBlock, Role};
436
437 #[test]
438 fn test_session_creation() {
439 let config = SessionConfig::default();
440 let session = Session::new(config);
441
442 assert_eq!(session.state, SessionState::Created);
443 assert!(session.messages.is_empty());
444 assert!(session.current_leaf_id.is_none());
445 }
446
447 #[test]
448 fn test_add_message() {
449 let mut session = Session::new(SessionConfig::default());
450
451 let msg1 = SessionMessage::user(vec![ContentBlock::text("Hello")]);
452 session.add_message(msg1);
453
454 assert_eq!(session.messages.len(), 1);
455 assert!(session.current_leaf_id.is_some());
456 }
457
458 #[test]
459 fn test_message_tree() {
460 let mut session = Session::new(SessionConfig::default());
461
462 let user_msg = SessionMessage::user(vec![ContentBlock::text("Hello")]);
463 session.add_message(user_msg);
464
465 let assistant_msg = SessionMessage::assistant(vec![ContentBlock::text("Hi there!")]);
466 session.add_message(assistant_msg);
467
468 let branch = session.get_current_branch();
469 assert_eq!(branch.len(), 2);
470 assert_eq!(branch[0].role, Role::User);
471 assert_eq!(branch[1].role, Role::Assistant);
472 }
473
474 #[test]
475 fn test_session_expiry() {
476 let config = SessionConfig {
477 ttl_secs: Some(0),
478 ..Default::default()
479 };
480 let session = Session::new(config);
481
482 std::thread::sleep(std::time::Duration::from_millis(10));
483 assert!(session.is_expired());
484 }
485
486 #[test]
487 fn test_token_usage_accumulation() {
488 let mut session = Session::new(SessionConfig::default());
489
490 let msg1 = SessionMessage::assistant(vec![ContentBlock::text("Response 1")]).with_usage(
491 TokenUsage {
492 input_tokens: 100,
493 output_tokens: 50,
494 ..Default::default()
495 },
496 );
497 session.add_message(msg1);
498
499 let msg2 = SessionMessage::assistant(vec![ContentBlock::text("Response 2")]).with_usage(
500 TokenUsage {
501 input_tokens: 150,
502 output_tokens: 75,
503 ..Default::default()
504 },
505 );
506 session.add_message(msg2);
507
508 assert_eq!(session.total_usage.input_tokens, 250);
509 assert_eq!(session.total_usage.output_tokens, 125);
510 }
511
512 #[test]
513 fn test_compact_history_limit() {
514 let mut session = Session::new(SessionConfig::default());
515
516 for i in 0..MAX_COMPACT_HISTORY_SIZE + 10 {
517 let record = CompactRecord::new(session.id).with_summary(format!("Summary {}", i));
518 session.record_compact(record);
519 }
520
521 assert_eq!(session.compact_history.len(), MAX_COMPACT_HISTORY_SIZE);
522 assert!(session.compact_history[0].summary.contains("10"));
523 }
524
525 #[test]
526 fn test_exit_plan_mode_takes_ownership() {
527 let mut session = Session::new(SessionConfig::default());
528 session.enter_plan_mode(Some("Test Plan".to_string()));
529
530 let plan = session.exit_plan_mode();
531 assert!(plan.is_some());
532 assert!(session.current_plan.is_none());
533 }
534
535 #[test]
536 fn test_message_caching_applies_to_last_user_turn() {
537 let mut session = Session::new(SessionConfig::default());
538
539 session.add_user_message("First question");
540 session.add_message(SessionMessage::assistant(vec![ContentBlock::text(
541 "First answer",
542 )]));
543 session.add_user_message("Second question");
544
545 let messages = session.to_api_messages();
546
547 assert_eq!(messages.len(), 3);
548 assert!(!messages[0].has_cache_control());
549 assert!(!messages[1].has_cache_control());
550 assert!(messages[2].has_cache_control());
551 }
552
553 #[test]
554 fn test_message_caching_disabled() {
555 let mut session = Session::new(SessionConfig::default());
556
557 session.add_user_message("Question");
558
559 let messages = session.to_api_messages_with_cache(None);
561
562 assert_eq!(messages.len(), 1);
563 assert!(!messages[0].has_cache_control());
564 }
565
566 #[test]
567 fn test_message_caching_empty_session() {
568 let session = Session::new(SessionConfig::default());
569 let messages = session.to_api_messages();
570 assert!(messages.is_empty());
571 }
572
573 #[test]
574 fn test_message_caching_assistant_only() {
575 let mut session = Session::new(SessionConfig::default());
576 session.add_message(SessionMessage::assistant(vec![ContentBlock::text("Hi")]));
577
578 let messages = session.to_api_messages();
579
580 assert_eq!(messages.len(), 1);
581 assert!(!messages[0].has_cache_control());
582 }
583}