Skip to main content

claude_agent/session/
session_state.rs

1//! Tool state for thread-safe state access.
2
3use std::collections::VecDeque;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use tokio::sync::{Notify, RwLock, Semaphore};
8use uuid::Uuid;
9
10use super::queue::{MergedInput, QueueError, QueuedInput, SharedInputQueue};
11use super::state::{Session, SessionConfig, SessionId};
12use super::types::{CompactRecord, Plan, PlanStatus, TodoItem, ToolExecution};
13
14const MAX_EXECUTION_LOG_SIZE: usize = 1000;
15
16#[derive(Debug)]
17struct ToolExecutionLog {
18    entries: RwLock<VecDeque<ToolExecution>>,
19}
20
21impl Default for ToolExecutionLog {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl ToolExecutionLog {
28    fn new() -> Self {
29        Self {
30            entries: RwLock::new(VecDeque::with_capacity(64)),
31        }
32    }
33
34    async fn append(&self, exec: ToolExecution) {
35        let mut entries = self.entries.write().await;
36        if entries.len() >= MAX_EXECUTION_LOG_SIZE {
37            entries.pop_front();
38        }
39        entries.push_back(exec);
40    }
41
42    async fn with_entries<F, R>(&self, f: F) -> R
43    where
44        F: FnOnce(&VecDeque<ToolExecution>) -> R,
45    {
46        let entries = self.entries.read().await;
47        f(&entries)
48    }
49
50    async fn len(&self) -> usize {
51        self.entries.read().await.len()
52    }
53
54    async fn clear(&self) {
55        self.entries.write().await.clear();
56    }
57}
58
59struct ToolStateInner {
60    id: SessionId,
61    session: RwLock<Session>,
62    executions: ToolExecutionLog,
63    input_queue: SharedInputQueue,
64    execution_lock: Semaphore,
65    executing: AtomicBool,
66    queue_notify: Notify,
67}
68
69impl std::fmt::Debug for ToolStateInner {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("ToolStateInner")
72            .field("id", &self.id)
73            .field("executions", &self.executions)
74            .field("executing", &self.executing.load(Ordering::Relaxed))
75            .finish_non_exhaustive()
76    }
77}
78
79impl ToolStateInner {
80    fn new(session_id: SessionId) -> Self {
81        Self {
82            id: session_id,
83            session: RwLock::new(Session::from_id(session_id, SessionConfig::default())),
84            executions: ToolExecutionLog::new(),
85            input_queue: SharedInputQueue::new(),
86            execution_lock: Semaphore::new(1),
87            executing: AtomicBool::new(false),
88            queue_notify: Notify::new(),
89        }
90    }
91
92    fn from_session(session: Session) -> Self {
93        let id = session.id;
94        Self {
95            id,
96            session: RwLock::new(session),
97            executions: ToolExecutionLog::new(),
98            input_queue: SharedInputQueue::new(),
99            execution_lock: Semaphore::new(1),
100            executing: AtomicBool::new(false),
101            queue_notify: Notify::new(),
102        }
103    }
104}
105
106/// Thread-safe tool state handle.
107#[derive(Debug, Clone)]
108pub struct ToolState(Arc<ToolStateInner>);
109
110impl ToolState {
111    pub fn new(session_id: SessionId) -> Self {
112        Self(Arc::new(ToolStateInner::new(session_id)))
113    }
114
115    pub fn from_session(session: Session) -> Self {
116        Self(Arc::new(ToolStateInner::from_session(session)))
117    }
118
119    #[inline]
120    pub fn session_id(&self) -> SessionId {
121        self.0.id
122    }
123
124    pub async fn session(&self) -> Session {
125        self.0.session.read().await.clone()
126    }
127
128    pub async fn update_session(&self, session: Session) {
129        *self.0.session.write().await = session;
130    }
131
132    pub async fn enter_plan_mode(&self, name: Option<String>) -> Plan {
133        self.0.session.write().await.enter_plan_mode(name).clone()
134    }
135
136    pub async fn current_plan(&self) -> Option<Plan> {
137        self.0.session.read().await.current_plan.clone()
138    }
139
140    pub async fn update_plan_content(&self, content: String) {
141        self.0.session.write().await.update_plan_content(content);
142    }
143
144    pub async fn exit_plan_mode(&self) -> Option<Plan> {
145        self.0.session.write().await.exit_plan_mode()
146    }
147
148    pub async fn cancel_plan(&self) -> Option<Plan> {
149        self.0.session.write().await.cancel_plan()
150    }
151
152    #[inline]
153    pub async fn is_in_plan_mode(&self) -> bool {
154        self.0.session.read().await.is_in_plan_mode()
155    }
156
157    pub async fn set_todos(&self, todos: Vec<TodoItem>) {
158        self.0.session.write().await.set_todos(todos);
159    }
160
161    pub async fn todos(&self) -> Vec<TodoItem> {
162        self.0.session.read().await.todos.clone()
163    }
164
165    #[inline]
166    pub async fn todos_in_progress_count(&self) -> usize {
167        self.0.session.read().await.todos_in_progress_count()
168    }
169
170    pub async fn record_tool_execution(&self, mut exec: ToolExecution) {
171        let plan_id = {
172            let session = self.0.session.read().await;
173            if let Some(ref plan) = session.current_plan
174                && plan.status == PlanStatus::Executing
175            {
176                Some(plan.id)
177            } else {
178                None
179            }
180        };
181        exec.plan_id = plan_id;
182        self.0.executions.append(exec).await;
183    }
184
185    pub async fn with_tool_executions<F, R>(&self, f: F) -> R
186    where
187        F: FnOnce(&VecDeque<ToolExecution>) -> R,
188    {
189        self.0.executions.with_entries(f).await
190    }
191
192    pub async fn execution_log_len(&self) -> usize {
193        self.0.executions.len().await
194    }
195
196    pub async fn clear_execution_log(&self) {
197        self.0.executions.clear().await;
198    }
199
200    pub async fn record_compact(&self, record: CompactRecord) {
201        self.0.session.write().await.record_compact(record);
202    }
203
204    pub async fn with_compact_history<F, R>(&self, f: F) -> R
205    where
206        F: FnOnce(&VecDeque<CompactRecord>) -> R,
207    {
208        let session = self.0.session.read().await;
209        f(&session.compact_history)
210    }
211
212    #[inline]
213    pub async fn session_snapshot(&self) -> (SessionId, usize, Option<Plan>) {
214        let session = self.0.session.read().await;
215        (
216            session.id,
217            session.todos.len(),
218            session.current_plan.clone(),
219        )
220    }
221
222    #[inline]
223    pub async fn execution_state(&self) -> (SessionId, bool, usize) {
224        let session = self.0.session.read().await;
225        (
226            session.id,
227            session.is_in_plan_mode(),
228            session.todos_in_progress_count(),
229        )
230    }
231
232    pub async fn record_execution_with_todos(
233        &self,
234        mut exec: ToolExecution,
235        todos: Option<Vec<TodoItem>>,
236    ) {
237        let plan_id = {
238            let mut session = self.0.session.write().await;
239            let plan_id = if let Some(ref plan) = session.current_plan
240                && plan.status == PlanStatus::Executing
241            {
242                Some(plan.id)
243            } else {
244                None
245            };
246            if let Some(todos) = todos {
247                session.set_todos(todos);
248            }
249            plan_id
250        };
251        exec.plan_id = plan_id;
252        self.0.executions.append(exec).await;
253    }
254
255    pub async fn enqueue(&self, content: impl Into<String>) -> Result<Uuid, QueueError> {
256        let input = QueuedInput::new(self.session_id(), content);
257        let result = self.0.input_queue.enqueue(input).await;
258        if result.is_ok() {
259            self.0.queue_notify.notify_waiters();
260        }
261        result
262    }
263
264    pub async fn dequeue_or_merge(&self) -> Option<MergedInput> {
265        self.0.input_queue.merge_all().await
266    }
267
268    pub async fn pending_count(&self) -> usize {
269        self.0.input_queue.pending_count().await
270    }
271
272    pub async fn cancel_pending(&self, id: Uuid) -> bool {
273        self.0.input_queue.cancel(id).await.is_some()
274    }
275
276    pub async fn cancel_all_pending(&self) -> usize {
277        self.0.input_queue.cancel_all().await.len()
278    }
279
280    pub fn is_executing(&self) -> bool {
281        self.0.executing.load(Ordering::Acquire)
282    }
283
284    pub async fn wait_for_queue_signal(&self) {
285        self.0.queue_notify.notified().await;
286    }
287
288    pub async fn acquire_execution(&self) -> ExecutionGuard<'_> {
289        let permit = self
290            .0
291            .execution_lock
292            .acquire()
293            .await
294            .expect("semaphore should not be closed");
295        self.0.executing.store(true, Ordering::Release);
296        ExecutionGuard {
297            permit,
298            executing: &self.0.executing,
299            queue_notify: &self.0.queue_notify,
300        }
301    }
302
303    pub fn try_acquire_execution(&self) -> Option<ExecutionGuard<'_>> {
304        self.0.execution_lock.try_acquire().ok().map(|permit| {
305            self.0.executing.store(true, Ordering::Release);
306            ExecutionGuard {
307                permit,
308                executing: &self.0.executing,
309                queue_notify: &self.0.queue_notify,
310            }
311        })
312    }
313
314    pub async fn with_session<F, R>(&self, f: F) -> R
315    where
316        F: FnOnce(&Session) -> R,
317    {
318        let session = self.0.session.read().await;
319        f(&session)
320    }
321
322    pub async fn with_session_mut<F, R>(&self, f: F) -> R
323    where
324        F: FnOnce(&mut Session) -> R,
325    {
326        let mut session = self.0.session.write().await;
327        f(&mut session)
328    }
329
330    pub async fn compact(
331        &self,
332        client: &crate::Client,
333        keep_messages: usize,
334    ) -> crate::Result<crate::types::CompactResult> {
335        let mut session = self.0.session.write().await;
336        session.compact(client, keep_messages).await
337    }
338}
339
340pub struct ExecutionGuard<'a> {
341    #[allow(dead_code)]
342    permit: tokio::sync::SemaphorePermit<'a>,
343    executing: &'a AtomicBool,
344    queue_notify: &'a Notify,
345}
346
347impl Drop for ExecutionGuard<'_> {
348    fn drop(&mut self) {
349        self.executing.store(false, Ordering::Release);
350        self.queue_notify.notify_waiters();
351    }
352}
353
354impl Default for ToolState {
355    fn default() -> Self {
356        Self::new(SessionId::default())
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[tokio::test]
365    async fn test_plan_lifecycle() {
366        let state = ToolState::new(SessionId::new());
367
368        let plan = state.enter_plan_mode(Some("Test Plan".to_string())).await;
369        assert_eq!(plan.status, PlanStatus::Draft);
370        assert!(state.is_in_plan_mode().await);
371
372        state
373            .update_plan_content("Step 1\nStep 2".to_string())
374            .await;
375
376        let approved = state.exit_plan_mode().await;
377        assert!(approved.is_some());
378        assert_eq!(approved.unwrap().status, PlanStatus::Approved);
379    }
380
381    #[tokio::test]
382    async fn test_todos() {
383        let session_id = SessionId::new();
384        let state = ToolState::new(session_id);
385
386        let todos = vec![
387            TodoItem::new(session_id, "Task 1", "Doing task 1"),
388            TodoItem::new(session_id, "Task 2", "Doing task 2"),
389        ];
390
391        state.set_todos(todos).await;
392        let loaded = state.todos().await;
393        assert_eq!(loaded.len(), 2);
394    }
395
396    #[tokio::test]
397    async fn test_tool_execution_recording() {
398        let session_id = SessionId::new();
399        let state = ToolState::new(session_id);
400
401        let exec = ToolExecution::new(session_id, "Bash", serde_json::json!({"command": "ls"}))
402            .output("file1\nfile2", false)
403            .duration(100);
404
405        state.record_tool_execution(exec).await;
406
407        let count = state.with_tool_executions(|e| e.len()).await;
408        assert_eq!(count, 1);
409
410        let name = state
411            .with_tool_executions(|e| e.front().map(|x| x.tool_name.clone()))
412            .await;
413        assert_eq!(name, Some("Bash".to_string()));
414    }
415
416    #[tokio::test]
417    async fn test_session_persistence_ready() {
418        let session_id = SessionId::new();
419        let state = ToolState::new(session_id);
420
421        let todos = vec![TodoItem::new(session_id, "Task 1", "Doing task 1")];
422        state.set_todos(todos).await;
423        state.enter_plan_mode(Some("My Plan".to_string())).await;
424
425        let session = state.session().await;
426        assert_eq!(session.todos.len(), 1);
427        assert!(session.current_plan.is_some());
428        assert_eq!(
429            session.current_plan.unwrap().name,
430            Some("My Plan".to_string())
431        );
432    }
433
434    #[tokio::test]
435    async fn test_resume_from_session() {
436        let session_id = SessionId::new();
437
438        let mut session = Session::new(SessionConfig::default());
439        session.id = session_id;
440        session.set_todos(vec![TodoItem::new(
441            session_id,
442            "Resumed Task",
443            "Working on it",
444        )]);
445        session.enter_plan_mode(Some("Resumed Plan".to_string()));
446
447        let state = ToolState::from_session(session);
448
449        let todos = state.todos().await;
450        assert_eq!(todos.len(), 1);
451        assert_eq!(todos[0].content, "Resumed Task");
452
453        let plan = state.current_plan().await;
454        assert!(plan.is_some());
455        assert_eq!(plan.unwrap().name, Some("Resumed Plan".to_string()));
456    }
457
458    #[tokio::test]
459    async fn test_concurrent_execution_recording() {
460        let session_id = SessionId::new();
461        let state = ToolState::new(session_id);
462
463        let handles: Vec<_> = (0..10)
464            .map(|i| {
465                let state = state.clone();
466                let sid = session_id;
467                tokio::spawn(async move {
468                    let exec =
469                        ToolExecution::new(sid, format!("Tool{}", i), serde_json::json!({"id": i}));
470                    state.record_tool_execution(exec).await;
471                })
472            })
473            .collect();
474
475        for h in handles {
476            h.await.unwrap();
477        }
478
479        let count = state.with_tool_executions(|e| e.len()).await;
480        assert_eq!(count, 10);
481    }
482
483    #[tokio::test]
484    async fn test_execution_log_limit() {
485        let session_id = SessionId::new();
486        let state = ToolState::new(session_id);
487
488        for i in 0..MAX_EXECUTION_LOG_SIZE + 100 {
489            let exec = ToolExecution::new(session_id, format!("Tool{}", i), serde_json::json!({}));
490            state.record_tool_execution(exec).await;
491        }
492
493        let count = state.with_tool_executions(|e| e.len()).await;
494        assert_eq!(count, MAX_EXECUTION_LOG_SIZE);
495
496        let first_name = state
497            .with_tool_executions(|e| e.front().map(|x| x.tool_name.clone()))
498            .await;
499        assert!(first_name.unwrap().contains("100"));
500    }
501}