1use crate::types::RecallFilter;
8
9pub fn sql_str(s: &str) -> String {
11 format!("'{}'", s.replace('\'', "''"))
12}
13
14fn make_array(embedding: &[f32]) -> String {
16 let mut s = String::with_capacity(embedding.len() * 8 + 12);
17 s.push_str("make_array(");
18 for (i, x) in embedding.iter().enumerate() {
19 if i > 0 {
20 s.push_str(", ");
21 }
22 s.push_str(&format!("{x}"));
24 }
25 s.push(')');
26 s
27}
28
29fn filter_conditions(filter: &RecallFilter, default_non_archived: bool) -> Vec<String> {
30 let mut conds = Vec::new();
31 if filter.statuses.is_empty() {
32 if default_non_archived {
33 conds.push("status <> 'archived'".to_string());
34 }
35 } else {
36 let list = filter
37 .statuses
38 .iter()
39 .map(|s| sql_str(s.as_str()))
40 .collect::<Vec<_>>()
41 .join(", ");
42 conds.push(format!("status IN ({list})"));
43 }
44 if !filter.realms.is_empty() {
45 let list = filter
46 .realms
47 .iter()
48 .map(|r| sql_str(r))
49 .collect::<Vec<_>>()
50 .join(", ");
51 conds.push(format!("realm IN ({list})"));
52 }
53 if let Some(t) = filter.memory_type {
54 conds.push(format!("memory_type = {}", sql_str(t.as_str())));
55 }
56 if let Some(min) = filter.importance_min {
57 conds.push(format!("importance >= {}", min as f64));
58 }
59 if let Some(since) = &filter.since {
60 conds.push(format!("created_at >= {}", sql_str(since)));
61 }
62 if let Some(until) = &filter.until {
63 conds.push(format!("created_at <= {}", sql_str(until)));
64 }
65 for tag in &filter.tags {
66 let needle = format!("%{}%", tag.replace('%', "").replace('_', ""));
67 conds.push(format!("tags LIKE {}", sql_str(&needle)));
68 }
69 if !filter.include_invalidated {
75 match &filter.as_of {
76 Some(t) => {
77 conds.push(format!("(invalid_at IS NULL OR invalid_at > {})", sql_str(t)));
78 conds.push(format!("(valid_at IS NULL OR valid_at <= {})", sql_str(t)));
79 }
80 None => conds.push("invalid_at IS NULL".to_string()),
81 }
82 }
83 conds
84}
85
86pub fn recall_sql(
89 node_table: &str,
90 embedding: &[f32],
91 filter: &RecallFilter,
92 limit: usize,
93 ann_threshold: Option<f64>,
94) -> String {
95 let arr = make_array(embedding);
96 let mut conds = filter_conditions(filter, true);
97 if conds.is_empty() {
98 conds.push("1 = 1".to_string());
99 }
100 let where_clause = conds.join(" AND ");
101 let ann = match ann_threshold {
108 Some(t) if t > 0.0 => format!(" AND cosine_distance(embedding, {arr}) < {t}"),
109 _ => String::new(),
110 };
111 format!(
112 "WITH latest AS (\n \
113 SELECT *, row_number() OVER (PARTITION BY id ORDER BY updated_at DESC) AS __rn FROM {nt}\n), \
114 scored AS (\n \
115 SELECT id, memory_type, title, content, content_preview, tags, importance, status, realm, created_at, valid_at, invalid_at, \
116 cosine_distance(embedding, {arr}) AS distance\n \
117 FROM latest WHERE __rn = 1{ann}\n) \
118 SELECT id, memory_type, title, content, content_preview, tags, importance, status, realm, created_at, valid_at, invalid_at, distance, \
119 ({rw} * (1 - distance) + {iw} * importance) AS score \
120 FROM scored WHERE {where_clause} ORDER BY score DESC LIMIT {limit}",
121 nt = node_table,
122 arr = arr,
123 where_clause = where_clause,
124 limit = limit,
125 rw = crate::RELEVANCE_WEIGHT,
126 iw = crate::IMPORTANCE_WEIGHT,
127 )
128}
129
130pub fn list_sql(node_table: &str, filter: &RecallFilter, limit: usize, offset: usize) -> String {
132 let mut conds = filter_conditions(filter, true);
133 if conds.is_empty() {
134 conds.push("1 = 1".to_string());
135 }
136 let where_clause = conds.join(" AND ");
137 format!(
138 "WITH latest AS (\n \
139 SELECT *, row_number() OVER (PARTITION BY id ORDER BY updated_at DESC) AS __rn FROM {nt}\n) \
140 SELECT id, memory_type, title, content_preview, tags, importance, status, realm, created_at \
141 FROM latest WHERE __rn = 1 AND {where_clause} ORDER BY created_at DESC LIMIT {limit} OFFSET {offset}",
142 nt = node_table,
143 where_clause = where_clause,
144 limit = limit,
145 offset = offset,
146 )
147}
148
149pub fn latest_node_sql(node_table: &str, node_id: &str) -> String {
152 format!(
153 "WITH latest AS (\n \
154 SELECT *, row_number() OVER (PARTITION BY id ORDER BY updated_at DESC) AS __rn FROM {nt}\n) \
155 SELECT id, labels, realm, memory_type, title, content, content_preview, tags, importance, status, \
156 source_session_id, source_run_id, embedding, created_at, updated_at, \
157 valid_at, invalid_at, superseded_by, provenance, topic_key \
158 FROM latest WHERE __rn = 1 AND id = {idv}",
159 nt = node_table,
160 idv = sql_str(node_id),
161 )
162}
163
164pub fn tokenize_query(q: &str) -> Vec<String> {
168 let mut out: Vec<String> = Vec::new();
169 for tok in q.split(|c: char| !c.is_alphanumeric()) {
170 if tok.chars().count() >= 2 {
171 let t = tok.to_ascii_lowercase();
172 if !out.contains(&t) {
173 out.push(t);
174 }
175 }
176 }
177 out
178}
179
180pub fn keyword_recall_sql(
184 node_table: &str,
185 tokens: &[String],
186 filter: &RecallFilter,
187 limit: usize,
188) -> String {
189 let mut conds = filter_conditions(filter, true);
190 if conds.is_empty() {
191 conds.push("1 = 1".to_string());
192 }
193 let where_clause = conds.join(" AND ");
194 let score_expr = if tokens.is_empty() {
195 "0".to_string()
196 } else {
197 tokens
198 .iter()
199 .map(|t| {
200 let needle = sql_str(&format!("%{}%", t.replace('%', "").replace('_', "")));
201 format!(
202 "(CASE WHEN lower(content) LIKE {n} OR lower(title) LIKE {n} \
203 OR lower(tags) LIKE {n} THEN 1 ELSE 0 END)",
204 n = needle
205 )
206 })
207 .collect::<Vec<_>>()
208 .join(" + ")
209 };
210 format!(
211 "WITH latest AS (\n \
212 SELECT *, row_number() OVER (PARTITION BY id ORDER BY updated_at DESC) AS __rn FROM {nt}\n), \
213 scored AS (\n \
214 SELECT id, memory_type, title, content, content_preview, tags, importance, status, realm, created_at, valid_at, invalid_at, \
215 ({score}) AS kw_score\n \
216 FROM latest WHERE __rn = 1\n) \
217 SELECT id, memory_type, title, content, content_preview, tags, importance, status, realm, created_at, valid_at, invalid_at, kw_score \
218 FROM scored WHERE {where_clause} AND kw_score > 0 ORDER BY kw_score DESC, importance DESC LIMIT {limit}",
219 nt = node_table,
220 score = score_expr,
221 where_clause = where_clause,
222 limit = limit,
223 )
224}
225
226pub fn neighbors_sql(
231 edge_table: &str,
232 seed_ids: &[String],
233 realms: &[String],
234 limit: usize,
235) -> String {
236 let ids = seed_ids
237 .iter()
238 .map(|s| sql_str(s))
239 .collect::<Vec<_>>()
240 .join(", ");
241 let mut where_parts = vec![format!("(src IN ({ids}) OR dst IN ({ids}))")];
242 if !realms.is_empty() {
243 let rl = realms
244 .iter()
245 .map(|r| sql_str(r))
246 .collect::<Vec<_>>()
247 .join(", ");
248 where_parts.push(format!("realm IN ({rl})"));
249 }
250 format!(
251 "SELECT src, dst, type, realm, target_namespace FROM {et} WHERE {wc} LIMIT {limit}",
252 et = edge_table,
253 wc = where_parts.join(" AND "),
254 limit = limit,
255 )
256}
257
258pub fn nodes_by_id_sql(node_table: &str, ids: &[String]) -> String {
261 let idlist = ids.iter().map(|s| sql_str(s)).collect::<Vec<_>>().join(", ");
262 format!(
263 "WITH latest AS (\n \
264 SELECT *, row_number() OVER (PARTITION BY id ORDER BY updated_at DESC) AS __rn FROM {nt}\n) \
265 SELECT id, memory_type, title, content_preview, tags, importance, status, realm, created_at, valid_at, invalid_at \
266 FROM latest WHERE __rn = 1 AND id IN ({idlist}) AND invalid_at IS NULL",
267 nt = node_table,
268 idlist = idlist,
269 )
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::types::{MemoryStatus, MemoryType};
276
277 #[test]
278 fn sql_str_escapes_quotes() {
279 assert_eq!(sql_str("a'b"), "'a''b'");
280 }
281
282 #[test]
283 fn recall_sql_has_cosine_and_blend() {
284 let f = RecallFilter {
285 realms: vec!["proj".into(), "global".into()],
286 memory_type: Some(MemoryType::Fact),
287 ..Default::default()
288 };
289 let s = recall_sql("memory_nodes", &[0.1, 0.2], &f, 8, None);
290 assert!(s.contains("cosine_distance(embedding, make_array(0.1, 0.2))"));
291 assert!(s.contains("0.7 * (1 - distance) + 0.3 * importance"));
292 assert!(s.contains("realm IN ('proj', 'global')"));
293 assert!(s.contains("memory_type = 'fact'"));
294 assert!(s.contains("status <> 'archived'"));
295 assert!(s.contains("invalid_at IS NULL"));
297 assert!(s.contains(", valid_at, invalid_at, distance,"));
298 assert!(s.trim_end().ends_with("LIMIT 8"));
299 }
300
301 #[test]
302 fn recall_sql_as_of_bounds_validity_interval() {
303 let f = RecallFilter {
304 as_of: Some("2026-01-01T00:00:00Z".into()),
305 ..Default::default()
306 };
307 let s = recall_sql("memory_nodes", &[0.1], &f, 5, None);
308 assert!(s.contains("(invalid_at IS NULL OR invalid_at > '2026-01-01T00:00:00Z')"));
309 assert!(s.contains("(valid_at IS NULL OR valid_at <= '2026-01-01T00:00:00Z')"));
310 }
311
312 #[test]
313 fn recall_sql_adds_ann_threshold_when_set() {
314 let f = RecallFilter::default();
315 let s = recall_sql("memory_nodes", &[0.1, 0.2], &f, 5, Some(0.4));
316 assert!(s.contains("cosine_distance(embedding, make_array(0.1, 0.2)) < 0.4"));
317 let s2 = recall_sql("memory_nodes", &[0.1, 0.2], &f, 5, None);
318 assert!(!s2.contains("< 0.4"));
319 }
320
321 #[test]
322 fn recall_sql_include_invalidated_drops_validity_guard() {
323 let f = RecallFilter {
324 include_invalidated: true,
325 ..Default::default()
326 };
327 let s = recall_sql("memory_nodes", &[0.1], &f, 5, None);
328 assert!(!s.contains("invalid_at IS NULL"));
329 }
330
331 #[test]
332 fn list_sql_respects_status_filter() {
333 let f = RecallFilter {
334 statuses: vec![MemoryStatus::Active],
335 ..Default::default()
336 };
337 let s = list_sql("memory_nodes", &f, 50, 10);
338 assert!(s.contains("status IN ('active')"));
339 assert!(s.contains("OFFSET 10"));
340 }
341
342 #[test]
343 fn tokenize_query_lowercases_splits_and_dedups() {
344 let toks = tokenize_query("Kyma uses PGVECTOR; kyma!!");
345 assert_eq!(toks, vec!["kyma", "uses", "pgvector"]);
346 }
347
348 #[test]
349 fn keyword_recall_builds_like_scoring() {
350 let toks = tokenize_query("pgvector index");
351 let f = RecallFilter {
352 realms: vec!["proj".into()],
353 ..Default::default()
354 };
355 let s = keyword_recall_sql("memory_nodes", &toks, &f, 10);
356 assert!(s.contains("lower(content) LIKE '%pgvector%'"));
357 assert!(s.contains("lower(tags) LIKE '%index%'"));
358 assert!(s.contains("AS kw_score"));
359 assert!(s.contains("kw_score > 0"));
360 assert!(s.contains("invalid_at IS NULL")); assert!(s.contains("realm IN ('proj')"));
362 }
363
364 #[test]
365 fn neighbors_sql_both_directions_and_realm() {
366 let s = neighbors_sql(
367 "memory_edges",
368 &["memory:a".into(), "memory:b".into()],
369 &["proj".into()],
370 100,
371 );
372 assert!(s.contains("src IN ('memory:a', 'memory:b')"));
373 assert!(s.contains("OR dst IN ('memory:a', 'memory:b')"));
374 assert!(s.contains("realm IN ('proj')"));
375 assert!(s.trim_end().ends_with("LIMIT 100"));
376 }
377
378 #[test]
379 fn nodes_by_id_filters_invalidated() {
380 let s = nodes_by_id_sql("memory_nodes", &["memory:a".into()]);
381 assert!(s.contains("id IN ('memory:a')"));
382 assert!(s.contains("invalid_at IS NULL"));
383 }
384}