1use std::collections::VecDeque;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use tokio::sync::{RwLock, Semaphore};
8use uuid::Uuid;
9
10use super::queue::{MergedInput, QueueError, QueuedInput, SharedInputQueue};
11use super::state::{Session, SessionConfig, SessionId};
12use super::types::{CompactRecord, Plan, PlanStatus, TodoItem, ToolExecution};
13
14const MAX_EXECUTION_LOG_SIZE: usize = 1000;
15
16#[derive(Debug)]
17struct ToolExecutionLog {
18 entries: RwLock<VecDeque<ToolExecution>>,
19}
20
21impl Default for ToolExecutionLog {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl ToolExecutionLog {
28 fn new() -> Self {
29 Self {
30 entries: RwLock::new(VecDeque::with_capacity(64)),
31 }
32 }
33
34 async fn append(&self, exec: ToolExecution) {
35 let mut entries = self.entries.write().await;
36 if entries.len() >= MAX_EXECUTION_LOG_SIZE {
37 entries.pop_front();
38 }
39 entries.push_back(exec);
40 }
41
42 async fn with_entries<F, R>(&self, f: F) -> R
43 where
44 F: FnOnce(&VecDeque<ToolExecution>) -> R,
45 {
46 let entries = self.entries.read().await;
47 f(&entries)
48 }
49
50 async fn for_plan(&self, plan_id: Uuid) -> Vec<ToolExecution> {
51 self.entries
52 .read()
53 .await
54 .iter()
55 .filter(|e| e.plan_id == Some(plan_id))
56 .cloned()
57 .collect()
58 }
59
60 async fn len(&self) -> usize {
61 self.entries.read().await.len()
62 }
63
64 async fn clear(&self) {
65 self.entries.write().await.clear();
66 }
67}
68
69struct ToolStateInner {
70 id: SessionId,
71 session: RwLock<Session>,
72 executions: ToolExecutionLog,
73 input_queue: SharedInputQueue,
74 execution_lock: Semaphore,
75 executing: AtomicBool,
76}
77
78impl std::fmt::Debug for ToolStateInner {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("ToolStateInner")
81 .field("id", &self.id)
82 .field("executions", &self.executions)
83 .field("executing", &self.executing.load(Ordering::Relaxed))
84 .finish_non_exhaustive()
85 }
86}
87
88impl ToolStateInner {
89 fn new(session_id: SessionId) -> Self {
90 Self {
91 id: session_id,
92 session: RwLock::new(Session::with_id(session_id, SessionConfig::default())),
93 executions: ToolExecutionLog::new(),
94 input_queue: SharedInputQueue::new(),
95 execution_lock: Semaphore::new(1),
96 executing: AtomicBool::new(false),
97 }
98 }
99
100 fn from_session(session: Session) -> Self {
101 let id = session.id;
102 Self {
103 id,
104 session: RwLock::new(session),
105 executions: ToolExecutionLog::new(),
106 input_queue: SharedInputQueue::new(),
107 execution_lock: Semaphore::new(1),
108 executing: AtomicBool::new(false),
109 }
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct ToolState(Arc<ToolStateInner>);
116
117impl ToolState {
118 pub fn new(session_id: SessionId) -> Self {
119 Self(Arc::new(ToolStateInner::new(session_id)))
120 }
121
122 pub fn from_session(session: Session) -> Self {
123 Self(Arc::new(ToolStateInner::from_session(session)))
124 }
125
126 #[inline]
127 pub fn session_id(&self) -> SessionId {
128 self.0.id
129 }
130
131 pub async fn session(&self) -> Session {
132 self.0.session.read().await.clone()
133 }
134
135 pub async fn update_session(&self, session: Session) {
136 *self.0.session.write().await = session;
137 }
138
139 pub async fn enter_plan_mode(&self, name: Option<String>) -> Plan {
140 self.0.session.write().await.enter_plan_mode(name).clone()
141 }
142
143 pub async fn current_plan(&self) -> Option<Plan> {
144 self.0.session.read().await.current_plan.clone()
145 }
146
147 pub async fn update_plan_content(&self, content: String) {
148 self.0.session.write().await.update_plan_content(content);
149 }
150
151 pub async fn exit_plan_mode(&self) -> Option<Plan> {
152 self.0.session.write().await.exit_plan_mode()
153 }
154
155 pub async fn cancel_plan(&self) -> Option<Plan> {
156 self.0.session.write().await.cancel_plan()
157 }
158
159 #[inline]
160 pub async fn is_in_plan_mode(&self) -> bool {
161 self.0.session.read().await.is_in_plan_mode()
162 }
163
164 pub async fn set_todos(&self, todos: Vec<TodoItem>) {
165 self.0.session.write().await.set_todos(todos);
166 }
167
168 pub async fn todos(&self) -> Vec<TodoItem> {
169 self.0.session.read().await.todos.clone()
170 }
171
172 #[inline]
173 pub async fn todos_in_progress_count(&self) -> usize {
174 self.0.session.read().await.todos_in_progress_count()
175 }
176
177 pub async fn record_tool_execution(&self, mut exec: ToolExecution) {
178 let plan_id = {
179 let session = self.0.session.read().await;
180 if let Some(ref plan) = session.current_plan
181 && plan.status == PlanStatus::Executing
182 {
183 Some(plan.id)
184 } else {
185 None
186 }
187 };
188 exec.plan_id = plan_id;
189 self.0.executions.append(exec).await;
190 }
191
192 pub async fn with_tool_executions<F, R>(&self, f: F) -> R
193 where
194 F: FnOnce(&VecDeque<ToolExecution>) -> R,
195 {
196 self.0.executions.with_entries(f).await
197 }
198
199 pub async fn tool_executions_for_plan(&self, plan_id: Uuid) -> Vec<ToolExecution> {
200 self.0.executions.for_plan(plan_id).await
201 }
202
203 pub async fn execution_log_len(&self) -> usize {
204 self.0.executions.len().await
205 }
206
207 pub async fn clear_execution_log(&self) {
208 self.0.executions.clear().await;
209 }
210
211 pub async fn record_compact(&self, record: CompactRecord) {
212 self.0.session.write().await.record_compact(record);
213 }
214
215 pub async fn with_compact_history<F, R>(&self, f: F) -> R
216 where
217 F: FnOnce(&[CompactRecord]) -> R,
218 {
219 let session = self.0.session.read().await;
220 f(&session.compact_history)
221 }
222
223 #[inline]
224 pub async fn session_snapshot(&self) -> (SessionId, usize, Option<Plan>) {
225 let session = self.0.session.read().await;
226 (
227 session.id,
228 session.todos.len(),
229 session.current_plan.clone(),
230 )
231 }
232
233 #[inline]
234 pub async fn execution_state(&self) -> (SessionId, bool, usize) {
235 let session = self.0.session.read().await;
236 (
237 session.id,
238 session.is_in_plan_mode(),
239 session.todos_in_progress_count(),
240 )
241 }
242
243 pub async fn record_execution_with_todos(
244 &self,
245 mut exec: ToolExecution,
246 todos: Option<Vec<TodoItem>>,
247 ) {
248 let plan_id = {
249 let mut session = self.0.session.write().await;
250 let plan_id = if let Some(ref plan) = session.current_plan
251 && plan.status == PlanStatus::Executing
252 {
253 Some(plan.id)
254 } else {
255 None
256 };
257 if let Some(todos) = todos {
258 session.set_todos(todos);
259 }
260 plan_id
261 };
262 exec.plan_id = plan_id;
263 self.0.executions.append(exec).await;
264 }
265
266 pub async fn enqueue(&self, content: impl Into<String>) -> Result<Uuid, QueueError> {
267 let input = QueuedInput::new(self.session_id(), content);
268 self.0.input_queue.enqueue(input).await
269 }
270
271 pub async fn dequeue_or_merge(&self) -> Option<MergedInput> {
272 self.0.input_queue.merge_all().await
273 }
274
275 pub async fn pending_count(&self) -> usize {
276 self.0.input_queue.pending_count().await
277 }
278
279 pub async fn cancel_pending(&self, id: Uuid) -> bool {
280 self.0.input_queue.cancel(id).await.is_some()
281 }
282
283 pub async fn cancel_all_pending(&self) -> usize {
284 self.0.input_queue.cancel_all().await.len()
285 }
286
287 pub fn is_executing(&self) -> bool {
288 self.0.executing.load(Ordering::Acquire)
289 }
290
291 pub async fn acquire_execution(&self) -> ExecutionGuard<'_> {
292 let permit = self
293 .0
294 .execution_lock
295 .acquire()
296 .await
297 .expect("semaphore should not be closed");
298 self.0.executing.store(true, Ordering::Release);
299 ExecutionGuard {
300 permit,
301 executing: &self.0.executing,
302 }
303 }
304
305 pub fn try_acquire_execution(&self) -> Option<ExecutionGuard<'_>> {
306 self.0.execution_lock.try_acquire().ok().map(|permit| {
307 self.0.executing.store(true, Ordering::Release);
308 ExecutionGuard {
309 permit,
310 executing: &self.0.executing,
311 }
312 })
313 }
314
315 pub async fn with_session<F, R>(&self, f: F) -> R
316 where
317 F: FnOnce(&Session) -> R,
318 {
319 let session = self.0.session.read().await;
320 f(&session)
321 }
322
323 pub async fn with_session_mut<F, R>(&self, f: F) -> R
324 where
325 F: FnOnce(&mut Session) -> R,
326 {
327 let mut session = self.0.session.write().await;
328 f(&mut session)
329 }
330
331 pub async fn compact(
332 &self,
333 client: &crate::Client,
334 keep_messages: usize,
335 ) -> crate::Result<crate::types::CompactResult> {
336 let mut session = self.0.session.write().await;
337 session.compact(client, keep_messages).await
338 }
339}
340
341pub struct ExecutionGuard<'a> {
342 #[allow(dead_code)]
343 permit: tokio::sync::SemaphorePermit<'a>,
344 executing: &'a AtomicBool,
345}
346
347impl Drop for ExecutionGuard<'_> {
348 fn drop(&mut self) {
349 self.executing.store(false, Ordering::Release);
350 }
351}
352
353impl Default for ToolState {
354 fn default() -> Self {
355 Self::new(SessionId::default())
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[tokio::test]
364 async fn test_plan_lifecycle() {
365 let state = ToolState::new(SessionId::new());
366
367 let plan = state.enter_plan_mode(Some("Test Plan".to_string())).await;
368 assert_eq!(plan.status, PlanStatus::Draft);
369 assert!(state.is_in_plan_mode().await);
370
371 state
372 .update_plan_content("Step 1\nStep 2".to_string())
373 .await;
374
375 let approved = state.exit_plan_mode().await;
376 assert!(approved.is_some());
377 assert_eq!(approved.unwrap().status, PlanStatus::Approved);
378 }
379
380 #[tokio::test]
381 async fn test_todos() {
382 let session_id = SessionId::new();
383 let state = ToolState::new(session_id);
384
385 let todos = vec![
386 TodoItem::new(session_id, "Task 1", "Doing task 1"),
387 TodoItem::new(session_id, "Task 2", "Doing task 2"),
388 ];
389
390 state.set_todos(todos).await;
391 let loaded = state.todos().await;
392 assert_eq!(loaded.len(), 2);
393 }
394
395 #[tokio::test]
396 async fn test_tool_execution_recording() {
397 let session_id = SessionId::new();
398 let state = ToolState::new(session_id);
399
400 let exec = ToolExecution::new(session_id, "Bash", serde_json::json!({"command": "ls"}))
401 .with_output("file1\nfile2", false)
402 .with_duration(100);
403
404 state.record_tool_execution(exec).await;
405
406 let count = state.with_tool_executions(|e| e.len()).await;
407 assert_eq!(count, 1);
408
409 let name = state
410 .with_tool_executions(|e| e.front().map(|x| x.tool_name.clone()))
411 .await;
412 assert_eq!(name, Some("Bash".to_string()));
413 }
414
415 #[tokio::test]
416 async fn test_session_persistence_ready() {
417 let session_id = SessionId::new();
418 let state = ToolState::new(session_id);
419
420 let todos = vec![TodoItem::new(session_id, "Task 1", "Doing task 1")];
421 state.set_todos(todos).await;
422 state.enter_plan_mode(Some("My Plan".to_string())).await;
423
424 let session = state.session().await;
425 assert_eq!(session.todos.len(), 1);
426 assert!(session.current_plan.is_some());
427 assert_eq!(
428 session.current_plan.unwrap().name,
429 Some("My Plan".to_string())
430 );
431 }
432
433 #[tokio::test]
434 async fn test_resume_from_session() {
435 let session_id = SessionId::new();
436
437 let mut session = Session::new(SessionConfig::default());
438 session.id = session_id;
439 session.set_todos(vec![TodoItem::new(
440 session_id,
441 "Resumed Task",
442 "Working on it",
443 )]);
444 session.enter_plan_mode(Some("Resumed Plan".to_string()));
445
446 let state = ToolState::from_session(session);
447
448 let todos = state.todos().await;
449 assert_eq!(todos.len(), 1);
450 assert_eq!(todos[0].content, "Resumed Task");
451
452 let plan = state.current_plan().await;
453 assert!(plan.is_some());
454 assert_eq!(plan.unwrap().name, Some("Resumed Plan".to_string()));
455 }
456
457 #[tokio::test]
458 async fn test_concurrent_execution_recording() {
459 let session_id = SessionId::new();
460 let state = ToolState::new(session_id);
461
462 let handles: Vec<_> = (0..10)
463 .map(|i| {
464 let state = state.clone();
465 let sid = session_id;
466 tokio::spawn(async move {
467 let exec =
468 ToolExecution::new(sid, format!("Tool{}", i), serde_json::json!({"id": i}));
469 state.record_tool_execution(exec).await;
470 })
471 })
472 .collect();
473
474 for h in handles {
475 h.await.unwrap();
476 }
477
478 let count = state.with_tool_executions(|e| e.len()).await;
479 assert_eq!(count, 10);
480 }
481
482 #[tokio::test]
483 async fn test_execution_log_limit() {
484 let session_id = SessionId::new();
485 let state = ToolState::new(session_id);
486
487 for i in 0..MAX_EXECUTION_LOG_SIZE + 100 {
488 let exec = ToolExecution::new(session_id, format!("Tool{}", i), serde_json::json!({}));
489 state.record_tool_execution(exec).await;
490 }
491
492 let count = state.with_tool_executions(|e| e.len()).await;
493 assert_eq!(count, MAX_EXECUTION_LOG_SIZE);
494
495 let first_name = state
496 .with_tool_executions(|e| e.front().map(|x| x.tool_name.clone()))
497 .await;
498 assert!(first_name.unwrap().contains("100"));
499 }
500}