Skip to main content

cognis/agent/
memory.rs

1//! Conversation memory — the per-agent message buffer.
2//!
3//! Memory implementations:
4//! - [`Window`] — bounded FIFO with optional pinned system prompt.
5//! - [`Buffer`] — unbounded; keeps every message.
6//! - [`TokenBufferMemory`] — token-budgeted trim; drops oldest until under budget.
7//! - [`SummaryMemory`] — summarizes older history with an LLM, drops the originals.
8//! - [`SummaryBufferMemory`] — token budget + summary; the best of both.
9//! - [`VectorMemory`] — semantic recall via a [`VectorStore`]; the most
10//!   relevant past messages are surfaced into the seed.
11
12use std::collections::VecDeque;
13use std::sync::Arc;
14
15use cognis_core::tokenizer::{CharTokenizer, Tokenizer};
16use cognis_core::{trim_messages, Message, TrimStrategy};
17
18use cognis_llm::chat::ChatOptions;
19use cognis_llm::Client;
20use cognis_rag::VectorStore;
21use tokio::sync::RwLock;
22
23/// Pluggable memory backend. The `Agent` reads via `seed()` to build
24/// initial state, and writes incremental messages via `write()`.
25pub trait Memory: Send + Sync {
26    /// All currently buffered messages.
27    fn read(&self) -> &[Message];
28
29    /// Append one message.
30    fn write(&mut self, msg: Message);
31
32    /// Clear all buffered messages (system pinned ones survive in the Window impl).
33    fn clear(&mut self);
34
35    /// Build the seed messages for a fresh graph run. Default: `read().to_vec()`.
36    fn seed(&self) -> Vec<Message> {
37        self.read().to_vec()
38    }
39}
40
41/// Bounded-capacity sliding window. Drops oldest non-system messages
42/// when capacity is hit. The system message (if pinned) is kept at
43/// index 0 across all writes and clears.
44#[derive(Debug, Clone)]
45pub struct Window {
46    capacity: usize,
47    system_pinned: Option<Message>,
48    buf: VecDeque<Message>,
49}
50
51impl Window {
52    /// New empty window with the given capacity (for non-system messages).
53    pub fn new(capacity: usize) -> Self {
54        Self {
55            capacity: capacity.max(1),
56            system_pinned: None,
57            buf: VecDeque::with_capacity(capacity),
58        }
59    }
60
61    /// Pin a system message that survives writes and clears.
62    pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
63        self.system_pinned = Some(Message::system(prompt));
64        self
65    }
66}
67
68impl Memory for Window {
69    fn read(&self) -> &[Message] {
70        // Build a temp slice including system_pinned at the start. Since
71        // `&[Message]` requires contiguous storage and we keep system
72        // separate, we expose the buf only here. `seed()` (overridden
73        // below) handles the merge for callers that need both.
74        self.buf.as_slices().0
75    }
76
77    fn write(&mut self, msg: Message) {
78        if self.buf.len() >= self.capacity {
79            self.buf.pop_front();
80        }
81        self.buf.push_back(msg);
82    }
83
84    fn clear(&mut self) {
85        self.buf.clear();
86    }
87
88    fn seed(&self) -> Vec<Message> {
89        let mut out = Vec::with_capacity(self.buf.len() + 1);
90        if let Some(s) = &self.system_pinned {
91            out.push(s.clone());
92        }
93        out.extend(self.buf.iter().cloned());
94        out
95    }
96}
97
98// ---------------------------------------------------------------------------
99// Buffer — unbounded message store.
100// ---------------------------------------------------------------------------
101
102/// Unbounded memory: keeps every message ever written. Use when conversation
103/// length is small enough that token cost isn't a concern.
104#[derive(Debug, Default, Clone)]
105pub struct Buffer {
106    system_pinned: Option<Message>,
107    msgs: Vec<Message>,
108}
109
110impl Buffer {
111    /// Empty buffer.
112    pub fn new() -> Self {
113        Self::default()
114    }
115    /// Pin a system message at the head.
116    pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
117        self.system_pinned = Some(Message::system(prompt));
118        self
119    }
120}
121
122impl Memory for Buffer {
123    fn read(&self) -> &[Message] {
124        &self.msgs
125    }
126    fn write(&mut self, msg: Message) {
127        self.msgs.push(msg);
128    }
129    fn clear(&mut self) {
130        self.msgs.clear();
131    }
132    fn seed(&self) -> Vec<Message> {
133        let mut out = Vec::with_capacity(self.msgs.len() + 1);
134        if let Some(s) = &self.system_pinned {
135            out.push(s.clone());
136        }
137        out.extend(self.msgs.iter().cloned());
138        out
139    }
140}
141
142// ---------------------------------------------------------------------------
143// TokenBufferMemory — drop oldest until under a token budget.
144// ---------------------------------------------------------------------------
145
146/// Token-budgeted memory: every `seed()` call trims the conversation
147/// down to `max_tokens` using the configured [`Tokenizer`]. The pinned
148/// system prompt (if any) is always kept.
149pub struct TokenBufferMemory {
150    system_pinned: Option<Message>,
151    msgs: Vec<Message>,
152    max_tokens: usize,
153    tokenizer: Arc<dyn Tokenizer>,
154    strategy: TrimStrategy,
155}
156
157impl TokenBufferMemory {
158    /// Build with the default `CharTokenizer` (chars-as-tokens; conservative).
159    pub fn new(max_tokens: usize) -> Self {
160        Self {
161            system_pinned: None,
162            msgs: Vec::new(),
163            max_tokens,
164            tokenizer: Arc::new(CharTokenizer),
165            strategy: TrimStrategy::First,
166        }
167    }
168
169    /// Override the tokenizer (e.g. plug in tiktoken).
170    pub fn with_tokenizer(mut self, t: Arc<dyn Tokenizer>) -> Self {
171        self.tokenizer = t;
172        self
173    }
174
175    /// Override the trim strategy. Default: drop oldest first.
176    pub fn with_strategy(mut self, s: TrimStrategy) -> Self {
177        self.strategy = s;
178        self
179    }
180
181    /// Pin a system message at the head of the seed.
182    pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
183        self.system_pinned = Some(Message::system(prompt));
184        self
185    }
186}
187
188impl std::fmt::Debug for TokenBufferMemory {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        f.debug_struct("TokenBufferMemory")
191            .field("max_tokens", &self.max_tokens)
192            .field("strategy", &self.strategy)
193            .field("msgs", &self.msgs.len())
194            .finish()
195    }
196}
197
198impl Memory for TokenBufferMemory {
199    fn read(&self) -> &[Message] {
200        &self.msgs
201    }
202    fn write(&mut self, msg: Message) {
203        self.msgs.push(msg);
204    }
205    fn clear(&mut self) {
206        self.msgs.clear();
207    }
208    fn seed(&self) -> Vec<Message> {
209        let mut all = Vec::with_capacity(self.msgs.len() + 1);
210        if let Some(s) = &self.system_pinned {
211            all.push(s.clone());
212        }
213        all.extend(self.msgs.iter().cloned());
214        trim_messages(
215            &all,
216            self.max_tokens,
217            self.tokenizer.as_ref(),
218            self.strategy,
219        )
220    }
221}
222
223// ---------------------------------------------------------------------------
224// SummaryMemory — LLM-backed compression.
225// ---------------------------------------------------------------------------
226
227/// LLM-backed memory: when message count exceeds `threshold`, summarize the
228/// oldest `threshold/2` messages into a single system message (via the
229/// supplied [`Client`]) and drop the originals.
230///
231/// Summarization is **lazy** — it runs in `seed()` (called by the agent
232/// before each turn). That keeps `write()` synchronous, which is what
233/// the [`Memory`] trait requires.
234pub struct SummaryMemory {
235    system_pinned: Option<Message>,
236    msgs: Vec<Message>,
237    summary: Option<String>,
238    threshold: usize,
239    client: Client,
240    prompt: String,
241}
242
243impl SummaryMemory {
244    /// Build with the LLM client used to summarize, and the message count
245    /// at which compression kicks in.
246    pub fn new(client: Client, threshold: usize) -> Self {
247        Self {
248            system_pinned: None,
249            msgs: Vec::new(),
250            summary: None,
251            threshold,
252            client,
253            prompt: DEFAULT_SUMMARY_PROMPT.to_string(),
254        }
255    }
256
257    /// Pin a system message at the head of the seed.
258    pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
259        self.system_pinned = Some(Message::system(prompt));
260        self
261    }
262
263    /// Override the summarization prompt.
264    pub fn with_prompt(mut self, p: impl Into<String>) -> Self {
265        self.prompt = p.into();
266        self
267    }
268
269    /// Force compression now, regardless of threshold. Useful for tests
270    /// and explicit "compact" calls.
271    pub async fn compact(&mut self) -> cognis_core::Result<()> {
272        if self.msgs.len() < 2 {
273            return Ok(());
274        }
275        let half = self.msgs.len() / 2;
276        let to_summarize: Vec<Message> = self.msgs.drain(..half).collect();
277        let transcript = to_summarize
278            .iter()
279            .map(|m| format!("[{}] {}", role_label(m), m.content()))
280            .collect::<Vec<_>>()
281            .join("\n");
282        let request = format!("{}\n\nConversation:\n{transcript}", self.prompt);
283        let resp = self
284            .client
285            .chat(vec![Message::human(request)], ChatOptions::default())
286            .await?;
287        let new = resp.message.content().to_string();
288        self.summary = Some(match self.summary.take() {
289            Some(prev) => format!("{prev}\n\n{new}"),
290            None => new,
291        });
292        Ok(())
293    }
294}
295
296const DEFAULT_SUMMARY_PROMPT: &str =
297    "Summarize the following conversation in a few sentences. Preserve key \
298     facts, decisions, names, and unfinished work. Output the summary only.";
299
300fn role_label(m: &Message) -> &'static str {
301    match m {
302        Message::Human(_) => "user",
303        Message::Ai(_) => "assistant",
304        Message::System(_) => "system",
305        Message::Tool(_) => "tool",
306    }
307}
308
309impl Memory for SummaryMemory {
310    fn read(&self) -> &[Message] {
311        &self.msgs
312    }
313    fn write(&mut self, msg: Message) {
314        self.msgs.push(msg);
315        // If we're past the threshold, schedule compression on the next
316        // `seed()` (which is async-friendly via the agent's run loop).
317        // We just mark the threshold here — actual compaction happens via
318        // explicit `compact()` calls.
319    }
320    fn clear(&mut self) {
321        self.msgs.clear();
322        self.summary = None;
323    }
324    fn seed(&self) -> Vec<Message> {
325        let mut out = Vec::with_capacity(self.msgs.len() + 2);
326        if let Some(s) = &self.system_pinned {
327            out.push(s.clone());
328        }
329        if let Some(summary) = &self.summary {
330            out.push(Message::system(format!(
331                "Earlier conversation summary:\n{summary}"
332            )));
333        }
334        out.extend(self.msgs.iter().cloned());
335        out
336    }
337}
338
339impl SummaryMemory {
340    /// True when the buffer has grown past the configured threshold and a
341    /// `compact()` call would do work.
342    pub fn needs_compact(&self) -> bool {
343        self.msgs.len() >= self.threshold
344    }
345}
346
347// ---------------------------------------------------------------------------
348// SummaryBufferMemory — token-budgeted buffer with summarized overflow.
349// ---------------------------------------------------------------------------
350
351/// Hybrid memory: keeps the most recent messages whole, but compresses
352/// older ones into a running LLM-generated summary so the total seed
353/// stays under a token budget.
354///
355/// On every `compact()` call (or every `seed()` after a `compact()`),
356/// the oldest messages whose cumulative token cost would push the
357/// transcript over `max_tokens` are summarized into the running summary
358/// and dropped from the message list.
359pub struct SummaryBufferMemory {
360    system_pinned: Option<Message>,
361    msgs: Vec<Message>,
362    summary: Option<String>,
363    max_tokens: usize,
364    tokenizer: Arc<dyn Tokenizer>,
365    client: Client,
366    prompt: String,
367}
368
369impl SummaryBufferMemory {
370    /// Build with a token budget and the LLM client used to summarize
371    /// overflow.
372    pub fn new(client: Client, max_tokens: usize) -> Self {
373        Self {
374            system_pinned: None,
375            msgs: Vec::new(),
376            summary: None,
377            max_tokens,
378            tokenizer: Arc::new(CharTokenizer),
379            client,
380            prompt: DEFAULT_SUMMARY_PROMPT.to_string(),
381        }
382    }
383
384    /// Override the tokenizer.
385    pub fn with_tokenizer(mut self, t: Arc<dyn Tokenizer>) -> Self {
386        self.tokenizer = t;
387        self
388    }
389
390    /// Override the summarization prompt.
391    pub fn with_prompt(mut self, p: impl Into<String>) -> Self {
392        self.prompt = p.into();
393        self
394    }
395
396    /// Pin a system message at the head.
397    pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
398        self.system_pinned = Some(Message::system(prompt));
399        self
400    }
401
402    /// Total token cost of the current seed (system + summary + msgs).
403    fn current_cost(&self) -> usize {
404        let mut total = 0;
405        if let Some(s) = &self.system_pinned {
406            total += self.tokenizer.count(s.content());
407        }
408        if let Some(s) = &self.summary {
409            total += self.tokenizer.count(s);
410        }
411        for m in &self.msgs {
412            total += self.tokenizer.count(m.content());
413        }
414        total
415    }
416
417    /// Force compression now: summarize the oldest messages until
418    /// `current_cost <= max_tokens`. Returns the number of messages
419    /// folded into the summary.
420    pub async fn compact(&mut self) -> cognis_core::Result<usize> {
421        if self.current_cost() <= self.max_tokens {
422            return Ok(0);
423        }
424        // Identify the oldest messages to summarize: take from the front
425        // until the remaining cost is within budget.
426        let mut to_summarize: Vec<Message> = Vec::new();
427        while self.current_cost_with(&self.msgs[to_summarize.len()..]) > self.max_tokens
428            && to_summarize.len() < self.msgs.len()
429        {
430            to_summarize.push(self.msgs[to_summarize.len()].clone());
431        }
432        if to_summarize.is_empty() {
433            return Ok(0);
434        }
435        let n = to_summarize.len();
436        let transcript = to_summarize
437            .iter()
438            .map(|m| format!("[{}] {}", role_label(m), m.content()))
439            .collect::<Vec<_>>()
440            .join("\n");
441        let request = format!("{}\n\nConversation:\n{transcript}", self.prompt);
442        let resp = self
443            .client
444            .chat(vec![Message::human(request)], ChatOptions::default())
445            .await?;
446        let new_summary = resp.message.content().to_string();
447        self.summary = Some(match self.summary.take() {
448            Some(prev) => format!("{prev}\n\n{new_summary}"),
449            None => new_summary,
450        });
451        // Drain compacted messages.
452        self.msgs.drain(..n);
453        Ok(n)
454    }
455
456    fn current_cost_with(&self, tail: &[Message]) -> usize {
457        let mut total = 0;
458        if let Some(s) = &self.system_pinned {
459            total += self.tokenizer.count(s.content());
460        }
461        if let Some(s) = &self.summary {
462            total += self.tokenizer.count(s);
463        }
464        for m in tail {
465            total += self.tokenizer.count(m.content());
466        }
467        total
468    }
469
470    /// True if a `compact()` would do work.
471    pub fn needs_compact(&self) -> bool {
472        self.current_cost() > self.max_tokens
473    }
474}
475
476impl Memory for SummaryBufferMemory {
477    fn read(&self) -> &[Message] {
478        &self.msgs
479    }
480    fn write(&mut self, msg: Message) {
481        self.msgs.push(msg);
482    }
483    fn clear(&mut self) {
484        self.msgs.clear();
485        self.summary = None;
486    }
487    fn seed(&self) -> Vec<Message> {
488        let mut out = Vec::with_capacity(self.msgs.len() + 2);
489        if let Some(s) = &self.system_pinned {
490            out.push(s.clone());
491        }
492        if let Some(summary) = &self.summary {
493            out.push(Message::system(format!(
494                "Earlier conversation summary:\n{summary}"
495            )));
496        }
497        out.extend(self.msgs.iter().cloned());
498        out
499    }
500}
501
502// ---------------------------------------------------------------------------
503// VectorMemory — semantic recall via a VectorStore.
504// ---------------------------------------------------------------------------
505
506/// Memory backed by a [`VectorStore`]. Every `write` adds the message text
507/// to the store (with role metadata). `seed()` returns the system pin
508/// only — agents wanting recall call [`VectorMemory::recall`] explicitly
509/// to pull in the top-k most relevant messages for the current query.
510pub struct VectorMemory {
511    system_pinned: Option<Message>,
512    store: Arc<RwLock<dyn VectorStore>>,
513    k: usize,
514}
515
516impl VectorMemory {
517    /// Wrap a vector store with default k=4.
518    pub fn new(store: Arc<RwLock<dyn VectorStore>>) -> Self {
519        Self {
520            system_pinned: None,
521            store,
522            k: 4,
523        }
524    }
525
526    /// Pin a system message at the head.
527    pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
528        self.system_pinned = Some(Message::system(prompt));
529        self
530    }
531
532    /// Override how many memories to surface per recall.
533    pub fn with_k(mut self, k: usize) -> Self {
534        self.k = k;
535        self
536    }
537
538    /// Pull the top-k semantically-similar past messages for `query`.
539    pub async fn recall(&self, query: &str) -> cognis_core::Result<Vec<Message>> {
540        let hits = self
541            .store
542            .read()
543            .await
544            .similarity_search(query, self.k)
545            .await?;
546        Ok(hits
547            .into_iter()
548            .map(|h| {
549                let role = h
550                    .metadata
551                    .get("role")
552                    .and_then(|v| v.as_str())
553                    .unwrap_or("user");
554                match role {
555                    "assistant" => Message::ai(h.text),
556                    "system" => Message::system(h.text),
557                    _ => Message::human(h.text),
558                }
559            })
560            .collect())
561    }
562}
563
564impl Memory for VectorMemory {
565    fn read(&self) -> &[Message] {
566        // Vector memory has no on-disk message ordering — the store is
567        // keyed by similarity, not time. `read()` is nominal.
568        &[]
569    }
570    fn write(&mut self, msg: Message) {
571        // Synchronous interface — best-effort: spawn a task that persists.
572        let store = self.store.clone();
573        let m = msg.clone();
574        tokio::spawn(async move {
575            let mut meta = std::collections::HashMap::new();
576            meta.insert(
577                "role".into(),
578                serde_json::Value::String(role_label(&m).into()),
579            );
580            let _ = store
581                .write()
582                .await
583                .add_texts(vec![m.content().to_string()], Some(vec![meta]))
584                .await;
585        });
586    }
587    fn clear(&mut self) {
588        // Best-effort: spawn a delete-all. We don't expose `delete_all` on
589        // VectorStore yet, so this is a noop. Future: extend trait.
590    }
591    fn seed(&self) -> Vec<Message> {
592        let mut out = Vec::new();
593        if let Some(s) = &self.system_pinned {
594            out.push(s.clone());
595        }
596        out
597    }
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603
604    #[test]
605    fn write_below_capacity() {
606        let mut w = Window::new(5);
607        w.write(Message::human("a"));
608        w.write(Message::human("b"));
609        assert_eq!(w.seed().len(), 2);
610    }
611
612    #[test]
613    fn fifo_drop_above_capacity() {
614        let mut w = Window::new(2);
615        w.write(Message::human("1"));
616        w.write(Message::human("2"));
617        w.write(Message::human("3"));
618        let seed = w.seed();
619        assert_eq!(seed.len(), 2);
620        assert_eq!(seed[0].content(), "2");
621        assert_eq!(seed[1].content(), "3");
622    }
623
624    #[test]
625    fn system_pinned_survives_clear() {
626        let mut w = Window::new(5).with_system("you are helpful");
627        w.write(Message::human("hi"));
628        w.clear();
629        let seed = w.seed();
630        assert_eq!(seed.len(), 1);
631        assert_eq!(seed[0].content(), "you are helpful");
632    }
633
634    #[test]
635    fn system_pinned_at_index_0() {
636        let mut w = Window::new(5).with_system("system!");
637        w.write(Message::human("u1"));
638        w.write(Message::human("u2"));
639        let seed = w.seed();
640        assert_eq!(seed.len(), 3);
641        assert_eq!(seed[0].content(), "system!");
642        assert_eq!(seed[1].content(), "u1");
643        assert_eq!(seed[2].content(), "u2");
644    }
645
646    #[test]
647    fn token_buffer_drops_oldest_until_under_budget() {
648        // CharTokenizer counts chars. Budget 6 with 3-char messages.
649        let mut m = TokenBufferMemory::new(6);
650        m.write(Message::human("aaa"));
651        m.write(Message::human("bbb"));
652        m.write(Message::human("ccc"));
653        let seed = m.seed();
654        // Two messages fit (3 + 3 = 6); the third would push to 9 → dropped.
655        assert_eq!(seed.len(), 2);
656        // Oldest dropped → tail kept.
657        assert_eq!(seed[0].content(), "bbb");
658        assert_eq!(seed[1].content(), "ccc");
659    }
660
661    #[test]
662    fn token_buffer_keeps_pinned_system() {
663        let mut m = TokenBufferMemory::new(10).with_system("sys");
664        m.write(Message::human("aaaa"));
665        m.write(Message::human("bbbb"));
666        let seed = m.seed();
667        // System ("sys", 3 chars) is pinned; budget is 10; remaining 7 fits 4-char + can fit one more.
668        assert!(!seed.is_empty());
669        assert_eq!(seed[0].content(), "sys");
670    }
671
672    #[test]
673    fn token_buffer_with_strategy_last_drops_newest() {
674        let mut m = TokenBufferMemory::new(6).with_strategy(TrimStrategy::Last);
675        m.write(Message::human("aaa"));
676        m.write(Message::human("bbb"));
677        m.write(Message::human("ccc"));
678        let seed = m.seed();
679        assert_eq!(seed.len(), 2);
680        // Newest dropped → head kept.
681        assert_eq!(seed[0].content(), "aaa");
682        assert_eq!(seed[1].content(), "bbb");
683    }
684
685    #[test]
686    fn token_buffer_clear_removes_all() {
687        let mut m = TokenBufferMemory::new(100);
688        m.write(Message::human("a"));
689        m.clear();
690        assert!(m.seed().is_empty());
691    }
692}
693
694// ────────────────────────────────────────────────────────────────────────
695// EntityMemory — extracts entities + facts from messages, surfaces them
696// back into the seed as a system message.
697// ────────────────────────────────────────────────────────────────────────
698
699/// Extracted entity / fact pair. The fact is a free-form snippet —
700/// typically the sentence the entity appeared in.
701pub type EntityFact = (String, String);
702
703/// Closure-based extractor: text in, `(entity, fact)` pairs out.
704pub type EntityExtractor = Arc<dyn Fn(&str) -> Vec<EntityFact> + Send + Sync>;
705
706/// Buffers messages and maintains a per-entity fact ledger. Each `write`
707/// runs the extractor over the message content; the seed surfaces the
708/// ledger as a system-message preamble so the model can reference prior
709/// observations across turns.
710///
711/// Default extractor = capitalized-word heuristic: any token starting
712/// with an uppercase letter becomes an entity, paired with the sentence
713/// it appeared in. Plug in [`with_extractor`](EntityMemory::with_extractor)
714/// for an LLM-driven version.
715pub struct EntityMemory {
716    buf: Vec<Message>,
717    entities: std::collections::HashMap<String, Vec<String>>,
718    extractor: EntityExtractor,
719    system_pinned: Option<Message>,
720}
721
722impl EntityMemory {
723    /// Empty memory with the default capitalized-word extractor.
724    pub fn new() -> Self {
725        Self {
726            buf: Vec::new(),
727            entities: std::collections::HashMap::new(),
728            extractor: Arc::new(default_entity_extractor),
729            system_pinned: None,
730        }
731    }
732
733    /// Plug in a custom extractor (e.g. an LLM-backed NER).
734    pub fn with_extractor<F>(mut self, f: F) -> Self
735    where
736        F: Fn(&str) -> Vec<EntityFact> + Send + Sync + 'static,
737    {
738        self.extractor = Arc::new(f);
739        self
740    }
741
742    /// Pin a system prompt at the head of the seed.
743    pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
744        self.system_pinned = Some(Message::system(prompt));
745        self
746    }
747
748    /// Inspect the current entity ledger.
749    pub fn entities(&self) -> &std::collections::HashMap<String, Vec<String>> {
750        &self.entities
751    }
752}
753
754impl Default for EntityMemory {
755    fn default() -> Self {
756        Self::new()
757    }
758}
759
760impl Memory for EntityMemory {
761    fn read(&self) -> &[Message] {
762        &self.buf
763    }
764    fn write(&mut self, msg: Message) {
765        for (entity, fact) in (self.extractor)(msg.content()) {
766            self.entities.entry(entity).or_default().push(fact);
767        }
768        self.buf.push(msg);
769    }
770    fn clear(&mut self) {
771        self.buf.clear();
772        self.entities.clear();
773    }
774    fn seed(&self) -> Vec<Message> {
775        let mut out = Vec::with_capacity(self.buf.len() + 2);
776        if let Some(s) = &self.system_pinned {
777            out.push(s.clone());
778        }
779        if !self.entities.is_empty() {
780            let mut keys: Vec<&String> = self.entities.keys().collect();
781            keys.sort();
782            let body = keys
783                .into_iter()
784                .map(|k| {
785                    let facts = self.entities.get(k).unwrap();
786                    let joined = facts.join("; ");
787                    format!("- {k}: {joined}")
788                })
789                .collect::<Vec<_>>()
790                .join("\n");
791            out.push(Message::system(format!("Known entities:\n{body}")));
792        }
793        out.extend(self.buf.iter().cloned());
794        out
795    }
796}
797
798fn default_entity_extractor(text: &str) -> Vec<EntityFact> {
799    // Common capitalized stopwords that lead sentences but aren't entities.
800    const STOPWORDS: &[&str] = &[
801        "The", "A", "An", "This", "That", "These", "Those", "It", "Its", "Their", "There", "Here",
802        "What", "Who", "Which", "When", "Where", "Why", "How", "And", "But", "Or", "If", "Then",
803    ];
804    let mut out = Vec::new();
805    for sentence in split_sentences(text) {
806        for tok in sentence.split_whitespace() {
807            // Strip surrounding punctuation (keeps internal apostrophes).
808            let trimmed: String = tok.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
809            if trimmed.len() >= 2
810                && trimmed.chars().next().is_some_and(|c| c.is_uppercase())
811                && !STOPWORDS.contains(&trimmed.as_str())
812            {
813                out.push((trimmed, sentence.trim().to_string()));
814            }
815        }
816    }
817    out
818}
819
820fn split_sentences(text: &str) -> Vec<&str> {
821    let mut out = Vec::new();
822    let mut start = 0;
823    for (i, c) in text.char_indices() {
824        if matches!(c, '.' | '!' | '?') {
825            let end = i + c.len_utf8();
826            let s = text[start..end].trim();
827            if !s.is_empty() {
828                out.push(s);
829            }
830            start = end;
831        }
832    }
833    let tail = text[start..].trim();
834    if !tail.is_empty() {
835        out.push(tail);
836    }
837    out
838}
839
840// ────────────────────────────────────────────────────────────────────────
841// KnowledgeGraphMemory — buffers messages and extracts (S, P, O) triples,
842// surfaces them as a system-message KB.
843// ────────────────────────────────────────────────────────────────────────
844
845/// Subject-predicate-object triple.
846pub type Triple = (String, String, String);
847
848/// Closure-based triple extractor.
849pub type TripleExtractor = Arc<dyn Fn(&str) -> Vec<Triple> + Send + Sync>;
850
851/// Buffers messages and a triple store. Each `write` extracts
852/// triples; the seed prefixes a `Knowledge:` system message listing
853/// every triple. Plug in an LLM extractor for production use; the
854/// default handles "X is Y" / "X has Y" / "X are Y" patterns.
855pub struct KnowledgeGraphMemory {
856    buf: Vec<Message>,
857    triples: Vec<Triple>,
858    extractor: TripleExtractor,
859    system_pinned: Option<Message>,
860}
861
862impl KnowledgeGraphMemory {
863    /// Empty memory with the default regex extractor.
864    pub fn new() -> Self {
865        Self {
866            buf: Vec::new(),
867            triples: Vec::new(),
868            extractor: Arc::new(default_triple_extractor),
869            system_pinned: None,
870        }
871    }
872
873    /// Plug in a custom extractor.
874    pub fn with_extractor<F>(mut self, f: F) -> Self
875    where
876        F: Fn(&str) -> Vec<Triple> + Send + Sync + 'static,
877    {
878        self.extractor = Arc::new(f);
879        self
880    }
881
882    /// Pin a system prompt at the head of the seed.
883    pub fn with_system(mut self, prompt: impl Into<String>) -> Self {
884        self.system_pinned = Some(Message::system(prompt));
885        self
886    }
887
888    /// Inspect the triple store.
889    pub fn triples(&self) -> &[Triple] {
890        &self.triples
891    }
892}
893
894impl Default for KnowledgeGraphMemory {
895    fn default() -> Self {
896        Self::new()
897    }
898}
899
900impl Memory for KnowledgeGraphMemory {
901    fn read(&self) -> &[Message] {
902        &self.buf
903    }
904    fn write(&mut self, msg: Message) {
905        for t in (self.extractor)(msg.content()) {
906            // Dedupe — the same fact restated stays one triple.
907            if !self.triples.contains(&t) {
908                self.triples.push(t);
909            }
910        }
911        self.buf.push(msg);
912    }
913    fn clear(&mut self) {
914        self.buf.clear();
915        self.triples.clear();
916    }
917    fn seed(&self) -> Vec<Message> {
918        let mut out = Vec::with_capacity(self.buf.len() + 2);
919        if let Some(s) = &self.system_pinned {
920            out.push(s.clone());
921        }
922        if !self.triples.is_empty() {
923            let body = self
924                .triples
925                .iter()
926                .map(|(s, p, o)| format!("- ({s}, {p}, {o})"))
927                .collect::<Vec<_>>()
928                .join("\n");
929            out.push(Message::system(format!("Knowledge:\n{body}")));
930        }
931        out.extend(self.buf.iter().cloned());
932        out
933    }
934}
935
936fn default_triple_extractor(text: &str) -> Vec<Triple> {
937    let mut out = Vec::new();
938    for sentence in split_sentences(text) {
939        // Find linking verbs.
940        for predicate in [" is ", " are ", " has ", " have ", " was ", " were "] {
941            if let Some(idx) = sentence.find(predicate) {
942                let s = sentence[..idx].trim();
943                let o_raw = sentence[idx + predicate.len()..]
944                    .trim_end_matches(['.', '!', '?'])
945                    .trim();
946                if !s.is_empty() && !o_raw.is_empty() {
947                    out.push((
948                        s.to_string(),
949                        predicate.trim().to_string(),
950                        o_raw.to_string(),
951                    ));
952                    break; // one triple per sentence
953                }
954            }
955        }
956    }
957    out
958}
959
960#[cfg(test)]
961mod tests_entity_kg {
962    use super::*;
963
964    #[test]
965    fn entity_memory_extracts_default() {
966        let mut m = EntityMemory::new();
967        m.write(Message::human("Ada writes Rust. Bob reviews Ada's PRs."));
968        let ents = m.entities();
969        assert!(
970            ents.contains_key("Ada"),
971            "got: {:?}",
972            ents.keys().collect::<Vec<_>>()
973        );
974        assert!(ents.contains_key("Rust"));
975        assert!(ents.contains_key("Bob"));
976    }
977
978    #[test]
979    fn entity_memory_seed_includes_summary() {
980        let mut m = EntityMemory::new();
981        m.write(Message::human("Cognis is fast."));
982        let seed = m.seed();
983        // Expect: [system "Known entities:..."] + the original human message.
984        assert_eq!(seed.len(), 2);
985        assert!(matches!(seed[0], Message::System(_)));
986        assert!(seed[0].content().contains("Cognis"));
987    }
988
989    #[test]
990    fn entity_memory_with_custom_extractor() {
991        let mut m = EntityMemory::new()
992            .with_extractor(|_text: &str| vec![("forced".into(), "via custom extractor".into())]);
993        m.write(Message::human("ignored"));
994        assert!(m.entities().contains_key("forced"));
995    }
996
997    #[test]
998    fn entity_memory_clear_drops_everything() {
999        let mut m = EntityMemory::new();
1000        m.write(Message::human("Rust ships."));
1001        m.clear();
1002        assert!(m.entities().is_empty());
1003        assert!(m.read().is_empty());
1004    }
1005
1006    #[test]
1007    fn kg_memory_extracts_is_pattern() {
1008        let mut m = KnowledgeGraphMemory::new();
1009        m.write(Message::human(
1010            "Cognis is a Rust framework. Tokio is async.",
1011        ));
1012        let ts = m.triples();
1013        assert!(ts.contains(&("Cognis".into(), "is".into(), "a Rust framework".into())));
1014        assert!(ts.contains(&("Tokio".into(), "is".into(), "async".into())));
1015    }
1016
1017    #[test]
1018    fn kg_memory_dedupes_repeated_triples() {
1019        let mut m = KnowledgeGraphMemory::new();
1020        m.write(Message::human("Rust is fast."));
1021        m.write(Message::human("Rust is fast."));
1022        assert_eq!(m.triples().len(), 1);
1023    }
1024
1025    #[test]
1026    fn kg_memory_seed_includes_kb() {
1027        let mut m = KnowledgeGraphMemory::new();
1028        m.write(Message::human("Cognis is fast."));
1029        let seed = m.seed();
1030        assert_eq!(seed.len(), 2);
1031        assert!(matches!(seed[0], Message::System(_)));
1032        assert!(seed[0].content().contains("(Cognis, is, fast)"));
1033    }
1034
1035    #[test]
1036    fn kg_memory_with_custom_extractor() {
1037        let mut m = KnowledgeGraphMemory::new()
1038            .with_extractor(|_text: &str| vec![("X".into(), "rel".into(), "Y".into())]);
1039        m.write(Message::human("ignored"));
1040        assert_eq!(m.triples(), &[("X".into(), "rel".into(), "Y".into())]);
1041    }
1042}
1043
1044// ────────────────────────────────────────────────────────────────────────
1045// HybridMemory — combine N member memories. Writes fan out to every
1046// member; seed() concatenates each member's contribution in registration
1047// order. Use to compose specialized memories (e.g. recent buffer +
1048// long-term summary + entity ledger + semantic vector recall).
1049// ────────────────────────────────────────────────────────────────────────
1050
1051/// A memory composed of several member memories. Each `write` is
1052/// broadcast to every member; `seed` concatenates each member's
1053/// contribution in registration order.
1054///
1055/// Use to compose specialized memories — e.g. a `Window` for recent
1056/// turns plus a `SummaryMemory` for older context plus an `EntityMemory`
1057/// to surface known entities. Each member can do its own thing on write
1058/// (the Window will trim, the SummaryMemory will compact, etc.); the
1059/// agent sees a unified seed.
1060pub struct HybridMemory {
1061    members: Vec<Box<dyn Memory>>,
1062    /// Tracks the raw write history so `read()` can return a `&[Message]`
1063    /// without materializing across members. Members own their own
1064    /// (possibly-transformed) buffers.
1065    buf: Vec<Message>,
1066}
1067
1068impl Default for HybridMemory {
1069    fn default() -> Self {
1070        Self::new()
1071    }
1072}
1073
1074impl HybridMemory {
1075    /// Empty hybrid with no members. Add members via [`HybridMemory::with`].
1076    pub fn new() -> Self {
1077        Self {
1078            members: Vec::new(),
1079            buf: Vec::new(),
1080        }
1081    }
1082
1083    /// Append a member memory. Builder-style.
1084    pub fn with(mut self, member: impl Memory + 'static) -> Self {
1085        self.members.push(Box::new(member));
1086        self
1087    }
1088
1089    /// Number of members.
1090    pub fn member_count(&self) -> usize {
1091        self.members.len()
1092    }
1093}
1094
1095impl Memory for HybridMemory {
1096    fn read(&self) -> &[Message] {
1097        &self.buf
1098    }
1099    fn write(&mut self, msg: Message) {
1100        for m in &mut self.members {
1101            m.write(msg.clone());
1102        }
1103        self.buf.push(msg);
1104    }
1105    fn clear(&mut self) {
1106        for m in &mut self.members {
1107            m.clear();
1108        }
1109        self.buf.clear();
1110    }
1111    fn seed(&self) -> Vec<Message> {
1112        let mut out: Vec<Message> = Vec::new();
1113        for m in &self.members {
1114            out.extend(m.seed());
1115        }
1116        out
1117    }
1118}
1119
1120#[cfg(test)]
1121mod tests_hybrid {
1122    use super::*;
1123
1124    #[test]
1125    fn write_fans_out_to_every_member() {
1126        let mut h = HybridMemory::new()
1127            .with(Buffer::new())
1128            .with(Window::new(10));
1129        h.write(Message::human("a"));
1130        h.write(Message::human("b"));
1131        assert_eq!(h.read().len(), 2);
1132        // Both members should have seen both writes.
1133        let seed = h.seed();
1134        // Buffer contributes 2 + Window contributes 2 = 4 (no dedup).
1135        assert_eq!(seed.len(), 4);
1136    }
1137
1138    #[test]
1139    fn clear_empties_every_member() {
1140        let mut h = HybridMemory::new()
1141            .with(Buffer::new())
1142            .with(Window::new(10));
1143        h.write(Message::human("a"));
1144        h.clear();
1145        assert!(h.read().is_empty());
1146        assert!(h.seed().is_empty());
1147    }
1148
1149    #[test]
1150    fn seed_concatenates_in_member_order() {
1151        let mut h = HybridMemory::new()
1152            .with(Buffer::new().with_system("recent context"))
1153            .with(EntityMemory::new());
1154        h.write(Message::human("Cognis is fast."));
1155        let seed = h.seed();
1156        // Buffer: system pin + 1 human msg → 2
1157        // EntityMemory: synthesized "Known entities" system + 1 human → 2
1158        assert_eq!(seed.len(), 4);
1159        // First member's contribution comes first.
1160        assert!(matches!(seed[0], Message::System(_)));
1161        assert_eq!(seed[0].content(), "recent context");
1162    }
1163
1164    #[test]
1165    fn empty_hybrid_round_trips() {
1166        let mut h = HybridMemory::new();
1167        h.write(Message::human("a"));
1168        // No members → seed is empty (only members contribute).
1169        assert!(h.seed().is_empty());
1170        // But read() reflects the canonical write-buffer.
1171        assert_eq!(h.read().len(), 1);
1172        assert_eq!(h.member_count(), 0);
1173    }
1174}