Skip to main content

devsper_memory/
supermemory.rs

1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use std::io::{self, Read};
5
6#[derive(Debug, Deserialize)]
7struct RankRequest {
8    query_text: String,
9    query_embedding: Option<Vec<f32>>,
10    top_k: usize,
11    min_similarity: f32,
12    embed_weight: f32,
13    candidates: Vec<Candidate>,
14}
15
16#[derive(Debug, Deserialize)]
17struct Candidate {
18    id: String,
19    content: String,
20    tags: Vec<String>,
21    embedding: Option<Vec<f32>>,
22    timestamp: Option<String>,
23    memory_type: Option<String>,
24    source_task: Option<String>,
25}
26
27#[derive(Debug, Serialize)]
28struct Ranked {
29    id: String,
30    score: f32,
31}
32
33#[derive(Debug, Serialize)]
34struct RankResponse {
35    ranked: Vec<Ranked>,
36}
37
38#[derive(Debug, Deserialize)]
39struct Injection {
40    content: String,
41    tags: Option<Vec<String>>,
42}
43
44#[derive(Debug, Deserialize)]
45struct FormatContextRequest {
46    user_injections: Vec<Injection>,
47    ranked_candidates: Vec<Candidate>,
48}
49
50#[derive(Debug, Serialize)]
51struct FormatContextResponse {
52    context: String,
53}
54
55fn truncate_chars(s: &str, max_chars: usize) -> String {
56    let mut count = 0usize;
57    let mut end_idx = 0usize;
58    for (idx, _) in s.char_indices() {
59        if count >= max_chars {
60            end_idx = idx;
61            break;
62        }
63        end_idx = idx;
64        count += 1;
65    }
66    if s.chars().count() <= max_chars {
67        s.to_string()
68    } else {
69        format!("{}...", s[..end_idx].to_string())
70    }
71}
72
73fn token_set(s: &str, re: &Regex) -> HashSet<String> {
74    re.find_iter(s)
75        .map(|m| m.as_str().to_ascii_lowercase())
76        .filter(|t| t.len() >= 2)
77        .collect()
78}
79
80fn overlap_score(query_terms: &HashSet<String>, candidate_terms: &HashSet<String>) -> f32 {
81    if query_terms.is_empty() || candidate_terms.is_empty() {
82        return 0.0;
83    }
84    let mut matches = 0usize;
85    for t in query_terms {
86        if candidate_terms.contains(t) {
87            matches += 1;
88        }
89    }
90    matches as f32 / (query_terms.len() as f32)
91}
92
93fn signature_tokens(s: &str, re: &Regex, max_tokens: usize) -> String {
94    let mut tokens: Vec<String> = re
95        .find_iter(s)
96        .map(|m| m.as_str().to_ascii_lowercase())
97        .filter(|t| t.len() >= 2)
98        .collect();
99    tokens.sort_unstable();
100    tokens.dedup();
101    if tokens.len() > max_tokens {
102        tokens.truncate(max_tokens);
103    }
104    tokens.join(" ")
105}
106
107fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
108    if a.is_empty() || b.is_empty() || a.len() != b.len() {
109        return 0.0;
110    }
111    let mut dot = 0.0f32;
112    let mut na = 0.0f32;
113    let mut nb = 0.0f32;
114    for (x, y) in a.iter().zip(b.iter()) {
115        dot += x * y;
116        na += x * x;
117        nb += y * y;
118    }
119    if na <= 0.0 || nb <= 0.0 {
120        return 0.0;
121    }
122    let denom = (na.sqrt()) * (nb.sqrt());
123    if denom == 0.0 {
124        return 0.0;
125    }
126    let s = dot / denom;
127    if !s.is_finite() {
128        0.0
129    } else {
130        s
131    }
132}
133
134fn parse_rfc3339_epoch_seconds(ts: &str) -> Option<i64> {
135    // Timestamp comes from Python (`datetime.isoformat()`), typically RFC3339-ish.
136    // We use the RFC3339 parser and fall back to None on failure.
137    let dt = chrono::DateTime::parse_from_rfc3339(ts).ok()?;
138    Some(dt.timestamp())
139}
140
141fn memory_type_weight(memory_type: Option<&str>) -> f32 {
142    match memory_type.unwrap_or("").to_ascii_lowercase().as_str() {
143        "research" => 1.15,
144        "artifact" => 1.10,
145        "semantic" => 1.05,
146        "episodic" => 1.00,
147        _ => 1.00,
148    }
149}
150
151pub fn run_ranking() {
152    let mut argv = std::env::args();
153    let _bin = argv.next();
154    let cmd = argv.next().unwrap_or_else(|| "rank".to_string());
155
156    if cmd != "rank" {
157        eprintln!("Unsupported command: {}", cmd);
158        std::process::exit(2);
159    }
160
161    let mut input = String::new();
162    if io::stdin().read_to_string(&mut input).is_err() {
163        std::process::exit(2);
164    }
165
166    let token_re = Regex::new(r"\w+").expect("valid token regex");
167    if cmd == "rank" {
168        let req: RankRequest = match serde_json::from_str(&input) {
169            Ok(v) => v,
170            Err(e) => {
171                eprintln!("Invalid JSON input: {e}");
172                std::process::exit(2);
173            }
174        };
175
176        let query_terms = token_set(&req.query_text, &token_re);
177        let embed_weight = req.embed_weight.clamp(0.0, 1.0);
178        let top_k = if req.top_k < 1 { 1 } else { req.top_k };
179
180        // Parse timestamps once so we can normalize recency.
181        let mut ts_vals: Vec<i64> = req
182            .candidates
183            .iter()
184            .filter_map(|c| c.timestamp.as_deref().and_then(parse_rfc3339_epoch_seconds))
185            .collect();
186        ts_vals.sort_unstable();
187        ts_vals.dedup();
188        let min_ts = ts_vals.first().copied().unwrap_or(0);
189        let max_ts = ts_vals.last().copied().unwrap_or(0);
190        let ts_span = (max_ts - min_ts) as f32;
191
192        let has_query_embedding = req.query_embedding.is_some();
193        let recency_weight = if has_query_embedding { 0.02f32 } else { 0.05f32 };
194        let mut per_key: HashMap<String, (Candidate, f32, Option<i64>)> = HashMap::new();
195
196        for c in req.candidates.into_iter() {
197            let content_tokens = token_set(&c.content, &token_re);
198            let tags_text = c.tags.join(" ");
199            let tags_tokens = token_set(&tags_text, &token_re);
200
201            let content_score = overlap_score(&query_terms, &content_tokens);
202            let tag_score = overlap_score(&query_terms, &tags_tokens);
203            let lexical = 0.8f32 * content_score + 0.2f32 * tag_score;
204
205            let mut base_score = lexical;
206            if let (Some(qe), Some(ce)) = (&req.query_embedding, &c.embedding) {
207                if !ce.is_empty() && qe.len() == ce.len() {
208                    let cos = cosine_sim(qe, &ce).max(0.0);
209                    base_score = (embed_weight * cos) + ((1.0 - embed_weight) * lexical);
210                }
211            }
212
213            // Type weighting (devsper routing intent).
214            let type_mult = memory_type_weight(c.memory_type.as_deref());
215            base_score *= type_mult;
216
217            // Recency normalization (tie-break when scores are close).
218            let ts = c
219                .timestamp
220                .as_deref()
221                .and_then(parse_rfc3339_epoch_seconds);
222            let recency_norm = if let Some(tsv) = ts {
223                if ts_span > 0.0 {
224                    ((tsv - min_ts) as f32) / ts_span
225                } else {
226                    0.0
227                }
228            } else {
229                0.0
230            };
231
232            let final_score = base_score + (recency_weight * recency_norm);
233
234            if req.min_similarity > 0.0 && final_score < req.min_similarity {
235                continue;
236            }
237
238            // Deterministic dedup key (normalized content tokens).
239            let dedup_key = signature_tokens(&c.content, &token_re, 80);
240            let replace = match per_key.get(&dedup_key) {
241                None => true,
242                Some((best_cand, best_score, best_ts)) => {
243                    let best_score_cmp =
244                        best_score.partial_cmp(&final_score).unwrap_or(std::cmp::Ordering::Equal);
245                    if best_score_cmp == std::cmp::Ordering::Greater {
246                        false
247                    } else if best_score_cmp == std::cmp::Ordering::Less {
248                        true
249                    } else {
250                        let best_ts_val = best_ts.unwrap_or(i64::MIN);
251                        let this_ts_val = ts.unwrap_or(i64::MIN);
252                        if this_ts_val != best_ts_val {
253                            this_ts_val > best_ts_val
254                        } else {
255                            c.id < best_cand.id
256                        }
257                    }
258                }
259            };
260
261            if replace {
262                per_key.insert(dedup_key, (c, final_score, ts));
263            }
264        }
265
266        let mut uniques: Vec<(Candidate, f32, Option<i64>)> = per_key
267            .into_iter()
268            .map(|(_, v)| v)
269            .collect();
270
271        // Deterministic ordering:
272        // final_score desc, timestamp desc, id asc.
273        uniques.sort_by(|a, b| {
274            let ord_score = b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal);
275            if ord_score != std::cmp::Ordering::Equal {
276                return ord_score;
277            }
278            let ta = a.2.unwrap_or(i64::MIN);
279            let tb = b.2.unwrap_or(i64::MIN);
280            if ta != tb {
281                return tb.cmp(&ta);
282            }
283            a.0.id.cmp(&b.0.id)
284        });
285
286        let mut ranked: Vec<Ranked> = Vec::new();
287        for (c, score, _) in uniques.into_iter().take(top_k) {
288            ranked.push(Ranked { id: c.id, score });
289        }
290
291        let resp = RankResponse { ranked };
292        let out = serde_json::to_string(&resp).unwrap_or_else(|_| "{\"ranked\":[]}".to_string());
293        print!("{out}");
294    } else if cmd == "format_context" {
295        let req: FormatContextRequest = match serde_json::from_str(&input) {
296            Ok(v) => v,
297            Err(e) => {
298                eprintln!("Invalid JSON input: {e}");
299                std::process::exit(2);
300            }
301        };
302
303        let mut lines: Vec<String> = Vec::new();
304        if !req.user_injections.is_empty() {
305            lines.push("USER INJECTIONS (high priority):".to_string());
306            for inj in req.user_injections.into_iter() {
307                let truncated = {
308                    let max = 1000usize;
309                    let content = inj.content;
310                    if content.chars().count() > max {
311                        format!("{}...", content.chars().take(max).collect::<String>())
312                    } else {
313                        content
314                    }
315                };
316                lines.push(format!("- {}", truncated));
317            }
318        }
319
320        let relevant: Vec<&Candidate> = req
321            .ranked_candidates
322            .iter()
323            .filter(|c| {
324                // Skip duplicates: if candidate is itself a user injection, don't render it under relevant memories.
325                let tags = c.tags.as_slice();
326                !tags.iter().any(|t| t == "user_injection")
327            })
328            .collect();
329
330        if !relevant.is_empty() {
331            if !lines.is_empty() {
332                lines.push("".to_string());
333            }
334            lines.push("RELEVANT MEMORY (previous research notes, findings, artifacts):".to_string());
335            for c in relevant.into_iter() {
336                let mtype = c
337                    .memory_type
338                    .as_deref()
339                    .filter(|s| !s.is_empty())
340                    .unwrap_or("general");
341                let src = c
342                    .source_task
343                    .as_deref()
344                    .filter(|s| !s.is_empty())
345                    .unwrap_or("general");
346                let content = &c.content;
347                let max = 500usize;
348                let truncated = if content.chars().count() > max {
349                    format!("{}...", content.chars().take(max).collect::<String>())
350                } else {
351                    content.clone()
352                };
353                lines.push(format!(
354                    "- [{}] {}: {}",
355                    mtype,
356                    src,
357                    truncated
358                ));
359            }
360        }
361
362        let context = lines.join("\n");
363        let resp = FormatContextResponse { context };
364        let out = serde_json::to_string(&resp).unwrap_or_else(|_| "{\"context\":\"\"}".to_string());
365        print!("{out}");
366    } else {
367        eprintln!("Unsupported command: {}", cmd);
368        std::process::exit(2);
369    }
370}
371