claude_agent/session/
queue.rs

1//! Input queue for handling concurrent user inputs.
2
3use std::collections::VecDeque;
4use std::sync::Arc;
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use super::state::SessionId;
12use super::types::EnvironmentContext;
13
14const MAX_QUEUE_SIZE: usize = 100;
15const MAX_MERGE_CHARS: usize = 100_000;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum QueueError {
19    Full,
20    MergeLimitExceeded,
21}
22
23impl std::fmt::Display for QueueError {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::Full => write!(f, "queue is full"),
27            Self::MergeLimitExceeded => write!(f, "merge size limit exceeded"),
28        }
29    }
30}
31
32impl std::error::Error for QueueError {}
33
34#[derive(Clone, Debug, Serialize, Deserialize)]
35pub struct QueuedInput {
36    pub id: Uuid,
37    pub session_id: SessionId,
38    pub content: String,
39    pub environment: Option<EnvironmentContext>,
40    pub created_at: DateTime<Utc>,
41}
42
43impl QueuedInput {
44    pub fn new(session_id: SessionId, content: impl Into<String>) -> Self {
45        Self {
46            id: Uuid::new_v4(),
47            session_id,
48            content: content.into(),
49            environment: None,
50            created_at: Utc::now(),
51        }
52    }
53
54    pub fn with_environment(mut self, env: EnvironmentContext) -> Self {
55        self.environment = Some(env);
56        self
57    }
58}
59
60#[derive(Clone, Debug)]
61pub struct MergedInput {
62    pub ids: Vec<Uuid>,
63    pub content: String,
64    pub environment: Option<EnvironmentContext>,
65}
66
67#[derive(Debug)]
68pub struct InputQueue {
69    items: VecDeque<QueuedInput>,
70}
71
72impl Default for InputQueue {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78impl InputQueue {
79    pub fn new() -> Self {
80        Self {
81            items: VecDeque::with_capacity(16),
82        }
83    }
84
85    pub fn enqueue(&mut self, input: QueuedInput) -> Result<Uuid, QueueError> {
86        if self.items.len() >= MAX_QUEUE_SIZE {
87            return Err(QueueError::Full);
88        }
89        let id = input.id;
90        self.items.push_back(input);
91        Ok(id)
92    }
93
94    pub fn cancel(&mut self, id: Uuid) -> Option<QueuedInput> {
95        self.items
96            .iter()
97            .position(|i| i.id == id)
98            .and_then(|pos| self.items.remove(pos))
99    }
100
101    pub fn cancel_all(&mut self) -> Vec<QueuedInput> {
102        self.items.drain(..).collect()
103    }
104
105    pub fn pending(&self) -> impl Iterator<Item = &QueuedInput> {
106        self.items.iter()
107    }
108
109    pub fn pending_count(&self) -> usize {
110        self.items.len()
111    }
112
113    pub fn is_empty(&self) -> bool {
114        self.items.is_empty()
115    }
116
117    pub fn merge_all(&mut self) -> Option<MergedInput> {
118        if self.items.is_empty() {
119            return None;
120        }
121
122        let mut ids = Vec::with_capacity(self.items.len());
123        let mut total_len = 0;
124        let mut contents = Vec::with_capacity(self.items.len());
125        let mut environment = None;
126
127        while let Some(item) = self.items.pop_front() {
128            let item_len = item.content.len();
129            if total_len + item_len > MAX_MERGE_CHARS && !contents.is_empty() {
130                self.items.push_front(item);
131                break;
132            }
133            ids.push(item.id);
134            total_len += item_len + 1;
135            environment = item.environment.or(environment);
136            contents.push(item.content);
137        }
138
139        let content = contents.join("\n");
140        Some(MergedInput {
141            ids,
142            content,
143            environment,
144        })
145    }
146
147    pub fn dequeue(&mut self) -> Option<QueuedInput> {
148        self.items.pop_front()
149    }
150}
151
152#[derive(Clone)]
153pub struct SharedInputQueue {
154    inner: Arc<RwLock<InputQueue>>,
155}
156
157impl SharedInputQueue {
158    pub fn new() -> Self {
159        Self {
160            inner: Arc::new(RwLock::new(InputQueue::new())),
161        }
162    }
163
164    pub async fn enqueue(&self, input: QueuedInput) -> Result<Uuid, QueueError> {
165        self.inner.write().await.enqueue(input)
166    }
167
168    pub async fn cancel(&self, id: Uuid) -> Option<QueuedInput> {
169        self.inner.write().await.cancel(id)
170    }
171
172    pub async fn cancel_all(&self) -> Vec<QueuedInput> {
173        self.inner.write().await.cancel_all()
174    }
175
176    pub async fn pending_count(&self) -> usize {
177        self.inner.read().await.pending_count()
178    }
179
180    pub async fn is_empty(&self) -> bool {
181        self.inner.read().await.is_empty()
182    }
183
184    pub async fn merge_all(&self) -> Option<MergedInput> {
185        self.inner.write().await.merge_all()
186    }
187
188    pub async fn dequeue(&self) -> Option<QueuedInput> {
189        self.inner.write().await.dequeue()
190    }
191
192    pub async fn pending_ids(&self) -> Vec<Uuid> {
193        self.inner.read().await.pending().map(|i| i.id).collect()
194    }
195}
196
197impl Default for SharedInputQueue {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203impl std::fmt::Debug for SharedInputQueue {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        f.debug_struct("SharedInputQueue").finish_non_exhaustive()
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_queue_enqueue_dequeue() {
215        let mut queue = InputQueue::new();
216        let session_id = SessionId::new();
217
218        let input = QueuedInput::new(session_id, "Hello");
219        let id = queue.enqueue(input).unwrap();
220
221        assert_eq!(queue.pending_count(), 1);
222
223        let dequeued = queue.dequeue().unwrap();
224        assert_eq!(dequeued.id, id);
225        assert_eq!(dequeued.content, "Hello");
226        assert!(queue.is_empty());
227    }
228
229    #[test]
230    fn test_queue_size_limit() {
231        let mut queue = InputQueue::new();
232        let session_id = SessionId::new();
233
234        for i in 0..MAX_QUEUE_SIZE {
235            let input = QueuedInput::new(session_id, format!("Message {}", i));
236            assert!(queue.enqueue(input).is_ok());
237        }
238
239        let input = QueuedInput::new(session_id, "Overflow");
240        assert_eq!(queue.enqueue(input), Err(QueueError::Full));
241    }
242
243    #[test]
244    fn test_queue_cancel() {
245        let mut queue = InputQueue::new();
246        let session_id = SessionId::new();
247
248        let input1 = QueuedInput::new(session_id, "First");
249        let id1 = queue.enqueue(input1).unwrap();
250
251        let input2 = QueuedInput::new(session_id, "Second");
252        let _id2 = queue.enqueue(input2).unwrap();
253
254        assert_eq!(queue.pending_count(), 2);
255
256        let cancelled = queue.cancel(id1);
257        assert!(cancelled.is_some());
258        assert_eq!(cancelled.unwrap().content, "First");
259        assert_eq!(queue.pending_count(), 1);
260    }
261
262    #[test]
263    fn test_queue_merge_single() {
264        let mut queue = InputQueue::new();
265        let session_id = SessionId::new();
266
267        let input = QueuedInput::new(session_id, "Only one");
268        queue.enqueue(input).unwrap();
269
270        let merged = queue.merge_all().unwrap();
271        assert_eq!(merged.ids.len(), 1);
272        assert_eq!(merged.content, "Only one");
273        assert!(queue.is_empty());
274    }
275
276    #[test]
277    fn test_queue_merge_multiple() {
278        let mut queue = InputQueue::new();
279        let session_id = SessionId::new();
280
281        queue
282            .enqueue(QueuedInput::new(session_id, "First"))
283            .unwrap();
284        queue
285            .enqueue(QueuedInput::new(session_id, "Second"))
286            .unwrap();
287        queue
288            .enqueue(QueuedInput::new(session_id, "Third"))
289            .unwrap();
290
291        let merged = queue.merge_all().unwrap();
292        assert_eq!(merged.ids.len(), 3);
293        assert_eq!(merged.content, "First\nSecond\nThird");
294        assert!(queue.is_empty());
295    }
296
297    #[test]
298    fn test_queue_merge_with_environment() {
299        let mut queue = InputQueue::new();
300        let session_id = SessionId::new();
301
302        let env1 = EnvironmentContext {
303            git_branch: Some("main".to_string()),
304            ..Default::default()
305        };
306        let env2 = EnvironmentContext {
307            git_branch: Some("feature".to_string()),
308            ..Default::default()
309        };
310
311        queue
312            .enqueue(QueuedInput::new(session_id, "First").with_environment(env1))
313            .unwrap();
314        queue
315            .enqueue(QueuedInput::new(session_id, "Second").with_environment(env2))
316            .unwrap();
317
318        let merged = queue.merge_all().unwrap();
319        assert_eq!(
320            merged.environment.unwrap().git_branch,
321            Some("feature".to_string())
322        );
323    }
324
325    #[test]
326    fn test_queue_merge_empty() {
327        let mut queue = InputQueue::new();
328        assert!(queue.merge_all().is_none());
329    }
330
331    #[test]
332    fn test_queue_cancel_all() {
333        let mut queue = InputQueue::new();
334        let session_id = SessionId::new();
335
336        queue
337            .enqueue(QueuedInput::new(session_id, "First"))
338            .unwrap();
339        queue
340            .enqueue(QueuedInput::new(session_id, "Second"))
341            .unwrap();
342
343        let cancelled = queue.cancel_all();
344        assert_eq!(cancelled.len(), 2);
345        assert!(queue.is_empty());
346    }
347
348    #[test]
349    fn test_queue_merge_size_limit() {
350        let mut queue = InputQueue::new();
351        let session_id = SessionId::new();
352
353        let large_content = "x".repeat(MAX_MERGE_CHARS / 2 + 1);
354        queue
355            .enqueue(QueuedInput::new(session_id, large_content.clone()))
356            .unwrap();
357        queue
358            .enqueue(QueuedInput::new(session_id, large_content.clone()))
359            .unwrap();
360        queue
361            .enqueue(QueuedInput::new(session_id, "Small"))
362            .unwrap();
363
364        let merged = queue.merge_all().unwrap();
365        assert_eq!(merged.ids.len(), 1);
366        assert!(!queue.is_empty());
367        assert_eq!(queue.pending_count(), 2);
368    }
369
370    #[tokio::test]
371    async fn test_shared_queue() {
372        let queue = SharedInputQueue::new();
373        let session_id = SessionId::new();
374
375        let id = queue
376            .enqueue(QueuedInput::new(session_id, "Test"))
377            .await
378            .unwrap();
379        assert_eq!(queue.pending_count().await, 1);
380
381        let cancelled = queue.cancel(id).await;
382        assert!(cancelled.is_some());
383        assert!(queue.is_empty().await);
384    }
385}