Skip to main content

kyma_memory/
sql.rs

1//! SQL builders for memory recall / listing. The queries run server-side
2//! (where `KymaTable` + the `cosine_distance` UDF live) against the columnar
3//! memory tables. Each query first dedups to the latest version per node id
4//! (`row_number() ... ORDER BY updated_at DESC`), matching the append-only
5//! "latest-wins" model.
6
7use crate::types::RecallFilter;
8
9/// Quote + escape a string literal for inline SQL.
10pub fn sql_str(s: &str) -> String {
11    format!("'{}'", s.replace('\'', "''"))
12}
13
14/// Render an embedding as a DataFusion `make_array(...)` literal.
15fn make_array(embedding: &[f32]) -> String {
16    let mut s = String::with_capacity(embedding.len() * 8 + 12);
17    s.push_str("make_array(");
18    for (i, x) in embedding.iter().enumerate() {
19        if i > 0 {
20            s.push_str(", ");
21        }
22        // f32 Display is the shortest round-trippable representation.
23        s.push_str(&format!("{x}"));
24    }
25    s.push(')');
26    s
27}
28
29fn filter_conditions(filter: &RecallFilter, default_non_archived: bool) -> Vec<String> {
30    let mut conds = Vec::new();
31    if filter.statuses.is_empty() {
32        if default_non_archived {
33            conds.push("status <> 'archived'".to_string());
34        }
35    } else {
36        let list = filter
37            .statuses
38            .iter()
39            .map(|s| sql_str(s.as_str()))
40            .collect::<Vec<_>>()
41            .join(", ");
42        conds.push(format!("status IN ({list})"));
43    }
44    if !filter.realms.is_empty() {
45        let list = filter
46            .realms
47            .iter()
48            .map(|r| sql_str(r))
49            .collect::<Vec<_>>()
50            .join(", ");
51        conds.push(format!("realm IN ({list})"));
52    }
53    if let Some(t) = filter.memory_type {
54        conds.push(format!("memory_type = {}", sql_str(t.as_str())));
55    }
56    if let Some(min) = filter.importance_min {
57        conds.push(format!("importance >= {}", min as f64));
58    }
59    if let Some(since) = &filter.since {
60        conds.push(format!("created_at >= {}", sql_str(since)));
61    }
62    if let Some(until) = &filter.until {
63        conds.push(format!("created_at <= {}", sql_str(until)));
64    }
65    for tag in &filter.tags {
66        let needle = format!("%{}%", tag.replace('%', "").replace('_', ""));
67        conds.push(format!("tags LIKE {}", sql_str(&needle)));
68    }
69    // Bi-temporal validity: exclude superseded/invalidated memories unless asked.
70    // With an `as_of` instant we bound both ends of the validity interval; the
71    // default ("now") just drops anything already invalidated. `valid_at` /
72    // `invalid_at` are NULL on memories written before bi-temporal support, and
73    // NULL is treated as "always valid" so older stores keep working.
74    if !filter.include_invalidated {
75        match &filter.as_of {
76            Some(t) => {
77                conds.push(format!("(invalid_at IS NULL OR invalid_at > {})", sql_str(t)));
78                conds.push(format!("(valid_at IS NULL OR valid_at <= {})", sql_str(t)));
79            }
80            None => conds.push("invalid_at IS NULL".to_string()),
81        }
82    }
83    conds
84}
85
86/// Semantic recall: dedup → cosine distance → blended re-rank
87/// (`relevance*0.7 + importance*0.3`) → top `limit`.
88pub fn recall_sql(
89    node_table: &str,
90    embedding: &[f32],
91    filter: &RecallFilter,
92    limit: usize,
93    ann_threshold: Option<f64>,
94) -> String {
95    let arr = make_array(embedding);
96    let mut conds = filter_conditions(filter, true);
97    if conds.is_empty() {
98        conds.push("1 = 1".to_string());
99    }
100    let where_clause = conds.join(" AND ");
101    // Optional native-ANN activation: when `ann_threshold` is set, add a
102    // `cosine_distance(embedding, q) < threshold` predicate. The planner pushes
103    // it down to the columnar scan, where per-extent centroid+radius stats
104    // prune extents (see kyma-exec/kyma-catalog). `None`/0 → exact full-scan.
105    // Correctness holds either way; the predicate only filters out memories
106    // already more distant than the threshold.
107    let ann = match ann_threshold {
108        Some(t) if t > 0.0 => format!(" AND cosine_distance(embedding, {arr}) < {t}"),
109        _ => String::new(),
110    };
111    format!(
112        "WITH latest AS (\n  \
113           SELECT *, row_number() OVER (PARTITION BY id ORDER BY updated_at DESC) AS __rn FROM {nt}\n), \
114         scored AS (\n  \
115           SELECT id, memory_type, title, content, content_preview, tags, importance, status, realm, created_at, valid_at, invalid_at, \
116                  cosine_distance(embedding, {arr}) AS distance\n  \
117           FROM latest WHERE __rn = 1{ann}\n) \
118         SELECT id, memory_type, title, content, content_preview, tags, importance, status, realm, created_at, valid_at, invalid_at, distance, \
119                ({rw} * (1 - distance) + {iw} * importance) AS score \
120         FROM scored WHERE {where_clause} ORDER BY score DESC LIMIT {limit}",
121        nt = node_table,
122        arr = arr,
123        where_clause = where_clause,
124        limit = limit,
125        rw = crate::RELEVANCE_WEIGHT,
126        iw = crate::IMPORTANCE_WEIGHT,
127    )
128}
129
130/// Non-semantic listing with filters, newest first.
131pub fn list_sql(node_table: &str, filter: &RecallFilter, limit: usize, offset: usize) -> String {
132    let mut conds = filter_conditions(filter, true);
133    if conds.is_empty() {
134        conds.push("1 = 1".to_string());
135    }
136    let where_clause = conds.join(" AND ");
137    format!(
138        "WITH latest AS (\n  \
139           SELECT *, row_number() OVER (PARTITION BY id ORDER BY updated_at DESC) AS __rn FROM {nt}\n) \
140         SELECT id, memory_type, title, content_preview, tags, importance, status, realm, created_at \
141         FROM latest WHERE __rn = 1 AND {where_clause} ORDER BY created_at DESC LIMIT {limit} OFFSET {offset}",
142        nt = node_table,
143        where_clause = where_clause,
144        limit = limit,
145        offset = offset,
146    )
147}
148
149/// Fetch the full latest version of a single node (for read-then-append
150/// mutations like status/importance updates).
151pub fn latest_node_sql(node_table: &str, node_id: &str) -> String {
152    format!(
153        "WITH latest AS (\n  \
154           SELECT *, row_number() OVER (PARTITION BY id ORDER BY updated_at DESC) AS __rn FROM {nt}\n) \
155         SELECT id, labels, realm, memory_type, title, content, content_preview, tags, importance, status, \
156                source_session_id, source_run_id, embedding, created_at, updated_at, \
157                valid_at, invalid_at, superseded_by, provenance, topic_key \
158         FROM latest WHERE __rn = 1 AND id = {idv}",
159        nt = node_table,
160        idv = sql_str(node_id),
161    )
162}
163
164/// Tokenize a free-text query the same way the columnar writer tokenizes
165/// stored text (lowercase, split on non-alphanumeric, drop tokens < 2 chars),
166/// so query tokens line up with the per-extent token index used for pruning.
167pub fn tokenize_query(q: &str) -> Vec<String> {
168    let mut out: Vec<String> = Vec::new();
169    for tok in q.split(|c: char| !c.is_alphanumeric()) {
170        if tok.chars().count() >= 2 {
171            let t = tok.to_ascii_lowercase();
172            if !out.contains(&t) {
173                out.push(t);
174            }
175        }
176    }
177    out
178}
179
180/// Keyword recall: rank latest memories by how many query tokens appear in
181/// `content`/`title`/`tags`. Each `LIKE` triggers the extent token-set pruning
182/// (`ContainsTokens`) so this is index-accelerated. Excludes rows with no hit.
183pub fn keyword_recall_sql(
184    node_table: &str,
185    tokens: &[String],
186    filter: &RecallFilter,
187    limit: usize,
188) -> String {
189    let mut conds = filter_conditions(filter, true);
190    if conds.is_empty() {
191        conds.push("1 = 1".to_string());
192    }
193    let where_clause = conds.join(" AND ");
194    let score_expr = if tokens.is_empty() {
195        "0".to_string()
196    } else {
197        tokens
198            .iter()
199            .map(|t| {
200                let needle = sql_str(&format!("%{}%", t.replace('%', "").replace('_', "")));
201                format!(
202                    "(CASE WHEN lower(content) LIKE {n} OR lower(title) LIKE {n} \
203                       OR lower(tags) LIKE {n} THEN 1 ELSE 0 END)",
204                    n = needle
205                )
206            })
207            .collect::<Vec<_>>()
208            .join(" + ")
209    };
210    format!(
211        "WITH latest AS (\n  \
212           SELECT *, row_number() OVER (PARTITION BY id ORDER BY updated_at DESC) AS __rn FROM {nt}\n), \
213         scored AS (\n  \
214           SELECT id, memory_type, title, content, content_preview, tags, importance, status, realm, created_at, valid_at, invalid_at, \
215                  ({score}) AS kw_score\n  \
216           FROM latest WHERE __rn = 1\n) \
217         SELECT id, memory_type, title, content, content_preview, tags, importance, status, realm, created_at, valid_at, invalid_at, kw_score \
218         FROM scored WHERE {where_clause} AND kw_score > 0 ORDER BY kw_score DESC, importance DESC LIMIT {limit}",
219        nt = node_table,
220        score = score_expr,
221        where_clause = where_clause,
222        limit = limit,
223    )
224}
225
226/// One-hop neighbour edges for a set of seed node ids (both directions), used
227/// to graph-expand a contextual subgraph. Returns edge rows; the orchestrator
228/// derives the far endpoint per seed and follows `target_namespace` for
229/// cross-graph (catalog resource / trace) links. Caller must pass ≥1 seed.
230pub fn neighbors_sql(
231    edge_table: &str,
232    seed_ids: &[String],
233    realms: &[String],
234    limit: usize,
235) -> String {
236    let ids = seed_ids
237        .iter()
238        .map(|s| sql_str(s))
239        .collect::<Vec<_>>()
240        .join(", ");
241    let mut where_parts = vec![format!("(src IN ({ids}) OR dst IN ({ids}))")];
242    if !realms.is_empty() {
243        let rl = realms
244            .iter()
245            .map(|r| sql_str(r))
246            .collect::<Vec<_>>()
247            .join(", ");
248        where_parts.push(format!("realm IN ({rl})"));
249    }
250    format!(
251        "SELECT src, dst, type, realm, target_namespace FROM {et} WHERE {wc} LIMIT {limit}",
252        et = edge_table,
253        wc = where_parts.join(" AND "),
254        limit = limit,
255    )
256}
257
258/// Fetch the latest versions of specific (currently-valid) memory node ids —
259/// used to materialize graph-pulled neighbour memories. Caller must pass ≥1 id.
260pub fn nodes_by_id_sql(node_table: &str, ids: &[String]) -> String {
261    let idlist = ids.iter().map(|s| sql_str(s)).collect::<Vec<_>>().join(", ");
262    format!(
263        "WITH latest AS (\n  \
264           SELECT *, row_number() OVER (PARTITION BY id ORDER BY updated_at DESC) AS __rn FROM {nt}\n) \
265         SELECT id, memory_type, title, content_preview, tags, importance, status, realm, created_at, valid_at, invalid_at \
266         FROM latest WHERE __rn = 1 AND id IN ({idlist}) AND invalid_at IS NULL",
267        nt = node_table,
268        idlist = idlist,
269    )
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::types::{MemoryStatus, MemoryType};
276
277    #[test]
278    fn sql_str_escapes_quotes() {
279        assert_eq!(sql_str("a'b"), "'a''b'");
280    }
281
282    #[test]
283    fn recall_sql_has_cosine_and_blend() {
284        let f = RecallFilter {
285            realms: vec!["proj".into(), "global".into()],
286            memory_type: Some(MemoryType::Fact),
287            ..Default::default()
288        };
289        let s = recall_sql("memory_nodes", &[0.1, 0.2], &f, 8, None);
290        assert!(s.contains("cosine_distance(embedding, make_array(0.1, 0.2))"));
291        assert!(s.contains("0.7 * (1 - distance) + 0.3 * importance"));
292        assert!(s.contains("realm IN ('proj', 'global')"));
293        assert!(s.contains("memory_type = 'fact'"));
294        assert!(s.contains("status <> 'archived'"));
295        // Bi-temporal: invalidated memories excluded by default + columns projected.
296        assert!(s.contains("invalid_at IS NULL"));
297        assert!(s.contains(", valid_at, invalid_at, distance,"));
298        assert!(s.trim_end().ends_with("LIMIT 8"));
299    }
300
301    #[test]
302    fn recall_sql_as_of_bounds_validity_interval() {
303        let f = RecallFilter {
304            as_of: Some("2026-01-01T00:00:00Z".into()),
305            ..Default::default()
306        };
307        let s = recall_sql("memory_nodes", &[0.1], &f, 5, None);
308        assert!(s.contains("(invalid_at IS NULL OR invalid_at > '2026-01-01T00:00:00Z')"));
309        assert!(s.contains("(valid_at IS NULL OR valid_at <= '2026-01-01T00:00:00Z')"));
310    }
311
312    #[test]
313    fn recall_sql_adds_ann_threshold_when_set() {
314        let f = RecallFilter::default();
315        let s = recall_sql("memory_nodes", &[0.1, 0.2], &f, 5, Some(0.4));
316        assert!(s.contains("cosine_distance(embedding, make_array(0.1, 0.2)) < 0.4"));
317        let s2 = recall_sql("memory_nodes", &[0.1, 0.2], &f, 5, None);
318        assert!(!s2.contains("< 0.4"));
319    }
320
321    #[test]
322    fn recall_sql_include_invalidated_drops_validity_guard() {
323        let f = RecallFilter {
324            include_invalidated: true,
325            ..Default::default()
326        };
327        let s = recall_sql("memory_nodes", &[0.1], &f, 5, None);
328        assert!(!s.contains("invalid_at IS NULL"));
329    }
330
331    #[test]
332    fn list_sql_respects_status_filter() {
333        let f = RecallFilter {
334            statuses: vec![MemoryStatus::Active],
335            ..Default::default()
336        };
337        let s = list_sql("memory_nodes", &f, 50, 10);
338        assert!(s.contains("status IN ('active')"));
339        assert!(s.contains("OFFSET 10"));
340    }
341
342    #[test]
343    fn tokenize_query_lowercases_splits_and_dedups() {
344        let toks = tokenize_query("Kyma uses PGVECTOR; kyma!!");
345        assert_eq!(toks, vec!["kyma", "uses", "pgvector"]);
346    }
347
348    #[test]
349    fn keyword_recall_builds_like_scoring() {
350        let toks = tokenize_query("pgvector index");
351        let f = RecallFilter {
352            realms: vec!["proj".into()],
353            ..Default::default()
354        };
355        let s = keyword_recall_sql("memory_nodes", &toks, &f, 10);
356        assert!(s.contains("lower(content) LIKE '%pgvector%'"));
357        assert!(s.contains("lower(tags) LIKE '%index%'"));
358        assert!(s.contains("AS kw_score"));
359        assert!(s.contains("kw_score > 0"));
360        assert!(s.contains("invalid_at IS NULL")); // validity guard threads through
361        assert!(s.contains("realm IN ('proj')"));
362    }
363
364    #[test]
365    fn neighbors_sql_both_directions_and_realm() {
366        let s = neighbors_sql(
367            "memory_edges",
368            &["memory:a".into(), "memory:b".into()],
369            &["proj".into()],
370            100,
371        );
372        assert!(s.contains("src IN ('memory:a', 'memory:b')"));
373        assert!(s.contains("OR dst IN ('memory:a', 'memory:b')"));
374        assert!(s.contains("realm IN ('proj')"));
375        assert!(s.trim_end().ends_with("LIMIT 100"));
376    }
377
378    #[test]
379    fn nodes_by_id_filters_invalidated() {
380        let s = nodes_by_id_sql("memory_nodes", &["memory:a".into()]);
381        assert!(s.contains("id IN ('memory:a')"));
382        assert!(s.contains("invalid_at IS NULL"));
383    }
384}