1use anyhow::Result;
21use async_trait::async_trait;
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::time::{Duration, Instant};
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
32pub enum SessionLane {
33 Control,
35 Query,
37 Execute,
39 Generate,
41}
42
43impl SessionLane {
44 pub fn priority(&self) -> u8 {
46 match self {
47 SessionLane::Control => 0,
48 SessionLane::Query => 1,
49 SessionLane::Execute => 2,
50 SessionLane::Generate => 3,
51 }
52 }
53
54 pub fn from_tool_name(tool_name: &str) -> Self {
56 match tool_name {
57 "read" | "glob" | "ls" | "grep" | "list_files" | "search" => SessionLane::Query,
58 "bash" | "write" | "edit" | "delete" | "move" | "copy" | "execute" => {
59 SessionLane::Execute
60 }
61 _ => SessionLane::Execute,
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
72pub enum TaskHandlerMode {
73 #[default]
75 Internal,
76 External,
78 Hybrid,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct LaneHandlerConfig {
85 pub mode: TaskHandlerMode,
87 pub timeout_ms: u64,
89}
90
91impl Default for LaneHandlerConfig {
92 fn default() -> Self {
93 Self {
94 mode: TaskHandlerMode::Internal,
95 timeout_ms: 60_000,
96 }
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ExternalTask {
107 pub task_id: String,
109 pub session_id: String,
111 pub lane: SessionLane,
113 pub command_type: String,
115 pub payload: serde_json::Value,
117 pub timeout_ms: u64,
119 #[serde(skip)]
121 pub created_at: Option<Instant>,
122}
123
124impl ExternalTask {
125 pub fn is_timed_out(&self) -> bool {
127 self.created_at
128 .map(|t| t.elapsed() > Duration::from_millis(self.timeout_ms))
129 .unwrap_or(false)
130 }
131
132 pub fn remaining_ms(&self) -> u64 {
134 self.created_at
135 .map(|t| {
136 let elapsed = t.elapsed().as_millis() as u64;
137 self.timeout_ms.saturating_sub(elapsed)
138 })
139 .unwrap_or(self.timeout_ms)
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct ExternalTaskResult {
146 pub success: bool,
148 pub result: serde_json::Value,
150 pub error: Option<String>,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
160#[serde(rename_all = "camelCase")]
161pub struct SessionQueueConfig {
162 #[serde(default = "default_control_concurrency")]
164 pub control_max_concurrency: usize,
165 #[serde(default = "default_query_concurrency")]
167 pub query_max_concurrency: usize,
168 #[serde(default = "default_execute_concurrency")]
170 pub execute_max_concurrency: usize,
171 #[serde(default = "default_generate_concurrency")]
173 pub generate_max_concurrency: usize,
174 #[serde(default)]
176 pub lane_handlers: HashMap<SessionLane, LaneHandlerConfig>,
177
178 #[serde(default)]
183 pub enable_dlq: bool,
184 #[serde(default)]
186 pub dlq_max_size: Option<usize>,
187 #[serde(default)]
189 pub enable_metrics: bool,
190 #[serde(default)]
192 pub enable_alerts: bool,
193 #[serde(default)]
195 pub default_timeout_ms: Option<u64>,
196 #[serde(default)]
198 pub storage_path: Option<std::path::PathBuf>,
199
200 #[serde(default)]
205 pub retry_policy: Option<RetryPolicyConfig>,
206 #[serde(default)]
208 pub rate_limit: Option<RateLimitConfig>,
209 #[serde(default)]
211 pub priority_boost: Option<PriorityBoostConfig>,
212 #[serde(default)]
214 pub pressure_threshold: Option<usize>,
215 #[serde(default)]
217 pub lane_timeouts: HashMap<SessionLane, u64>,
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
222#[serde(rename_all = "camelCase")]
223pub struct RetryPolicyConfig {
224 pub strategy: String,
226 #[serde(default = "default_max_retries")]
228 pub max_retries: u32,
229 #[serde(default = "default_initial_delay_ms")]
231 pub initial_delay_ms: u64,
232 #[serde(default)]
234 pub fixed_delay_ms: Option<u64>,
235}
236
237fn default_max_retries() -> u32 {
238 3
239}
240
241fn default_initial_delay_ms() -> u64 {
242 100
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
247#[serde(rename_all = "camelCase")]
248pub struct RateLimitConfig {
249 pub limit_type: String,
251 #[serde(default)]
253 pub max_operations: Option<u64>,
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
258#[serde(rename_all = "camelCase")]
259pub struct PriorityBoostConfig {
260 pub strategy: String,
262 #[serde(default)]
264 pub deadline_ms: Option<u64>,
265}
266
267fn default_control_concurrency() -> usize {
268 4
269}
270
271fn default_query_concurrency() -> usize {
272 12 }
274
275fn default_execute_concurrency() -> usize {
276 4
277}
278
279fn default_generate_concurrency() -> usize {
280 2
281}
282
283impl Default for SessionQueueConfig {
284 fn default() -> Self {
285 Self {
286 control_max_concurrency: 2,
287 query_max_concurrency: 4,
288 execute_max_concurrency: 2,
289 generate_max_concurrency: 1,
290 lane_handlers: HashMap::new(),
291 enable_dlq: false,
292 dlq_max_size: None,
293 enable_metrics: false,
294 enable_alerts: false,
295 default_timeout_ms: None,
296 storage_path: None,
297 retry_policy: None,
298 rate_limit: None,
299 priority_boost: None,
300 pressure_threshold: None,
301 lane_timeouts: HashMap::new(),
302 }
303 }
304}
305
306impl SessionQueueConfig {
307 pub fn max_concurrency(&self, lane: SessionLane) -> usize {
309 match lane {
310 SessionLane::Control => self.control_max_concurrency,
311 SessionLane::Query => self.query_max_concurrency,
312 SessionLane::Execute => self.execute_max_concurrency,
313 SessionLane::Generate => self.generate_max_concurrency,
314 }
315 }
316
317 pub fn handler_config(&self, lane: SessionLane) -> LaneHandlerConfig {
319 self.lane_handlers.get(&lane).cloned().unwrap_or_default()
320 }
321
322 pub fn with_dlq(mut self, max_size: Option<usize>) -> Self {
324 self.enable_dlq = true;
325 self.dlq_max_size = max_size;
326 self
327 }
328
329 pub fn with_metrics(mut self) -> Self {
331 self.enable_metrics = true;
332 self
333 }
334
335 pub fn with_alerts(mut self) -> Self {
337 self.enable_alerts = true;
338 self
339 }
340
341 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
343 self.default_timeout_ms = Some(timeout_ms);
344 self
345 }
346
347 pub fn with_storage(mut self, path: impl Into<std::path::PathBuf>) -> Self {
349 self.storage_path = Some(path.into());
350 self
351 }
352
353 pub fn with_lane_features(mut self) -> Self {
355 self.enable_dlq = true;
356 self.dlq_max_size = Some(1000);
357 self.enable_metrics = true;
358 self.enable_alerts = true;
359 self.default_timeout_ms = Some(60_000);
360 self
361 }
362}
363
364#[async_trait]
370pub trait SessionCommand: Send + Sync {
371 async fn execute(&self) -> Result<serde_json::Value>;
373
374 fn command_type(&self) -> &str;
376
377 fn payload(&self) -> serde_json::Value {
379 serde_json::json!({})
380 }
381}
382
383#[derive(Debug, Clone, Serialize, Deserialize)]
389pub struct LaneStatus {
390 pub lane: SessionLane,
391 pub pending: usize,
392 pub active: usize,
393 pub max_concurrency: usize,
394 pub handler_mode: TaskHandlerMode,
395}
396
397#[derive(Debug, Clone, Default, Serialize, Deserialize)]
399pub struct SessionQueueStats {
400 pub total_pending: usize,
401 pub total_active: usize,
402 pub external_pending: usize,
403 pub lanes: HashMap<String, LaneStatus>,
404}
405
406#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_task_handler_mode_default() {
416 let mode = TaskHandlerMode::default();
417 assert_eq!(mode, TaskHandlerMode::Internal);
418 }
419
420 #[test]
421 fn test_lane_handler_config_default() {
422 let config = LaneHandlerConfig::default();
423 assert_eq!(config.mode, TaskHandlerMode::Internal);
424 assert_eq!(config.timeout_ms, 60_000);
425 }
426
427 #[test]
428 fn test_external_task_timeout() {
429 let task = ExternalTask {
430 task_id: "test".to_string(),
431 session_id: "session".to_string(),
432 lane: SessionLane::Query,
433 command_type: "read".to_string(),
434 payload: serde_json::json!({}),
435 timeout_ms: 100,
436 created_at: Some(Instant::now()),
437 };
438
439 assert!(!task.is_timed_out());
440 assert!(task.remaining_ms() <= 100);
441 }
442
443 #[test]
444 fn test_session_queue_config_default() {
445 let config = SessionQueueConfig::default();
446 assert_eq!(config.control_max_concurrency, 2);
447 assert_eq!(config.query_max_concurrency, 4);
448 assert_eq!(config.execute_max_concurrency, 2);
449 assert_eq!(config.generate_max_concurrency, 1);
450 assert!(!config.enable_dlq);
451 assert!(!config.enable_metrics);
452 assert!(!config.enable_alerts);
453 }
454
455 #[test]
456 fn test_session_queue_config_max_concurrency() {
457 let config = SessionQueueConfig::default();
458 assert_eq!(config.max_concurrency(SessionLane::Control), 2);
459 assert_eq!(config.max_concurrency(SessionLane::Query), 4);
460 assert_eq!(config.max_concurrency(SessionLane::Execute), 2);
461 assert_eq!(config.max_concurrency(SessionLane::Generate), 1);
462 }
463
464 #[test]
465 fn test_session_queue_config_handler_config() {
466 let config = SessionQueueConfig::default();
467 let handler = config.handler_config(SessionLane::Execute);
468 assert_eq!(handler.mode, TaskHandlerMode::Internal);
469 assert_eq!(handler.timeout_ms, 60_000);
470 }
471
472 #[test]
473 fn test_session_queue_config_builders() {
474 let config = SessionQueueConfig::default()
475 .with_dlq(Some(500))
476 .with_metrics()
477 .with_alerts()
478 .with_timeout(30_000);
479
480 assert!(config.enable_dlq);
481 assert_eq!(config.dlq_max_size, Some(500));
482 assert!(config.enable_metrics);
483 assert!(config.enable_alerts);
484 assert_eq!(config.default_timeout_ms, Some(30_000));
485 }
486
487 #[test]
488 fn test_session_queue_config_with_lane_features() {
489 let config = SessionQueueConfig::default().with_lane_features();
490
491 assert!(config.enable_dlq);
492 assert_eq!(config.dlq_max_size, Some(1000));
493 assert!(config.enable_metrics);
494 assert!(config.enable_alerts);
495 assert_eq!(config.default_timeout_ms, Some(60_000));
496 }
497
498 #[test]
499 fn test_external_task_result() {
500 let result = ExternalTaskResult {
501 success: true,
502 result: serde_json::json!({"output": "hello"}),
503 error: None,
504 };
505 assert!(result.success);
506 assert!(result.error.is_none());
507 }
508}