Skip to main content

claude_agent/session/
queue.rs

1//! Input queue for handling concurrent user inputs.
2
3use std::collections::VecDeque;
4use std::sync::Arc;
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use super::state::SessionId;
12use super::types::EnvironmentContext;
13
14const MAX_QUEUE_SIZE: usize = 100;
15const MAX_MERGE_CHARS: usize = 100_000;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum QueueError {
19    Full,
20}
21
22impl std::fmt::Display for QueueError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        match self {
25            Self::Full => write!(f, "queue is full"),
26        }
27    }
28}
29
30impl std::error::Error for QueueError {}
31
32#[derive(Clone, Debug, Serialize, Deserialize)]
33pub struct QueuedInput {
34    pub id: Uuid,
35    pub session_id: SessionId,
36    pub content: String,
37    pub environment: Option<EnvironmentContext>,
38    pub created_at: DateTime<Utc>,
39}
40
41impl QueuedInput {
42    pub fn new(session_id: SessionId, content: impl Into<String>) -> Self {
43        Self {
44            id: Uuid::new_v4(),
45            session_id,
46            content: content.into(),
47            environment: None,
48            created_at: Utc::now(),
49        }
50    }
51
52    pub fn environment(mut self, env: EnvironmentContext) -> Self {
53        self.environment = Some(env);
54        self
55    }
56}
57
58#[derive(Clone, Debug)]
59pub struct MergedInput {
60    pub ids: Vec<Uuid>,
61    pub content: String,
62    pub environment: Option<EnvironmentContext>,
63}
64
65#[derive(Debug)]
66pub struct InputQueue {
67    items: VecDeque<QueuedInput>,
68}
69
70impl Default for InputQueue {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl InputQueue {
77    pub fn new() -> Self {
78        Self {
79            items: VecDeque::with_capacity(16),
80        }
81    }
82
83    pub fn enqueue(&mut self, input: QueuedInput) -> Result<Uuid, QueueError> {
84        // A caller may read `pending_count()` (read lock) then call `enqueue()` (write lock).
85        // Between those calls another enqueue could succeed, so the queue could briefly
86        // hold MAX_QUEUE_SIZE + 1 items. This is bounded and harmless.
87        if self.items.len() >= MAX_QUEUE_SIZE {
88            return Err(QueueError::Full);
89        }
90        let id = input.id;
91        self.items.push_back(input);
92        Ok(id)
93    }
94
95    pub fn cancel(&mut self, id: Uuid) -> Option<QueuedInput> {
96        self.items
97            .iter()
98            .position(|i| i.id == id)
99            .and_then(|pos| self.items.remove(pos))
100    }
101
102    pub fn cancel_all(&mut self) -> Vec<QueuedInput> {
103        self.items.drain(..).collect()
104    }
105
106    pub fn pending(&self) -> impl Iterator<Item = &QueuedInput> {
107        self.items.iter()
108    }
109
110    pub fn pending_count(&self) -> usize {
111        self.items.len()
112    }
113
114    pub fn is_empty(&self) -> bool {
115        self.items.is_empty()
116    }
117
118    pub fn merge_all(&mut self) -> Option<MergedInput> {
119        if self.items.is_empty() {
120            return None;
121        }
122
123        let mut ids = Vec::with_capacity(self.items.len());
124        let mut total_len = 0;
125        let mut contents = Vec::with_capacity(self.items.len());
126        let mut environment = None;
127
128        while let Some(item) = self.items.pop_front() {
129            let item_len = item.content.len();
130            if total_len + item_len > MAX_MERGE_CHARS && !contents.is_empty() {
131                self.items.push_front(item);
132                break;
133            }
134            ids.push(item.id);
135            total_len += item_len + 1;
136            environment = item.environment.or(environment);
137            contents.push(item.content);
138        }
139
140        let content = contents.join("\n");
141        Some(MergedInput {
142            ids,
143            content,
144            environment,
145        })
146    }
147
148    pub fn dequeue(&mut self) -> Option<QueuedInput> {
149        self.items.pop_front()
150    }
151}
152
153#[derive(Clone)]
154pub struct SharedInputQueue {
155    inner: Arc<RwLock<InputQueue>>,
156}
157
158impl SharedInputQueue {
159    pub fn new() -> Self {
160        Self {
161            inner: Arc::new(RwLock::new(InputQueue::new())),
162        }
163    }
164
165    pub async fn enqueue(&self, input: QueuedInput) -> Result<Uuid, QueueError> {
166        self.inner.write().await.enqueue(input)
167    }
168
169    pub async fn cancel(&self, id: Uuid) -> Option<QueuedInput> {
170        self.inner.write().await.cancel(id)
171    }
172
173    pub async fn cancel_all(&self) -> Vec<QueuedInput> {
174        self.inner.write().await.cancel_all()
175    }
176
177    pub async fn pending_count(&self) -> usize {
178        self.inner.read().await.pending_count()
179    }
180
181    pub async fn is_empty(&self) -> bool {
182        self.inner.read().await.is_empty()
183    }
184
185    pub async fn merge_all(&self) -> Option<MergedInput> {
186        self.inner.write().await.merge_all()
187    }
188
189    pub async fn dequeue(&self) -> Option<QueuedInput> {
190        self.inner.write().await.dequeue()
191    }
192
193    pub async fn pending_ids(&self) -> Vec<Uuid> {
194        self.inner.read().await.pending().map(|i| i.id).collect()
195    }
196}
197
198impl Default for SharedInputQueue {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204impl std::fmt::Debug for SharedInputQueue {
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        f.debug_struct("SharedInputQueue").finish_non_exhaustive()
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn test_queue_enqueue_dequeue() {
216        let mut queue = InputQueue::new();
217        let session_id = SessionId::new();
218
219        let input = QueuedInput::new(session_id, "Hello");
220        let id = queue.enqueue(input).unwrap();
221
222        assert_eq!(queue.pending_count(), 1);
223
224        let dequeued = queue.dequeue().unwrap();
225        assert_eq!(dequeued.id, id);
226        assert_eq!(dequeued.content, "Hello");
227        assert!(queue.is_empty());
228    }
229
230    #[test]
231    fn test_queue_size_limit() {
232        let mut queue = InputQueue::new();
233        let session_id = SessionId::new();
234
235        for i in 0..MAX_QUEUE_SIZE {
236            let input = QueuedInput::new(session_id, format!("Message {}", i));
237            assert!(queue.enqueue(input).is_ok());
238        }
239
240        let input = QueuedInput::new(session_id, "Overflow");
241        assert_eq!(queue.enqueue(input), Err(QueueError::Full));
242    }
243
244    #[test]
245    fn test_queue_cancel() {
246        let mut queue = InputQueue::new();
247        let session_id = SessionId::new();
248
249        let input1 = QueuedInput::new(session_id, "First");
250        let id1 = queue.enqueue(input1).unwrap();
251
252        let input2 = QueuedInput::new(session_id, "Second");
253        let _id2 = queue.enqueue(input2).unwrap();
254
255        assert_eq!(queue.pending_count(), 2);
256
257        let cancelled = queue.cancel(id1);
258        assert!(cancelled.is_some());
259        assert_eq!(cancelled.unwrap().content, "First");
260        assert_eq!(queue.pending_count(), 1);
261    }
262
263    #[test]
264    fn test_queue_merge_single() {
265        let mut queue = InputQueue::new();
266        let session_id = SessionId::new();
267
268        let input = QueuedInput::new(session_id, "Only one");
269        queue.enqueue(input).unwrap();
270
271        let merged = queue.merge_all().unwrap();
272        assert_eq!(merged.ids.len(), 1);
273        assert_eq!(merged.content, "Only one");
274        assert!(queue.is_empty());
275    }
276
277    #[test]
278    fn test_queue_merge_multiple() {
279        let mut queue = InputQueue::new();
280        let session_id = SessionId::new();
281
282        queue
283            .enqueue(QueuedInput::new(session_id, "First"))
284            .unwrap();
285        queue
286            .enqueue(QueuedInput::new(session_id, "Second"))
287            .unwrap();
288        queue
289            .enqueue(QueuedInput::new(session_id, "Third"))
290            .unwrap();
291
292        let merged = queue.merge_all().unwrap();
293        assert_eq!(merged.ids.len(), 3);
294        assert_eq!(merged.content, "First\nSecond\nThird");
295        assert!(queue.is_empty());
296    }
297
298    #[test]
299    fn test_queue_merge_with_environment() {
300        let mut queue = InputQueue::new();
301        let session_id = SessionId::new();
302
303        let env1 = EnvironmentContext {
304            git_branch: Some("main".to_string()),
305            ..Default::default()
306        };
307        let env2 = EnvironmentContext {
308            git_branch: Some("feature".to_string()),
309            ..Default::default()
310        };
311
312        queue
313            .enqueue(QueuedInput::new(session_id, "First").environment(env1))
314            .unwrap();
315        queue
316            .enqueue(QueuedInput::new(session_id, "Second").environment(env2))
317            .unwrap();
318
319        let merged = queue.merge_all().unwrap();
320        assert_eq!(
321            merged.environment.unwrap().git_branch,
322            Some("feature".to_string())
323        );
324    }
325
326    #[test]
327    fn test_queue_merge_empty() {
328        let mut queue = InputQueue::new();
329        assert!(queue.merge_all().is_none());
330    }
331
332    #[test]
333    fn test_queue_cancel_all() {
334        let mut queue = InputQueue::new();
335        let session_id = SessionId::new();
336
337        queue
338            .enqueue(QueuedInput::new(session_id, "First"))
339            .unwrap();
340        queue
341            .enqueue(QueuedInput::new(session_id, "Second"))
342            .unwrap();
343
344        let cancelled = queue.cancel_all();
345        assert_eq!(cancelled.len(), 2);
346        assert!(queue.is_empty());
347    }
348
349    #[test]
350    fn test_queue_merge_size_limit() {
351        let mut queue = InputQueue::new();
352        let session_id = SessionId::new();
353
354        let large_content = "x".repeat(MAX_MERGE_CHARS / 2 + 1);
355        queue
356            .enqueue(QueuedInput::new(session_id, large_content.clone()))
357            .unwrap();
358        queue
359            .enqueue(QueuedInput::new(session_id, large_content.clone()))
360            .unwrap();
361        queue
362            .enqueue(QueuedInput::new(session_id, "Small"))
363            .unwrap();
364
365        let merged = queue.merge_all().unwrap();
366        assert_eq!(merged.ids.len(), 1);
367        assert!(!queue.is_empty());
368        assert_eq!(queue.pending_count(), 2);
369    }
370
371    #[tokio::test]
372    async fn test_shared_queue() {
373        let queue = SharedInputQueue::new();
374        let session_id = SessionId::new();
375
376        let id = queue
377            .enqueue(QueuedInput::new(session_id, "Test"))
378            .await
379            .unwrap();
380        assert_eq!(queue.pending_count().await, 1);
381
382        let cancelled = queue.cancel(id).await;
383        assert!(cancelled.is_some());
384        assert!(queue.is_empty().await);
385    }
386}