exo_temporal/
short_term.rs1use crate::types::{TemporalPattern, PatternId};
4use dashmap::DashMap;
5use parking_lot::RwLock;
6use std::collections::VecDeque;
7use std::sync::Arc;
8
9#[derive(Debug, Clone)]
11pub struct ShortTermConfig {
12 pub max_capacity: usize,
14 pub consolidation_threshold: f32,
16}
17
18impl Default for ShortTermConfig {
19 fn default() -> Self {
20 Self {
21 max_capacity: 10_000,
22 consolidation_threshold: 0.8,
23 }
24 }
25}
26
27pub struct ShortTermBuffer {
29 patterns: Arc<RwLock<VecDeque<TemporalPattern>>>,
31 index: DashMap<PatternId, usize>,
33 config: ShortTermConfig,
35}
36
37impl ShortTermBuffer {
38 pub fn new(config: ShortTermConfig) -> Self {
40 Self {
41 patterns: Arc::new(RwLock::new(VecDeque::with_capacity(config.max_capacity))),
42 index: DashMap::new(),
43 config,
44 }
45 }
46
47 pub fn insert(&self, temporal_pattern: TemporalPattern) -> PatternId {
49 let id = temporal_pattern.pattern.id;
50 let mut patterns = self.patterns.write();
51
52 let position = patterns.len();
54 patterns.push_back(temporal_pattern);
55
56 self.index.insert(id, position);
58
59 id
60 }
61
62 pub fn get(&self, id: &PatternId) -> Option<TemporalPattern> {
64 let index = self.index.get(id)?;
65 let patterns = self.patterns.read();
66 patterns.get(*index).cloned()
67 }
68
69 pub fn get_mut<F, R>(&self, id: &PatternId, f: F) -> Option<R>
71 where
72 F: FnOnce(&mut TemporalPattern) -> R,
73 {
74 let index = *self.index.get(id)?;
75 let mut patterns = self.patterns.write();
76 patterns.get_mut(index).map(f)
77 }
78
79 pub fn update(&self, temporal_pattern: TemporalPattern) -> bool {
81 let id = temporal_pattern.pattern.id;
82 if let Some(index) = self.index.get(&id) {
83 let mut patterns = self.patterns.write();
84 if let Some(p) = patterns.get_mut(*index) {
85 *p = temporal_pattern;
86 return true;
87 }
88 }
89 false
90 }
91
92 pub fn should_consolidate(&self) -> bool {
94 let patterns = self.patterns.read();
95 let usage = patterns.len() as f32 / self.config.max_capacity as f32;
96 usage >= self.config.consolidation_threshold
97 }
98
99 pub fn len(&self) -> usize {
101 self.patterns.read().len()
102 }
103
104 pub fn is_empty(&self) -> bool {
106 self.patterns.read().is_empty()
107 }
108
109 pub fn drain(&self) -> Vec<TemporalPattern> {
111 let mut patterns = self.patterns.write();
112 self.index.clear();
113 patterns.drain(..).collect()
114 }
115
116 pub fn drain_filter<F>(&self, mut predicate: F) -> Vec<TemporalPattern>
118 where
119 F: FnMut(&TemporalPattern) -> bool,
120 {
121 let mut patterns = self.patterns.write();
122 let mut result = Vec::new();
123 let mut i = 0;
124
125 while i < patterns.len() {
126 if predicate(&patterns[i]) {
127 let temporal_pattern = patterns.remove(i).unwrap();
128 self.index.remove(&temporal_pattern.pattern.id);
129 result.push(temporal_pattern);
130 } else {
132 self.index.insert(patterns[i].pattern.id, i);
134 i += 1;
135 }
136 }
137
138 result
139 }
140
141 pub fn all(&self) -> Vec<TemporalPattern> {
143 self.patterns.read().iter().cloned().collect()
144 }
145
146 pub fn clear(&self) {
148 self.patterns.write().clear();
149 self.index.clear();
150 }
151
152 pub fn stats(&self) -> ShortTermStats {
154 let patterns = self.patterns.read();
155 let size = patterns.len();
156 let capacity = self.config.max_capacity;
157 let usage = size as f32 / capacity as f32;
158
159 let total_salience: f32 = patterns.iter().map(|p| p.pattern.salience).sum();
161 let avg_salience = if size > 0 {
162 total_salience / size as f32
163 } else {
164 0.0
165 };
166
167 ShortTermStats {
168 size,
169 capacity,
170 usage,
171 avg_salience,
172 }
173 }
174}
175
176impl Default for ShortTermBuffer {
177 fn default() -> Self {
178 Self::new(ShortTermConfig::default())
179 }
180}
181
182#[derive(Debug, Clone)]
184pub struct ShortTermStats {
185 pub size: usize,
187 pub capacity: usize,
189 pub usage: f32,
191 pub avg_salience: f32,
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use crate::types::Metadata;
199
200 #[test]
201 fn test_short_term_buffer() {
202 let buffer = ShortTermBuffer::default();
203
204 let temporal_pattern = TemporalPattern::from_embedding(vec![1.0, 2.0, 3.0], Metadata::new());
205 let id = temporal_pattern.pattern.id;
206
207 buffer.insert(temporal_pattern);
208
209 assert_eq!(buffer.len(), 1);
210 assert!(buffer.get(&id).is_some());
211
212 let patterns = buffer.drain();
213 assert_eq!(patterns.len(), 1);
214 assert!(buffer.is_empty());
215 }
216
217 #[test]
218 fn test_consolidation_threshold() {
219 let config = ShortTermConfig {
220 max_capacity: 10,
221 consolidation_threshold: 0.8,
222 };
223 let buffer = ShortTermBuffer::new(config);
224
225 for i in 0..7 {
227 let temporal_pattern = TemporalPattern::from_embedding(vec![i as f32], Metadata::new());
228 buffer.insert(temporal_pattern);
229 }
230
231 assert!(!buffer.should_consolidate());
232
233 let temporal_pattern = TemporalPattern::from_embedding(vec![8.0], Metadata::new());
235 buffer.insert(temporal_pattern);
236
237 assert!(buffer.should_consolidate());
238 }
239}