1use std::collections::VecDeque;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use tokio::sync::{Notify, RwLock, Semaphore};
8use uuid::Uuid;
9
10use super::queue::{MergedInput, QueueError, QueuedInput, SharedInputQueue};
11use super::state::{Session, SessionConfig, SessionId};
12use super::types::{CompactRecord, Plan, PlanStatus, TodoItem, ToolExecution};
13
14const MAX_EXECUTION_LOG_SIZE: usize = 1000;
15
16#[derive(Debug)]
17struct ToolExecutionLog {
18 entries: RwLock<VecDeque<ToolExecution>>,
19}
20
21impl Default for ToolExecutionLog {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl ToolExecutionLog {
28 fn new() -> Self {
29 Self {
30 entries: RwLock::new(VecDeque::with_capacity(64)),
31 }
32 }
33
34 async fn append(&self, exec: ToolExecution) {
35 let mut entries = self.entries.write().await;
36 if entries.len() >= MAX_EXECUTION_LOG_SIZE {
37 entries.pop_front();
38 }
39 entries.push_back(exec);
40 }
41
42 async fn with_entries<F, R>(&self, f: F) -> R
43 where
44 F: FnOnce(&VecDeque<ToolExecution>) -> R,
45 {
46 let entries = self.entries.read().await;
47 f(&entries)
48 }
49
50 async fn len(&self) -> usize {
51 self.entries.read().await.len()
52 }
53
54 async fn clear(&self) {
55 self.entries.write().await.clear();
56 }
57}
58
59struct ToolStateInner {
60 id: SessionId,
61 session: RwLock<Session>,
62 executions: ToolExecutionLog,
63 input_queue: SharedInputQueue,
64 execution_lock: Semaphore,
65 executing: AtomicBool,
66 queue_notify: Notify,
67}
68
69impl std::fmt::Debug for ToolStateInner {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("ToolStateInner")
72 .field("id", &self.id)
73 .field("executions", &self.executions)
74 .field("executing", &self.executing.load(Ordering::Relaxed))
75 .finish_non_exhaustive()
76 }
77}
78
79impl ToolStateInner {
80 fn new(session_id: SessionId) -> Self {
81 Self {
82 id: session_id,
83 session: RwLock::new(Session::from_id(session_id, SessionConfig::default())),
84 executions: ToolExecutionLog::new(),
85 input_queue: SharedInputQueue::new(),
86 execution_lock: Semaphore::new(1),
87 executing: AtomicBool::new(false),
88 queue_notify: Notify::new(),
89 }
90 }
91
92 fn from_session(session: Session) -> Self {
93 let id = session.id;
94 Self {
95 id,
96 session: RwLock::new(session),
97 executions: ToolExecutionLog::new(),
98 input_queue: SharedInputQueue::new(),
99 execution_lock: Semaphore::new(1),
100 executing: AtomicBool::new(false),
101 queue_notify: Notify::new(),
102 }
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct ToolState(Arc<ToolStateInner>);
109
110impl ToolState {
111 pub fn new(session_id: SessionId) -> Self {
112 Self(Arc::new(ToolStateInner::new(session_id)))
113 }
114
115 pub fn from_session(session: Session) -> Self {
116 Self(Arc::new(ToolStateInner::from_session(session)))
117 }
118
119 #[inline]
120 pub fn session_id(&self) -> SessionId {
121 self.0.id
122 }
123
124 pub async fn session(&self) -> Session {
125 self.0.session.read().await.clone()
126 }
127
128 pub async fn update_session(&self, session: Session) {
129 *self.0.session.write().await = session;
130 }
131
132 pub async fn enter_plan_mode(&self, name: Option<String>) -> Plan {
133 self.0.session.write().await.enter_plan_mode(name).clone()
134 }
135
136 pub async fn current_plan(&self) -> Option<Plan> {
137 self.0.session.read().await.current_plan.clone()
138 }
139
140 pub async fn update_plan_content(&self, content: String) {
141 self.0.session.write().await.update_plan_content(content);
142 }
143
144 pub async fn exit_plan_mode(&self) -> Option<Plan> {
145 self.0.session.write().await.exit_plan_mode()
146 }
147
148 pub async fn cancel_plan(&self) -> Option<Plan> {
149 self.0.session.write().await.cancel_plan()
150 }
151
152 #[inline]
153 pub async fn is_in_plan_mode(&self) -> bool {
154 self.0.session.read().await.is_in_plan_mode()
155 }
156
157 pub async fn set_todos(&self, todos: Vec<TodoItem>) {
158 self.0.session.write().await.set_todos(todos);
159 }
160
161 pub async fn todos(&self) -> Vec<TodoItem> {
162 self.0.session.read().await.todos.clone()
163 }
164
165 #[inline]
166 pub async fn todos_in_progress_count(&self) -> usize {
167 self.0.session.read().await.todos_in_progress_count()
168 }
169
170 pub async fn record_tool_execution(&self, mut exec: ToolExecution) {
171 let plan_id = {
172 let session = self.0.session.read().await;
173 if let Some(ref plan) = session.current_plan
174 && plan.status == PlanStatus::Executing
175 {
176 Some(plan.id)
177 } else {
178 None
179 }
180 };
181 exec.plan_id = plan_id;
182 self.0.executions.append(exec).await;
183 }
184
185 pub async fn with_tool_executions<F, R>(&self, f: F) -> R
186 where
187 F: FnOnce(&VecDeque<ToolExecution>) -> R,
188 {
189 self.0.executions.with_entries(f).await
190 }
191
192 pub async fn execution_log_len(&self) -> usize {
193 self.0.executions.len().await
194 }
195
196 pub async fn clear_execution_log(&self) {
197 self.0.executions.clear().await;
198 }
199
200 pub async fn record_compact(&self, record: CompactRecord) {
201 self.0.session.write().await.record_compact(record);
202 }
203
204 pub async fn with_compact_history<F, R>(&self, f: F) -> R
205 where
206 F: FnOnce(&VecDeque<CompactRecord>) -> R,
207 {
208 let session = self.0.session.read().await;
209 f(&session.compact_history)
210 }
211
212 #[inline]
213 pub async fn session_snapshot(&self) -> (SessionId, usize, Option<Plan>) {
214 let session = self.0.session.read().await;
215 (
216 session.id,
217 session.todos.len(),
218 session.current_plan.clone(),
219 )
220 }
221
222 #[inline]
223 pub async fn execution_state(&self) -> (SessionId, bool, usize) {
224 let session = self.0.session.read().await;
225 (
226 session.id,
227 session.is_in_plan_mode(),
228 session.todos_in_progress_count(),
229 )
230 }
231
232 pub async fn record_execution_with_todos(
233 &self,
234 mut exec: ToolExecution,
235 todos: Option<Vec<TodoItem>>,
236 ) {
237 let plan_id = {
238 let mut session = self.0.session.write().await;
239 let plan_id = if let Some(ref plan) = session.current_plan
240 && plan.status == PlanStatus::Executing
241 {
242 Some(plan.id)
243 } else {
244 None
245 };
246 if let Some(todos) = todos {
247 session.set_todos(todos);
248 }
249 plan_id
250 };
251 exec.plan_id = plan_id;
252 self.0.executions.append(exec).await;
253 }
254
255 pub async fn enqueue(&self, content: impl Into<String>) -> Result<Uuid, QueueError> {
256 let input = QueuedInput::new(self.session_id(), content);
257 let result = self.0.input_queue.enqueue(input).await;
258 if result.is_ok() {
259 self.0.queue_notify.notify_waiters();
260 }
261 result
262 }
263
264 pub async fn dequeue_or_merge(&self) -> Option<MergedInput> {
265 self.0.input_queue.merge_all().await
266 }
267
268 pub async fn pending_count(&self) -> usize {
269 self.0.input_queue.pending_count().await
270 }
271
272 pub async fn cancel_pending(&self, id: Uuid) -> bool {
273 self.0.input_queue.cancel(id).await.is_some()
274 }
275
276 pub async fn cancel_all_pending(&self) -> usize {
277 self.0.input_queue.cancel_all().await.len()
278 }
279
280 pub fn is_executing(&self) -> bool {
281 self.0.executing.load(Ordering::Acquire)
282 }
283
284 pub async fn wait_for_queue_signal(&self) {
285 self.0.queue_notify.notified().await;
286 }
287
288 pub async fn acquire_execution(&self) -> ExecutionGuard<'_> {
289 let permit = self
290 .0
291 .execution_lock
292 .acquire()
293 .await
294 .expect("semaphore should not be closed");
295 self.0.executing.store(true, Ordering::Release);
296 ExecutionGuard {
297 permit,
298 executing: &self.0.executing,
299 queue_notify: &self.0.queue_notify,
300 }
301 }
302
303 pub fn try_acquire_execution(&self) -> Option<ExecutionGuard<'_>> {
304 self.0.execution_lock.try_acquire().ok().map(|permit| {
305 self.0.executing.store(true, Ordering::Release);
306 ExecutionGuard {
307 permit,
308 executing: &self.0.executing,
309 queue_notify: &self.0.queue_notify,
310 }
311 })
312 }
313
314 pub async fn with_session<F, R>(&self, f: F) -> R
315 where
316 F: FnOnce(&Session) -> R,
317 {
318 let session = self.0.session.read().await;
319 f(&session)
320 }
321
322 pub async fn with_session_mut<F, R>(&self, f: F) -> R
323 where
324 F: FnOnce(&mut Session) -> R,
325 {
326 let mut session = self.0.session.write().await;
327 f(&mut session)
328 }
329
330 pub async fn compact(
331 &self,
332 client: &crate::Client,
333 keep_messages: usize,
334 ) -> crate::Result<crate::types::CompactResult> {
335 let mut session = self.0.session.write().await;
336 session.compact(client, keep_messages).await
337 }
338}
339
340pub struct ExecutionGuard<'a> {
341 #[allow(dead_code)]
342 permit: tokio::sync::SemaphorePermit<'a>,
343 executing: &'a AtomicBool,
344 queue_notify: &'a Notify,
345}
346
347impl Drop for ExecutionGuard<'_> {
348 fn drop(&mut self) {
349 self.executing.store(false, Ordering::Release);
350 self.queue_notify.notify_waiters();
351 }
352}
353
354impl Default for ToolState {
355 fn default() -> Self {
356 Self::new(SessionId::default())
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[tokio::test]
365 async fn test_plan_lifecycle() {
366 let state = ToolState::new(SessionId::new());
367
368 let plan = state.enter_plan_mode(Some("Test Plan".to_string())).await;
369 assert_eq!(plan.status, PlanStatus::Draft);
370 assert!(state.is_in_plan_mode().await);
371
372 state
373 .update_plan_content("Step 1\nStep 2".to_string())
374 .await;
375
376 let approved = state.exit_plan_mode().await;
377 assert!(approved.is_some());
378 assert_eq!(approved.unwrap().status, PlanStatus::Approved);
379 }
380
381 #[tokio::test]
382 async fn test_todos() {
383 let session_id = SessionId::new();
384 let state = ToolState::new(session_id);
385
386 let todos = vec![
387 TodoItem::new(session_id, "Task 1", "Doing task 1"),
388 TodoItem::new(session_id, "Task 2", "Doing task 2"),
389 ];
390
391 state.set_todos(todos).await;
392 let loaded = state.todos().await;
393 assert_eq!(loaded.len(), 2);
394 }
395
396 #[tokio::test]
397 async fn test_tool_execution_recording() {
398 let session_id = SessionId::new();
399 let state = ToolState::new(session_id);
400
401 let exec = ToolExecution::new(session_id, "Bash", serde_json::json!({"command": "ls"}))
402 .output("file1\nfile2", false)
403 .duration(100);
404
405 state.record_tool_execution(exec).await;
406
407 let count = state.with_tool_executions(|e| e.len()).await;
408 assert_eq!(count, 1);
409
410 let name = state
411 .with_tool_executions(|e| e.front().map(|x| x.tool_name.clone()))
412 .await;
413 assert_eq!(name, Some("Bash".to_string()));
414 }
415
416 #[tokio::test]
417 async fn test_session_persistence_ready() {
418 let session_id = SessionId::new();
419 let state = ToolState::new(session_id);
420
421 let todos = vec![TodoItem::new(session_id, "Task 1", "Doing task 1")];
422 state.set_todos(todos).await;
423 state.enter_plan_mode(Some("My Plan".to_string())).await;
424
425 let session = state.session().await;
426 assert_eq!(session.todos.len(), 1);
427 assert!(session.current_plan.is_some());
428 assert_eq!(
429 session.current_plan.unwrap().name,
430 Some("My Plan".to_string())
431 );
432 }
433
434 #[tokio::test]
435 async fn test_resume_from_session() {
436 let session_id = SessionId::new();
437
438 let mut session = Session::new(SessionConfig::default());
439 session.id = session_id;
440 session.set_todos(vec![TodoItem::new(
441 session_id,
442 "Resumed Task",
443 "Working on it",
444 )]);
445 session.enter_plan_mode(Some("Resumed Plan".to_string()));
446
447 let state = ToolState::from_session(session);
448
449 let todos = state.todos().await;
450 assert_eq!(todos.len(), 1);
451 assert_eq!(todos[0].content, "Resumed Task");
452
453 let plan = state.current_plan().await;
454 assert!(plan.is_some());
455 assert_eq!(plan.unwrap().name, Some("Resumed Plan".to_string()));
456 }
457
458 #[tokio::test]
459 async fn test_concurrent_execution_recording() {
460 let session_id = SessionId::new();
461 let state = ToolState::new(session_id);
462
463 let handles: Vec<_> = (0..10)
464 .map(|i| {
465 let state = state.clone();
466 let sid = session_id;
467 tokio::spawn(async move {
468 let exec =
469 ToolExecution::new(sid, format!("Tool{}", i), serde_json::json!({"id": i}));
470 state.record_tool_execution(exec).await;
471 })
472 })
473 .collect();
474
475 for h in handles {
476 h.await.unwrap();
477 }
478
479 let count = state.with_tool_executions(|e| e.len()).await;
480 assert_eq!(count, 10);
481 }
482
483 #[tokio::test]
484 async fn test_execution_log_limit() {
485 let session_id = SessionId::new();
486 let state = ToolState::new(session_id);
487
488 for i in 0..MAX_EXECUTION_LOG_SIZE + 100 {
489 let exec = ToolExecution::new(session_id, format!("Tool{}", i), serde_json::json!({}));
490 state.record_tool_execution(exec).await;
491 }
492
493 let count = state.with_tool_executions(|e| e.len()).await;
494 assert_eq!(count, MAX_EXECUTION_LOG_SIZE);
495
496 let first_name = state
497 .with_tool_executions(|e| e.front().map(|x| x.tool_name.clone()))
498 .await;
499 assert!(first_name.unwrap().contains("100"));
500 }
501}