claude_agent/session/
queue.rs1use std::collections::VecDeque;
4use std::sync::Arc;
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use super::state::SessionId;
12use super::types::EnvironmentContext;
13
14const MAX_QUEUE_SIZE: usize = 100;
15const MAX_MERGE_CHARS: usize = 100_000;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum QueueError {
19 Full,
20}
21
22impl std::fmt::Display for QueueError {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 match self {
25 Self::Full => write!(f, "queue is full"),
26 }
27 }
28}
29
30impl std::error::Error for QueueError {}
31
32#[derive(Clone, Debug, Serialize, Deserialize)]
33pub struct QueuedInput {
34 pub id: Uuid,
35 pub session_id: SessionId,
36 pub content: String,
37 pub environment: Option<EnvironmentContext>,
38 pub created_at: DateTime<Utc>,
39}
40
41impl QueuedInput {
42 pub fn new(session_id: SessionId, content: impl Into<String>) -> Self {
43 Self {
44 id: Uuid::new_v4(),
45 session_id,
46 content: content.into(),
47 environment: None,
48 created_at: Utc::now(),
49 }
50 }
51
52 pub fn environment(mut self, env: EnvironmentContext) -> Self {
53 self.environment = Some(env);
54 self
55 }
56}
57
58#[derive(Clone, Debug)]
59pub struct MergedInput {
60 pub ids: Vec<Uuid>,
61 pub content: String,
62 pub environment: Option<EnvironmentContext>,
63}
64
65#[derive(Debug)]
66pub struct InputQueue {
67 items: VecDeque<QueuedInput>,
68}
69
70impl Default for InputQueue {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl InputQueue {
77 pub fn new() -> Self {
78 Self {
79 items: VecDeque::with_capacity(16),
80 }
81 }
82
83 pub fn enqueue(&mut self, input: QueuedInput) -> Result<Uuid, QueueError> {
84 if self.items.len() >= MAX_QUEUE_SIZE {
88 return Err(QueueError::Full);
89 }
90 let id = input.id;
91 self.items.push_back(input);
92 Ok(id)
93 }
94
95 pub fn cancel(&mut self, id: Uuid) -> Option<QueuedInput> {
96 self.items
97 .iter()
98 .position(|i| i.id == id)
99 .and_then(|pos| self.items.remove(pos))
100 }
101
102 pub fn cancel_all(&mut self) -> Vec<QueuedInput> {
103 self.items.drain(..).collect()
104 }
105
106 pub fn pending(&self) -> impl Iterator<Item = &QueuedInput> {
107 self.items.iter()
108 }
109
110 pub fn pending_count(&self) -> usize {
111 self.items.len()
112 }
113
114 pub fn is_empty(&self) -> bool {
115 self.items.is_empty()
116 }
117
118 pub fn merge_all(&mut self) -> Option<MergedInput> {
119 if self.items.is_empty() {
120 return None;
121 }
122
123 let mut ids = Vec::with_capacity(self.items.len());
124 let mut total_len = 0;
125 let mut contents = Vec::with_capacity(self.items.len());
126 let mut environment = None;
127
128 while let Some(item) = self.items.pop_front() {
129 let item_len = item.content.len();
130 if total_len + item_len > MAX_MERGE_CHARS && !contents.is_empty() {
131 self.items.push_front(item);
132 break;
133 }
134 ids.push(item.id);
135 total_len += item_len + 1;
136 environment = item.environment.or(environment);
137 contents.push(item.content);
138 }
139
140 let content = contents.join("\n");
141 Some(MergedInput {
142 ids,
143 content,
144 environment,
145 })
146 }
147
148 pub fn dequeue(&mut self) -> Option<QueuedInput> {
149 self.items.pop_front()
150 }
151}
152
153#[derive(Clone)]
154pub struct SharedInputQueue {
155 inner: Arc<RwLock<InputQueue>>,
156}
157
158impl SharedInputQueue {
159 pub fn new() -> Self {
160 Self {
161 inner: Arc::new(RwLock::new(InputQueue::new())),
162 }
163 }
164
165 pub async fn enqueue(&self, input: QueuedInput) -> Result<Uuid, QueueError> {
166 self.inner.write().await.enqueue(input)
167 }
168
169 pub async fn cancel(&self, id: Uuid) -> Option<QueuedInput> {
170 self.inner.write().await.cancel(id)
171 }
172
173 pub async fn cancel_all(&self) -> Vec<QueuedInput> {
174 self.inner.write().await.cancel_all()
175 }
176
177 pub async fn pending_count(&self) -> usize {
178 self.inner.read().await.pending_count()
179 }
180
181 pub async fn is_empty(&self) -> bool {
182 self.inner.read().await.is_empty()
183 }
184
185 pub async fn merge_all(&self) -> Option<MergedInput> {
186 self.inner.write().await.merge_all()
187 }
188
189 pub async fn dequeue(&self) -> Option<QueuedInput> {
190 self.inner.write().await.dequeue()
191 }
192
193 pub async fn pending_ids(&self) -> Vec<Uuid> {
194 self.inner.read().await.pending().map(|i| i.id).collect()
195 }
196}
197
198impl Default for SharedInputQueue {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl std::fmt::Debug for SharedInputQueue {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 f.debug_struct("SharedInputQueue").finish_non_exhaustive()
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 #[test]
215 fn test_queue_enqueue_dequeue() {
216 let mut queue = InputQueue::new();
217 let session_id = SessionId::new();
218
219 let input = QueuedInput::new(session_id, "Hello");
220 let id = queue.enqueue(input).unwrap();
221
222 assert_eq!(queue.pending_count(), 1);
223
224 let dequeued = queue.dequeue().unwrap();
225 assert_eq!(dequeued.id, id);
226 assert_eq!(dequeued.content, "Hello");
227 assert!(queue.is_empty());
228 }
229
230 #[test]
231 fn test_queue_size_limit() {
232 let mut queue = InputQueue::new();
233 let session_id = SessionId::new();
234
235 for i in 0..MAX_QUEUE_SIZE {
236 let input = QueuedInput::new(session_id, format!("Message {}", i));
237 assert!(queue.enqueue(input).is_ok());
238 }
239
240 let input = QueuedInput::new(session_id, "Overflow");
241 assert_eq!(queue.enqueue(input), Err(QueueError::Full));
242 }
243
244 #[test]
245 fn test_queue_cancel() {
246 let mut queue = InputQueue::new();
247 let session_id = SessionId::new();
248
249 let input1 = QueuedInput::new(session_id, "First");
250 let id1 = queue.enqueue(input1).unwrap();
251
252 let input2 = QueuedInput::new(session_id, "Second");
253 let _id2 = queue.enqueue(input2).unwrap();
254
255 assert_eq!(queue.pending_count(), 2);
256
257 let cancelled = queue.cancel(id1);
258 assert!(cancelled.is_some());
259 assert_eq!(cancelled.unwrap().content, "First");
260 assert_eq!(queue.pending_count(), 1);
261 }
262
263 #[test]
264 fn test_queue_merge_single() {
265 let mut queue = InputQueue::new();
266 let session_id = SessionId::new();
267
268 let input = QueuedInput::new(session_id, "Only one");
269 queue.enqueue(input).unwrap();
270
271 let merged = queue.merge_all().unwrap();
272 assert_eq!(merged.ids.len(), 1);
273 assert_eq!(merged.content, "Only one");
274 assert!(queue.is_empty());
275 }
276
277 #[test]
278 fn test_queue_merge_multiple() {
279 let mut queue = InputQueue::new();
280 let session_id = SessionId::new();
281
282 queue
283 .enqueue(QueuedInput::new(session_id, "First"))
284 .unwrap();
285 queue
286 .enqueue(QueuedInput::new(session_id, "Second"))
287 .unwrap();
288 queue
289 .enqueue(QueuedInput::new(session_id, "Third"))
290 .unwrap();
291
292 let merged = queue.merge_all().unwrap();
293 assert_eq!(merged.ids.len(), 3);
294 assert_eq!(merged.content, "First\nSecond\nThird");
295 assert!(queue.is_empty());
296 }
297
298 #[test]
299 fn test_queue_merge_with_environment() {
300 let mut queue = InputQueue::new();
301 let session_id = SessionId::new();
302
303 let env1 = EnvironmentContext {
304 git_branch: Some("main".to_string()),
305 ..Default::default()
306 };
307 let env2 = EnvironmentContext {
308 git_branch: Some("feature".to_string()),
309 ..Default::default()
310 };
311
312 queue
313 .enqueue(QueuedInput::new(session_id, "First").environment(env1))
314 .unwrap();
315 queue
316 .enqueue(QueuedInput::new(session_id, "Second").environment(env2))
317 .unwrap();
318
319 let merged = queue.merge_all().unwrap();
320 assert_eq!(
321 merged.environment.unwrap().git_branch,
322 Some("feature".to_string())
323 );
324 }
325
326 #[test]
327 fn test_queue_merge_empty() {
328 let mut queue = InputQueue::new();
329 assert!(queue.merge_all().is_none());
330 }
331
332 #[test]
333 fn test_queue_cancel_all() {
334 let mut queue = InputQueue::new();
335 let session_id = SessionId::new();
336
337 queue
338 .enqueue(QueuedInput::new(session_id, "First"))
339 .unwrap();
340 queue
341 .enqueue(QueuedInput::new(session_id, "Second"))
342 .unwrap();
343
344 let cancelled = queue.cancel_all();
345 assert_eq!(cancelled.len(), 2);
346 assert!(queue.is_empty());
347 }
348
349 #[test]
350 fn test_queue_merge_size_limit() {
351 let mut queue = InputQueue::new();
352 let session_id = SessionId::new();
353
354 let large_content = "x".repeat(MAX_MERGE_CHARS / 2 + 1);
355 queue
356 .enqueue(QueuedInput::new(session_id, large_content.clone()))
357 .unwrap();
358 queue
359 .enqueue(QueuedInput::new(session_id, large_content.clone()))
360 .unwrap();
361 queue
362 .enqueue(QueuedInput::new(session_id, "Small"))
363 .unwrap();
364
365 let merged = queue.merge_all().unwrap();
366 assert_eq!(merged.ids.len(), 1);
367 assert!(!queue.is_empty());
368 assert_eq!(queue.pending_count(), 2);
369 }
370
371 #[tokio::test]
372 async fn test_shared_queue() {
373 let queue = SharedInputQueue::new();
374 let session_id = SessionId::new();
375
376 let id = queue
377 .enqueue(QueuedInput::new(session_id, "Test"))
378 .await
379 .unwrap();
380 assert_eq!(queue.pending_count().await, 1);
381
382 let cancelled = queue.cancel(id).await;
383 assert!(cancelled.is_some());
384 assert!(queue.is_empty().await);
385 }
386}