devsper_memory/
supermemory.rs1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use std::collections::{HashMap, HashSet};
4use std::io::{self, Read};
5
6#[derive(Debug, Deserialize)]
7struct RankRequest {
8 query_text: String,
9 query_embedding: Option<Vec<f32>>,
10 top_k: usize,
11 min_similarity: f32,
12 embed_weight: f32,
13 candidates: Vec<Candidate>,
14}
15
16#[derive(Debug, Deserialize)]
17struct Candidate {
18 id: String,
19 content: String,
20 tags: Vec<String>,
21 embedding: Option<Vec<f32>>,
22 timestamp: Option<String>,
23 memory_type: Option<String>,
24 source_task: Option<String>,
25}
26
27#[derive(Debug, Serialize)]
28struct Ranked {
29 id: String,
30 score: f32,
31}
32
33#[derive(Debug, Serialize)]
34struct RankResponse {
35 ranked: Vec<Ranked>,
36}
37
38#[derive(Debug, Deserialize)]
39struct Injection {
40 content: String,
41 tags: Option<Vec<String>>,
42}
43
44#[derive(Debug, Deserialize)]
45struct FormatContextRequest {
46 user_injections: Vec<Injection>,
47 ranked_candidates: Vec<Candidate>,
48}
49
50#[derive(Debug, Serialize)]
51struct FormatContextResponse {
52 context: String,
53}
54
55fn truncate_chars(s: &str, max_chars: usize) -> String {
56 let mut count = 0usize;
57 let mut end_idx = 0usize;
58 for (idx, _) in s.char_indices() {
59 if count >= max_chars {
60 end_idx = idx;
61 break;
62 }
63 end_idx = idx;
64 count += 1;
65 }
66 if s.chars().count() <= max_chars {
67 s.to_string()
68 } else {
69 format!("{}...", s[..end_idx].to_string())
70 }
71}
72
73fn token_set(s: &str, re: &Regex) -> HashSet<String> {
74 re.find_iter(s)
75 .map(|m| m.as_str().to_ascii_lowercase())
76 .filter(|t| t.len() >= 2)
77 .collect()
78}
79
80fn overlap_score(query_terms: &HashSet<String>, candidate_terms: &HashSet<String>) -> f32 {
81 if query_terms.is_empty() || candidate_terms.is_empty() {
82 return 0.0;
83 }
84 let mut matches = 0usize;
85 for t in query_terms {
86 if candidate_terms.contains(t) {
87 matches += 1;
88 }
89 }
90 matches as f32 / (query_terms.len() as f32)
91}
92
93fn signature_tokens(s: &str, re: &Regex, max_tokens: usize) -> String {
94 let mut tokens: Vec<String> = re
95 .find_iter(s)
96 .map(|m| m.as_str().to_ascii_lowercase())
97 .filter(|t| t.len() >= 2)
98 .collect();
99 tokens.sort_unstable();
100 tokens.dedup();
101 if tokens.len() > max_tokens {
102 tokens.truncate(max_tokens);
103 }
104 tokens.join(" ")
105}
106
107fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
108 if a.is_empty() || b.is_empty() || a.len() != b.len() {
109 return 0.0;
110 }
111 let mut dot = 0.0f32;
112 let mut na = 0.0f32;
113 let mut nb = 0.0f32;
114 for (x, y) in a.iter().zip(b.iter()) {
115 dot += x * y;
116 na += x * x;
117 nb += y * y;
118 }
119 if na <= 0.0 || nb <= 0.0 {
120 return 0.0;
121 }
122 let denom = (na.sqrt()) * (nb.sqrt());
123 if denom == 0.0 {
124 return 0.0;
125 }
126 let s = dot / denom;
127 if !s.is_finite() {
128 0.0
129 } else {
130 s
131 }
132}
133
134fn parse_rfc3339_epoch_seconds(ts: &str) -> Option<i64> {
135 let dt = chrono::DateTime::parse_from_rfc3339(ts).ok()?;
138 Some(dt.timestamp())
139}
140
141fn memory_type_weight(memory_type: Option<&str>) -> f32 {
142 match memory_type.unwrap_or("").to_ascii_lowercase().as_str() {
143 "research" => 1.15,
144 "artifact" => 1.10,
145 "semantic" => 1.05,
146 "episodic" => 1.00,
147 _ => 1.00,
148 }
149}
150
151pub fn run_ranking() {
152 let mut argv = std::env::args();
153 let _bin = argv.next();
154 let cmd = argv.next().unwrap_or_else(|| "rank".to_string());
155
156 if cmd != "rank" {
157 eprintln!("Unsupported command: {}", cmd);
158 std::process::exit(2);
159 }
160
161 let mut input = String::new();
162 if io::stdin().read_to_string(&mut input).is_err() {
163 std::process::exit(2);
164 }
165
166 let token_re = Regex::new(r"\w+").expect("valid token regex");
167 if cmd == "rank" {
168 let req: RankRequest = match serde_json::from_str(&input) {
169 Ok(v) => v,
170 Err(e) => {
171 eprintln!("Invalid JSON input: {e}");
172 std::process::exit(2);
173 }
174 };
175
176 let query_terms = token_set(&req.query_text, &token_re);
177 let embed_weight = req.embed_weight.clamp(0.0, 1.0);
178 let top_k = if req.top_k < 1 { 1 } else { req.top_k };
179
180 let mut ts_vals: Vec<i64> = req
182 .candidates
183 .iter()
184 .filter_map(|c| c.timestamp.as_deref().and_then(parse_rfc3339_epoch_seconds))
185 .collect();
186 ts_vals.sort_unstable();
187 ts_vals.dedup();
188 let min_ts = ts_vals.first().copied().unwrap_or(0);
189 let max_ts = ts_vals.last().copied().unwrap_or(0);
190 let ts_span = (max_ts - min_ts) as f32;
191
192 let has_query_embedding = req.query_embedding.is_some();
193 let recency_weight = if has_query_embedding { 0.02f32 } else { 0.05f32 };
194 let mut per_key: HashMap<String, (Candidate, f32, Option<i64>)> = HashMap::new();
195
196 for c in req.candidates.into_iter() {
197 let content_tokens = token_set(&c.content, &token_re);
198 let tags_text = c.tags.join(" ");
199 let tags_tokens = token_set(&tags_text, &token_re);
200
201 let content_score = overlap_score(&query_terms, &content_tokens);
202 let tag_score = overlap_score(&query_terms, &tags_tokens);
203 let lexical = 0.8f32 * content_score + 0.2f32 * tag_score;
204
205 let mut base_score = lexical;
206 if let (Some(qe), Some(ce)) = (&req.query_embedding, &c.embedding) {
207 if !ce.is_empty() && qe.len() == ce.len() {
208 let cos = cosine_sim(qe, &ce).max(0.0);
209 base_score = (embed_weight * cos) + ((1.0 - embed_weight) * lexical);
210 }
211 }
212
213 let type_mult = memory_type_weight(c.memory_type.as_deref());
215 base_score *= type_mult;
216
217 let ts = c
219 .timestamp
220 .as_deref()
221 .and_then(parse_rfc3339_epoch_seconds);
222 let recency_norm = if let Some(tsv) = ts {
223 if ts_span > 0.0 {
224 ((tsv - min_ts) as f32) / ts_span
225 } else {
226 0.0
227 }
228 } else {
229 0.0
230 };
231
232 let final_score = base_score + (recency_weight * recency_norm);
233
234 if req.min_similarity > 0.0 && final_score < req.min_similarity {
235 continue;
236 }
237
238 let dedup_key = signature_tokens(&c.content, &token_re, 80);
240 let replace = match per_key.get(&dedup_key) {
241 None => true,
242 Some((best_cand, best_score, best_ts)) => {
243 let best_score_cmp =
244 best_score.partial_cmp(&final_score).unwrap_or(std::cmp::Ordering::Equal);
245 if best_score_cmp == std::cmp::Ordering::Greater {
246 false
247 } else if best_score_cmp == std::cmp::Ordering::Less {
248 true
249 } else {
250 let best_ts_val = best_ts.unwrap_or(i64::MIN);
251 let this_ts_val = ts.unwrap_or(i64::MIN);
252 if this_ts_val != best_ts_val {
253 this_ts_val > best_ts_val
254 } else {
255 c.id < best_cand.id
256 }
257 }
258 }
259 };
260
261 if replace {
262 per_key.insert(dedup_key, (c, final_score, ts));
263 }
264 }
265
266 let mut uniques: Vec<(Candidate, f32, Option<i64>)> = per_key
267 .into_iter()
268 .map(|(_, v)| v)
269 .collect();
270
271 uniques.sort_by(|a, b| {
274 let ord_score = b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal);
275 if ord_score != std::cmp::Ordering::Equal {
276 return ord_score;
277 }
278 let ta = a.2.unwrap_or(i64::MIN);
279 let tb = b.2.unwrap_or(i64::MIN);
280 if ta != tb {
281 return tb.cmp(&ta);
282 }
283 a.0.id.cmp(&b.0.id)
284 });
285
286 let mut ranked: Vec<Ranked> = Vec::new();
287 for (c, score, _) in uniques.into_iter().take(top_k) {
288 ranked.push(Ranked { id: c.id, score });
289 }
290
291 let resp = RankResponse { ranked };
292 let out = serde_json::to_string(&resp).unwrap_or_else(|_| "{\"ranked\":[]}".to_string());
293 print!("{out}");
294 } else if cmd == "format_context" {
295 let req: FormatContextRequest = match serde_json::from_str(&input) {
296 Ok(v) => v,
297 Err(e) => {
298 eprintln!("Invalid JSON input: {e}");
299 std::process::exit(2);
300 }
301 };
302
303 let mut lines: Vec<String> = Vec::new();
304 if !req.user_injections.is_empty() {
305 lines.push("USER INJECTIONS (high priority):".to_string());
306 for inj in req.user_injections.into_iter() {
307 let truncated = {
308 let max = 1000usize;
309 let content = inj.content;
310 if content.chars().count() > max {
311 format!("{}...", content.chars().take(max).collect::<String>())
312 } else {
313 content
314 }
315 };
316 lines.push(format!("- {}", truncated));
317 }
318 }
319
320 let relevant: Vec<&Candidate> = req
321 .ranked_candidates
322 .iter()
323 .filter(|c| {
324 let tags = c.tags.as_slice();
326 !tags.iter().any(|t| t == "user_injection")
327 })
328 .collect();
329
330 if !relevant.is_empty() {
331 if !lines.is_empty() {
332 lines.push("".to_string());
333 }
334 lines.push("RELEVANT MEMORY (previous research notes, findings, artifacts):".to_string());
335 for c in relevant.into_iter() {
336 let mtype = c
337 .memory_type
338 .as_deref()
339 .filter(|s| !s.is_empty())
340 .unwrap_or("general");
341 let src = c
342 .source_task
343 .as_deref()
344 .filter(|s| !s.is_empty())
345 .unwrap_or("general");
346 let content = &c.content;
347 let max = 500usize;
348 let truncated = if content.chars().count() > max {
349 format!("{}...", content.chars().take(max).collect::<String>())
350 } else {
351 content.clone()
352 };
353 lines.push(format!(
354 "- [{}] {}: {}",
355 mtype,
356 src,
357 truncated
358 ));
359 }
360 }
361
362 let context = lines.join("\n");
363 let resp = FormatContextResponse { context };
364 let out = serde_json::to_string(&resp).unwrap_or_else(|_| "{\"context\":\"\"}".to_string());
365 print!("{out}");
366 } else {
367 eprintln!("Unsupported command: {}", cmd);
368 std::process::exit(2);
369 }
370}
371