1use crate::cache::MemoryCache;
5use crate::decision_gate::{DecisionGate, GateConfig, SaveDecision};
6use crate::knowledge::KnowledgeCache;
7use crate::vector_search;
8use crate::{MemoryEntry, SearchResult};
9
10#[derive(Debug, Clone)]
16pub struct Exchange {
17 pub user_turn: String,
18 pub agent_turn: String,
19 pub session_id: String,
20 pub turn_number: u32,
21 pub timestamp: f64,
22 pub user_embedding: Option<Vec<f32>>,
23 pub agent_embedding: Option<Vec<f32>>,
24}
25
26#[derive(Debug, Clone)]
28pub struct StrategyOutput {
29 pub entries: Vec<MemoryEntry>,
30 pub entity_updates: Vec<EntityUpdate>,
31 pub skipped: Option<SkipReason>,
32}
33
34#[derive(Debug, Clone)]
35pub enum SkipReason {
36 Trivial,
37 Duplicate,
38 BelowThreshold,
39 Custom(String),
40}
41
42#[derive(Debug, Clone)]
43pub struct EntityUpdate {
44 pub name: String,
45 pub entity_type: String,
46 pub aliases: Vec<String>,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq)]
51pub enum SaveAs {
52 UserTurn,
53 AgentTurn,
54 Both,
55 Combined,
56}
57
58pub trait MemoryStoreView {
60 fn search(&self, embedding: &[f32], k: usize) -> Vec<SearchResult>;
61 fn memory_count(&self) -> usize;
62 fn entity_count(&self) -> usize;
63}
64
65pub struct CacheStoreView<'a> {
70 cache: &'a MemoryCache,
71 knowledge: &'a KnowledgeCache,
72}
73
74impl<'a> CacheStoreView<'a> {
75 pub fn new(cache: &'a MemoryCache, knowledge: &'a KnowledgeCache) -> Self {
76 Self { cache, knowledge }
77 }
78}
79
80impl MemoryStoreView for CacheStoreView<'_> {
81 fn search(&self, embedding: &[f32], k: usize) -> Vec<SearchResult> {
82 let scored = vector_search::cosine_similarity_batch_prenorm(
83 embedding, &self.cache.embeddings, &self.cache.norms, &self.cache.tombstones,
84 );
85 vector_search::top_k(scored, k)
86 .into_iter()
87 .map(|(idx, score)| SearchResult {
88 score,
89 chunk: self.cache.chunks[idx].clone(),
90 index: idx,
91 timestamp: self.cache.timestamps[idx],
92 source_channel: self.cache.source_channels[idx].clone(),
93 activation: self.cache.activation_weights[idx],
94 })
95 .collect()
96 }
97
98 fn memory_count(&self) -> usize {
99 self.cache.len()
100 }
101
102 fn entity_count(&self) -> usize {
103 self.knowledge.entities.len()
104 }
105}
106
107pub trait MemoryStrategy: Send + Sync {
112 fn evaluate(
113 &self,
114 exchange: &Exchange,
115 store: &dyn MemoryStoreView,
116 ) -> StrategyOutput;
117}
118
119fn make_entry(
124 text: String,
125 embedding: Vec<f32>,
126 source_channel: &str,
127 exchange: &Exchange,
128) -> MemoryEntry {
129 MemoryEntry {
130 chunk: text,
131 embedding,
132 source_channel: source_channel.to_string(),
133 timestamp: exchange.timestamp,
134 session_id: exchange.session_id.clone(),
135 tags: String::new(),
136 }
137}
138
139fn average_embeddings(a: &Option<Vec<f32>>, b: &Option<Vec<f32>>) -> Vec<f32> {
141 match (a, b) {
142 (Some(va), Some(vb)) => {
143 va.iter()
144 .zip(vb.iter())
145 .map(|(x, y)| (x + y) / 2.0)
146 .collect()
147 }
148 (Some(v), None) | (None, Some(v)) => v.clone(),
149 (None, None) => Vec::new(),
150 }
151}
152
153pub struct SaveEveryExchange {
158 pub gate: DecisionGate,
159 pub save_as: SaveAs,
160}
161
162impl Default for SaveEveryExchange {
163 fn default() -> Self {
164 Self {
165 gate: DecisionGate::new(GateConfig::default()),
166 save_as: SaveAs::Combined,
167 }
168 }
169}
170
171impl MemoryStrategy for SaveEveryExchange {
172 fn evaluate(
173 &self,
174 exchange: &Exchange,
175 _store: &dyn MemoryStoreView,
176 ) -> StrategyOutput {
177 if let SaveDecision::Skip(_) = self.gate.should_save(&exchange.user_turn) {
179 return StrategyOutput {
180 entries: Vec::new(),
181 entity_updates: Vec::new(),
182 skipped: Some(SkipReason::Trivial),
183 };
184 }
185
186 let entries = match self.save_as {
187 SaveAs::Combined => {
188 let text = format!("{}\n---\n{}", exchange.user_turn, exchange.agent_turn);
189 let emb = average_embeddings(&exchange.user_embedding, &exchange.agent_embedding);
190 vec![make_entry(text, emb, "conversation", exchange)]
191 }
192 SaveAs::UserTurn => {
193 let emb = exchange.user_embedding.clone().unwrap_or_default();
194 vec![make_entry(exchange.user_turn.clone(), emb, "conversation", exchange)]
195 }
196 SaveAs::AgentTurn => {
197 let emb = exchange.agent_embedding.clone().unwrap_or_default();
198 vec![make_entry(exchange.agent_turn.clone(), emb, "conversation", exchange)]
199 }
200 SaveAs::Both => {
201 let u_emb = exchange.user_embedding.clone().unwrap_or_default();
202 let a_emb = exchange.agent_embedding.clone().unwrap_or_default();
203 vec![
204 make_entry(exchange.user_turn.clone(), u_emb, "conversation", exchange),
205 make_entry(exchange.agent_turn.clone(), a_emb, "conversation", exchange),
206 ]
207 }
208 };
209
210 StrategyOutput {
211 entries,
212 entity_updates: Vec::new(),
213 skipped: None,
214 }
215 }
216}
217
218pub struct SaveOnSemanticShift {
223 pub gate: DecisionGate,
224 pub shift_threshold: f32,
225 pub lookback_k: usize,
226}
227
228impl Default for SaveOnSemanticShift {
229 fn default() -> Self {
230 Self {
231 gate: DecisionGate::new(GateConfig::default()),
232 shift_threshold: 0.25,
233 lookback_k: 5,
234 }
235 }
236}
237
238impl MemoryStrategy for SaveOnSemanticShift {
239 fn evaluate(
240 &self,
241 exchange: &Exchange,
242 store: &dyn MemoryStoreView,
243 ) -> StrategyOutput {
244 if let SaveDecision::Skip(_) = self.gate.should_save(&exchange.user_turn) {
246 return StrategyOutput {
247 entries: Vec::new(),
248 entity_updates: Vec::new(),
249 skipped: Some(SkipReason::Trivial),
250 };
251 }
252
253 let embedding = match &exchange.user_embedding {
255 Some(e) => e,
256 None => {
257 let text = format!("{}\n---\n{}", exchange.user_turn, exchange.agent_turn);
259 return StrategyOutput {
260 entries: vec![make_entry(text, Vec::new(), "conversation", exchange)],
261 entity_updates: Vec::new(),
262 skipped: None,
263 };
264 }
265 };
266
267 let results = store.search(embedding, self.lookback_k);
269 if let Some(top) = results.first() {
270 if top.score > (1.0 - self.shift_threshold) {
271 return StrategyOutput {
272 entries: Vec::new(),
273 entity_updates: Vec::new(),
274 skipped: Some(SkipReason::Duplicate),
275 };
276 }
277 }
278
279 let text = format!("{}\n---\n{}", exchange.user_turn, exchange.agent_turn);
281 let emb = average_embeddings(&exchange.user_embedding, &exchange.agent_embedding);
282 StrategyOutput {
283 entries: vec![make_entry(text, emb, "conversation", exchange)],
284 entity_updates: Vec::new(),
285 skipped: None,
286 }
287 }
288}
289
290const DEFAULT_CORRECTION_CUES: &[&str] = &[
295 "no,", "no ", "actually,", "actually ", "thats wrong", "not quite",
296 "correction:", "to clarify", "i meant", "what i meant", "let me clarify",
297 "to be clear",
298];
299
300pub struct SaveOnUserCorrection {
301 pub base: Box<dyn MemoryStrategy>,
302 pub correction_cues: Vec<String>,
303}
304
305impl SaveOnUserCorrection {
306 pub fn new(base: Box<dyn MemoryStrategy>) -> Self {
307 Self {
308 base,
309 correction_cues: DEFAULT_CORRECTION_CUES.iter().map(|s| s.to_string()).collect(),
310 }
311 }
312}
313
314impl MemoryStrategy for SaveOnUserCorrection {
315 fn evaluate(
316 &self,
317 exchange: &Exchange,
318 store: &dyn MemoryStoreView,
319 ) -> StrategyOutput {
320 let lower = exchange.user_turn.to_lowercase();
321 let is_correction = self.correction_cues.iter().any(|cue| {
322 lower.starts_with(cue) || lower.contains(cue)
323 });
324
325 if is_correction {
326 let text = format!("{}\n---\n{}", exchange.user_turn, exchange.agent_turn);
328 let emb = average_embeddings(&exchange.user_embedding, &exchange.agent_embedding);
329 return StrategyOutput {
330 entries: vec![make_entry(text, emb, "correction", exchange)],
331 entity_updates: Vec::new(),
332 skipped: None,
333 };
334 }
335
336 self.base.evaluate(exchange, store)
338 }
339}
340
341#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::vector_search::{compute_norm, cosine_similarity_batch_prenorm, top_k};
349
350 struct TestStoreView {
352 embeddings: Vec<Vec<f32>>,
353 chunks: Vec<String>,
354 norms: Vec<f32>,
355 tombstones: Vec<u8>,
356 }
357
358 impl TestStoreView {
359 fn new() -> Self {
360 Self {
361 embeddings: Vec::new(),
362 chunks: Vec::new(),
363 norms: Vec::new(),
364 tombstones: Vec::new(),
365 }
366 }
367
368 fn add(&mut self, chunk: &str, embedding: Vec<f32>) {
369 let norm = compute_norm(&embedding);
370 self.embeddings.push(embedding);
371 self.chunks.push(chunk.to_string());
372 self.norms.push(norm);
373 self.tombstones.push(0);
374 }
375 }
376
377 impl MemoryStoreView for TestStoreView {
378 fn search(&self, query: &[f32], k: usize) -> Vec<SearchResult> {
379 let scored = cosine_similarity_batch_prenorm(
380 query,
381 &self.embeddings,
382 &self.norms,
383 &self.tombstones,
384 );
385 let top = top_k(scored, k);
386 top.into_iter()
387 .map(|(idx, score)| SearchResult {
388 score,
389 chunk: self.chunks[idx].clone(),
390 index: idx,
391 timestamp: 0.0,
392 source_channel: "test".to_string(),
393 activation: 1.0,
394 })
395 .collect()
396 }
397
398 fn memory_count(&self) -> usize {
399 self.embeddings.len()
400 }
401
402 fn entity_count(&self) -> usize {
403 0
404 }
405 }
406
407 fn substantive_exchange() -> Exchange {
408 Exchange {
409 user_turn: "Tell me about the deployment architecture for our microservices".to_string(),
410 agent_turn: "The deployment uses Kubernetes with three namespaces for staging, QA, and production".to_string(),
411 session_id: "sess-1".to_string(),
412 turn_number: 1,
413 timestamp: 1000000.0,
414 user_embedding: Some(vec![1.0, 0.0, 0.0, 0.0]),
415 agent_embedding: Some(vec![0.0, 1.0, 0.0, 0.0]),
416 }
417 }
418
419 fn trivial_exchange() -> Exchange {
420 Exchange {
421 user_turn: "ok".to_string(),
422 agent_turn: "Got it!".to_string(),
423 session_id: "sess-1".to_string(),
424 turn_number: 2,
425 timestamp: 1000001.0,
426 user_embedding: Some(vec![0.1, 0.1, 0.0, 0.0]),
427 agent_embedding: None,
428 }
429 }
430
431 #[test]
433 fn test_save_every_exchange_combined() {
434 let strategy = SaveEveryExchange::default();
435 let store = TestStoreView::new();
436 let exchange = substantive_exchange();
437
438 let output = strategy.evaluate(&exchange, &store);
439 assert!(output.skipped.is_none());
440 assert_eq!(output.entries.len(), 1);
441 assert!(output.entries[0].chunk.contains("deployment architecture"));
442 assert!(output.entries[0].chunk.contains("---"));
443 assert!(output.entries[0].chunk.contains("Kubernetes"));
444 assert_eq!(output.entries[0].embedding.len(), 4);
446 assert!((output.entries[0].embedding[0] - 0.5).abs() < 1e-6);
447 assert!((output.entries[0].embedding[1] - 0.5).abs() < 1e-6);
448 }
449
450 #[test]
452 fn test_save_every_exchange_trivial_skip() {
453 let strategy = SaveEveryExchange::default();
454 let store = TestStoreView::new();
455 let exchange = trivial_exchange();
456
457 let output = strategy.evaluate(&exchange, &store);
458 assert!(output.entries.is_empty());
459 assert!(matches!(output.skipped, Some(SkipReason::Trivial)));
460 }
461
462 #[test]
464 fn test_save_every_exchange_both() {
465 let strategy = SaveEveryExchange {
466 gate: DecisionGate::new(GateConfig::default()),
467 save_as: SaveAs::Both,
468 };
469 let store = TestStoreView::new();
470 let exchange = substantive_exchange();
471
472 let output = strategy.evaluate(&exchange, &store);
473 assert!(output.skipped.is_none());
474 assert_eq!(output.entries.len(), 2);
475 assert!(output.entries[0].chunk.contains("deployment architecture"));
476 assert!(output.entries[1].chunk.contains("Kubernetes"));
477 }
478
479 #[test]
481 fn test_save_every_exchange_user_only() {
482 let strategy = SaveEveryExchange {
483 gate: DecisionGate::new(GateConfig::default()),
484 save_as: SaveAs::UserTurn,
485 };
486 let store = TestStoreView::new();
487 let exchange = substantive_exchange();
488
489 let output = strategy.evaluate(&exchange, &store);
490 assert_eq!(output.entries.len(), 1);
491 assert!(output.entries[0].chunk.contains("deployment architecture"));
492 assert!(!output.entries[0].chunk.contains("Kubernetes"));
493 assert_eq!(output.entries[0].embedding, vec![1.0, 0.0, 0.0, 0.0]);
495 }
496
497 #[test]
499 fn test_semantic_shift_novel() {
500 let strategy = SaveOnSemanticShift::default();
501 let mut store = TestStoreView::new();
502 store.add("The weather is nice today", vec![0.0, 0.0, 1.0, 0.0]);
504
505 let exchange = substantive_exchange();
506 let output = strategy.evaluate(&exchange, &store);
507
508 assert!(output.skipped.is_none());
509 assert_eq!(output.entries.len(), 1);
510 }
511
512 #[test]
514 fn test_semantic_shift_duplicate() {
515 let strategy = SaveOnSemanticShift::default();
516 let mut store = TestStoreView::new();
517 store.add("deployment architecture details", vec![1.0, 0.0, 0.0, 0.0]);
519
520 let exchange = substantive_exchange();
521 let output = strategy.evaluate(&exchange, &store);
523
524 assert!(output.entries.is_empty());
525 assert!(matches!(output.skipped, Some(SkipReason::Duplicate)));
526 }
527
528 #[test]
530 fn test_semantic_shift_no_embedding() {
531 let strategy = SaveOnSemanticShift::default();
532 let store = TestStoreView::new();
533
534 let mut exchange = substantive_exchange();
535 exchange.user_embedding = None;
536
537 let output = strategy.evaluate(&exchange, &store);
538 assert!(output.skipped.is_none());
539 assert_eq!(output.entries.len(), 1);
540 }
541
542 #[test]
544 fn test_correction_detected() {
545 let base = SaveEveryExchange::default();
546 let strategy = SaveOnUserCorrection::new(Box::new(base));
547 let store = TestStoreView::new();
548
549 let exchange = Exchange {
550 user_turn: "Actually, thats wrong. The answer is 42".to_string(),
551 agent_turn: "You're right, I apologize. The answer is indeed 42.".to_string(),
552 session_id: "sess-1".to_string(),
553 turn_number: 3,
554 timestamp: 1000002.0,
555 user_embedding: Some(vec![0.5, 0.5, 0.0, 0.0]),
556 agent_embedding: None,
557 };
558
559 let output = strategy.evaluate(&exchange, &store);
560 assert!(output.skipped.is_none());
561 assert_eq!(output.entries.len(), 1);
562 assert_eq!(output.entries[0].source_channel, "correction");
563 }
564
565 #[test]
567 fn test_correction_delegates_to_base() {
568 let base = SaveEveryExchange::default();
569 let strategy = SaveOnUserCorrection::new(Box::new(base));
570 let store = TestStoreView::new();
571
572 let exchange = substantive_exchange();
573 let output = strategy.evaluate(&exchange, &store);
574
575 assert!(output.skipped.is_none());
577 assert_eq!(output.entries.len(), 1);
578 assert_eq!(output.entries[0].source_channel, "conversation");
579 }
580
581 #[test]
583 fn test_correction_wrapping_shift() {
584 let mut store = TestStoreView::new();
585 store.add("deployment stuff", vec![1.0, 0.0, 0.0, 0.0]);
587
588 let base = SaveOnSemanticShift::default();
589 let strategy = SaveOnUserCorrection::new(Box::new(base));
590
591 let correction = Exchange {
593 user_turn: "No, thats wrong. The deployment uses ECS not EKS".to_string(),
594 agent_turn: "Corrected: the deployment uses ECS".to_string(),
595 session_id: "sess-1".to_string(),
596 turn_number: 4,
597 timestamp: 1000003.0,
598 user_embedding: Some(vec![1.0, 0.0, 0.0, 0.0]),
599 agent_embedding: None,
600 };
601 let output = strategy.evaluate(&correction, &store);
602 assert!(output.skipped.is_none(), "correction should bypass shift");
603 assert_eq!(output.entries[0].source_channel, "correction");
604
605 let non_correction = Exchange {
607 user_turn: "Tell me about the deployment architecture for our microservices".to_string(),
608 agent_turn: "The deployment uses Kubernetes".to_string(),
609 session_id: "sess-1".to_string(),
610 turn_number: 5,
611 timestamp: 1000004.0,
612 user_embedding: Some(vec![1.0, 0.0, 0.0, 0.0]),
613 agent_embedding: None,
614 };
615 let output2 = strategy.evaluate(&non_correction, &store);
616 assert!(matches!(output2.skipped, Some(SkipReason::Duplicate)));
617 }
618
619 #[test]
621 fn test_skip_reason_returned() {
622 let store = TestStoreView::new();
623
624 let s1 = SaveEveryExchange::default();
626 let out1 = s1.evaluate(&trivial_exchange(), &store);
627 assert!(matches!(out1.skipped, Some(SkipReason::Trivial)));
628
629 let mut dup_store = TestStoreView::new();
631 dup_store.add("exact match", vec![1.0, 0.0, 0.0, 0.0]);
632 let s2 = SaveOnSemanticShift::default();
633 let out2 = s2.evaluate(&substantive_exchange(), &dup_store);
634 assert!(matches!(out2.skipped, Some(SkipReason::Duplicate)));
635
636 let custom = SkipReason::Custom("test reason".to_string());
638 assert!(matches!(custom, SkipReason::Custom(_)));
639
640 let below = SkipReason::BelowThreshold;
642 assert!(matches!(below, SkipReason::BelowThreshold));
643 }
644
645 #[test]
647 fn test_entity_updates() {
648 let output = StrategyOutput {
649 entries: Vec::new(),
650 entity_updates: vec![EntityUpdate {
651 name: "Alice".to_string(),
652 entity_type: "person".to_string(),
653 aliases: vec!["my friend".to_string()],
654 }],
655 skipped: None,
656 };
657 assert_eq!(output.entity_updates.len(), 1);
658 assert_eq!(output.entity_updates[0].name, "Alice");
659 assert_eq!(output.entity_updates[0].aliases, vec!["my friend"]);
660 }
661
662 #[test]
664 fn test_record_with_strategy() {
665 use crate::{AgentMemory, HDF5Memory, MemoryConfig};
666 let dir = tempfile::TempDir::new().unwrap();
667 let config = MemoryConfig::new(dir.path().join("test.h5"), "agent-test", 4);
668 let mut mem = HDF5Memory::create(config).unwrap();
669 mem.set_strategy(Box::new(SaveEveryExchange::default()));
670 let exchange = Exchange {
671 user_turn: "Tell me about the deployment architecture for microservices".into(),
672 agent_turn: "It uses Kubernetes".into(),
673 session_id: "s1".into(), turn_number: 1, timestamp: 1e6,
674 user_embedding: Some(vec![1.0, 0.0, 0.0, 0.0]),
675 agent_embedding: Some(vec![0.0, 1.0, 0.0, 0.0]),
676 };
677 let out = mem.record(exchange).unwrap();
678 assert!(out.skipped.is_none());
679 assert_eq!(mem.count(), 1);
680 }
681
682 #[test]
684 fn test_record_trivial_skip() {
685 use crate::{AgentMemory, HDF5Memory, MemoryConfig};
686 let dir = tempfile::TempDir::new().unwrap();
687 let config = MemoryConfig::new(dir.path().join("test.h5"), "agent-test", 4);
688 let mut mem = HDF5Memory::create(config).unwrap();
689 mem.set_strategy(Box::new(SaveEveryExchange::default()));
690 let exchange = Exchange {
691 user_turn: "ok".into(), agent_turn: "Got it!".into(),
692 session_id: "s1".into(), turn_number: 2, timestamp: 1e6,
693 user_embedding: None, agent_embedding: None,
694 };
695 let out = mem.record(exchange).unwrap();
696 assert!(out.skipped.is_some());
697 assert_eq!(mem.count(), 0);
698 }
699}