Skip to main content

mag/memory_core/
scoring.rs

1use std::borrow::Cow;
2use std::collections::HashSet;
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use super::{EventType, MemoryKind};
6
7#[derive(Debug, Clone, serde::Serialize)]
8pub struct ScoringParams {
9    pub rrf_k: f64,
10    pub rrf_weight_vec: f64,
11    pub rrf_weight_fts: f64,
12    pub abstention_min_text: f64,
13    pub graph_neighbor_factor: f64,
14    pub graph_min_edge_weight: f64,
15    pub word_overlap_weight: f64,
16    pub query_coverage_weight: f64,
17    pub jaccard_weight: f64,
18    pub importance_floor: f64,
19    pub importance_scale: f64,
20    pub context_tag_weight: f64,
21    pub time_decay_days: f64,
22    pub priority_base: f64,
23    pub priority_scale: f64,
24    pub feedback_heavy_suppress: f64,
25    pub feedback_strong_suppress: f64,
26    pub feedback_positive_scale: f64,
27    pub feedback_positive_cap: f64,
28    pub feedback_heavy_threshold: i64,
29    pub neighbor_word_overlap_weight: f64,
30    pub neighbor_importance_floor: f64,
31    pub neighbor_importance_scale: f64,
32    pub graph_seed_min: usize,
33    pub graph_seed_max: usize,
34    /// Fallback multiplier for candidates appearing in both vector and FTS
35    /// result lists during RRF fusion.  The pipeline now uses an adaptive
36    /// boost (1.3x – 1.8x) scaled by FTS rank, so this value is only used
37    /// as a fallback when the adaptive path cannot determine rank.
38    /// Default 1.5 (midpoint of the adaptive range).
39    pub dual_match_boost: f64,
40    /// Number of top candidates to rerank with the cross-encoder (default 30).
41    pub rerank_top_n: usize,
42    /// Blend weight: `alpha * rrf_score + (1 - alpha) * cross_encoder_score`.
43    /// Default 0.5 (equal weight).
44    pub rerank_blend_alpha: f64,
45    /// Multiplicative boost for PRECEDED_BY graph edges (temporal adjacency).
46    /// Default 1.5 — adjacent conversation turns get 50% more weight.
47    pub preceded_by_boost: f64,
48    /// Multiplicative boost for entity-related graph edges (RELATES_TO, SIMILAR_TO, etc.).
49    /// Default 1.3 — entity-connected memories get 30% more weight.
50    pub entity_relation_boost: f64,
51}
52
53impl Default for ScoringParams {
54    fn default() -> Self {
55        Self {
56            rrf_k: 60.0,
57            rrf_weight_vec: RRF_WEIGHT_VEC,
58            rrf_weight_fts: RRF_WEIGHT_FTS,
59            abstention_min_text: ABSTENTION_MIN_TEXT,
60            graph_neighbor_factor: GRAPH_NEIGHBOR_FACTOR,
61            graph_min_edge_weight: GRAPH_MIN_EDGE_WEIGHT,
62            word_overlap_weight: 0.75,
63            query_coverage_weight: 0.35,
64            jaccard_weight: 0.25,
65            importance_floor: 0.3,
66            importance_scale: 0.5,
67            context_tag_weight: 0.25,
68            time_decay_days: 0.0,
69            priority_base: 0.7,
70            priority_scale: 0.08,
71            feedback_heavy_suppress: 0.1,
72            feedback_strong_suppress: 0.3,
73            feedback_positive_scale: 0.05,
74            feedback_positive_cap: 1.3,
75            feedback_heavy_threshold: -3,
76            neighbor_word_overlap_weight: 0.5,
77            neighbor_importance_floor: 0.5,
78            neighbor_importance_scale: 0.5,
79            graph_seed_min: 5,
80            graph_seed_max: 8,
81            dual_match_boost: 1.5,
82            rerank_top_n: 30,
83            rerank_blend_alpha: 0.5,
84            preceded_by_boost: 1.5,
85            entity_relation_boost: 1.3,
86        }
87    }
88}
89
90#[cfg(test)]
91pub fn type_weight(event_type: &str) -> f64 {
92    let et = event_type
93        .parse::<EventType>()
94        .unwrap_or_else(|e| match e {});
95    et.type_weight()
96}
97
98/// Returns the type weight for a typed `EventType` (avoids string round-trip).
99pub fn type_weight_et(et: &EventType) -> f64 {
100    et.type_weight()
101}
102
103pub fn priority_factor(priority: u8, scoring_params: &ScoringParams) -> f64 {
104    scoring_params.priority_base + (priority as f64 * scoring_params.priority_scale)
105}
106
107#[cfg(test)]
108pub fn time_decay(created_at: &str, event_type: &str, scoring_params: &ScoringParams) -> f64 {
109    let et = event_type
110        .parse::<EventType>()
111        .unwrap_or_else(|e| match e {});
112    time_decay_et(created_at, &et, scoring_params)
113}
114
115/// Time decay using a typed `EventType` (avoids string round-trip).
116pub fn time_decay_et(created_at: &str, et: &EventType, scoring_params: &ScoringParams) -> f64 {
117    if et.memory_kind() == MemoryKind::Semantic {
118        return 1.0;
119    }
120
121    if !scoring_params.time_decay_days.is_finite() || scoring_params.time_decay_days <= 0.0 {
122        return 1.0;
123    }
124
125    let now = match SystemTime::now().duration_since(UNIX_EPOCH) {
126        Ok(duration) => duration.as_secs_f64(),
127        Err(_) => return 1.0,
128    };
129    let created = match parse_iso8601_to_unix_seconds(created_at) {
130        Some(value) => value,
131        None => return 1.0,
132    };
133    let age_seconds = (now - created).max(0.0);
134    let days_old = age_seconds / 86_400.0;
135    1.0 / (1.0 + (days_old / scoring_params.time_decay_days))
136}
137
138#[cfg(test)]
139fn word_overlap(query_words: &[&str], text: &str) -> f64 {
140    let text_words = token_set(text, 3);
141    let filtered_query: HashSet<String> = query_words
142        .iter()
143        .map(|w| simple_stem(&w.trim().to_lowercase()).into_owned())
144        .filter(|w| w.len() > 2)
145        .collect();
146
147    if filtered_query.is_empty() {
148        return 0.0;
149    }
150
151    let overlap = filtered_query
152        .iter()
153        .filter(|w| text_words.contains(*w))
154        .count();
155
156    #[allow(clippy::cast_precision_loss)]
157    let result = overlap as f64 / filtered_query.len() as f64;
158    result
159}
160
161pub fn jaccard_similarity(text_a: &str, text_b: &str, min_word_len: usize) -> f64 {
162    let a = token_set(text_a, min_word_len);
163    let b = token_set(text_b, min_word_len);
164
165    if a.is_empty() && b.is_empty() {
166        return 0.0;
167    }
168
169    let intersection = a.intersection(&b).count();
170    let union = a.union(&b).count();
171
172    if union == 0 {
173        0.0
174    } else {
175        #[allow(clippy::cast_precision_loss)]
176        let result = intersection as f64 / union as f64;
177        result
178    }
179}
180
181/// Abstention threshold — collection-level gate on max text overlap.
182/// Dense embeddings (bge-small-en-v1.5) produce 0.80+ cosine similarity even
183/// for unrelated content, so vec_sim is NOT used for abstention.
184/// Lowered from 0.30 to 0.15 to avoid dropping valid results for
185/// numeric/synonym-heavy queries where word overlap is inherently lower.
186pub const ABSTENTION_MIN_TEXT: f64 = 0.15;
187
188/// Graph enrichment — re-enabled at a conservative factor.
189/// Grid search over [0.0, 0.05, 0.1, 0.15, 0.2, 0.3] showed no regression on
190/// LoCoMo (benchmark data has sparse relationship graphs).  0.1 chosen as a
191/// safe default: neighbors get at most 10 % of the seed score, enough to
192/// surface related memories in production without dominating rankings.
193pub const GRAPH_NEIGHBOR_FACTOR: f64 = 0.1;
194pub const GRAPH_MIN_EDGE_WEIGHT: f64 = 0.3;
195
196/// Multiplicative boost for memories found via entity tag expansion.
197/// Applied on top of the standard scoring pipeline. AutoMem uses +0.15 additive;
198/// we use a multiplicative 1.15 to integrate with the existing RRF-based scoring.
199pub const ENTITY_EXPANSION_BOOST: f64 = 1.15;
200
201/// Weighted RRF fusion — equal weight for vector and FTS (grid search optimal).
202/// Previous bias (1.5 vec / 1.0 fts) was suboptimal on LongMemEval.
203pub const RRF_WEIGHT_VEC: f64 = 1.0;
204pub const RRF_WEIGHT_FTS: f64 = 1.0;
205
206/// Feedback is an explicit user/system signal — asymmetric by design.
207/// Negative feedback aggressively suppresses (explicit downvote).
208/// Positive feedback gives only a mild boost (prevents displacing unrelated results).
209pub fn feedback_factor(feedback_score: i64, scoring_params: &ScoringParams) -> f64 {
210    if feedback_score < 0 {
211        if feedback_score <= scoring_params.feedback_heavy_threshold {
212            scoring_params.feedback_heavy_suppress // flagged for review — near-total suppress
213        } else {
214            scoring_params.feedback_strong_suppress // explicit negative — strong suppress
215        }
216    } else if feedback_score > 0 {
217        #[allow(clippy::cast_precision_loss)]
218        let result = (1.0 + (feedback_score as f64 * scoring_params.feedback_positive_scale))
219            .min(scoring_params.feedback_positive_cap);
220        result
221    } else {
222        1.0 // neutral (no feedback = no effect)
223    }
224}
225
226/// Common English stopwords that dilute BM25 scoring and word overlap.
227/// Used by both `token_set()` (scoring) and `build_fts5_query()` (FTS5).
228pub(crate) fn is_stopword(word: &str) -> bool {
229    matches!(
230        word,
231        "a" | "an"
232            | "and"
233            | "are"
234            | "as"
235            | "at"
236            | "be"
237            | "but"
238            | "by"
239            | "do"
240            | "for"
241            | "from"
242            | "had"
243            | "has"
244            | "have"
245            | "he"
246            | "her"
247            | "his"
248            | "how"
249            | "if"
250            | "in"
251            | "into"
252            | "is"
253            | "it"
254            | "its"
255            | "me"
256            | "my"
257            | "no"
258            | "not"
259            | "of"
260            | "on"
261            | "or"
262            | "our"
263            | "out"
264            | "she"
265            | "so"
266            | "than"
267            | "that"
268            | "the"
269            | "their"
270            | "them"
271            | "then"
272            | "there"
273            | "these"
274            | "they"
275            | "this"
276            | "to"
277            | "too"
278            | "up"
279            | "us"
280            | "very"
281            | "was"
282            | "we"
283            | "were"
284            | "what"
285            | "when"
286            | "which"
287            | "who"
288            | "will"
289            | "with"
290            | "would"
291            | "you"
292            | "your"
293    )
294}
295
296pub(crate) fn token_set(text: &str, min_word_len: usize) -> HashSet<String> {
297    let mut tokens = HashSet::new();
298
299    // Collect raw whitespace-split tokens first so we can detect adjacent
300    // all-uppercase acronyms (e.g. "CI CD") and emit their concatenation
301    // ("cicd") to match punctuated forms like "CI/CD".
302    let raws: Vec<&str> = text.split_whitespace().collect();
303    let mut i = 0;
304    while i < raws.len() {
305        let raw = raws[i];
306        let has_punctuation = raw.chars().any(|c| !c.is_alphanumeric());
307
308        for word in raw.split(|c: char| !c.is_alphanumeric()) {
309            let trimmed = word.trim();
310            if trimmed.len() < min_word_len {
311                // Preserve numeric tokens (e.g. "42", "2023") and mixed
312                // alphanumeric tokens with digits (e.g. "3d", "v2") as long as
313                // they have at least 1 character. This prevents numbers from
314                // being silently dropped by the min_word_len filter.
315                let has_digit = trimmed.chars().any(|c| c.is_ascii_digit());
316                if has_digit && !trimmed.is_empty() {
317                    tokens.insert(trimmed.to_lowercase());
318                    continue;
319                }
320                // For space-separated short all-uppercase acronyms (no punctuation
321                // in the raw token, e.g. the "CI" in "CI CD") include the
322                // lowercased form so that "CI CD" and "CI/CD" share tokens.
323                if !has_punctuation
324                    && trimmed.len() >= 2
325                    && trimmed.chars().all(|c| c.is_ascii_uppercase())
326                {
327                    tokens.insert(trimmed.to_lowercase());
328                }
329                continue;
330            }
331            let lower = trimmed.to_lowercase();
332            if is_stopword(&lower) {
333                continue;
334            }
335            tokens.insert(simple_stem(&lower).into_owned());
336        }
337
338        if has_punctuation {
339            // Punctuated compound (e.g. "CI/CD") → emit collapsed form "cicd"
340            // so it bridges with space-separated "CI CD" queries and vice-versa.
341            let collapsed: String = raw
342                .chars()
343                .filter(|c| c.is_alphanumeric())
344                .collect::<String>()
345                .to_lowercase();
346            if collapsed.len() >= min_word_len {
347                tokens.insert(simple_stem(&collapsed).into_owned());
348            }
349        } else {
350            // For a standalone short all-uppercase token (no punctuation), also
351            // check the next token: if both are short all-uppercase acronyms,
352            // emit their concatenation ("CI" + "CD" → "cicd") so that "CI CD"
353            // matches candidates containing "CI/CD".
354            let is_short_upper = raw.len() < min_word_len
355                && raw.len() >= 2
356                && raw.chars().all(|c| c.is_ascii_uppercase());
357            if is_short_upper && let Some(&next) = raws.get(i + 1) {
358                let next_is_short_upper = next.len() < min_word_len
359                    && next.len() >= 2
360                    && next.chars().all(|c| c.is_ascii_uppercase());
361                if next_is_short_upper {
362                    let concat = format!("{}{}", raw.to_lowercase(), next.to_lowercase());
363                    if concat.len() >= min_word_len {
364                        tokens.insert(concat);
365                    }
366                }
367            }
368        }
369
370        i += 1;
371    }
372
373    tokens
374}
375
376/// Simple suffix stemmer for English words.
377///
378/// Strips common suffixes to normalize inflected forms so that e.g.
379/// "threading" matches "threads" (both stem to "thread").
380///
381/// Design constraints:
382/// - Never reduces a word below 3 characters.
383/// - Idempotent: stemming an already-stemmed word returns the same result.
384/// - No external crates — pure string operations.
385pub(crate) fn simple_stem(word: &str) -> Cow<'_, str> {
386    // Short words are returned as-is (nothing to strip safely).
387    if word.len() < 4 {
388        return Cow::Borrowed(word);
389    }
390
391    // -ies → -y  (e.g. "memories" → "memory")
392    // Check this before -s to avoid "memori" from -s stripping.
393    if word.ends_with("ies") && word.len() >= 4 {
394        let base_len = word.len() - 3;
395        if base_len >= 3 {
396            let mut result = word[..base_len].to_string();
397            result.push('y');
398            return Cow::Owned(result);
399        }
400    }
401
402    // ── Compound suffixes (checked before their single-suffix components) ──
403
404    // -tions (e.g. "connections" → "connec")
405    if word.ends_with("tions") && word.len() - 5 >= 4 {
406        return Cow::Borrowed(&word[..word.len() - 5]);
407    }
408
409    // -ments (e.g. "deployments" → "deploy")
410    if word.ends_with("ments") && word.len() - 5 >= 4 {
411        return Cow::Borrowed(&word[..word.len() - 5]);
412    }
413
414    // -ings (e.g. "settings" → "sett")
415    if word.ends_with("ings") && word.len() - 4 >= 5 {
416        return Cow::Borrowed(&word[..word.len() - 4]);
417    }
418
419    // -ers (e.g. "workers" → "work")
420    if word.ends_with("ers") && word.len() - 3 >= 4 {
421        return Cow::Borrowed(&word[..word.len() - 3]);
422    }
423
424    // ── Single suffixes ──
425
426    // -tion (e.g. "connection" → "connec")
427    if word.ends_with("tion") && word.len() - 4 >= 4 {
428        return Cow::Borrowed(&word[..word.len() - 4]);
429    }
430
431    // -ment (e.g. "deployment" → "deploy")
432    if word.ends_with("ment") && word.len() - 4 >= 4 {
433        return Cow::Borrowed(&word[..word.len() - 4]);
434    }
435
436    // -ness (e.g. "darkness" → "dark")
437    if word.ends_with("ness") && word.len() - 4 >= 4 {
438        return Cow::Borrowed(&word[..word.len() - 4]);
439    }
440
441    // -able / -ible (e.g. "readable" → "read")
442    if (word.ends_with("able") || word.ends_with("ible")) && word.len() - 4 >= 4 {
443        return Cow::Borrowed(&word[..word.len() - 4]);
444    }
445
446    // -ing (e.g. "threading" → "thread", but not "ring" or "king")
447    if word.ends_with("ing") && word.len() - 3 >= 5 {
448        return Cow::Borrowed(&word[..word.len() - 3]);
449    }
450
451    // -est (e.g. "fastest" → "fast", but not "est" or "best")
452    // Check before -ed/-er so "fastest" doesn't lose just -t.
453    if word.ends_with("est") && word.len() - 3 >= 4 {
454        return Cow::Borrowed(&word[..word.len() - 3]);
455    }
456
457    // -ed (e.g. "created" → "creat", but not "red" or "bed")
458    if word.ends_with("ed") && word.len() - 2 >= 4 {
459        return Cow::Borrowed(&word[..word.len() - 2]);
460    }
461
462    // -er (e.g. "worker" → "work", but not "her")
463    if word.ends_with("er") && word.len() - 2 >= 4 {
464        return Cow::Borrowed(&word[..word.len() - 2]);
465    }
466
467    // -ly (e.g. "quickly" → "quick", but not "fly")
468    if word.ends_with("ly") && word.len() - 2 >= 4 {
469        return Cow::Borrowed(&word[..word.len() - 2]);
470    }
471
472    // -s (e.g. "threads" → "thread", but not "is"/"as", and not -ss like "glass")
473    if word.ends_with('s') && !word.ends_with("ss") && word.len() > 4 {
474        return Cow::Borrowed(&word[..word.len() - 1]);
475    }
476
477    Cow::Borrowed(word)
478}
479
480/// Like `word_overlap`, but accepts pre-computed token sets for both query and candidate.
481pub fn word_overlap_pre(query_tokens: &HashSet<String>, text_tokens: &HashSet<String>) -> f64 {
482    if query_tokens.is_empty() {
483        return 0.0;
484    }
485
486    let overlap = query_tokens
487        .iter()
488        .filter(|w| text_tokens.contains(*w))
489        .count();
490
491    #[allow(clippy::cast_precision_loss)]
492    let result = overlap as f64 / query_tokens.len() as f64;
493    result
494}
495
496/// Extra bonus for candidates that cover most query terms.
497///
498/// This separates "covers nearly the whole query" from "matches a couple of
499/// broad topic words", which helps task-completion memories compete with more
500/// generic architectural decisions during multi-session retrieval.
501pub fn query_coverage_boost(overlap: f64, scoring_params: &ScoringParams) -> f64 {
502    if overlap <= 0.0 {
503        return 1.0;
504    }
505    1.0 + overlap.powi(2) * scoring_params.query_coverage_weight
506}
507
508/// Like `jaccard_similarity`, but accepts pre-computed token sets.
509pub fn jaccard_pre(a: &HashSet<String>, b: &HashSet<String>) -> f64 {
510    if a.is_empty() && b.is_empty() {
511        return 0.0;
512    }
513
514    let intersection = a.intersection(b).count();
515    let union = a.union(b).count();
516
517    if union == 0 {
518        0.0
519    } else {
520        #[allow(clippy::cast_precision_loss)]
521        let result = intersection as f64 / union as f64;
522        result
523    }
524}
525
526fn parse_iso8601_to_unix_seconds(value: &str) -> Option<f64> {
527    if !value.ends_with('Z') || value.len() < 20 {
528        return None;
529    }
530
531    let year: i32 = value.get(0..4)?.parse().ok()?;
532    let month: u32 = value.get(5..7)?.parse().ok()?;
533    let day: u32 = value.get(8..10)?.parse().ok()?;
534    let hour: u32 = value.get(11..13)?.parse().ok()?;
535    let minute: u32 = value.get(14..16)?.parse().ok()?;
536    let second: u32 = value.get(17..19)?.parse().ok()?;
537
538    if value.as_bytes().get(4) != Some(&b'-')
539        || value.as_bytes().get(7) != Some(&b'-')
540        || value.as_bytes().get(10) != Some(&b'T')
541        || value.as_bytes().get(13) != Some(&b':')
542        || value.as_bytes().get(16) != Some(&b':')
543    {
544        return None;
545    }
546
547    if !(1..=12).contains(&month)
548        || !(1..=31).contains(&day)
549        || hour > 23
550        || minute > 59
551        || second > 60
552    {
553        return None;
554    }
555
556    let mut fraction = 0.0;
557    if let Some(dot_index) = value.find('.') {
558        let end = value.len() - 1;
559        if dot_index >= end {
560            return None;
561        }
562        let frac_str = value.get(dot_index + 1..end)?;
563        if !frac_str.chars().all(|c| c.is_ascii_digit()) {
564            return None;
565        }
566        let frac_num: f64 = format!("0.{frac_str}").parse().ok()?;
567        fraction = frac_num;
568    }
569
570    let days = days_from_civil(year, month as i32, day as i32);
571    #[allow(clippy::cast_precision_loss)]
572    let day_seconds = (hour as i64 * 3600 + minute as i64 * 60 + second as i64) as f64;
573    #[allow(clippy::cast_precision_loss)]
574    let result = days as f64 * 86_400.0 + day_seconds + fraction;
575    Some(result)
576}
577
578fn days_from_civil(year: i32, month: i32, day: i32) -> i64 {
579    let adjusted_year = year - if month <= 2 { 1 } else { 0 };
580    let era = if adjusted_year >= 0 {
581        adjusted_year
582    } else {
583        adjusted_year - 399
584    } / 400;
585    let yoe = adjusted_year - era * 400;
586    let adjusted_month = month + if month > 2 { -3 } else { 9 };
587    let doy = (153 * adjusted_month + 2) / 5 + day - 1;
588    let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
589    (era * 146_097 + doe - 719_468) as i64
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    fn iso_string_days_ago(days_ago: f64) -> String {
597        let now = SystemTime::now()
598            .duration_since(UNIX_EPOCH)
599            .map(|duration| duration.as_secs_f64())
600            .unwrap_or(0.0);
601        let target = now - (days_ago * 86_400.0);
602        unix_to_iso8601(target)
603    }
604
605    fn unix_to_iso8601(timestamp: f64) -> String {
606        #[allow(clippy::cast_possible_truncation)]
607        let total_seconds = timestamp.floor() as i64;
608        let day = total_seconds.div_euclid(86_400);
609        let second_of_day = total_seconds.rem_euclid(86_400);
610
611        let (year, month, day_of_month) = civil_from_days(day);
612        let hour = second_of_day / 3600;
613        let minute = (second_of_day % 3600) / 60;
614        let second = second_of_day % 60;
615
616        format!("{year:04}-{month:02}-{day_of_month:02}T{hour:02}:{minute:02}:{second:02}Z")
617    }
618
619    fn civil_from_days(days: i64) -> (i32, i32, i32) {
620        let z = days + 719_468;
621        let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
622        let doe = z - era * 146_097;
623        let yoe = (doe - doe / 1460 + doe / 36_524 - doe / 146_096) / 365;
624        #[allow(clippy::cast_possible_truncation)]
625        let y = yoe as i32 + era as i32 * 400;
626        let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
627        let mp = (5 * doy + 2) / 153;
628        #[allow(clippy::cast_possible_truncation)]
629        let day = (doy - (153 * mp + 2) / 5 + 1) as i32;
630        #[allow(clippy::cast_possible_truncation)]
631        let month = (mp + if mp < 10 { 3 } else { -9 }) as i32;
632        let year = y + if month <= 2 { 1 } else { 0 };
633        (year, month, day)
634    }
635
636    #[test]
637    fn test_type_weight_known() {
638        assert_eq!(type_weight("reminder"), 3.0);
639    }
640
641    #[test]
642    fn test_type_weight_unknown() {
643        assert_eq!(type_weight("totally_unknown"), 1.0);
644    }
645
646    #[test]
647    fn test_priority_factor() {
648        let scoring_params = ScoringParams::default();
649        assert!((priority_factor(1, &scoring_params) - 0.78).abs() < 1e-9);
650        assert!((priority_factor(5, &scoring_params) - 1.10).abs() < 1e-9);
651    }
652
653    #[test]
654    fn test_time_decay_recent() {
655        let scoring_params = ScoringParams::default();
656        let now = iso_string_days_ago(0.0);
657        let decay = time_decay(&now, "session_summary", &scoring_params);
658        // Default time_decay_days=0.0 → always 1.0 (no decay)
659        assert!((decay - 1.0).abs() < 1e-9);
660    }
661
662    #[test]
663    fn test_time_decay_default_disables_decay() {
664        let scoring_params = ScoringParams::default();
665        let old = iso_string_days_ago(365.0);
666        let decay = time_decay(&old, "task_completion", &scoring_params);
667        // Default time_decay_days=0.0 → no decay even for old episodic memories
668        assert!((decay - 1.0).abs() < 1e-9);
669    }
670
671    #[test]
672    fn test_time_decay_old() {
673        let params = ScoringParams {
674            time_decay_days: 30.0,
675            ..ScoringParams::default()
676        };
677        let old = iso_string_days_ago(30.0);
678        let decay = time_decay(&old, "task_completion", &params);
679        assert!((decay - 0.5).abs() < 0.03);
680    }
681
682    #[test]
683    fn test_time_decay_semantic_type_has_zero_decay() {
684        let params = ScoringParams {
685            time_decay_days: 30.0,
686            ..ScoringParams::default()
687        };
688        let old = iso_string_days_ago(3650.0);
689        let decay = time_decay(&old, "decision", &params);
690        assert!((decay - 1.0).abs() < 1e-9);
691    }
692
693    #[test]
694    fn test_time_decay_unknown_type_defaults_to_episodic() {
695        let params = ScoringParams {
696            time_decay_days: 30.0,
697            ..ScoringParams::default()
698        };
699        let old = iso_string_days_ago(30.0);
700        let decay = time_decay(&old, "totally_unknown", &params);
701        assert!((decay - 0.5).abs() < 0.03);
702    }
703
704    #[test]
705    fn test_word_overlap() {
706        let ratio = word_overlap(
707            &["rust", "memory", "an"],
708            "Rust-based memory system with tags",
709        );
710        // "an" filtered (len<=2), "rust" + "memory" match → 2/2 = 1.0
711        assert!((ratio - 1.0).abs() < 1e-9);
712        // partial overlap: "rust" matches, "python" doesn't → 1/2 = 0.5
713        let miss_ratio = word_overlap(&["rust", "python"], "Rust-based memory system with tags");
714        assert!((miss_ratio - 0.5).abs() < 1e-9);
715    }
716
717    #[test]
718    fn test_jaccard_similarity() {
719        let similarity = jaccard_similarity("alpha beta gamma", "beta gamma delta", 2);
720        assert!((similarity - 0.5).abs() < 1e-9);
721    }
722
723    #[test]
724    fn test_feedback_factor_neutral() {
725        let scoring_params = ScoringParams::default();
726        assert!((feedback_factor(0, &scoring_params) - 1.0).abs() < 1e-9);
727    }
728
729    #[test]
730    fn test_feedback_factor_strong_suppress() {
731        let scoring_params = ScoringParams::default();
732        // fb=-1 → 0.3 (strong explicit downvote)
733        assert!((feedback_factor(-1, &scoring_params) - 0.3).abs() < 1e-9);
734        assert!((feedback_factor(-2, &scoring_params) - 0.3).abs() < 1e-9);
735    }
736
737    #[test]
738    fn test_feedback_factor_heavy_suppress() {
739        let scoring_params = ScoringParams::default();
740        // fb<=-3 → 0.1 (near-total suppress)
741        assert!((feedback_factor(-3, &scoring_params) - 0.1).abs() < 1e-9);
742        assert!((feedback_factor(-100, &scoring_params) - 0.1).abs() < 1e-9);
743    }
744
745    #[test]
746    fn test_feedback_factor_positive_boost() {
747        let scoring_params = ScoringParams::default();
748        // fb=+1 → 1.05, fb=+2 → 1.1, fb=+6 → 1.3 (capped)
749        assert!((feedback_factor(1, &scoring_params) - 1.05).abs() < 1e-9);
750        assert!((feedback_factor(2, &scoring_params) - 1.1).abs() < 1e-9);
751        assert!((feedback_factor(6, &scoring_params) - 1.3).abs() < 1e-9);
752        assert!((feedback_factor(100, &scoring_params) - 1.3).abs() < 1e-9);
753    }
754
755    #[test]
756    fn test_priority_factor_custom_params() {
757        let params = ScoringParams {
758            priority_base: 1.0,
759            priority_scale: 0.2,
760            ..ScoringParams::default()
761        };
762        assert!((priority_factor(5, &params) - 2.0).abs() < 1e-9);
763    }
764
765    #[test]
766    fn test_time_decay_custom_window() {
767        let old = iso_string_days_ago(60.0);
768        let params = ScoringParams {
769            time_decay_days: 60.0,
770            ..ScoringParams::default()
771        };
772        let decay = time_decay(&old, "task_completion", &params);
773        assert!((decay - 0.5).abs() < 0.03);
774    }
775
776    #[test]
777    fn test_time_decay_zero_days_returns_one() {
778        let old = iso_string_days_ago(30.0);
779        let params = ScoringParams {
780            time_decay_days: 0.0,
781            ..ScoringParams::default()
782        };
783        assert!((time_decay(&old, "task_completion", &params) - 1.0).abs() < 1e-9);
784    }
785
786    #[test]
787    fn test_word_overlap_pre() {
788        let query_tokens = token_set("rust memory an", 3);
789        let text_tokens = token_set("Rust-based memory system with tags", 3);
790        // "an" filtered (len<3), "rust" + "memory" match → 2/2 = 1.0
791        let ratio = word_overlap_pre(&query_tokens, &text_tokens);
792        assert!((ratio - 1.0).abs() < 1e-9);
793    }
794
795    #[test]
796    fn test_query_coverage_boost_prefers_higher_overlap() {
797        let params = ScoringParams::default();
798        let high = query_coverage_boost(0.75, &params);
799        let medium = query_coverage_boost(0.5, &params);
800        let none = query_coverage_boost(0.0, &params);
801
802        assert!(high > medium);
803        assert!(medium > 1.0);
804        assert_eq!(none, 1.0);
805    }
806
807    #[test]
808    fn test_token_set_preserves_compound_acronyms() {
809        let tokens = token_set("database migration in CI/CD", 3);
810        assert!(tokens.contains("database"));
811        assert!(tokens.contains("cicd"));
812    }
813
814    #[test]
815    fn test_token_set_acronym_spacing_match() {
816        // "CI/CD" (punctuated) and "CI CD" (space-separated) should share tokens
817        // so that word_overlap_pre is non-zero between them.
818        let punctuated = token_set("CI/CD", 3);
819        let spaced = token_set("CI CD", 3);
820        let intersection: HashSet<_> = punctuated.intersection(&spaced).collect();
821        assert!(
822            !intersection.is_empty(),
823            "CI/CD and CI CD must share at least one token; got {punctuated:?} vs {spaced:?}"
824        );
825        // Overlap must be non-zero when querying "CI/CD" against content "CI CD"
826        let overlap = word_overlap_pre(&punctuated, &spaced);
827        assert!(
828            overlap > 0.0,
829            "word_overlap_pre(CI/CD, CI CD) must be > 0, got {overlap}"
830        );
831    }
832
833    #[test]
834    fn test_jaccard_pre() {
835        let a = token_set("alpha beta gamma", 2);
836        let b = token_set("beta gamma delta", 2);
837        let similarity = jaccard_pre(&a, &b);
838        assert!((similarity - 0.5).abs() < 1e-9);
839    }
840
841    // ── simple_stem tests ──────────────────────────────────────────────
842
843    #[test]
844    fn test_stem_ing() {
845        assert_eq!(simple_stem("threading"), "thread");
846        assert_eq!(simple_stem("processing"), "process");
847        assert_eq!(simple_stem("computing"), "comput");
848    }
849
850    #[test]
851    fn test_stem_ing_short_words_preserved() {
852        // "ring" → only 1 char base after stripping → kept as-is
853        assert_eq!(simple_stem("ring"), "ring");
854        // "king" → only 1 char base → kept
855        assert_eq!(simple_stem("king"), "king");
856        // "bring" → 2 chars base → kept (need 5+ remaining)
857        assert_eq!(simple_stem("bring"), "bring");
858        // "string" has only 3 chars remaining, need 5+
859        assert_eq!(simple_stem("string"), "string");
860    }
861
862    #[test]
863    fn test_stem_ed() {
864        assert_eq!(simple_stem("created"), "creat");
865        assert_eq!(simple_stem("processed"), "process");
866        assert_eq!(simple_stem("stored"), "stor");
867    }
868
869    #[test]
870    fn test_stem_ed_short_words_preserved() {
871        // "red" → too short for any suffix
872        assert_eq!(simple_stem("red"), "red");
873        // "bed" → too short
874        assert_eq!(simple_stem("bed"), "bed");
875        // "shed" → base "sh" is only 2 chars, need 4+
876        assert_eq!(simple_stem("shed"), "shed");
877        // "used" → base "us" is only 2 chars, need 4+
878        assert_eq!(simple_stem("used"), "used");
879    }
880
881    #[test]
882    fn test_stem_s() {
883        assert_eq!(simple_stem("threads"), "thread");
884        assert_eq!(simple_stem("systems"), "system");
885        assert_eq!(simple_stem("memories"), "memory"); // -ies rule catches first
886    }
887
888    #[test]
889    fn test_stem_s_guards() {
890        // "is" and "as" too short (< 4 chars)
891        assert_eq!(simple_stem("is"), "is");
892        assert_eq!(simple_stem("as"), "as");
893        // "-ss" words should NOT be stripped
894        assert_eq!(simple_stem("glass"), "glass");
895        assert_eq!(simple_stem("class"), "class");
896        assert_eq!(simple_stem("moss"), "moss");
897    }
898
899    #[test]
900    fn test_stem_tion() {
901        assert_eq!(simple_stem("connection"), "connec");
902        assert_eq!(simple_stem("collection"), "collec");
903        assert_eq!(simple_stem("abstention"), "absten");
904    }
905
906    #[test]
907    fn test_stem_ment() {
908        assert_eq!(simple_stem("deployment"), "deploy");
909        assert_eq!(simple_stem("management"), "manage");
910        assert_eq!(simple_stem("environment"), "environ");
911    }
912
913    #[test]
914    fn test_stem_ness() {
915        assert_eq!(simple_stem("darkness"), "dark");
916        assert_eq!(simple_stem("happiness"), "happi");
917        assert_eq!(simple_stem("awareness"), "aware");
918    }
919
920    #[test]
921    fn test_stem_ly() {
922        assert_eq!(simple_stem("quickly"), "quick");
923        assert_eq!(simple_stem("slowly"), "slow");
924    }
925
926    #[test]
927    fn test_stem_ly_short_preserved() {
928        // "fly" → too short
929        assert_eq!(simple_stem("fly"), "fly");
930        // "holy" → base "ho" is only 2 chars, need 4+
931        assert_eq!(simple_stem("holy"), "holy");
932    }
933
934    #[test]
935    fn test_stem_er() {
936        assert_eq!(simple_stem("worker"), "work");
937        assert_eq!(simple_stem("builder"), "build");
938        assert_eq!(simple_stem("handler"), "handl");
939    }
940
941    #[test]
942    fn test_stem_er_short_preserved() {
943        // "her" → too short
944        assert_eq!(simple_stem("her"), "her");
945    }
946
947    #[test]
948    fn test_stem_est() {
949        assert_eq!(simple_stem("fastest"), "fast");
950        assert_eq!(simple_stem("largest"), "larg");
951    }
952
953    #[test]
954    fn test_stem_est_short_preserved() {
955        // "best" → base "b" only 1 char, need 4+
956        assert_eq!(simple_stem("best"), "best");
957        // "rest" → base "r" only 1 char
958        assert_eq!(simple_stem("rest"), "rest");
959    }
960
961    #[test]
962    fn test_stem_ies() {
963        assert_eq!(simple_stem("memories"), "memory");
964        assert_eq!(simple_stem("queries"), "query");
965        assert_eq!(simple_stem("entries"), "entry");
966    }
967
968    #[test]
969    fn test_stem_able_ible() {
970        assert_eq!(simple_stem("readable"), "read");
971        assert_eq!(simple_stem("searchable"), "search");
972        assert_eq!(simple_stem("flexible"), "flex");
973        assert_eq!(simple_stem("convertible"), "convert");
974    }
975
976    // ── Compound suffix tests ──────────────────────────────────────────
977
978    #[test]
979    fn test_stem_compound_ers() {
980        // "workers" → "work" (same as "worker" → "work")
981        assert_eq!(simple_stem("workers"), "work");
982        assert_eq!(simple_stem("builders"), "build");
983        assert_eq!(simple_stem("handlers"), "handl");
984    }
985
986    #[test]
987    fn test_stem_compound_ings() {
988        // "settings" base is only 4 chars, so -ings doesn't fire; -s strips to "setting"
989        assert_eq!(simple_stem("settings"), "setting");
990        // "buildings" base is 5 chars, so -ings fires → "build"
991        assert_eq!(simple_stem("buildings"), "build");
992        // "proceedings" base is 7 chars → "proceed"
993        assert_eq!(simple_stem("proceedings"), "proceed");
994    }
995
996    #[test]
997    fn test_stem_compound_tions() {
998        assert_eq!(simple_stem("connections"), "connec");
999        assert_eq!(simple_stem("collections"), "collec");
1000    }
1001
1002    #[test]
1003    fn test_stem_compound_ments() {
1004        assert_eq!(simple_stem("deployments"), "deploy");
1005        assert_eq!(simple_stem("environments"), "environ");
1006    }
1007
1008    #[test]
1009    fn test_stem_idempotent() {
1010        // Stemming an already-stemmed word should return the same result
1011        let words = [
1012            "thread", "process", "deploy", "dark", "quick", "work", "fast", "memory", "read",
1013            "search", "flex",
1014        ];
1015        for word in &words {
1016            let once = simple_stem(word);
1017            let twice = simple_stem(&once);
1018            assert_eq!(
1019                once, twice,
1020                "stem('{}') = '{}' but stem('{}') = '{}'",
1021                word, once, once, twice
1022            );
1023        }
1024    }
1025
1026    #[test]
1027    fn test_stem_never_below_3_chars() {
1028        // Verify we never produce a result shorter than 3 characters
1029        // for any input that is 3+ characters.
1030        let words = [
1031            "the", "ing", "bed", "red", "ant", "are", "ate", "use", "ring", "king", "sing", "dies",
1032            "ties",
1033        ];
1034        for word in &words {
1035            let stemmed = simple_stem(word);
1036            assert!(
1037                stemmed.len() >= word.len().min(3),
1038                "stem('{}') = '{}' is too short",
1039                word,
1040                stemmed
1041            );
1042        }
1043    }
1044
1045    // ── token_set stemming integration tests ───────────────────────────
1046
1047    #[test]
1048    fn test_token_set_stems_inflections() {
1049        // "threading" and "threads" should both stem to "thread"
1050        let a = token_set("threading issues", 3);
1051        let b = token_set("thread issues", 3);
1052        assert!(a.contains("thread"), "expected 'thread' in {:?}", a);
1053        assert!(b.contains("thread"), "expected 'thread' in {:?}", b);
1054    }
1055
1056    #[test]
1057    fn test_token_set_stemming_improves_overlap() {
1058        // Before stemming these wouldn't match; now they should
1059        let query = token_set("threading", 3);
1060        let text = token_set("threads are useful", 3);
1061        let overlap = word_overlap_pre(&query, &text);
1062        assert!(
1063            (overlap - 1.0).abs() < 1e-9,
1064            "expected overlap 1.0, got {}",
1065            overlap
1066        );
1067    }
1068
1069    #[test]
1070    fn test_token_set_stemming_jaccard() {
1071        // "deploying workers quickly" vs "deployment worker quick"
1072        // all three content words should match after stemming
1073        let a = token_set("deploying workers quickly", 3);
1074        let b = token_set("deployment worker quick", 3);
1075        let j = jaccard_pre(&a, &b);
1076        assert!(
1077            (j - 1.0).abs() < 1e-9,
1078            "expected Jaccard 1.0, got {} (a={:?}, b={:?})",
1079            j,
1080            a,
1081            b,
1082        );
1083    }
1084
1085    #[test]
1086    fn test_dual_match_boost_default() {
1087        let params = ScoringParams::default();
1088        assert!((params.dual_match_boost - 1.5).abs() < 1e-9);
1089    }
1090
1091    #[test]
1092    fn test_query_coverage_weight_default() {
1093        let params = ScoringParams::default();
1094        assert!((params.query_coverage_weight - 0.35).abs() < 1e-9);
1095    }
1096
1097    #[test]
1098    fn test_query_coverage_boost_quadratic_separation() {
1099        let params = ScoringParams::default();
1100        // high: 1.0 + 1.0^2 * 0.35 = 1.35
1101        // low:  1.0 + 0.3^2 * 0.35 = 1.0315
1102        let high = query_coverage_boost(1.0, &params);
1103        let low = query_coverage_boost(0.3, &params);
1104        assert!((high - 1.35).abs() < 1e-9);
1105        assert!((low - 1.0315).abs() < 1e-6);
1106        // High-coverage gets meaningfully more boost than low-coverage
1107        assert!(high > low);
1108        // The quadratic excess (above 1.0) scales as overlap^2:
1109        // excess_high / excess_low == 1.0^2 / 0.3^2 == 1/0.09 ≈ 11.1
1110        let excess_high = high - 1.0;
1111        let excess_low = low - 1.0;
1112        assert!(excess_high / excess_low > 10.0);
1113    }
1114
1115    #[test]
1116    fn test_query_coverage_boost_disabled() {
1117        let params = ScoringParams {
1118            query_coverage_weight: 0.0,
1119            ..ScoringParams::default()
1120        };
1121        // With weight=0, boost should be 1.0 regardless of overlap (no-op multiplier)
1122        assert!((query_coverage_boost(1.0, &params) - 1.0).abs() < 1e-9);
1123        assert!((query_coverage_boost(0.5, &params) - 1.0).abs() < 1e-9);
1124    }
1125
1126    // ── token_set numeric preservation tests ──────────────────────────
1127
1128    #[test]
1129    fn test_token_set_preserves_pure_numbers() {
1130        let tokens = token_set("version 42 was released", 3);
1131        assert!(tokens.contains("42"), "expected '42' in {:?}", tokens);
1132    }
1133
1134    #[test]
1135    fn test_token_set_preserves_year_numbers() {
1136        let tokens = token_set("deployed in 2023", 3);
1137        assert!(tokens.contains("2023"), "expected '2023' in {:?}", tokens);
1138    }
1139
1140    #[test]
1141    fn test_token_set_preserves_short_alphanumeric_with_digits() {
1142        // "v2", "3d" should be preserved even though len < 3
1143        let tokens = token_set("using v2 and 3d models", 3);
1144        assert!(tokens.contains("v2"), "expected 'v2' in {:?}", tokens);
1145        assert!(tokens.contains("3d"), "expected '3d' in {:?}", tokens);
1146    }
1147
1148    #[test]
1149    fn test_token_set_numeric_overlap() {
1150        // Ensure numeric tokens enable word overlap matching
1151        let query_tokens = token_set("version 42", 3);
1152        let text_tokens = token_set("released version 42 of the system", 3);
1153        let overlap = word_overlap_pre(&query_tokens, &text_tokens);
1154        assert!(
1155            overlap > 0.9,
1156            "expected high overlap for numeric query, got {}",
1157            overlap
1158        );
1159    }
1160
1161    // ── stopword filtering tests ──────────────────────────────────────
1162
1163    #[test]
1164    fn test_token_set_filters_stopwords() {
1165        let tokens = token_set("the path to the database", 3);
1166        assert!(
1167            !tokens.contains("the"),
1168            "'the' should be filtered as stopword, got {:?}",
1169            tokens
1170        );
1171        assert!(tokens.contains("path"), "expected 'path' in {:?}", tokens);
1172        assert!(
1173            tokens.contains("database"),
1174            "expected 'database' in {:?}",
1175            tokens
1176        );
1177    }
1178
1179    // ── abstention threshold tests ────────────────────────────────────
1180
1181    #[test]
1182    fn test_entity_expansion_boost_value() {
1183        assert!((ENTITY_EXPANSION_BOOST - 1.15).abs() < 1e-9);
1184    }
1185
1186    #[test]
1187    fn test_abstention_threshold_lowered() {
1188        assert!(
1189            (ABSTENTION_MIN_TEXT - 0.15).abs() < 1e-9,
1190            "ABSTENTION_MIN_TEXT should be 0.15, got {}",
1191            ABSTENTION_MIN_TEXT
1192        );
1193    }
1194
1195    #[test]
1196    fn test_abstention_default_params_uses_lowered_threshold() {
1197        let params = ScoringParams::default();
1198        assert!(
1199            (params.abstention_min_text - 0.15).abs() < 1e-9,
1200            "default abstention_min_text should be 0.15, got {}",
1201            params.abstention_min_text
1202        );
1203    }
1204
1205    #[test]
1206    fn test_preceded_by_boost_default() {
1207        let params = ScoringParams::default();
1208        assert!((params.preceded_by_boost - 1.5).abs() < 1e-9);
1209    }
1210
1211    #[test]
1212    fn test_entity_relation_boost_default() {
1213        let params = ScoringParams::default();
1214        assert!((params.entity_relation_boost - 1.3).abs() < 1e-9);
1215    }
1216}