claude_agent/session/
queue.rs1use std::collections::VecDeque;
4use std::sync::Arc;
5
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use super::state::SessionId;
12use super::types::EnvironmentContext;
13
14const MAX_QUEUE_SIZE: usize = 100;
15const MAX_MERGE_CHARS: usize = 100_000;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum QueueError {
19 Full,
20 MergeLimitExceeded,
21}
22
23impl std::fmt::Display for QueueError {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Self::Full => write!(f, "queue is full"),
27 Self::MergeLimitExceeded => write!(f, "merge size limit exceeded"),
28 }
29 }
30}
31
32impl std::error::Error for QueueError {}
33
34#[derive(Clone, Debug, Serialize, Deserialize)]
35pub struct QueuedInput {
36 pub id: Uuid,
37 pub session_id: SessionId,
38 pub content: String,
39 pub environment: Option<EnvironmentContext>,
40 pub created_at: DateTime<Utc>,
41}
42
43impl QueuedInput {
44 pub fn new(session_id: SessionId, content: impl Into<String>) -> Self {
45 Self {
46 id: Uuid::new_v4(),
47 session_id,
48 content: content.into(),
49 environment: None,
50 created_at: Utc::now(),
51 }
52 }
53
54 pub fn with_environment(mut self, env: EnvironmentContext) -> Self {
55 self.environment = Some(env);
56 self
57 }
58}
59
60#[derive(Clone, Debug)]
61pub struct MergedInput {
62 pub ids: Vec<Uuid>,
63 pub content: String,
64 pub environment: Option<EnvironmentContext>,
65}
66
67#[derive(Debug)]
68pub struct InputQueue {
69 items: VecDeque<QueuedInput>,
70}
71
72impl Default for InputQueue {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78impl InputQueue {
79 pub fn new() -> Self {
80 Self {
81 items: VecDeque::with_capacity(16),
82 }
83 }
84
85 pub fn enqueue(&mut self, input: QueuedInput) -> Result<Uuid, QueueError> {
86 if self.items.len() >= MAX_QUEUE_SIZE {
87 return Err(QueueError::Full);
88 }
89 let id = input.id;
90 self.items.push_back(input);
91 Ok(id)
92 }
93
94 pub fn cancel(&mut self, id: Uuid) -> Option<QueuedInput> {
95 self.items
96 .iter()
97 .position(|i| i.id == id)
98 .and_then(|pos| self.items.remove(pos))
99 }
100
101 pub fn cancel_all(&mut self) -> Vec<QueuedInput> {
102 self.items.drain(..).collect()
103 }
104
105 pub fn pending(&self) -> impl Iterator<Item = &QueuedInput> {
106 self.items.iter()
107 }
108
109 pub fn pending_count(&self) -> usize {
110 self.items.len()
111 }
112
113 pub fn is_empty(&self) -> bool {
114 self.items.is_empty()
115 }
116
117 pub fn merge_all(&mut self) -> Option<MergedInput> {
118 if self.items.is_empty() {
119 return None;
120 }
121
122 let mut ids = Vec::with_capacity(self.items.len());
123 let mut total_len = 0;
124 let mut contents = Vec::with_capacity(self.items.len());
125 let mut environment = None;
126
127 while let Some(item) = self.items.pop_front() {
128 let item_len = item.content.len();
129 if total_len + item_len > MAX_MERGE_CHARS && !contents.is_empty() {
130 self.items.push_front(item);
131 break;
132 }
133 ids.push(item.id);
134 total_len += item_len + 1;
135 environment = item.environment.or(environment);
136 contents.push(item.content);
137 }
138
139 let content = contents.join("\n");
140 Some(MergedInput {
141 ids,
142 content,
143 environment,
144 })
145 }
146
147 pub fn dequeue(&mut self) -> Option<QueuedInput> {
148 self.items.pop_front()
149 }
150}
151
152#[derive(Clone)]
153pub struct SharedInputQueue {
154 inner: Arc<RwLock<InputQueue>>,
155}
156
157impl SharedInputQueue {
158 pub fn new() -> Self {
159 Self {
160 inner: Arc::new(RwLock::new(InputQueue::new())),
161 }
162 }
163
164 pub async fn enqueue(&self, input: QueuedInput) -> Result<Uuid, QueueError> {
165 self.inner.write().await.enqueue(input)
166 }
167
168 pub async fn cancel(&self, id: Uuid) -> Option<QueuedInput> {
169 self.inner.write().await.cancel(id)
170 }
171
172 pub async fn cancel_all(&self) -> Vec<QueuedInput> {
173 self.inner.write().await.cancel_all()
174 }
175
176 pub async fn pending_count(&self) -> usize {
177 self.inner.read().await.pending_count()
178 }
179
180 pub async fn is_empty(&self) -> bool {
181 self.inner.read().await.is_empty()
182 }
183
184 pub async fn merge_all(&self) -> Option<MergedInput> {
185 self.inner.write().await.merge_all()
186 }
187
188 pub async fn dequeue(&self) -> Option<QueuedInput> {
189 self.inner.write().await.dequeue()
190 }
191
192 pub async fn pending_ids(&self) -> Vec<Uuid> {
193 self.inner.read().await.pending().map(|i| i.id).collect()
194 }
195}
196
197impl Default for SharedInputQueue {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203impl std::fmt::Debug for SharedInputQueue {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 f.debug_struct("SharedInputQueue").finish_non_exhaustive()
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_queue_enqueue_dequeue() {
215 let mut queue = InputQueue::new();
216 let session_id = SessionId::new();
217
218 let input = QueuedInput::new(session_id, "Hello");
219 let id = queue.enqueue(input).unwrap();
220
221 assert_eq!(queue.pending_count(), 1);
222
223 let dequeued = queue.dequeue().unwrap();
224 assert_eq!(dequeued.id, id);
225 assert_eq!(dequeued.content, "Hello");
226 assert!(queue.is_empty());
227 }
228
229 #[test]
230 fn test_queue_size_limit() {
231 let mut queue = InputQueue::new();
232 let session_id = SessionId::new();
233
234 for i in 0..MAX_QUEUE_SIZE {
235 let input = QueuedInput::new(session_id, format!("Message {}", i));
236 assert!(queue.enqueue(input).is_ok());
237 }
238
239 let input = QueuedInput::new(session_id, "Overflow");
240 assert_eq!(queue.enqueue(input), Err(QueueError::Full));
241 }
242
243 #[test]
244 fn test_queue_cancel() {
245 let mut queue = InputQueue::new();
246 let session_id = SessionId::new();
247
248 let input1 = QueuedInput::new(session_id, "First");
249 let id1 = queue.enqueue(input1).unwrap();
250
251 let input2 = QueuedInput::new(session_id, "Second");
252 let _id2 = queue.enqueue(input2).unwrap();
253
254 assert_eq!(queue.pending_count(), 2);
255
256 let cancelled = queue.cancel(id1);
257 assert!(cancelled.is_some());
258 assert_eq!(cancelled.unwrap().content, "First");
259 assert_eq!(queue.pending_count(), 1);
260 }
261
262 #[test]
263 fn test_queue_merge_single() {
264 let mut queue = InputQueue::new();
265 let session_id = SessionId::new();
266
267 let input = QueuedInput::new(session_id, "Only one");
268 queue.enqueue(input).unwrap();
269
270 let merged = queue.merge_all().unwrap();
271 assert_eq!(merged.ids.len(), 1);
272 assert_eq!(merged.content, "Only one");
273 assert!(queue.is_empty());
274 }
275
276 #[test]
277 fn test_queue_merge_multiple() {
278 let mut queue = InputQueue::new();
279 let session_id = SessionId::new();
280
281 queue
282 .enqueue(QueuedInput::new(session_id, "First"))
283 .unwrap();
284 queue
285 .enqueue(QueuedInput::new(session_id, "Second"))
286 .unwrap();
287 queue
288 .enqueue(QueuedInput::new(session_id, "Third"))
289 .unwrap();
290
291 let merged = queue.merge_all().unwrap();
292 assert_eq!(merged.ids.len(), 3);
293 assert_eq!(merged.content, "First\nSecond\nThird");
294 assert!(queue.is_empty());
295 }
296
297 #[test]
298 fn test_queue_merge_with_environment() {
299 let mut queue = InputQueue::new();
300 let session_id = SessionId::new();
301
302 let env1 = EnvironmentContext {
303 git_branch: Some("main".to_string()),
304 ..Default::default()
305 };
306 let env2 = EnvironmentContext {
307 git_branch: Some("feature".to_string()),
308 ..Default::default()
309 };
310
311 queue
312 .enqueue(QueuedInput::new(session_id, "First").with_environment(env1))
313 .unwrap();
314 queue
315 .enqueue(QueuedInput::new(session_id, "Second").with_environment(env2))
316 .unwrap();
317
318 let merged = queue.merge_all().unwrap();
319 assert_eq!(
320 merged.environment.unwrap().git_branch,
321 Some("feature".to_string())
322 );
323 }
324
325 #[test]
326 fn test_queue_merge_empty() {
327 let mut queue = InputQueue::new();
328 assert!(queue.merge_all().is_none());
329 }
330
331 #[test]
332 fn test_queue_cancel_all() {
333 let mut queue = InputQueue::new();
334 let session_id = SessionId::new();
335
336 queue
337 .enqueue(QueuedInput::new(session_id, "First"))
338 .unwrap();
339 queue
340 .enqueue(QueuedInput::new(session_id, "Second"))
341 .unwrap();
342
343 let cancelled = queue.cancel_all();
344 assert_eq!(cancelled.len(), 2);
345 assert!(queue.is_empty());
346 }
347
348 #[test]
349 fn test_queue_merge_size_limit() {
350 let mut queue = InputQueue::new();
351 let session_id = SessionId::new();
352
353 let large_content = "x".repeat(MAX_MERGE_CHARS / 2 + 1);
354 queue
355 .enqueue(QueuedInput::new(session_id, large_content.clone()))
356 .unwrap();
357 queue
358 .enqueue(QueuedInput::new(session_id, large_content.clone()))
359 .unwrap();
360 queue
361 .enqueue(QueuedInput::new(session_id, "Small"))
362 .unwrap();
363
364 let merged = queue.merge_all().unwrap();
365 assert_eq!(merged.ids.len(), 1);
366 assert!(!queue.is_empty());
367 assert_eq!(queue.pending_count(), 2);
368 }
369
370 #[tokio::test]
371 async fn test_shared_queue() {
372 let queue = SharedInputQueue::new();
373 let session_id = SessionId::new();
374
375 let id = queue
376 .enqueue(QueuedInput::new(session_id, "Test"))
377 .await
378 .unwrap();
379 assert_eq!(queue.pending_count().await, 1);
380
381 let cancelled = queue.cancel(id).await;
382 assert!(cancelled.is_some());
383 assert!(queue.is_empty().await);
384 }
385}