agcodex_core/subagents/
context.rs

1//! Execution context and progress tracking for AGCodex agents
2//!
3//! This module provides:
4//! - Shared execution context with AST cache and findings
5//! - Real-time progress reporting with ETA calculation
6//! - Context save/restore with compression
7//! - Inter-agent messaging with priorities
8//! - Performance metrics tracking
9
10use crate::code_tools::ast_agent_tools::Location as SourceLocation;
11use crate::modes::OperatingMode;
12use agcodex_ast::types::ParsedAst;
13use chrono::DateTime;
14use chrono::Utc;
15use dashmap::DashMap;
16use serde::Deserialize;
17use serde::Serialize;
18use std::collections::HashMap;
19use std::path::PathBuf;
20use std::sync::Arc;
21use std::sync::atomic::AtomicBool;
22use std::sync::atomic::AtomicU8;
23use std::sync::atomic::AtomicUsize;
24use std::sync::atomic::Ordering;
25use std::time::Duration;
26use std::time::Instant;
27use thiserror::Error;
28use tokio::sync::RwLock;
29use tokio::sync::mpsc;
30use uuid::Uuid;
31
32/// Errors that can occur in context operations
33#[derive(Error, Debug)]
34pub enum ContextError {
35    #[error("context snapshot version mismatch: expected {expected}, got {actual}")]
36    VersionMismatch { expected: u32, actual: u32 },
37
38    #[error("compression failed: {0}")]
39    CompressionFailed(String),
40
41    #[error("decompression failed: {0}")]
42    DecompressionFailed(String),
43
44    #[error("serialization failed: {0}")]
45    SerializationFailed(String),
46
47    #[error("message channel closed")]
48    ChannelClosed,
49
50    #[error("operation cancelled")]
51    Cancelled,
52
53    #[error("metric calculation failed: {0}")]
54    MetricError(String),
55
56    #[error("I/O error: {0}")]
57    Io(#[from] std::io::Error),
58}
59
60pub type ContextResult<T> = Result<T, ContextError>;
61
62/// Finding discovered by an agent (renamed to avoid conflict with agents::Finding)
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ContextFinding {
65    pub id: Uuid,
66    pub agent: String,
67    pub severity: FindingSeverity,
68    pub category: String,
69    pub message: String,
70    pub location: Option<SourceLocation>,
71    pub suggestion: Option<String>,
72    pub confidence: f32, // 0.0 to 1.0
73    pub timestamp: DateTime<Utc>,
74}
75
76/// Severity levels for findings
77#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
78pub enum FindingSeverity {
79    Info,
80    Warning,
81    Error,
82    Critical,
83}
84
85/// Shared execution context for agents
86#[derive(Clone)]
87pub struct AgentContext {
88    /// AST cache for parsed files
89    pub ast_cache: Arc<DashMap<PathBuf, ParsedAst>>,
90
91    /// Shared findings from all agents
92    pub shared_findings: Arc<RwLock<Vec<ContextFinding>>>,
93
94    /// User parameters for the session
95    pub parameters: Arc<HashMap<String, serde_json::Value>>,
96
97    /// Current operating mode
98    pub mode: OperatingMode,
99
100    /// Progress tracker for this context
101    pub progress: Arc<ProgressTracker>,
102
103    /// Session history access (optional)
104    pub session_history: Arc<RwLock<Option<Vec<String>>>>,
105
106    /// Cancellation token for graceful shutdown
107    pub cancellation_token: Arc<CancellationToken>,
108
109    /// Message bus for inter-agent communication
110    message_bus: Arc<MessageBus>,
111
112    /// Execution metrics
113    metrics: Arc<ExecutionMetrics>,
114
115    /// Context metadata
116    metadata: Arc<DashMap<String, serde_json::Value>>,
117}
118
119impl AgentContext {
120    /// Create a new agent context
121    pub fn new(mode: OperatingMode, parameters: HashMap<String, serde_json::Value>) -> Self {
122        Self {
123            ast_cache: Arc::new(DashMap::new()),
124            shared_findings: Arc::new(RwLock::new(Vec::new())),
125            parameters: Arc::new(parameters),
126            mode,
127            progress: Arc::new(ProgressTracker::new()),
128            session_history: Arc::new(RwLock::new(None)),
129            cancellation_token: Arc::new(CancellationToken::new()),
130            message_bus: Arc::new(MessageBus::new()),
131            metrics: Arc::new(ExecutionMetrics::new()),
132            metadata: Arc::new(DashMap::new()),
133        }
134    }
135
136    /// Add a finding to the shared context
137    pub async fn add_finding(&self, finding: ContextFinding) -> ContextResult<()> {
138        self.check_cancelled()?;
139        let mut findings = self.shared_findings.write().await;
140        findings.push(finding);
141        self.metrics.increment_findings();
142        Ok(())
143    }
144
145    /// Get all findings matching a severity level
146    pub async fn get_findings_by_severity(&self, severity: FindingSeverity) -> Vec<ContextFinding> {
147        let findings = self.shared_findings.read().await;
148        findings
149            .iter()
150            .filter(|f| f.severity == severity)
151            .cloned()
152            .collect()
153    }
154
155    /// Cache a parsed AST
156    pub fn cache_ast(&self, path: PathBuf, ast: ParsedAst) {
157        self.ast_cache.insert(path, ast);
158        self.metrics.increment_files_processed();
159    }
160
161    /// Get a cached AST if available
162    pub fn get_cached_ast(&self, path: &PathBuf) -> Option<ParsedAst> {
163        self.ast_cache.get(path).map(|entry| entry.clone())
164    }
165
166    /// Check if operation has been cancelled
167    pub fn check_cancelled(&self) -> ContextResult<()> {
168        if self.cancellation_token.is_cancelled() {
169            Err(ContextError::Cancelled)
170        } else {
171            Ok(())
172        }
173    }
174
175    /// Send a message to other agents
176    pub async fn send_message(&self, message: AgentMessage) -> ContextResult<()> {
177        self.message_bus.send(message).await
178    }
179
180    /// Subscribe to messages for a specific agent
181    pub fn subscribe(&self, agent_name: String) -> MessageReceiver {
182        self.message_bus.subscribe(agent_name)
183    }
184
185    /// Create a snapshot of the current context
186    pub async fn snapshot(&self) -> ContextResult<AgentContextSnapshot> {
187        let findings = self.shared_findings.read().await;
188        let history = self.session_history.read().await;
189
190        let snapshot = AgentContextSnapshot {
191            version: SNAPSHOT_VERSION,
192            timestamp: Utc::now(),
193            findings: findings.clone(),
194            parameters: (*self.parameters).clone(),
195            mode: self.mode,
196            session_history: history.clone(),
197            metadata: self.export_metadata(),
198            metrics: self.metrics.snapshot(),
199        };
200
201        Ok(snapshot)
202    }
203
204    /// Restore context from a snapshot
205    pub async fn restore(&mut self, snapshot: AgentContextSnapshot) -> ContextResult<()> {
206        if snapshot.version != SNAPSHOT_VERSION {
207            return Err(ContextError::VersionMismatch {
208                expected: SNAPSHOT_VERSION,
209                actual: snapshot.version,
210            });
211        }
212
213        *self.shared_findings.write().await = snapshot.findings;
214        self.parameters = Arc::new(snapshot.parameters);
215        self.mode = snapshot.mode;
216        *self.session_history.write().await = snapshot.session_history;
217        self.import_metadata(snapshot.metadata);
218        self.metrics.restore(snapshot.metrics);
219
220        Ok(())
221    }
222
223    /// Export metadata as a HashMap
224    fn export_metadata(&self) -> HashMap<String, serde_json::Value> {
225        self.metadata
226            .iter()
227            .map(|entry| (entry.key().clone(), entry.value().clone()))
228            .collect()
229    }
230
231    /// Import metadata from a HashMap
232    fn import_metadata(&self, metadata: HashMap<String, serde_json::Value>) {
233        self.metadata.clear();
234        for (key, value) in metadata {
235            self.metadata.insert(key, value);
236        }
237    }
238
239    /// Get execution metrics
240    pub fn metrics(&self) -> ExecutionMetricsSnapshot {
241        self.metrics.snapshot()
242    }
243}
244
245/// Progress tracker for real-time reporting
246pub struct ProgressTracker {
247    stages: RwLock<Vec<ProgressStage>>,
248    current_stage: AtomicUsize,
249    progress: AtomicU8,
250    tx: mpsc::UnboundedSender<ProgressEvent>,
251    rx: RwLock<mpsc::UnboundedReceiver<ProgressEvent>>,
252    start_time: Instant,
253    stage_history: RwLock<Vec<StageHistory>>,
254}
255
256impl Default for ProgressTracker {
257    fn default() -> Self {
258        Self::new()
259    }
260}
261
262impl ProgressTracker {
263    /// Create a new progress tracker
264    pub fn new() -> Self {
265        let (tx, rx) = mpsc::unbounded_channel();
266        Self {
267            stages: RwLock::new(Vec::new()),
268            current_stage: AtomicUsize::new(0),
269            progress: AtomicU8::new(0),
270            tx,
271            rx: RwLock::new(rx),
272            start_time: Instant::now(),
273            stage_history: RwLock::new(Vec::new()),
274        }
275    }
276
277    /// Set the stages for this operation
278    pub async fn set_stages(&self, stages: Vec<ProgressStage>) {
279        *self.stages.write().await = stages;
280        self.current_stage.store(0, Ordering::SeqCst);
281        self.progress.store(0, Ordering::SeqCst);
282    }
283
284    /// Move to the next stage
285    pub async fn next_stage(&self) -> ContextResult<()> {
286        let stages = self.stages.read().await;
287        let current = self.current_stage.load(Ordering::SeqCst);
288
289        if current < stages.len() {
290            // Record completion of current stage
291            if current > 0 {
292                let mut history = self.stage_history.write().await;
293                history.push(StageHistory {
294                    stage_index: current - 1,
295                    duration: self.start_time.elapsed(),
296                    completed_at: Instant::now(),
297                });
298            }
299
300            self.current_stage.fetch_add(1, Ordering::SeqCst);
301            self.progress.store(0, Ordering::SeqCst);
302
303            let _ = self.tx.send(ProgressEvent::StageChanged {
304                stage: current + 1,
305                total_stages: stages.len(),
306            });
307        }
308
309        Ok(())
310    }
311
312    /// Update progress within current stage (0-100)
313    pub fn update_progress(&self, progress: u8) {
314        let clamped = progress.min(100);
315        self.progress.store(clamped, Ordering::SeqCst);
316
317        let _ = self.tx.send(ProgressEvent::Progress {
318            percentage: clamped,
319            stage: self.current_stage.load(Ordering::SeqCst),
320        });
321    }
322
323    /// Set a detailed status message
324    pub fn set_status(&self, message: String) {
325        let _ = self.tx.send(ProgressEvent::Status { message });
326    }
327
328    /// Calculate ETA based on historical data
329    pub async fn calculate_eta(&self) -> Option<Duration> {
330        let stages = self.stages.read().await;
331        let current_stage = self.current_stage.load(Ordering::SeqCst);
332        let history = self.stage_history.read().await;
333
334        if stages.is_empty() || current_stage >= stages.len() {
335            return None;
336        }
337
338        // Calculate average time per stage from history
339        if history.is_empty() {
340            // Estimate based on current progress
341            let elapsed = self.start_time.elapsed();
342            let progress = self.progress.load(Ordering::SeqCst) as f64 / 100.0;
343
344            if progress > 0.01 {
345                let estimated_total = elapsed.as_secs_f64() / progress;
346                let remaining = estimated_total - elapsed.as_secs_f64();
347                return Some(Duration::from_secs_f64(remaining));
348            }
349        } else {
350            // Use historical data for better estimates
351            let avg_stage_time = history
352                .iter()
353                .map(|h| h.duration.as_secs_f64())
354                .sum::<f64>()
355                / history.len() as f64;
356
357            let remaining_stages = stages.len() - current_stage;
358            let estimated_remaining = avg_stage_time * remaining_stages as f64;
359
360            return Some(Duration::from_secs_f64(estimated_remaining));
361        }
362
363        None
364    }
365
366    /// Get current progress info
367    pub async fn get_info(&self) -> ProgressInfo {
368        let stages = self.stages.read().await;
369        let current_stage = self.current_stage.load(Ordering::SeqCst);
370        let progress = self.progress.load(Ordering::SeqCst);
371
372        ProgressInfo {
373            current_stage,
374            total_stages: stages.len(),
375            stage_progress: progress,
376            current_stage_name: stages.get(current_stage).map(|s| s.name.clone()),
377            eta: self.calculate_eta().await,
378            elapsed: self.start_time.elapsed(),
379        }
380    }
381
382    /// Receive progress updates
383    pub async fn recv(&self) -> Option<ProgressEvent> {
384        self.rx.write().await.recv().await
385    }
386}
387
388/// A stage in the progress tracking
389#[derive(Debug, Clone, Serialize, Deserialize)]
390pub struct ProgressStage {
391    pub name: String,
392    pub description: String,
393    pub weight: f32, // Relative weight for overall progress calculation
394}
395
396/// History of completed stages
397#[derive(Debug, Clone)]
398#[allow(dead_code)]
399struct StageHistory {
400    stage_index: usize,
401    duration: Duration,
402    completed_at: Instant,
403}
404
405/// Progress update message (renamed to avoid conflict with orchestrator::ProgressUpdate)
406#[derive(Debug, Clone)]
407pub enum ProgressEvent {
408    StageChanged { stage: usize, total_stages: usize },
409    Progress { percentage: u8, stage: usize },
410    Status { message: String },
411    Completed,
412    Failed { error: String },
413}
414
415/// Current progress information
416#[derive(Debug, Clone)]
417pub struct ProgressInfo {
418    pub current_stage: usize,
419    pub total_stages: usize,
420    pub stage_progress: u8,
421    pub current_stage_name: Option<String>,
422    pub eta: Option<Duration>,
423    pub elapsed: Duration,
424}
425
426/// Context snapshot for save/restore (renamed to avoid conflict with orchestrator::ContextSnapshot)
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct AgentContextSnapshot {
429    pub version: u32,
430    pub timestamp: DateTime<Utc>,
431    pub findings: Vec<ContextFinding>,
432    pub parameters: HashMap<String, serde_json::Value>,
433    pub mode: OperatingMode,
434    pub session_history: Option<Vec<String>>,
435    pub metadata: HashMap<String, serde_json::Value>,
436    pub metrics: ExecutionMetricsSnapshot,
437}
438
439const SNAPSHOT_VERSION: u32 = 1;
440
441impl AgentContextSnapshot {
442    /// Compress the snapshot using zstd
443    pub fn compress(&self) -> ContextResult<Vec<u8>> {
444        let config = bincode::config::standard();
445        let serialized = bincode::serde::encode_to_vec(self, config)
446            .map_err(|e| ContextError::SerializationFailed(e.to_string()))?;
447
448        // Note: zstd compression would be added here if the crate is added to dependencies
449        // For now, return uncompressed
450        Ok(serialized)
451    }
452
453    /// Decompress a snapshot
454    pub fn decompress(data: &[u8]) -> ContextResult<Self> {
455        // Note: zstd decompression would be added here if the crate is added to dependencies
456        // For now, treat as uncompressed
457        let config = bincode::config::standard();
458        let (snapshot, _) = bincode::serde::decode_from_slice(data, config)
459            .map_err(|e| ContextError::SerializationFailed(e.to_string()))?;
460        Ok(snapshot)
461    }
462
463    /// Merge with another snapshot (for parallel agent results)
464    pub fn merge(&mut self, other: AgentContextSnapshot) -> ContextResult<()> {
465        // Merge findings
466        self.findings.extend(other.findings);
467
468        // Merge metadata (other overwrites self for conflicts)
469        for (key, value) in other.metadata {
470            self.metadata.insert(key, value);
471        }
472
473        // Merge metrics
474        self.metrics.merge(other.metrics);
475
476        self.timestamp = Utc::now();
477        Ok(())
478    }
479}
480
481/// Inter-agent message
482#[derive(Debug, Clone, Serialize, Deserialize)]
483pub struct AgentMessage {
484    pub id: Uuid,
485    pub from: String,
486    pub to: MessageTarget,
487    pub message_type: MessageType,
488    pub priority: MessagePriority,
489    pub payload: serde_json::Value,
490    pub timestamp: DateTime<Utc>,
491}
492
493/// Message targeting
494#[derive(Debug, Clone, Serialize, Deserialize)]
495pub enum MessageTarget {
496    Agent(String),
497    Broadcast,
498    Group(Vec<String>),
499}
500
501/// Message types
502#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
503pub enum MessageType {
504    Info,
505    Warning,
506    Error,
507    Result,
508    Request,
509    Response,
510    Coordination,
511}
512
513/// Message priority levels
514#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
515pub enum MessagePriority {
516    Low,
517    Normal,
518    High,
519    Critical,
520}
521
522/// Message bus for inter-agent communication
523struct MessageBus {
524    subscribers: Arc<DashMap<String, mpsc::UnboundedSender<AgentMessage>>>,
525    _broadcast_tx: mpsc::UnboundedSender<AgentMessage>,
526    _broadcast_rx: Arc<RwLock<mpsc::UnboundedReceiver<AgentMessage>>>,
527}
528
529impl MessageBus {
530    fn new() -> Self {
531        let (broadcast_tx, broadcast_rx) = mpsc::unbounded_channel();
532        Self {
533            subscribers: Arc::new(DashMap::new()),
534            _broadcast_tx: broadcast_tx,
535            _broadcast_rx: Arc::new(RwLock::new(broadcast_rx)),
536        }
537    }
538
539    async fn send(&self, message: AgentMessage) -> ContextResult<()> {
540        match &message.to {
541            MessageTarget::Agent(name) => {
542                if let Some(tx) = self.subscribers.get(name) {
543                    tx.send(message).map_err(|_| ContextError::ChannelClosed)?;
544                }
545            }
546            MessageTarget::Broadcast => {
547                for entry in self.subscribers.iter() {
548                    let _ = entry.value().send(message.clone());
549                }
550            }
551            MessageTarget::Group(agents) => {
552                for agent in agents {
553                    if let Some(tx) = self.subscribers.get(agent) {
554                        let _ = tx.send(message.clone());
555                    }
556                }
557            }
558        }
559        Ok(())
560    }
561
562    fn subscribe(&self, agent_name: String) -> MessageReceiver {
563        let (tx, rx) = mpsc::unbounded_channel();
564        self.subscribers.insert(agent_name.clone(), tx);
565        MessageReceiver { rx, agent_name }
566    }
567}
568
569/// Message receiver for an agent
570pub struct MessageReceiver {
571    rx: mpsc::UnboundedReceiver<AgentMessage>,
572    agent_name: String,
573}
574
575impl MessageReceiver {
576    /// Receive the next message
577    pub async fn recv(&mut self) -> Option<AgentMessage> {
578        self.rx.recv().await
579    }
580
581    /// Try to receive without blocking
582    pub fn try_recv(&mut self) -> Option<AgentMessage> {
583        self.rx.try_recv().ok()
584    }
585
586    /// Get the agent name for this receiver
587    pub fn agent_name(&self) -> &str {
588        &self.agent_name
589    }
590}
591
592/// Execution metrics for performance tracking
593struct ExecutionMetrics {
594    start_time: Instant,
595    files_processed: AtomicUsize,
596    findings_generated: AtomicUsize,
597    memory_allocated: AtomicUsize,
598    cache_hits: AtomicUsize,
599    cache_misses: AtomicUsize,
600}
601
602impl ExecutionMetrics {
603    fn new() -> Self {
604        Self {
605            start_time: Instant::now(),
606            files_processed: AtomicUsize::new(0),
607            findings_generated: AtomicUsize::new(0),
608            memory_allocated: AtomicUsize::new(0),
609            cache_hits: AtomicUsize::new(0),
610            cache_misses: AtomicUsize::new(0),
611        }
612    }
613
614    fn increment_files_processed(&self) {
615        self.files_processed.fetch_add(1, Ordering::Relaxed);
616    }
617
618    fn increment_findings(&self) {
619        self.findings_generated.fetch_add(1, Ordering::Relaxed);
620    }
621
622    fn _record_cache_hit(&self) {
623        self.cache_hits.fetch_add(1, Ordering::Relaxed);
624    }
625
626    fn _record_cache_miss(&self) {
627        self.cache_misses.fetch_add(1, Ordering::Relaxed);
628    }
629
630    fn snapshot(&self) -> ExecutionMetricsSnapshot {
631        let cache_total =
632            self.cache_hits.load(Ordering::Relaxed) + self.cache_misses.load(Ordering::Relaxed);
633
634        let cache_hit_rate = if cache_total > 0 {
635            self.cache_hits.load(Ordering::Relaxed) as f64 / cache_total as f64
636        } else {
637            0.0
638        };
639
640        ExecutionMetricsSnapshot {
641            elapsed: self.start_time.elapsed(),
642            files_processed: self.files_processed.load(Ordering::Relaxed),
643            findings_generated: self.findings_generated.load(Ordering::Relaxed),
644            memory_allocated: self.memory_allocated.load(Ordering::Relaxed),
645            cache_hit_rate,
646            files_per_second: self.calculate_throughput(),
647        }
648    }
649
650    fn calculate_throughput(&self) -> f64 {
651        let elapsed = self.start_time.elapsed().as_secs_f64();
652        if elapsed > 0.0 {
653            self.files_processed.load(Ordering::Relaxed) as f64 / elapsed
654        } else {
655            0.0
656        }
657    }
658
659    fn restore(&self, snapshot: ExecutionMetricsSnapshot) {
660        self.files_processed
661            .store(snapshot.files_processed, Ordering::Relaxed);
662        self.findings_generated
663            .store(snapshot.findings_generated, Ordering::Relaxed);
664        self.memory_allocated
665            .store(snapshot.memory_allocated, Ordering::Relaxed);
666    }
667}
668
669/// Snapshot of execution metrics
670#[derive(Debug, Clone, Serialize, Deserialize)]
671pub struct ExecutionMetricsSnapshot {
672    pub elapsed: Duration,
673    pub files_processed: usize,
674    pub findings_generated: usize,
675    pub memory_allocated: usize,
676    pub cache_hit_rate: f64,
677    pub files_per_second: f64,
678}
679
680impl ExecutionMetricsSnapshot {
681    /// Merge with another metrics snapshot
682    pub fn merge(&mut self, other: ExecutionMetricsSnapshot) {
683        self.files_processed += other.files_processed;
684        self.findings_generated += other.findings_generated;
685        self.memory_allocated = self.memory_allocated.max(other.memory_allocated);
686
687        // Weighted average for cache hit rate
688        let total = self.files_processed + other.files_processed;
689        if total > 0 {
690            self.cache_hit_rate = (self.cache_hit_rate * self.files_processed as f64
691                + other.cache_hit_rate * other.files_processed as f64)
692                / total as f64;
693        }
694    }
695}
696
697/// Cancellation token for graceful shutdown
698#[derive(Debug, Clone)]
699pub struct CancellationToken {
700    inner: Arc<CancellationTokenInner>,
701}
702
703#[derive(Debug)]
704struct CancellationTokenInner {
705    cancelled: AtomicBool,
706    waiters: RwLock<Vec<tokio::sync::oneshot::Sender<()>>>,
707}
708
709impl Default for CancellationToken {
710    fn default() -> Self {
711        Self::new()
712    }
713}
714
715impl CancellationToken {
716    /// Create a new cancellation token
717    pub fn new() -> Self {
718        Self {
719            inner: Arc::new(CancellationTokenInner {
720                cancelled: AtomicBool::new(false),
721                waiters: RwLock::new(Vec::new()),
722            }),
723        }
724    }
725
726    /// Cancel all operations
727    pub async fn cancel(&self) {
728        self.inner.cancelled.store(true, Ordering::SeqCst);
729        let mut waiters = self.inner.waiters.write().await;
730        for waiter in waiters.drain(..) {
731            let _ = waiter.send(());
732        }
733    }
734
735    /// Check if cancelled
736    pub fn is_cancelled(&self) -> bool {
737        self.inner.cancelled.load(Ordering::SeqCst)
738    }
739
740    /// Wait for cancellation
741    pub async fn cancelled(&self) {
742        // Early check for already cancelled
743        if self.is_cancelled() {
744            return;
745        }
746
747        let (tx, rx) = tokio::sync::oneshot::channel();
748
749        // Add waiter while holding the lock
750        {
751            let mut waiters = self.inner.waiters.write().await;
752
753            // Double-check cancellation state after acquiring the lock
754            // This prevents the race condition where cancellation happens
755            // between the first check and adding the waiter
756            if self.is_cancelled() {
757                return;
758            }
759
760            waiters.push(tx);
761        }
762
763        // Now wait for cancellation signal
764        let _ = rx.await;
765    }
766
767    /// Create a child token that cancels when parent cancels
768    pub fn child(&self) -> CancellationToken {
769        let child = CancellationToken::new();
770        let child_clone = child.clone();
771        let parent = self.clone();
772
773        tokio::spawn(async move {
774            parent.cancelled().await;
775            child_clone.cancel().await;
776        });
777
778        child
779    }
780}
781
782#[cfg(test)]
783mod tests {
784    use super::*;
785
786    #[tokio::test]
787    async fn test_agent_context_creation() {
788        let params = HashMap::new();
789        let context = AgentContext::new(OperatingMode::Build, params);
790
791        assert_eq!(context.mode, OperatingMode::Build);
792        assert!(!context.cancellation_token.is_cancelled());
793    }
794
795    #[tokio::test]
796    async fn test_finding_management() {
797        let context = AgentContext::new(OperatingMode::Review, HashMap::new());
798
799        let finding = ContextFinding {
800            id: Uuid::new_v4(),
801            agent: "test-agent".to_string(),
802            severity: FindingSeverity::Warning,
803            category: "code-quality".to_string(),
804            message: "Test finding".to_string(),
805            location: None,
806            suggestion: None,
807            confidence: 0.8,
808            timestamp: Utc::now(),
809        };
810
811        context.add_finding(finding.clone()).await.unwrap();
812
813        let warnings = context
814            .get_findings_by_severity(FindingSeverity::Warning)
815            .await;
816        assert_eq!(warnings.len(), 1);
817        assert_eq!(warnings[0].message, "Test finding");
818    }
819
820    #[tokio::test]
821    async fn test_progress_tracking() {
822        let tracker = ProgressTracker::new();
823
824        let stages = vec![
825            ProgressStage {
826                name: "Analysis".to_string(),
827                description: "Analyzing code".to_string(),
828                weight: 1.0,
829            },
830            ProgressStage {
831                name: "Processing".to_string(),
832                description: "Processing results".to_string(),
833                weight: 1.0,
834            },
835        ];
836
837        tracker.set_stages(stages).await;
838        tracker.update_progress(50);
839
840        let info = tracker.get_info().await;
841        assert_eq!(info.current_stage, 0);
842        assert_eq!(info.total_stages, 2);
843        assert_eq!(info.stage_progress, 50);
844    }
845
846    #[tokio::test]
847    async fn test_context_snapshot() {
848        let context = AgentContext::new(OperatingMode::Plan, HashMap::new());
849
850        let finding = ContextFinding {
851            id: Uuid::new_v4(),
852            agent: "test-agent".to_string(),
853            severity: FindingSeverity::Info,
854            category: "test".to_string(),
855            message: "Test finding".to_string(),
856            location: None,
857            suggestion: None,
858            confidence: 1.0,
859            timestamp: Utc::now(),
860        };
861
862        context.add_finding(finding.clone()).await.unwrap();
863
864        let snapshot = context.snapshot().await.unwrap();
865        assert_eq!(snapshot.findings.len(), 1);
866        assert_eq!(snapshot.mode, OperatingMode::Plan);
867
868        // Test compression/decompression
869        let compressed = snapshot.compress().unwrap();
870        let restored = AgentContextSnapshot::decompress(&compressed).unwrap();
871        assert_eq!(restored.findings.len(), 1);
872    }
873
874    #[tokio::test]
875    async fn test_cancellation_token() {
876        let token = CancellationToken::new();
877        assert!(!token.is_cancelled());
878
879        let token_clone = token.clone();
880        let handle = tokio::spawn(async move {
881            token_clone.cancelled().await;
882        });
883
884        tokio::time::sleep(Duration::from_millis(10)).await;
885        token.cancel().await;
886
887        handle.await.unwrap();
888        assert!(token.is_cancelled());
889    }
890
891    #[tokio::test]
892    async fn test_message_bus() {
893        let context = AgentContext::new(OperatingMode::Build, HashMap::new());
894
895        let mut receiver = context.subscribe("agent-1".to_string());
896
897        let message = AgentMessage {
898            id: Uuid::new_v4(),
899            from: "agent-2".to_string(),
900            to: MessageTarget::Agent("agent-1".to_string()),
901            message_type: MessageType::Info,
902            priority: MessagePriority::Normal,
903            payload: serde_json::json!({"test": "data"}),
904            timestamp: Utc::now(),
905        };
906
907        context.send_message(message.clone()).await.unwrap();
908
909        let received = receiver.recv().await.unwrap();
910        assert_eq!(received.from, "agent-2");
911        assert_eq!(received.payload["test"], "data");
912    }
913
914    #[tokio::test]
915    async fn test_metrics_tracking() {
916        let context = AgentContext::new(OperatingMode::Build, HashMap::new());
917
918        // Simulate processing
919        context.metrics.increment_files_processed();
920        context.metrics.increment_files_processed();
921        context.metrics.increment_findings();
922
923        let metrics = context.metrics();
924        assert_eq!(metrics.files_processed, 2);
925        assert_eq!(metrics.findings_generated, 1);
926    }
927}