1use anyhow::Result;
21use async_trait::async_trait;
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::time::{Duration, Instant};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
32pub enum SessionLane {
33 Control,
35 Query,
37 Execute,
39 Generate,
41}
42
43impl SessionLane {
44 pub fn priority(&self) -> u8 {
46 match self {
47 SessionLane::Control => 0,
48 SessionLane::Query => 1,
49 SessionLane::Execute => 2,
50 SessionLane::Generate => 3,
51 }
52 }
53
54 pub fn from_tool_name(tool_name: &str) -> Self {
56 match tool_name {
57 "read" | "glob" | "ls" | "grep" | "list_files" | "search" => SessionLane::Query,
58 "bash" | "write" | "edit" | "delete" | "move" | "copy" | "execute" => {
59 SessionLane::Execute
60 }
61 _ => SessionLane::Execute,
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
72pub enum TaskHandlerMode {
73 #[default]
75 Internal,
76 External,
78 Hybrid,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct LaneHandlerConfig {
85 pub mode: TaskHandlerMode,
87 pub timeout_ms: u64,
89}
90
91impl Default for LaneHandlerConfig {
92 fn default() -> Self {
93 Self {
94 mode: TaskHandlerMode::Internal,
95 timeout_ms: 60_000,
96 }
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ExternalTask {
107 pub task_id: String,
109 pub session_id: String,
111 pub lane: SessionLane,
113 pub command_type: String,
115 pub payload: serde_json::Value,
117 pub timeout_ms: u64,
119 #[serde(skip)]
121 pub created_at: Option<Instant>,
122}
123
124impl ExternalTask {
125 pub fn is_timed_out(&self) -> bool {
127 self.created_at
128 .map(|t| t.elapsed() > Duration::from_millis(self.timeout_ms))
129 .unwrap_or(false)
130 }
131
132 pub fn remaining_ms(&self) -> u64 {
134 self.created_at
135 .map(|t| {
136 let elapsed = t.elapsed().as_millis() as u64;
137 self.timeout_ms.saturating_sub(elapsed)
138 })
139 .unwrap_or(self.timeout_ms)
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct ExternalTaskResult {
146 pub success: bool,
148 pub result: serde_json::Value,
150 pub error: Option<String>,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct SessionQueueConfig {
161 pub control_max_concurrency: usize,
163 pub query_max_concurrency: usize,
165 pub execute_max_concurrency: usize,
167 pub generate_max_concurrency: usize,
169 #[serde(default)]
171 pub lane_handlers: HashMap<SessionLane, LaneHandlerConfig>,
172
173 #[serde(default)]
178 pub enable_dlq: bool,
179 #[serde(default)]
181 pub dlq_max_size: Option<usize>,
182 #[serde(default)]
184 pub enable_metrics: bool,
185 #[serde(default)]
187 pub enable_alerts: bool,
188 #[serde(default)]
190 pub default_timeout_ms: Option<u64>,
191 #[serde(default)]
193 pub storage_path: Option<std::path::PathBuf>,
194}
195
196impl Default for SessionQueueConfig {
197 fn default() -> Self {
198 Self {
199 control_max_concurrency: 2,
200 query_max_concurrency: 4,
201 execute_max_concurrency: 2,
202 generate_max_concurrency: 1,
203 lane_handlers: HashMap::new(),
204 enable_dlq: false,
205 dlq_max_size: None,
206 enable_metrics: false,
207 enable_alerts: false,
208 default_timeout_ms: None,
209 storage_path: None,
210 }
211 }
212}
213
214impl SessionQueueConfig {
215 pub fn max_concurrency(&self, lane: SessionLane) -> usize {
217 match lane {
218 SessionLane::Control => self.control_max_concurrency,
219 SessionLane::Query => self.query_max_concurrency,
220 SessionLane::Execute => self.execute_max_concurrency,
221 SessionLane::Generate => self.generate_max_concurrency,
222 }
223 }
224
225 pub fn handler_config(&self, lane: SessionLane) -> LaneHandlerConfig {
227 self.lane_handlers.get(&lane).cloned().unwrap_or_default()
228 }
229
230 pub fn with_dlq(mut self, max_size: Option<usize>) -> Self {
232 self.enable_dlq = true;
233 self.dlq_max_size = max_size;
234 self
235 }
236
237 pub fn with_metrics(mut self) -> Self {
239 self.enable_metrics = true;
240 self
241 }
242
243 pub fn with_alerts(mut self) -> Self {
245 self.enable_alerts = true;
246 self
247 }
248
249 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
251 self.default_timeout_ms = Some(timeout_ms);
252 self
253 }
254
255 pub fn with_storage(mut self, path: impl Into<std::path::PathBuf>) -> Self {
257 self.storage_path = Some(path.into());
258 self
259 }
260
261 pub fn with_lane_features(mut self) -> Self {
263 self.enable_dlq = true;
264 self.dlq_max_size = Some(1000);
265 self.enable_metrics = true;
266 self.enable_alerts = true;
267 self.default_timeout_ms = Some(60_000);
268 self
269 }
270}
271
272#[async_trait]
278pub trait SessionCommand: Send + Sync {
279 async fn execute(&self) -> Result<serde_json::Value>;
281
282 fn command_type(&self) -> &str;
284
285 fn payload(&self) -> serde_json::Value {
287 serde_json::json!({})
288 }
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct LaneStatus {
298 pub lane: SessionLane,
299 pub pending: usize,
300 pub active: usize,
301 pub max_concurrency: usize,
302 pub handler_mode: TaskHandlerMode,
303}
304
305#[derive(Debug, Clone, Default, Serialize, Deserialize)]
307pub struct SessionQueueStats {
308 pub total_pending: usize,
309 pub total_active: usize,
310 pub external_pending: usize,
311 pub lanes: HashMap<String, LaneStatus>,
312}
313
314#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_task_handler_mode_default() {
324 let mode = TaskHandlerMode::default();
325 assert_eq!(mode, TaskHandlerMode::Internal);
326 }
327
328 #[test]
329 fn test_lane_handler_config_default() {
330 let config = LaneHandlerConfig::default();
331 assert_eq!(config.mode, TaskHandlerMode::Internal);
332 assert_eq!(config.timeout_ms, 60_000);
333 }
334
335 #[test]
336 fn test_external_task_timeout() {
337 let task = ExternalTask {
338 task_id: "test".to_string(),
339 session_id: "session".to_string(),
340 lane: SessionLane::Query,
341 command_type: "read".to_string(),
342 payload: serde_json::json!({}),
343 timeout_ms: 100,
344 created_at: Some(Instant::now()),
345 };
346
347 assert!(!task.is_timed_out());
348 assert!(task.remaining_ms() <= 100);
349 }
350
351 #[test]
352 fn test_session_queue_config_default() {
353 let config = SessionQueueConfig::default();
354 assert_eq!(config.control_max_concurrency, 2);
355 assert_eq!(config.query_max_concurrency, 4);
356 assert_eq!(config.execute_max_concurrency, 2);
357 assert_eq!(config.generate_max_concurrency, 1);
358 assert!(!config.enable_dlq);
359 assert!(!config.enable_metrics);
360 assert!(!config.enable_alerts);
361 }
362
363 #[test]
364 fn test_session_queue_config_max_concurrency() {
365 let config = SessionQueueConfig::default();
366 assert_eq!(config.max_concurrency(SessionLane::Control), 2);
367 assert_eq!(config.max_concurrency(SessionLane::Query), 4);
368 assert_eq!(config.max_concurrency(SessionLane::Execute), 2);
369 assert_eq!(config.max_concurrency(SessionLane::Generate), 1);
370 }
371
372 #[test]
373 fn test_session_queue_config_handler_config() {
374 let config = SessionQueueConfig::default();
375 let handler = config.handler_config(SessionLane::Execute);
376 assert_eq!(handler.mode, TaskHandlerMode::Internal);
377 assert_eq!(handler.timeout_ms, 60_000);
378 }
379
380 #[test]
381 fn test_session_queue_config_builders() {
382 let config = SessionQueueConfig::default()
383 .with_dlq(Some(500))
384 .with_metrics()
385 .with_alerts()
386 .with_timeout(30_000);
387
388 assert!(config.enable_dlq);
389 assert_eq!(config.dlq_max_size, Some(500));
390 assert!(config.enable_metrics);
391 assert!(config.enable_alerts);
392 assert_eq!(config.default_timeout_ms, Some(30_000));
393 }
394
395 #[test]
396 fn test_session_queue_config_with_lane_features() {
397 let config = SessionQueueConfig::default().with_lane_features();
398
399 assert!(config.enable_dlq);
400 assert_eq!(config.dlq_max_size, Some(1000));
401 assert!(config.enable_metrics);
402 assert!(config.enable_alerts);
403 assert_eq!(config.default_timeout_ms, Some(60_000));
404 }
405
406 #[test]
407 fn test_external_task_result() {
408 let result = ExternalTaskResult {
409 success: true,
410 result: serde_json::json!({"output": "hello"}),
411 error: None,
412 };
413 assert!(result.success);
414 assert!(result.error.is_none());
415 }
416}