1use anyhow::Result;
21use async_trait::async_trait;
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::time::{Duration, Instant};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
32pub enum SessionLane {
33 Control,
35 Query,
37 Execute,
39 Generate,
41}
42
43impl SessionLane {
44 pub fn priority(&self) -> u8 {
46 match self {
47 SessionLane::Control => 0,
48 SessionLane::Query => 1,
49 SessionLane::Execute => 2,
50 SessionLane::Generate => 3,
51 }
52 }
53
54 pub fn from_tool_name(tool_name: &str) -> Self {
56 match tool_name {
57 "read" | "glob" | "ls" | "grep" | "list_files" | "search" | "web_fetch"
58 | "web_search" => SessionLane::Query,
59 "bash" | "write" | "edit" | "delete" | "move" | "copy" | "execute" => {
60 SessionLane::Execute
61 }
62 _ => SessionLane::Execute,
63 }
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
73pub enum TaskHandlerMode {
74 #[default]
76 Internal,
77 External,
79 Hybrid,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct LaneHandlerConfig {
86 pub mode: TaskHandlerMode,
88 pub timeout_ms: u64,
90}
91
92impl Default for LaneHandlerConfig {
93 fn default() -> Self {
94 Self {
95 mode: TaskHandlerMode::Internal,
96 timeout_ms: 60_000,
97 }
98 }
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct ExternalTask {
108 pub task_id: String,
110 pub session_id: String,
112 pub lane: SessionLane,
114 pub command_type: String,
116 pub payload: serde_json::Value,
118 pub timeout_ms: u64,
120 #[serde(skip)]
122 pub created_at: Option<Instant>,
123}
124
125impl ExternalTask {
126 pub fn is_timed_out(&self) -> bool {
128 self.created_at
129 .map(|t| t.elapsed() > Duration::from_millis(self.timeout_ms))
130 .unwrap_or(false)
131 }
132
133 pub fn remaining_ms(&self) -> u64 {
135 self.created_at
136 .map(|t| {
137 let elapsed = t.elapsed().as_millis() as u64;
138 self.timeout_ms.saturating_sub(elapsed)
139 })
140 .unwrap_or(self.timeout_ms)
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct ExternalTaskResult {
147 pub success: bool,
149 pub result: serde_json::Value,
151 pub error: Option<String>,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
161#[serde(rename_all = "camelCase")]
162pub struct SessionQueueConfig {
163 #[serde(default = "default_control_concurrency")]
165 pub control_max_concurrency: usize,
166 #[serde(default = "default_query_concurrency")]
168 pub query_max_concurrency: usize,
169 #[serde(default = "default_execute_concurrency")]
171 pub execute_max_concurrency: usize,
172 #[serde(default = "default_generate_concurrency")]
174 pub generate_max_concurrency: usize,
175 #[serde(default)]
177 pub lane_handlers: HashMap<SessionLane, LaneHandlerConfig>,
178
179 #[serde(default)]
184 pub enable_dlq: bool,
185 #[serde(default)]
187 pub dlq_max_size: Option<usize>,
188 #[serde(default)]
190 pub enable_metrics: bool,
191 #[serde(default)]
193 pub enable_alerts: bool,
194 #[serde(default)]
196 pub default_timeout_ms: Option<u64>,
197 #[serde(default)]
199 pub storage_path: Option<std::path::PathBuf>,
200
201 #[serde(default)]
206 pub retry_policy: Option<RetryPolicyConfig>,
207 #[serde(default)]
209 pub rate_limit: Option<RateLimitConfig>,
210 #[serde(default)]
212 pub priority_boost: Option<PriorityBoostConfig>,
213 #[serde(default)]
215 pub pressure_threshold: Option<usize>,
216 #[serde(default)]
218 pub lane_timeouts: HashMap<SessionLane, u64>,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
223#[serde(rename_all = "camelCase")]
224pub struct RetryPolicyConfig {
225 pub strategy: String,
227 #[serde(default = "default_max_retries")]
229 pub max_retries: u32,
230 #[serde(default = "default_initial_delay_ms")]
232 pub initial_delay_ms: u64,
233 #[serde(default)]
235 pub fixed_delay_ms: Option<u64>,
236}
237
238fn default_max_retries() -> u32 {
239 3
240}
241
242fn default_initial_delay_ms() -> u64 {
243 100
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
248#[serde(rename_all = "camelCase")]
249pub struct RateLimitConfig {
250 pub limit_type: String,
252 #[serde(default)]
254 pub max_operations: Option<u64>,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
259#[serde(rename_all = "camelCase")]
260pub struct PriorityBoostConfig {
261 pub strategy: String,
263 #[serde(default)]
265 pub deadline_ms: Option<u64>,
266}
267
268fn default_control_concurrency() -> usize {
269 4
270}
271
272fn default_query_concurrency() -> usize {
273 12 }
275
276fn default_execute_concurrency() -> usize {
277 4
278}
279
280fn default_generate_concurrency() -> usize {
281 2
282}
283
284impl Default for SessionQueueConfig {
285 fn default() -> Self {
286 Self {
287 control_max_concurrency: 2,
288 query_max_concurrency: 4,
289 execute_max_concurrency: 2,
290 generate_max_concurrency: 1,
291 lane_handlers: HashMap::new(),
292 enable_dlq: false,
293 dlq_max_size: None,
294 enable_metrics: false,
295 enable_alerts: false,
296 default_timeout_ms: None,
297 storage_path: None,
298 retry_policy: None,
299 rate_limit: None,
300 priority_boost: None,
301 pressure_threshold: None,
302 lane_timeouts: HashMap::new(),
303 }
304 }
305}
306
307impl SessionQueueConfig {
308 pub fn max_concurrency(&self, lane: SessionLane) -> usize {
310 match lane {
311 SessionLane::Control => self.control_max_concurrency,
312 SessionLane::Query => self.query_max_concurrency,
313 SessionLane::Execute => self.execute_max_concurrency,
314 SessionLane::Generate => self.generate_max_concurrency,
315 }
316 }
317
318 pub fn handler_config(&self, lane: SessionLane) -> LaneHandlerConfig {
320 self.lane_handlers.get(&lane).cloned().unwrap_or_default()
321 }
322
323 pub fn with_dlq(mut self, max_size: Option<usize>) -> Self {
325 self.enable_dlq = true;
326 self.dlq_max_size = max_size;
327 self
328 }
329
330 pub fn with_metrics(mut self) -> Self {
332 self.enable_metrics = true;
333 self
334 }
335
336 pub fn with_alerts(mut self) -> Self {
338 self.enable_alerts = true;
339 self
340 }
341
342 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
344 self.default_timeout_ms = Some(timeout_ms);
345 self
346 }
347
348 pub fn with_storage(mut self, path: impl Into<std::path::PathBuf>) -> Self {
350 self.storage_path = Some(path.into());
351 self
352 }
353
354 pub fn with_lane_features(mut self) -> Self {
356 self.enable_dlq = true;
357 self.dlq_max_size = Some(1000);
358 self.enable_metrics = true;
359 self.enable_alerts = true;
360 self.default_timeout_ms = Some(60_000);
361 self
362 }
363}
364
365#[async_trait]
371pub trait SessionCommand: Send + Sync {
372 async fn execute(&self) -> Result<serde_json::Value>;
374
375 fn command_type(&self) -> &str;
377
378 fn payload(&self) -> serde_json::Value {
380 serde_json::json!({})
381 }
382}
383
384#[derive(Debug, Clone, Serialize, Deserialize)]
390pub struct LaneStatus {
391 pub lane: SessionLane,
392 pub pending: usize,
393 pub active: usize,
394 pub max_concurrency: usize,
395 pub handler_mode: TaskHandlerMode,
396}
397
398#[derive(Debug, Clone, Default, Serialize, Deserialize)]
400pub struct SessionQueueStats {
401 pub total_pending: usize,
402 pub total_active: usize,
403 pub external_pending: usize,
404 pub lanes: HashMap<String, LaneStatus>,
405}
406
407#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_task_handler_mode_default() {
417 let mode = TaskHandlerMode::default();
418 assert_eq!(mode, TaskHandlerMode::Internal);
419 }
420
421 #[test]
422 fn test_lane_handler_config_default() {
423 let config = LaneHandlerConfig::default();
424 assert_eq!(config.mode, TaskHandlerMode::Internal);
425 assert_eq!(config.timeout_ms, 60_000);
426 }
427
428 #[test]
429 fn test_external_task_timeout() {
430 let task = ExternalTask {
431 task_id: "test".to_string(),
432 session_id: "session".to_string(),
433 lane: SessionLane::Query,
434 command_type: "read".to_string(),
435 payload: serde_json::json!({}),
436 timeout_ms: 100,
437 created_at: Some(Instant::now()),
438 };
439
440 assert!(!task.is_timed_out());
441 assert!(task.remaining_ms() <= 100);
442 }
443
444 #[test]
445 fn test_session_queue_config_default() {
446 let config = SessionQueueConfig::default();
447 assert_eq!(config.control_max_concurrency, 2);
448 assert_eq!(config.query_max_concurrency, 4);
449 assert_eq!(config.execute_max_concurrency, 2);
450 assert_eq!(config.generate_max_concurrency, 1);
451 assert!(!config.enable_dlq);
452 assert!(!config.enable_metrics);
453 assert!(!config.enable_alerts);
454 }
455
456 #[test]
457 fn test_session_queue_config_max_concurrency() {
458 let config = SessionQueueConfig::default();
459 assert_eq!(config.max_concurrency(SessionLane::Control), 2);
460 assert_eq!(config.max_concurrency(SessionLane::Query), 4);
461 assert_eq!(config.max_concurrency(SessionLane::Execute), 2);
462 assert_eq!(config.max_concurrency(SessionLane::Generate), 1);
463 }
464
465 #[test]
466 fn test_session_queue_config_handler_config() {
467 let config = SessionQueueConfig::default();
468 let handler = config.handler_config(SessionLane::Execute);
469 assert_eq!(handler.mode, TaskHandlerMode::Internal);
470 assert_eq!(handler.timeout_ms, 60_000);
471 }
472
473 #[test]
474 fn test_session_queue_config_builders() {
475 let config = SessionQueueConfig::default()
476 .with_dlq(Some(500))
477 .with_metrics()
478 .with_alerts()
479 .with_timeout(30_000);
480
481 assert!(config.enable_dlq);
482 assert_eq!(config.dlq_max_size, Some(500));
483 assert!(config.enable_metrics);
484 assert!(config.enable_alerts);
485 assert_eq!(config.default_timeout_ms, Some(30_000));
486 }
487
488 #[test]
489 fn test_session_queue_config_with_lane_features() {
490 let config = SessionQueueConfig::default().with_lane_features();
491
492 assert!(config.enable_dlq);
493 assert_eq!(config.dlq_max_size, Some(1000));
494 assert!(config.enable_metrics);
495 assert!(config.enable_alerts);
496 assert_eq!(config.default_timeout_ms, Some(60_000));
497 }
498
499 #[test]
500 fn test_external_task_result() {
501 let result = ExternalTaskResult {
502 success: true,
503 result: serde_json::json!({"output": "hello"}),
504 error: None,
505 };
506 assert!(result.success);
507 assert!(result.error.is_none());
508 }
509}