exo_temporal/
short_term.rs

1//! Short-term volatile memory buffer
2
3use crate::types::{TemporalPattern, PatternId};
4use dashmap::DashMap;
5use parking_lot::RwLock;
6use std::collections::VecDeque;
7use std::sync::Arc;
8
9/// Configuration for short-term buffer
10#[derive(Debug, Clone)]
11pub struct ShortTermConfig {
12    /// Maximum number of patterns before consolidation
13    pub max_capacity: usize,
14    /// Consolidation threshold (trigger when this full)
15    pub consolidation_threshold: f32,
16}
17
18impl Default for ShortTermConfig {
19    fn default() -> Self {
20        Self {
21            max_capacity: 10_000,
22            consolidation_threshold: 0.8,
23        }
24    }
25}
26
27/// Short-term volatile memory buffer
28pub struct ShortTermBuffer {
29    /// Pattern storage (FIFO queue)
30    patterns: Arc<RwLock<VecDeque<TemporalPattern>>>,
31    /// Index for fast lookup by ID
32    index: DashMap<PatternId, usize>,
33    /// Configuration
34    config: ShortTermConfig,
35}
36
37impl ShortTermBuffer {
38    /// Create new short-term buffer
39    pub fn new(config: ShortTermConfig) -> Self {
40        Self {
41            patterns: Arc::new(RwLock::new(VecDeque::with_capacity(config.max_capacity))),
42            index: DashMap::new(),
43            config,
44        }
45    }
46
47    /// Insert pattern into buffer
48    pub fn insert(&self, temporal_pattern: TemporalPattern) -> PatternId {
49        let id = temporal_pattern.pattern.id;
50        let mut patterns = self.patterns.write();
51
52        // Add to queue
53        let position = patterns.len();
54        patterns.push_back(temporal_pattern);
55
56        // Update index
57        self.index.insert(id, position);
58
59        id
60    }
61
62    /// Get pattern by ID
63    pub fn get(&self, id: &PatternId) -> Option<TemporalPattern> {
64        let index = self.index.get(id)?;
65        let patterns = self.patterns.read();
66        patterns.get(*index).cloned()
67    }
68
69    /// Get mutable pattern by ID
70    pub fn get_mut<F, R>(&self, id: &PatternId, f: F) -> Option<R>
71    where
72        F: FnOnce(&mut TemporalPattern) -> R,
73    {
74        let index = *self.index.get(id)?;
75        let mut patterns = self.patterns.write();
76        patterns.get_mut(index).map(f)
77    }
78
79    /// Update pattern
80    pub fn update(&self, temporal_pattern: TemporalPattern) -> bool {
81        let id = temporal_pattern.pattern.id;
82        if let Some(index) = self.index.get(&id) {
83            let mut patterns = self.patterns.write();
84            if let Some(p) = patterns.get_mut(*index) {
85                *p = temporal_pattern;
86                return true;
87            }
88        }
89        false
90    }
91
92    /// Check if should trigger consolidation
93    pub fn should_consolidate(&self) -> bool {
94        let patterns = self.patterns.read();
95        let usage = patterns.len() as f32 / self.config.max_capacity as f32;
96        usage >= self.config.consolidation_threshold
97    }
98
99    /// Get current size
100    pub fn len(&self) -> usize {
101        self.patterns.read().len()
102    }
103
104    /// Check if empty
105    pub fn is_empty(&self) -> bool {
106        self.patterns.read().is_empty()
107    }
108
109    /// Drain all patterns (for consolidation)
110    pub fn drain(&self) -> Vec<TemporalPattern> {
111        let mut patterns = self.patterns.write();
112        self.index.clear();
113        patterns.drain(..).collect()
114    }
115
116    /// Drain patterns matching predicate
117    pub fn drain_filter<F>(&self, mut predicate: F) -> Vec<TemporalPattern>
118    where
119        F: FnMut(&TemporalPattern) -> bool,
120    {
121        let mut patterns = self.patterns.write();
122        let mut result = Vec::new();
123        let mut i = 0;
124
125        while i < patterns.len() {
126            if predicate(&patterns[i]) {
127                let temporal_pattern = patterns.remove(i).unwrap();
128                self.index.remove(&temporal_pattern.pattern.id);
129                result.push(temporal_pattern);
130                // Don't increment i, as we removed an element
131            } else {
132                // Update index since positions shifted
133                self.index.insert(patterns[i].pattern.id, i);
134                i += 1;
135            }
136        }
137
138        result
139    }
140
141    /// Get all patterns (for iteration)
142    pub fn all(&self) -> Vec<TemporalPattern> {
143        self.patterns.read().iter().cloned().collect()
144    }
145
146    /// Clear all patterns
147    pub fn clear(&self) {
148        self.patterns.write().clear();
149        self.index.clear();
150    }
151
152    /// Get statistics
153    pub fn stats(&self) -> ShortTermStats {
154        let patterns = self.patterns.read();
155        let size = patterns.len();
156        let capacity = self.config.max_capacity;
157        let usage = size as f32 / capacity as f32;
158
159        // Compute average salience
160        let total_salience: f32 = patterns.iter().map(|p| p.pattern.salience).sum();
161        let avg_salience = if size > 0 {
162            total_salience / size as f32
163        } else {
164            0.0
165        };
166
167        ShortTermStats {
168            size,
169            capacity,
170            usage,
171            avg_salience,
172        }
173    }
174}
175
176impl Default for ShortTermBuffer {
177    fn default() -> Self {
178        Self::new(ShortTermConfig::default())
179    }
180}
181
182/// Short-term buffer statistics
183#[derive(Debug, Clone)]
184pub struct ShortTermStats {
185    /// Current number of patterns
186    pub size: usize,
187    /// Maximum capacity
188    pub capacity: usize,
189    /// Usage ratio (0.0 to 1.0)
190    pub usage: f32,
191    /// Average salience
192    pub avg_salience: f32,
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use crate::types::Metadata;
199
200    #[test]
201    fn test_short_term_buffer() {
202        let buffer = ShortTermBuffer::default();
203
204        let temporal_pattern = TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
205        let id = temporal_pattern.pattern.id;
206
207        buffer.insert(temporal_pattern);
208
209        assert_eq!(buffer.len(), 1);
210        assert!(buffer.get(&id).is_some());
211
212        let patterns = buffer.drain();
213        assert_eq!(patterns.len(), 1);
214        assert!(buffer.is_empty());
215    }
216
217    #[test]
218    fn test_consolidation_threshold() {
219        let config = ShortTermConfig {
220            max_capacity: 10,
221            consolidation_threshold: 0.8,
222        };
223        let buffer = ShortTermBuffer::new(config);
224
225        // Add 7 patterns (70% full)
226        for i in 0..7 {
227            let temporal_pattern = TemporalPattern::from_embedding(vec![i as f32], Metadata::new());
228            buffer.insert(temporal_pattern);
229        }
230
231        assert!(!buffer.should_consolidate());
232
233        // Add 1 more (80% full)
234        let temporal_pattern = TemporalPattern::from_embedding(vec![8.0], Metadata::new());
235        buffer.insert(temporal_pattern);
236
237        assert!(buffer.should_consolidate());
238    }
239}