Skip to main content

kiromi_ai_memory/
context.rs

1// SPDX-License-Identifier: Apache-2.0 OR MIT
2//! Plan 12 phase I — `Memory::build_context(focus, opts)`.
3//!
4//! Walks the partition tree top-down from the tenant root to the focus
5//! node, fetching the latest summary at each level whose level-index is
6//! in `opts.include_summaries_at`, then returns a token-budget-bounded
7//! list of [`ContextBlock`]s ready to feed into a prompt.
8
9use serde::{Deserialize, Serialize};
10
11use crate::error::Result;
12use crate::graph::NodeRef;
13use crate::handle::{Memory, MemoryView};
14use crate::summarizer::SummaryStyle;
15use crate::summary::SummarySubject;
16
17/// Kind tag carried by each [`ContextBlock`].
18#[non_exhaustive]
19#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum ContextKind {
22    /// The whole-tenant memo.
23    TenantMemo,
24    /// Per-partition rollup at any depth.
25    PartitionSummary,
26    /// A raw memory body.
27    Memory,
28    /// A linked memory (one hop from the focus).
29    LinkedMemory,
30}
31
32/// One block in an assembled context.
33#[non_exhaustive]
34#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
35pub struct ContextBlock {
36    /// Block kind.
37    pub kind: ContextKind,
38    /// Anchor node — partition for summaries, memory for body / linked.
39    pub anchor: NodeRef,
40    /// Rendered prose.
41    pub text: String,
42    /// Estimated tokens (4 bytes/token heuristic).
43    pub tokens_estimated: u32,
44}
45
46/// Plan 15: how `build_context` orders the budget-filled block list.
47#[non_exhaustive]
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
49#[serde(rename_all = "snake_case")]
50pub enum ContextOrdering {
51    /// Highest importance first, descending. Slice 1 / Plan 12 default.
52    #[default]
53    TopDown,
54    /// U-curve: highest importance at start AND end, less important in
55    /// the middle. Optimises for LLM attention curves under long
56    /// contexts (per the 2024 "Lost in the Middle" line of work).
57    UCurve,
58}
59
60/// Caller-tunable knobs on [`Memory::build_context`].
61#[non_exhaustive]
62#[derive(Debug, Clone, Serialize, Deserialize)]
63#[serde(default)]
64pub struct ContextOpts {
65    /// Hard token budget for the assembled context. Default 4000.
66    pub budget_tokens: u32,
67    /// Include the tenant memo block when present. Default `true`.
68    pub include_tenant_memo: bool,
69    /// Partition-level indices whose summary to include. Default `[0, 1]`.
70    pub include_summaries_at: Vec<u32>,
71    /// Top-K memories to include (when focus is a partition). Default 5.
72    pub include_memories_top_k: u32,
73    /// Style preset for partition summaries. Default `Compact`.
74    pub style: SummaryStyle,
75    /// Plan 15: ordering applied to the post-budget block list. Default
76    /// [`ContextOrdering::TopDown`] preserves the slice 1 behaviour.
77    pub ordering: ContextOrdering,
78}
79
80impl Default for ContextOpts {
81    fn default() -> Self {
82        Self {
83            budget_tokens: 4_000,
84            include_tenant_memo: true,
85            include_summaries_at: vec![0, 1],
86            include_memories_top_k: 5,
87            style: SummaryStyle::Compact,
88            ordering: ContextOrdering::TopDown,
89        }
90    }
91}
92
93impl ContextOpts {
94    /// Set the budget.
95    #[must_use]
96    pub fn with_budget(mut self, n: u32) -> Self {
97        self.budget_tokens = n;
98        self
99    }
100    /// Set include-tenant-memo.
101    #[must_use]
102    pub fn with_include_tenant_memo(mut self, v: bool) -> Self {
103        self.include_tenant_memo = v;
104        self
105    }
106    /// Set top-K memories.
107    #[must_use]
108    pub fn with_top_k(mut self, n: u32) -> Self {
109        self.include_memories_top_k = n;
110        self
111    }
112    /// Set the summary style.
113    #[must_use]
114    pub fn with_style(mut self, s: SummaryStyle) -> Self {
115        self.style = s;
116        self
117    }
118    /// Plan 15: pick the ordering applied to the post-budget block list.
119    #[must_use]
120    pub fn with_ordering(mut self, o: ContextOrdering) -> Self {
121        self.ordering = o;
122        self
123    }
124}
125
126/// Plan 15: re-order an importance-sorted block list (highest first)
127/// into a U-curve so the highest-importance blocks anchor the start
128/// AND the end of the prompt window.
129///
130/// Algorithm: split the input in half (front half gets the ceiling
131/// element on odd N), then alternate pulling one from the front of the
132/// front half with one from the back of the back half. Example:
133/// `[A, B, C, D, E, F, G]` → `[A, G, B, F, C, E, D]`.
134fn u_curve_reorder<T>(mut input: Vec<T>) -> Vec<T> {
135    let n = input.len();
136    if n <= 2 {
137        return input;
138    }
139    // Split in half — front keeps the ceiling-element on odd N so the
140    // most-important block always anchors the start.
141    let mid = n.div_ceil(2);
142    let back: Vec<T> = input.drain(mid..).collect();
143    let front = input;
144    let mut front_iter = front.into_iter();
145    let mut back_iter_rev = back.into_iter().rev();
146    let mut out: Vec<T> = Vec::with_capacity(n);
147    let mut alt: u32 = 0;
148    loop {
149        let pulled = if alt.is_multiple_of(2) {
150            front_iter.next().or_else(|| back_iter_rev.next())
151        } else {
152            back_iter_rev.next().or_else(|| front_iter.next())
153        };
154        match pulled {
155            Some(v) => out.push(v),
156            None => break,
157        }
158        alt = alt.wrapping_add(1);
159    }
160    out
161}
162
163fn estimate_tokens(text: &str) -> u32 {
164    // Heuristic: ~4 bytes per token. Documented inline so callers can
165    // recompute if they care about exact ChatML counts.
166    u32::try_from(text.len() / 4).unwrap_or(u32::MAX)
167}
168
169impl Memory {
170    /// Plan 12 — assemble a token-budget-bounded list of context
171    /// blocks rooted at `focus`. Walks the tenant tree top-down: tenant
172    /// memo → top-level summary → ... → focus → (optionally) linked
173    /// memories or top-K memories under the partition.
174    ///
175    /// ```no_run
176    /// # async fn _ex(mem: kiromi_ai_memory::Memory, r: kiromi_ai_memory::MemoryRef) -> kiromi_ai_memory::Result<()> {
177    /// use kiromi_ai_memory::{ContextOpts, graph::NodeRef};
178    /// let blocks = mem.build_context(NodeRef::Memory(r), ContextOpts::default()).await?;
179    /// # let _ = blocks; Ok(()) }
180    /// ```
181    pub async fn build_context(
182        &self,
183        focus: NodeRef,
184        opts: ContextOpts,
185    ) -> Result<Vec<ContextBlock>> {
186        let mut blocks: Vec<ContextBlock> = Vec::new();
187
188        // 1. Tenant memo.
189        if opts.include_tenant_memo
190            && let Some(memo) = self.tenant_memo().await?
191        {
192            let tokens = estimate_tokens(&memo);
193            blocks.push(ContextBlock {
194                kind: ContextKind::TenantMemo,
195                anchor: NodeRef::Partition(crate::partition::tenant_root_path()),
196                text: memo,
197                tokens_estimated: tokens,
198            });
199        }
200
201        // 2. Walk from root → focus partition, fetching partition
202        //    summaries at each requested level.
203        let focus_partition = match &focus {
204            NodeRef::Memory(r) => Some(r.partition.clone()),
205            NodeRef::Partition(p) => Some(p.clone()),
206            NodeRef::Summary(s) => s.subject.partition_path().cloned(),
207        };
208        if let Some(p) = focus_partition.as_ref() {
209            let mut chain: Vec<crate::partition::PartitionPath> = p.ancestors().collect();
210            chain.reverse();
211            chain.push(p.clone());
212            for path in chain {
213                let level = u32::try_from(path.depth().saturating_sub(1)).unwrap_or(0);
214                if !opts.include_summaries_at.contains(&level) {
215                    continue;
216                }
217                if let Some(rec) = self
218                    .latest_summary(&SummarySubject::Partition(path.clone()), &opts.style)
219                    .await?
220                {
221                    let text = rec.content.prose.clone();
222                    let tokens = estimate_tokens(&text);
223                    blocks.push(ContextBlock {
224                        kind: ContextKind::PartitionSummary,
225                        anchor: NodeRef::Partition(path.clone()),
226                        text,
227                        tokens_estimated: tokens,
228                    });
229                }
230            }
231        }
232
233        // 3. If focus is a memory, include its body + linked memories.
234        if let NodeRef::Memory(r) = &focus {
235            if let Ok(record) = self.get(r).await {
236                let text = record.content.as_str().to_string();
237                let tokens = estimate_tokens(&text);
238                blocks.push(ContextBlock {
239                    kind: ContextKind::Memory,
240                    anchor: NodeRef::Memory(r.clone()),
241                    text,
242                    tokens_estimated: tokens,
243                });
244            }
245            let links = self.links_of(r).await?;
246            for l in links {
247                let dst_ref = crate::memory::MemoryRef {
248                    id: l.dst,
249                    partition: r.partition.clone(),
250                };
251                if let Ok(rec) = self.get(&dst_ref).await {
252                    let text = rec.content.as_str().to_string();
253                    let tokens = estimate_tokens(&text);
254                    blocks.push(ContextBlock {
255                        kind: ContextKind::LinkedMemory,
256                        anchor: NodeRef::Memory(rec.r#ref.clone()),
257                        text,
258                        tokens_estimated: tokens,
259                    });
260                }
261            }
262        }
263
264        // 4. If focus is a partition, list top-K live memories.
265        if let NodeRef::Partition(p) = &focus {
266            let limit = opts.include_memories_top_k;
267            if limit > 0 {
268                let part = crate::partition::Partitions::from_path(p);
269                let page = self
270                    .list(
271                        part,
272                        crate::opts::ListOpts {
273                            limit,
274                            ..Default::default()
275                        },
276                    )
277                    .await?;
278                for mref in page.items {
279                    if let Ok(rec) = self.get(&mref).await {
280                        let text = rec.content.as_str().to_string();
281                        let tokens = estimate_tokens(&text);
282                        blocks.push(ContextBlock {
283                            kind: ContextKind::Memory,
284                            anchor: NodeRef::Memory(mref),
285                            text,
286                            tokens_estimated: tokens,
287                        });
288                    }
289                }
290            }
291        }
292
293        // 5. Greedy-fill: keep walk order, drop tail blocks past budget.
294        let mut accum: u32 = 0;
295        let mut out = Vec::with_capacity(blocks.len());
296        for b in blocks {
297            let next = accum.saturating_add(b.tokens_estimated);
298            if next > opts.budget_tokens {
299                break;
300            }
301            accum = next;
302            out.push(b);
303        }
304        // 6. Plan 15: optional U-curve reorder for long-context attention.
305        if matches!(opts.ordering, ContextOrdering::UCurve) {
306            out = u_curve_reorder(out);
307        }
308        Ok(out)
309    }
310}
311
312/// Plan 15: result of [`Memory::build_context_diff`].
313///
314/// Splits the post-budget block list into the parts that changed
315/// (`added`), dropped (`removed`), and survived (`kept`) since the
316/// supplied snapshot. Chat-style turn loops can re-emit only `added`
317/// to the LLM; `kept` carries anchors with no body so callers know
318/// which blocks the prior turn already saw.
319#[non_exhaustive]
320#[derive(Debug, Clone, Default, Serialize, Deserialize)]
321pub struct ContextDiff {
322    /// Blocks present now but absent at the snapshot.
323    pub added: Vec<ContextBlock>,
324    /// Anchors of blocks present at the snapshot but absent now.
325    pub removed: Vec<NodeRef>,
326    /// Anchors of blocks present in both — body intentionally omitted
327    /// so callers don't double-send.
328    pub kept: Vec<NodeRef>,
329    /// Cheap upper bound on the additional tokens the caller now needs
330    /// to budget for, summed over `added`.
331    pub tokens_estimated_added: u32,
332}
333
334impl MemoryView {
335    /// Plan 15: snapshot-pinned [`Memory::build_context`].
336    ///
337    /// Walks the same tenant-memo / partition-summary / focus chain as
338    /// the engine path but filters every memory + summary anchor through
339    /// the snapshot manifest, so only blocks that were live at snapshot
340    /// time appear. Greedy-fill + ordering match
341    /// [`Memory::build_context`].
342    pub async fn build_context(
343        &self,
344        focus: NodeRef,
345        opts: ContextOpts,
346    ) -> Result<Vec<ContextBlock>> {
347        // Re-run the engine path against a fresh handle then filter
348        // through the manifest.
349        let mem = Memory {
350            inner: std::sync::Arc::clone(&self.inner),
351        };
352        let raw = mem.build_context(focus, opts.clone()).await?;
353        let kept: Vec<ContextBlock> = raw
354            .into_iter()
355            .filter(|b| match &b.anchor {
356                NodeRef::Memory(r) => self.manifest.memory_ids.binary_search(&r.id).is_ok(),
357                NodeRef::Summary(s) => self.manifest.summary_ids.binary_search(&s.id).is_ok(),
358                // Partition + tenant-memo anchors aren't tied to live
359                // ids; carry them through unfiltered.
360                NodeRef::Partition(_) => true,
361            })
362            .collect();
363        // Apply U-curve once more so post-filter ordering matches the
364        // ordering the engine would produce against the same
365        // (now-filtered) input set.
366        if matches!(opts.ordering, ContextOrdering::UCurve) {
367            Ok(u_curve_reorder(kept))
368        } else {
369            Ok(kept)
370        }
371    }
372}
373
374impl Memory {
375    /// Plan 15: assemble a delta of [`ContextBlock`]s relative to a
376    /// prior snapshot.
377    ///
378    /// Builds the current context, builds the snapshot's view of the
379    /// same context, then partitions block anchors into:
380    /// - `added`: present now, absent at snapshot.
381    /// - `removed`: present at snapshot, absent now.
382    /// - `kept`: present in both.
383    ///
384    /// Chat-style turn loops can replay only `added` to the LLM and
385    /// rely on `kept` to remember which blocks the prior turn already
386    /// saw.
387    ///
388    /// ```no_run
389    /// # async fn _ex(mem: kiromi_ai_memory::Memory, s: kiromi_ai_memory::SnapshotRef) -> kiromi_ai_memory::Result<()> {
390    /// use kiromi_ai_memory::{ContextOpts, graph::NodeRef, PartitionPath};
391    /// let path: PartitionPath = "user=alex/topic=meetings".parse().unwrap();
392    /// let diff = mem.build_context_diff(NodeRef::Partition(path), &s, ContextOpts::default()).await?;
393    /// # let _ = diff; Ok(()) }
394    /// ```
395    pub async fn build_context_diff(
396        &self,
397        focus: NodeRef,
398        since: &crate::snapshot::SnapshotRef,
399        opts: ContextOpts,
400    ) -> Result<ContextDiff> {
401        let prior_view = self.at(since).await?;
402        let prior_blocks = prior_view
403            .build_context(focus.clone(), opts.clone())
404            .await?;
405        let now_blocks = self.build_context(focus, opts).await?;
406
407        let prior_anchors: std::collections::HashSet<NodeRef> =
408            prior_blocks.iter().map(|b| b.anchor.clone()).collect();
409        let now_anchors: std::collections::HashSet<NodeRef> =
410            now_blocks.iter().map(|b| b.anchor.clone()).collect();
411
412        let mut added: Vec<ContextBlock> = Vec::new();
413        let mut kept: Vec<NodeRef> = Vec::new();
414        for b in now_blocks {
415            if prior_anchors.contains(&b.anchor) {
416                kept.push(b.anchor);
417            } else {
418                added.push(b);
419            }
420        }
421        let removed: Vec<NodeRef> = prior_blocks
422            .into_iter()
423            .filter(|b| !now_anchors.contains(&b.anchor))
424            .map(|b| b.anchor)
425            .collect();
426        let tokens_estimated_added = added.iter().map(|b| b.tokens_estimated).sum();
427        Ok(ContextDiff {
428            added,
429            removed,
430            kept,
431            tokens_estimated_added,
432        })
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[test]
441    fn u_curve_reorders_seven_alternating_front_back() {
442        let v = vec!["A", "B", "C", "D", "E", "F", "G"];
443        let got = u_curve_reorder(v);
444        // Front half (ceiling on odd N) = [A,B,C,D]; back = [E,F,G].
445        // Alternating front-front-..., back-back-... -> A, G, B, F, C, E, D.
446        assert_eq!(got, vec!["A", "G", "B", "F", "C", "E", "D"]);
447    }
448
449    #[test]
450    fn u_curve_handles_short_lists() {
451        assert_eq!(u_curve_reorder::<i32>(vec![]), Vec::<i32>::new());
452        assert_eq!(u_curve_reorder(vec![1]), vec![1]);
453        assert_eq!(u_curve_reorder(vec![1, 2]), vec![1, 2]);
454    }
455
456    #[test]
457    fn u_curve_even_lists_preserve_endpoints() {
458        let v = vec![1, 2, 3, 4, 5, 6];
459        let got = u_curve_reorder(v);
460        // First element is always the highest-importance, last is the lowest.
461        assert_eq!(got[0], 1);
462    }
463}