claude_agent/session/
session_state.rs

1//! Tool state for thread-safe state access.
2
3use std::collections::VecDeque;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use tokio::sync::{RwLock, Semaphore};
8use uuid::Uuid;
9
10use super::queue::{MergedInput, QueueError, QueuedInput, SharedInputQueue};
11use super::state::{Session, SessionConfig, SessionId};
12use super::types::{CompactRecord, Plan, PlanStatus, TodoItem, ToolExecution};
13
14const MAX_EXECUTION_LOG_SIZE: usize = 1000;
15
16#[derive(Debug)]
17struct ToolExecutionLog {
18    entries: RwLock<VecDeque<ToolExecution>>,
19}
20
21impl Default for ToolExecutionLog {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl ToolExecutionLog {
28    fn new() -> Self {
29        Self {
30            entries: RwLock::new(VecDeque::with_capacity(64)),
31        }
32    }
33
34    async fn append(&self, exec: ToolExecution) {
35        let mut entries = self.entries.write().await;
36        if entries.len() >= MAX_EXECUTION_LOG_SIZE {
37            entries.pop_front();
38        }
39        entries.push_back(exec);
40    }
41
42    async fn with_entries<F, R>(&self, f: F) -> R
43    where
44        F: FnOnce(&VecDeque<ToolExecution>) -> R,
45    {
46        let entries = self.entries.read().await;
47        f(&entries)
48    }
49
50    async fn for_plan(&self, plan_id: Uuid) -> Vec<ToolExecution> {
51        self.entries
52            .read()
53            .await
54            .iter()
55            .filter(|e| e.plan_id == Some(plan_id))
56            .cloned()
57            .collect()
58    }
59
60    async fn len(&self) -> usize {
61        self.entries.read().await.len()
62    }
63
64    async fn clear(&self) {
65        self.entries.write().await.clear();
66    }
67}
68
69struct ToolStateInner {
70    id: SessionId,
71    session: RwLock<Session>,
72    executions: ToolExecutionLog,
73    input_queue: SharedInputQueue,
74    execution_lock: Semaphore,
75    executing: AtomicBool,
76}
77
78impl std::fmt::Debug for ToolStateInner {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        f.debug_struct("ToolStateInner")
81            .field("id", &self.id)
82            .field("executions", &self.executions)
83            .field("executing", &self.executing.load(Ordering::Relaxed))
84            .finish_non_exhaustive()
85    }
86}
87
88impl ToolStateInner {
89    fn new(session_id: SessionId) -> Self {
90        Self {
91            id: session_id,
92            session: RwLock::new(Session::with_id(session_id, SessionConfig::default())),
93            executions: ToolExecutionLog::new(),
94            input_queue: SharedInputQueue::new(),
95            execution_lock: Semaphore::new(1),
96            executing: AtomicBool::new(false),
97        }
98    }
99
100    fn from_session(session: Session) -> Self {
101        let id = session.id;
102        Self {
103            id,
104            session: RwLock::new(session),
105            executions: ToolExecutionLog::new(),
106            input_queue: SharedInputQueue::new(),
107            execution_lock: Semaphore::new(1),
108            executing: AtomicBool::new(false),
109        }
110    }
111}
112
113/// Thread-safe tool state handle.
114#[derive(Debug, Clone)]
115pub struct ToolState(Arc<ToolStateInner>);
116
117impl ToolState {
118    pub fn new(session_id: SessionId) -> Self {
119        Self(Arc::new(ToolStateInner::new(session_id)))
120    }
121
122    pub fn from_session(session: Session) -> Self {
123        Self(Arc::new(ToolStateInner::from_session(session)))
124    }
125
126    #[inline]
127    pub fn session_id(&self) -> SessionId {
128        self.0.id
129    }
130
131    pub async fn session(&self) -> Session {
132        self.0.session.read().await.clone()
133    }
134
135    pub async fn update_session(&self, session: Session) {
136        *self.0.session.write().await = session;
137    }
138
139    pub async fn enter_plan_mode(&self, name: Option<String>) -> Plan {
140        self.0.session.write().await.enter_plan_mode(name).clone()
141    }
142
143    pub async fn current_plan(&self) -> Option<Plan> {
144        self.0.session.read().await.current_plan.clone()
145    }
146
147    pub async fn update_plan_content(&self, content: String) {
148        self.0.session.write().await.update_plan_content(content);
149    }
150
151    pub async fn exit_plan_mode(&self) -> Option<Plan> {
152        self.0.session.write().await.exit_plan_mode()
153    }
154
155    pub async fn cancel_plan(&self) -> Option<Plan> {
156        self.0.session.write().await.cancel_plan()
157    }
158
159    #[inline]
160    pub async fn is_in_plan_mode(&self) -> bool {
161        self.0.session.read().await.is_in_plan_mode()
162    }
163
164    pub async fn set_todos(&self, todos: Vec<TodoItem>) {
165        self.0.session.write().await.set_todos(todos);
166    }
167
168    pub async fn todos(&self) -> Vec<TodoItem> {
169        self.0.session.read().await.todos.clone()
170    }
171
172    #[inline]
173    pub async fn todos_in_progress_count(&self) -> usize {
174        self.0.session.read().await.todos_in_progress_count()
175    }
176
177    pub async fn record_tool_execution(&self, mut exec: ToolExecution) {
178        let plan_id = {
179            let session = self.0.session.read().await;
180            if let Some(ref plan) = session.current_plan
181                && plan.status == PlanStatus::Executing
182            {
183                Some(plan.id)
184            } else {
185                None
186            }
187        };
188        exec.plan_id = plan_id;
189        self.0.executions.append(exec).await;
190    }
191
192    pub async fn with_tool_executions<F, R>(&self, f: F) -> R
193    where
194        F: FnOnce(&VecDeque<ToolExecution>) -> R,
195    {
196        self.0.executions.with_entries(f).await
197    }
198
199    pub async fn tool_executions_for_plan(&self, plan_id: Uuid) -> Vec<ToolExecution> {
200        self.0.executions.for_plan(plan_id).await
201    }
202
203    pub async fn execution_log_len(&self) -> usize {
204        self.0.executions.len().await
205    }
206
207    pub async fn clear_execution_log(&self) {
208        self.0.executions.clear().await;
209    }
210
211    pub async fn record_compact(&self, record: CompactRecord) {
212        self.0.session.write().await.record_compact(record);
213    }
214
215    pub async fn with_compact_history<F, R>(&self, f: F) -> R
216    where
217        F: FnOnce(&[CompactRecord]) -> R,
218    {
219        let session = self.0.session.read().await;
220        f(&session.compact_history)
221    }
222
223    #[inline]
224    pub async fn session_snapshot(&self) -> (SessionId, usize, Option<Plan>) {
225        let session = self.0.session.read().await;
226        (
227            session.id,
228            session.todos.len(),
229            session.current_plan.clone(),
230        )
231    }
232
233    #[inline]
234    pub async fn execution_state(&self) -> (SessionId, bool, usize) {
235        let session = self.0.session.read().await;
236        (
237            session.id,
238            session.is_in_plan_mode(),
239            session.todos_in_progress_count(),
240        )
241    }
242
243    pub async fn record_execution_with_todos(
244        &self,
245        mut exec: ToolExecution,
246        todos: Option<Vec<TodoItem>>,
247    ) {
248        let plan_id = {
249            let mut session = self.0.session.write().await;
250            let plan_id = if let Some(ref plan) = session.current_plan
251                && plan.status == PlanStatus::Executing
252            {
253                Some(plan.id)
254            } else {
255                None
256            };
257            if let Some(todos) = todos {
258                session.set_todos(todos);
259            }
260            plan_id
261        };
262        exec.plan_id = plan_id;
263        self.0.executions.append(exec).await;
264    }
265
266    pub async fn enqueue(&self, content: impl Into<String>) -> Result<Uuid, QueueError> {
267        let input = QueuedInput::new(self.session_id(), content);
268        self.0.input_queue.enqueue(input).await
269    }
270
271    pub async fn dequeue_or_merge(&self) -> Option<MergedInput> {
272        self.0.input_queue.merge_all().await
273    }
274
275    pub async fn pending_count(&self) -> usize {
276        self.0.input_queue.pending_count().await
277    }
278
279    pub async fn cancel_pending(&self, id: Uuid) -> bool {
280        self.0.input_queue.cancel(id).await.is_some()
281    }
282
283    pub async fn cancel_all_pending(&self) -> usize {
284        self.0.input_queue.cancel_all().await.len()
285    }
286
287    pub fn is_executing(&self) -> bool {
288        self.0.executing.load(Ordering::Acquire)
289    }
290
291    pub async fn acquire_execution(&self) -> ExecutionGuard<'_> {
292        let permit = self
293            .0
294            .execution_lock
295            .acquire()
296            .await
297            .expect("semaphore should not be closed");
298        self.0.executing.store(true, Ordering::Release);
299        ExecutionGuard {
300            permit,
301            executing: &self.0.executing,
302        }
303    }
304
305    pub fn try_acquire_execution(&self) -> Option<ExecutionGuard<'_>> {
306        self.0.execution_lock.try_acquire().ok().map(|permit| {
307            self.0.executing.store(true, Ordering::Release);
308            ExecutionGuard {
309                permit,
310                executing: &self.0.executing,
311            }
312        })
313    }
314
315    pub async fn with_session<F, R>(&self, f: F) -> R
316    where
317        F: FnOnce(&Session) -> R,
318    {
319        let session = self.0.session.read().await;
320        f(&session)
321    }
322
323    pub async fn with_session_mut<F, R>(&self, f: F) -> R
324    where
325        F: FnOnce(&mut Session) -> R,
326    {
327        let mut session = self.0.session.write().await;
328        f(&mut session)
329    }
330
331    pub async fn compact(
332        &self,
333        client: &crate::Client,
334        keep_messages: usize,
335    ) -> crate::Result<crate::types::CompactResult> {
336        let mut session = self.0.session.write().await;
337        session.compact(client, keep_messages).await
338    }
339}
340
341pub struct ExecutionGuard<'a> {
342    #[allow(dead_code)]
343    permit: tokio::sync::SemaphorePermit<'a>,
344    executing: &'a AtomicBool,
345}
346
347impl Drop for ExecutionGuard<'_> {
348    fn drop(&mut self) {
349        self.executing.store(false, Ordering::Release);
350    }
351}
352
353impl Default for ToolState {
354    fn default() -> Self {
355        Self::new(SessionId::default())
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[tokio::test]
364    async fn test_plan_lifecycle() {
365        let state = ToolState::new(SessionId::new());
366
367        let plan = state.enter_plan_mode(Some("Test Plan".to_string())).await;
368        assert_eq!(plan.status, PlanStatus::Draft);
369        assert!(state.is_in_plan_mode().await);
370
371        state
372            .update_plan_content("Step 1\nStep 2".to_string())
373            .await;
374
375        let approved = state.exit_plan_mode().await;
376        assert!(approved.is_some());
377        assert_eq!(approved.unwrap().status, PlanStatus::Approved);
378    }
379
380    #[tokio::test]
381    async fn test_todos() {
382        let session_id = SessionId::new();
383        let state = ToolState::new(session_id);
384
385        let todos = vec![
386            TodoItem::new(session_id, "Task 1", "Doing task 1"),
387            TodoItem::new(session_id, "Task 2", "Doing task 2"),
388        ];
389
390        state.set_todos(todos).await;
391        let loaded = state.todos().await;
392        assert_eq!(loaded.len(), 2);
393    }
394
395    #[tokio::test]
396    async fn test_tool_execution_recording() {
397        let session_id = SessionId::new();
398        let state = ToolState::new(session_id);
399
400        let exec = ToolExecution::new(session_id, "Bash", serde_json::json!({"command": "ls"}))
401            .with_output("file1\nfile2", false)
402            .with_duration(100);
403
404        state.record_tool_execution(exec).await;
405
406        let count = state.with_tool_executions(|e| e.len()).await;
407        assert_eq!(count, 1);
408
409        let name = state
410            .with_tool_executions(|e| e.front().map(|x| x.tool_name.clone()))
411            .await;
412        assert_eq!(name, Some("Bash".to_string()));
413    }
414
415    #[tokio::test]
416    async fn test_session_persistence_ready() {
417        let session_id = SessionId::new();
418        let state = ToolState::new(session_id);
419
420        let todos = vec![TodoItem::new(session_id, "Task 1", "Doing task 1")];
421        state.set_todos(todos).await;
422        state.enter_plan_mode(Some("My Plan".to_string())).await;
423
424        let session = state.session().await;
425        assert_eq!(session.todos.len(), 1);
426        assert!(session.current_plan.is_some());
427        assert_eq!(
428            session.current_plan.unwrap().name,
429            Some("My Plan".to_string())
430        );
431    }
432
433    #[tokio::test]
434    async fn test_resume_from_session() {
435        let session_id = SessionId::new();
436
437        let mut session = Session::new(SessionConfig::default());
438        session.id = session_id;
439        session.set_todos(vec![TodoItem::new(
440            session_id,
441            "Resumed Task",
442            "Working on it",
443        )]);
444        session.enter_plan_mode(Some("Resumed Plan".to_string()));
445
446        let state = ToolState::from_session(session);
447
448        let todos = state.todos().await;
449        assert_eq!(todos.len(), 1);
450        assert_eq!(todos[0].content, "Resumed Task");
451
452        let plan = state.current_plan().await;
453        assert!(plan.is_some());
454        assert_eq!(plan.unwrap().name, Some("Resumed Plan".to_string()));
455    }
456
457    #[tokio::test]
458    async fn test_concurrent_execution_recording() {
459        let session_id = SessionId::new();
460        let state = ToolState::new(session_id);
461
462        let handles: Vec<_> = (0..10)
463            .map(|i| {
464                let state = state.clone();
465                let sid = session_id;
466                tokio::spawn(async move {
467                    let exec =
468                        ToolExecution::new(sid, format!("Tool{}", i), serde_json::json!({"id": i}));
469                    state.record_tool_execution(exec).await;
470                })
471            })
472            .collect();
473
474        for h in handles {
475            h.await.unwrap();
476        }
477
478        let count = state.with_tool_executions(|e| e.len()).await;
479        assert_eq!(count, 10);
480    }
481
482    #[tokio::test]
483    async fn test_execution_log_limit() {
484        let session_id = SessionId::new();
485        let state = ToolState::new(session_id);
486
487        for i in 0..MAX_EXECUTION_LOG_SIZE + 100 {
488            let exec = ToolExecution::new(session_id, format!("Tool{}", i), serde_json::json!({}));
489            state.record_tool_execution(exec).await;
490        }
491
492        let count = state.with_tool_executions(|e| e.len()).await;
493        assert_eq!(count, MAX_EXECUTION_LOG_SIZE);
494
495        let first_name = state
496            .with_tool_executions(|e| e.front().map(|x| x.tool_name.clone()))
497            .await;
498        assert!(first_name.unwrap().contains("100"));
499    }
500}