Skip to main content

bamboo_memory/memory_store/
recall.rs

1use std::cmp::Ordering;
2use std::collections::HashSet;
3use std::io;
4use std::sync::Arc;
5
6use futures::StreamExt;
7use serde::Deserialize;
8
9use bamboo_agent_core::Message;
10use bamboo_domain::ReasoningEffort;
11use bamboo_llm::{LLMChunk, LLMProvider, LLMRequestOptions};
12
13use super::{
14    extract_keywords, parse_rfc3339, DurableMemoryStatus, LexicalIndexItem, MemoryScope,
15    MemoryStore, TemporalGranularity,
16};
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct MemoryRecallCandidate {
20    pub id: String,
21    pub title: String,
22    pub score: f64,
23    pub scope: MemoryScope,
24    pub project_key: Option<String>,
25    pub status: DurableMemoryStatus,
26    pub updated_at: String,
27    pub summary: String,
28    /// Optional temporal granularity, used as a stable tie-breaker in
29    /// [`sort_recall_candidates`]: among equally-relevant candidates, coarser
30    /// (more cache-stable) memories sort first. `None` is treated as most stable.
31    pub granularity: Option<TemporalGranularity>,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct MemoryRecallOptions {
36    pub shortlist_limit: usize,
37    pub include_global_fallback: bool,
38    pub max_candidates_per_scope: usize,
39}
40
41impl Default for MemoryRecallOptions {
42    fn default() -> Self {
43        Self {
44            shortlist_limit: 3,
45            include_global_fallback: true,
46            max_candidates_per_scope: 20,
47        }
48    }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum MemoryRecallStrategy {
53    Lexical,
54    Reranked,
55    RerankFallback,
56}
57
58impl MemoryRecallStrategy {
59    pub fn as_str(self) -> &'static str {
60        match self {
61            Self::Lexical => "lexical",
62            Self::Reranked => "reranked",
63            Self::RerankFallback => "rerank_fallback",
64        }
65    }
66}
67
68#[derive(Debug, Clone, PartialEq)]
69pub struct MemoryRecallSelection {
70    pub candidates: Vec<MemoryRecallCandidate>,
71    pub strategy: MemoryRecallStrategy,
72}
73
74#[derive(Clone)]
75pub struct MemoryRecallRerankContext {
76    pub llm: Arc<dyn LLMProvider>,
77    pub model: String,
78    pub session_id: Option<String>,
79}
80
81#[derive(Debug, Deserialize)]
82struct MemoryRecallRerankEnvelope {
83    #[serde(default)]
84    ids: Vec<String>,
85}
86
87pub async fn shortlist_relevant_memories(
88    store: &MemoryStore,
89    project_key: Option<&str>,
90    query: &str,
91    options: &MemoryRecallOptions,
92) -> io::Result<Vec<MemoryRecallCandidate>> {
93    let limit = options.shortlist_limit.max(1);
94    let mut candidates =
95        lexical_shortlist_relevant_memories(store, project_key, query, options).await?;
96    candidates.truncate(limit);
97    Ok(candidates)
98}
99
100pub async fn select_relevant_memories(
101    store: &MemoryStore,
102    project_key: Option<&str>,
103    query: &str,
104    options: &MemoryRecallOptions,
105    rerank_context: Option<&MemoryRecallRerankContext>,
106) -> io::Result<MemoryRecallSelection> {
107    let query = query.trim();
108    if query.is_empty() {
109        return Ok(MemoryRecallSelection {
110            candidates: Vec::new(),
111            strategy: MemoryRecallStrategy::Lexical,
112        });
113    }
114
115    let limit = options.shortlist_limit.max(1);
116    let mut shortlist =
117        lexical_shortlist_relevant_memories(store, project_key, query, options).await?;
118    if shortlist.is_empty() {
119        return Ok(MemoryRecallSelection {
120            candidates: shortlist,
121            strategy: MemoryRecallStrategy::Lexical,
122        });
123    }
124
125    let Some(rerank_context) = rerank_context else {
126        shortlist.truncate(limit);
127        return Ok(MemoryRecallSelection {
128            candidates: shortlist,
129            strategy: MemoryRecallStrategy::Lexical,
130        });
131    };
132
133    if shortlist.len() <= 1 {
134        shortlist.truncate(limit);
135        return Ok(MemoryRecallSelection {
136            candidates: shortlist,
137            strategy: MemoryRecallStrategy::Lexical,
138        });
139    }
140
141    match rerank_candidate_ids(query, &shortlist, limit, rerank_context).await {
142        Ok(ids) => {
143            let reranked = reorder_candidates_by_ids(&shortlist, &ids, limit);
144            if reranked.is_empty() {
145                let mut lexical = shortlist;
146                lexical.truncate(limit);
147                return Ok(MemoryRecallSelection {
148                    candidates: lexical,
149                    strategy: MemoryRecallStrategy::RerankFallback,
150                });
151            }
152            Ok(MemoryRecallSelection {
153                candidates: reranked,
154                strategy: MemoryRecallStrategy::Reranked,
155            })
156        }
157        Err(error) => {
158            tracing::warn!(
159                "Relevant memory rerank failed for model '{}': {}. Falling back to lexical shortlist.",
160                rerank_context.model,
161                error
162            );
163            shortlist.truncate(limit);
164            Ok(MemoryRecallSelection {
165                candidates: shortlist,
166                strategy: MemoryRecallStrategy::RerankFallback,
167            })
168        }
169    }
170}
171
172async fn lexical_shortlist_relevant_memories(
173    store: &MemoryStore,
174    project_key: Option<&str>,
175    query: &str,
176    options: &MemoryRecallOptions,
177) -> io::Result<Vec<MemoryRecallCandidate>> {
178    let query = query.trim();
179    if query.is_empty() {
180        return Ok(Vec::new());
181    }
182
183    let limit = options.shortlist_limit.max(1);
184    let per_scope_limit = options.max_candidates_per_scope.max(limit);
185
186    if let Some(project_key) = project_key.map(str::trim).filter(|value| !value.is_empty()) {
187        let mut project_hits =
188            shortlist_scope(store, MemoryScope::Project, Some(project_key), query).await?;
189        project_hits.truncate(per_scope_limit);
190        if !project_hits.is_empty() {
191            return Ok(project_hits);
192        }
193    }
194
195    if options.include_global_fallback {
196        let mut global_hits = shortlist_scope(store, MemoryScope::Global, None, query).await?;
197        global_hits.truncate(per_scope_limit);
198        return Ok(global_hits);
199    }
200
201    Ok(Vec::new())
202}
203
204async fn shortlist_scope(
205    store: &MemoryStore,
206    scope: MemoryScope,
207    project_key: Option<&str>,
208    query: &str,
209) -> io::Result<Vec<MemoryRecallCandidate>> {
210    let Some(index) = store.read_lexical_index(scope, project_key).await? else {
211        return Ok(Vec::new());
212    };
213
214    let query_tokens = extract_keywords(query, "", &[]);
215    if query_tokens.is_empty() {
216        return Ok(Vec::new());
217    }
218
219    let mut candidates = index
220        .items
221        .iter()
222        .filter_map(|item| score_lexical_index_item(item, &query_tokens).map(|score| (item, score)))
223        .map(|(item, score)| MemoryRecallCandidate {
224            id: item.id.clone(),
225            title: item.title.clone(),
226            score,
227            scope: item.scope,
228            project_key: item.project_key.clone(),
229            status: item.status,
230            updated_at: item.updated_at.clone(),
231            summary: item.summary.clone(),
232            granularity: item.granularity,
233        })
234        .collect::<Vec<_>>();
235
236    sort_recall_candidates(&mut candidates);
237    Ok(candidates)
238}
239
240fn score_lexical_index_item(item: &LexicalIndexItem, query_tokens: &[String]) -> Option<f64> {
241    match item.status {
242        DurableMemoryStatus::Superseded
243        | DurableMemoryStatus::Contradicted
244        | DurableMemoryStatus::Archived => return None,
245        DurableMemoryStatus::Active | DurableMemoryStatus::Stale => {}
246    }
247
248    let title = item.title.to_ascii_lowercase();
249    let summary = item.summary.to_ascii_lowercase();
250
251    let mut score = 0.0;
252    let mut matched_any = false;
253
254    for token in query_tokens {
255        let mut token_score = 0.0;
256        if title.contains(token) {
257            token_score += 3.0;
258        }
259        if item
260            .keywords
261            .iter()
262            .any(|value| value.eq_ignore_ascii_case(token))
263        {
264            token_score += 2.5;
265        }
266        if item
267            .tags
268            .iter()
269            .any(|value| value.eq_ignore_ascii_case(token))
270        {
271            token_score += 2.0;
272        }
273        if item
274            .entities
275            .iter()
276            .any(|value| value.eq_ignore_ascii_case(token))
277        {
278            token_score += 1.5;
279        }
280        if summary.contains(token) {
281            token_score += 1.0;
282        }
283        if token_score > 0.0 {
284            matched_any = true;
285            score += token_score;
286        }
287    }
288
289    if !matched_any {
290        return None;
291    }
292
293    score += lexical_status_adjustment(item.status);
294    Some((score / query_tokens.len() as f64 * 100.0).round() / 100.0)
295}
296
297fn lexical_status_adjustment(status: DurableMemoryStatus) -> f64 {
298    match status {
299        DurableMemoryStatus::Active => 0.0,
300        DurableMemoryStatus::Stale => -0.75,
301        DurableMemoryStatus::Superseded
302        | DurableMemoryStatus::Contradicted
303        | DurableMemoryStatus::Archived => -10.0,
304    }
305}
306
307fn sort_recall_candidates(candidates: &mut [MemoryRecallCandidate]) {
308    candidates.sort_by(|left, right| {
309        right
310            .score
311            .partial_cmp(&left.score)
312            .unwrap_or(Ordering::Equal)
313            // Among equally-relevant candidates, prefer coarser (more prefix-cache
314            // friendly) granularity first so the recalled block is stable across
315            // calls and does not churn the LLM prompt prefix (issue #61). Lower
316            // cache_stability_rank = coarser = earlier.
317            .then_with(|| {
318                TemporalGranularity::cache_stability_rank(left.granularity).cmp(
319                    &TemporalGranularity::cache_stability_rank(right.granularity),
320                )
321            })
322            .then_with(|| {
323                let left_dt = parse_rfc3339(&left.updated_at)
324                    .unwrap_or(chrono::DateTime::<chrono::Utc>::MIN_UTC);
325                let right_dt = parse_rfc3339(&right.updated_at)
326                    .unwrap_or(chrono::DateTime::<chrono::Utc>::MIN_UTC);
327                right_dt.cmp(&left_dt)
328            })
329            .then_with(|| left.title.cmp(&right.title))
330    });
331}
332
333fn build_rerank_prompt(query: &str, candidates: &[MemoryRecallCandidate], limit: usize) -> String {
334    let mut prompt = String::from("# Bamboo Relevant Memory Recall Rerank\n\n");
335    prompt.push_str(
336        "Select the durable memory candidates that are most relevant to the user query.\n",
337    );
338    prompt.push_str("Return JSON only in the form {\"ids\":[\"candidate-id\", ...]}.\n");
339    prompt
340        .push_str("Do not include commentary, markdown fences, explanations, or unknown ids.\n\n");
341    prompt.push_str("## User query\n");
342    prompt.push_str(query.trim());
343    prompt.push_str("\n\n## Candidate memories\n");
344
345    for (index, candidate) in candidates.iter().enumerate() {
346        prompt.push_str(&format!(
347            "{}. id={}\n   title: {}\n   scope: {}\n   status: {}\n   updated_at: {}\n   lexical_score: {:.2}\n   summary: {}\n",
348            index + 1,
349            candidate.id,
350            candidate.title,
351            candidate.scope.as_str(),
352            candidate.status.as_str(),
353            candidate.updated_at,
354            candidate.score,
355            candidate.summary.replace('\n', " "),
356        ));
357    }
358
359    prompt.push_str(&format!(
360        "\n## Selection rules\n- Return at most {limit} ids.\n- Use only ids from the candidate list above.\n- Prefer candidates that best answer the user query or encode active preferences/constraints relevant to it.\n- Prefer active memories over stale ones when relevance is otherwise similar.\n- Keep the ids ordered best-to-worst.\n"
361    ));
362    prompt
363}
364
365async fn rerank_candidate_ids(
366    query: &str,
367    candidates: &[MemoryRecallCandidate],
368    limit: usize,
369    context: &MemoryRecallRerankContext,
370) -> Result<Vec<String>, String> {
371    let model = context.model.trim();
372    if model.is_empty() {
373        return Err("rerank model is empty".to_string());
374    }
375
376    let messages = vec![
377        Message::system(
378            "You rerank Bamboo durable-memory recall candidates. Return strict JSON only in the form {\"ids\":[...]} using only candidate ids from the prompt.",
379        ),
380        Message::user(build_rerank_prompt(query, candidates, limit)),
381    ];
382    let options = LLMRequestOptions {
383        session_id: context.session_id.clone(),
384        reasoning_effort: Some(ReasoningEffort::High),
385        parallel_tool_calls: None,
386        responses: None,
387        request_purpose: Some("memory_rerank".to_string()),
388        cache: None,
389    };
390
391    let mut stream = context
392        .llm
393        .chat_stream_with_options(&messages, &[], Some(8192), model, Some(&options))
394        .await
395        .map_err(|error| format!("rerank provider call failed: {error}"))?;
396
397    let content = tokio::time::timeout(std::time::Duration::from_secs(30), async {
398        let mut content = String::new();
399        while let Some(chunk_result) = stream.next().await {
400            match chunk_result {
401                Ok(LLMChunk::Token(text)) => content.push_str(&text),
402                Ok(LLMChunk::Done) => break,
403                Ok(_) => {}
404                Err(error) => {
405                    if !content.trim().is_empty() {
406                        break;
407                    }
408                    return Err(format!("rerank stream failed: {error}"));
409                }
410            }
411        }
412        Ok(content)
413    })
414    .await
415    .unwrap_or_else(|_| Err("rerank timed out after 30s".to_string()))?;
416
417    parse_reranked_ids(&content, candidates)
418        .ok_or_else(|| format!("failed to parse rerank response: {}", content.trim()))
419}
420
421fn reorder_candidates_by_ids(
422    lexical_candidates: &[MemoryRecallCandidate],
423    preferred_ids: &[String],
424    limit: usize,
425) -> Vec<MemoryRecallCandidate> {
426    if lexical_candidates.is_empty() || limit == 0 {
427        return Vec::new();
428    }
429
430    let allowed = lexical_candidates
431        .iter()
432        .map(|candidate| candidate.id.as_str())
433        .collect::<HashSet<_>>();
434    let mut seen = HashSet::new();
435    let mut ordered = Vec::new();
436
437    for id in preferred_ids {
438        let trimmed = id.trim();
439        if trimmed.is_empty() || !allowed.contains(trimmed) || !seen.insert(trimmed.to_string()) {
440            continue;
441        }
442        if let Some(candidate) = lexical_candidates
443            .iter()
444            .find(|candidate| candidate.id == trimmed)
445            .cloned()
446        {
447            ordered.push(candidate);
448            if ordered.len() >= limit {
449                return ordered;
450            }
451        }
452    }
453
454    for candidate in lexical_candidates {
455        if seen.insert(candidate.id.clone()) {
456            ordered.push(candidate.clone());
457            if ordered.len() >= limit {
458                break;
459            }
460        }
461    }
462
463    ordered
464}
465
466fn parse_reranked_ids(raw: &str, candidates: &[MemoryRecallCandidate]) -> Option<Vec<String>> {
467    let stripped = strip_markdown_fence(raw);
468    let fragment = extract_json_fragment(&stripped).unwrap_or(stripped.trim());
469    let ids = serde_json::from_str::<MemoryRecallRerankEnvelope>(fragment)
470        .map(|value| value.ids)
471        .or_else(|_| serde_json::from_str::<Vec<String>>(fragment))
472        .ok()?;
473
474    let allowed = candidates
475        .iter()
476        .map(|candidate| candidate.id.as_str())
477        .collect::<HashSet<_>>();
478    let mut seen = HashSet::new();
479    let mut out = Vec::new();
480
481    for id in ids {
482        let trimmed = id.trim();
483        if trimmed.is_empty() || !allowed.contains(trimmed) || !seen.insert(trimmed.to_string()) {
484            continue;
485        }
486        out.push(trimmed.to_string());
487    }
488
489    (!out.is_empty()).then_some(out)
490}
491
492fn strip_markdown_fence(raw: &str) -> String {
493    let trimmed = raw.trim();
494    for fence in ["````", "```"] {
495        if let Some(after_fence) = trimmed.strip_prefix(fence) {
496            let Some(first_newline) = after_fence.find('\n') else {
497                continue;
498            };
499            let body = &after_fence[first_newline + 1..];
500            if let Some(end_idx) = body.rfind(fence) {
501                return body[..end_idx].trim().to_string();
502            }
503        }
504    }
505    trimmed.to_string()
506}
507
508fn extract_json_fragment(raw: &str) -> Option<&str> {
509    let trimmed = raw.trim();
510    if trimmed.is_empty() {
511        return None;
512    }
513
514    if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}')) {
515        if start <= end {
516            return Some(trimmed[start..=end].trim());
517        }
518    }
519
520    if let (Some(start), Some(end)) = (trimmed.find('['), trimmed.rfind(']')) {
521        if start <= end {
522            return Some(trimmed[start..=end].trim());
523        }
524    }
525
526    None
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532    use crate::memory_store::DurableMemoryType;
533    use async_trait::async_trait;
534    use bamboo_domain::ReasoningEffort;
535    use bamboo_llm::provider::LLMRequestOptions;
536    use bamboo_llm::{LLMChunk, LLMError, LLMProvider, LLMStream};
537    use futures::stream;
538    use std::sync::Mutex;
539    use tempfile::tempdir;
540
541    #[allow(clippy::too_many_arguments)]
542    fn item(
543        id: &str,
544        title: &str,
545        status: DurableMemoryStatus,
546        updated_at: &str,
547        keywords: &[&str],
548        tags: &[&str],
549        entities: &[&str],
550        summary: &str,
551    ) -> LexicalIndexItem {
552        LexicalIndexItem {
553            id: id.to_string(),
554            title: title.to_string(),
555            scope: MemoryScope::Project,
556            project_key: Some("proj-1".to_string()),
557            r#type: DurableMemoryType::Project,
558            status,
559            tags: tags.iter().map(|v| v.to_string()).collect(),
560            keywords: keywords.iter().map(|v| v.to_string()).collect(),
561            entities: entities.iter().map(|v| v.to_string()).collect(),
562            updated_at: updated_at.to_string(),
563            created_at: updated_at.to_string(),
564            summary: summary.to_string(),
565            granularity: None,
566        }
567    }
568
569    #[derive(Clone)]
570    struct StaticResponseProvider {
571        response: String,
572        requested_models: Arc<Mutex<Vec<String>>>,
573    }
574
575    impl StaticResponseProvider {
576        fn new(response: impl Into<String>) -> Self {
577            Self {
578                response: response.into(),
579                requested_models: Arc::new(Mutex::new(Vec::new())),
580            }
581        }
582    }
583
584    #[async_trait]
585    impl LLMProvider for StaticResponseProvider {
586        async fn chat_stream(
587            &self,
588            _messages: &[Message],
589            _tools: &[bamboo_agent_core::ToolSchema],
590            _max_output_tokens: Option<u32>,
591            model: &str,
592        ) -> Result<LLMStream, LLMError> {
593            self.requested_models
594                .lock()
595                .expect("lock poisoned")
596                .push(model.to_string());
597            Ok(Box::pin(stream::iter(vec![
598                Ok(LLMChunk::Token(self.response.clone())),
599                Ok(LLMChunk::Done),
600            ])))
601        }
602    }
603
604    fn candidate(
605        id: &str,
606        score: f64,
607        granularity: Option<TemporalGranularity>,
608    ) -> MemoryRecallCandidate {
609        MemoryRecallCandidate {
610            id: id.to_string(),
611            title: id.to_string(),
612            score,
613            scope: MemoryScope::Project,
614            project_key: Some("proj-1".to_string()),
615            status: DurableMemoryStatus::Active,
616            // Same timestamp for all so granularity is the deciding tie-breaker.
617            updated_at: "2026-04-09T00:00:00Z".to_string(),
618            summary: "summary".to_string(),
619            granularity,
620        }
621    }
622
623    #[test]
624    fn equal_score_candidates_sort_coarse_granularity_first_for_cache_stability() {
625        // All same score + same timestamp → granularity decides ordering. Coarser
626        // (year) is more prefix-cache friendly and must sort ahead of finer (day).
627        let mut candidates = vec![
628            candidate("day", 5.0, Some(TemporalGranularity::Day)),
629            candidate("year", 5.0, Some(TemporalGranularity::Year)),
630            candidate("none", 5.0, None),
631            candidate("month", 5.0, Some(TemporalGranularity::Month)),
632        ];
633        sort_recall_candidates(&mut candidates);
634        let order: Vec<&str> = candidates.iter().map(|c| c.id.as_str()).collect();
635        // None is treated as most stable (rank 0), then year, month, day.
636        assert_eq!(order, vec!["none", "year", "month", "day"]);
637    }
638
639    #[test]
640    fn higher_score_still_wins_over_cache_stable_granularity() {
641        // Granularity is only a tie-breaker: a more relevant (higher score) day-level
642        // memory must still outrank a less relevant year-level one.
643        let mut candidates = vec![
644            candidate("year-low", 1.0, Some(TemporalGranularity::Year)),
645            candidate("day-high", 9.0, Some(TemporalGranularity::Day)),
646        ];
647        sort_recall_candidates(&mut candidates);
648        assert_eq!(candidates[0].id, "day-high");
649    }
650
651    #[test]
652    fn title_matches_outrank_keyword_only_matches() {
653        let query_tokens = vec!["release".to_string(), "freeze".to_string()];
654        let title_item = item(
655            "a",
656            "Release freeze decision",
657            DurableMemoryStatus::Active,
658            "2026-04-09T00:00:00Z",
659            &[],
660            &[],
661            &[],
662            "summary",
663        );
664        let keyword_item = item(
665            "b",
666            "Deployment decision",
667            DurableMemoryStatus::Active,
668            "2026-04-09T00:00:00Z",
669            &["release", "freeze"],
670            &[],
671            &[],
672            "summary",
673        );
674
675        let title_score = score_lexical_index_item(&title_item, &query_tokens).unwrap();
676        let keyword_score = score_lexical_index_item(&keyword_item, &query_tokens).unwrap();
677        assert!(title_score > keyword_score);
678    }
679
680    #[test]
681    fn active_items_outrank_stale_items() {
682        let query_tokens = vec!["release".to_string()];
683        let active = item(
684            "a",
685            "Release freeze decision",
686            DurableMemoryStatus::Active,
687            "2026-04-09T00:00:00Z",
688            &[],
689            &[],
690            &[],
691            "summary",
692        );
693        let stale = item(
694            "b",
695            "Release freeze decision",
696            DurableMemoryStatus::Stale,
697            "2026-04-10T00:00:00Z",
698            &[],
699            &[],
700            &[],
701            "summary",
702        );
703
704        let active_score = score_lexical_index_item(&active, &query_tokens).unwrap();
705        let stale_score = score_lexical_index_item(&stale, &query_tokens).unwrap();
706        assert!(active_score > stale_score);
707    }
708
709    #[test]
710    fn contradicted_and_archived_items_are_filtered_out() {
711        let query_tokens = vec!["release".to_string()];
712        let contradicted = item(
713            "a",
714            "Release freeze decision",
715            DurableMemoryStatus::Contradicted,
716            "2026-04-09T00:00:00Z",
717            &[],
718            &[],
719            &[],
720            "summary",
721        );
722        let archived = item(
723            "b",
724            "Release freeze decision",
725            DurableMemoryStatus::Archived,
726            "2026-04-09T00:00:00Z",
727            &[],
728            &[],
729            &[],
730            "summary",
731        );
732
733        assert!(score_lexical_index_item(&contradicted, &query_tokens).is_none());
734        assert!(score_lexical_index_item(&archived, &query_tokens).is_none());
735    }
736
737    #[test]
738    fn parse_reranked_ids_accepts_fenced_json_and_filters_unknown_ids() {
739        let candidates = vec![
740            MemoryRecallCandidate {
741                id: "mem-a".to_string(),
742                title: "A".to_string(),
743                score: 10.0,
744                scope: MemoryScope::Project,
745                project_key: Some("proj-1".to_string()),
746                status: DurableMemoryStatus::Active,
747                updated_at: "2026-04-09T00:00:00Z".to_string(),
748                summary: "summary a".to_string(),
749                granularity: None,
750            },
751            MemoryRecallCandidate {
752                id: "mem-b".to_string(),
753                title: "B".to_string(),
754                score: 9.0,
755                scope: MemoryScope::Project,
756                project_key: Some("proj-1".to_string()),
757                status: DurableMemoryStatus::Active,
758                updated_at: "2026-04-09T00:00:00Z".to_string(),
759                summary: "summary b".to_string(),
760                granularity: None,
761            },
762        ];
763
764        let parsed = parse_reranked_ids(
765            "```json\n{\"ids\":[\"mem-b\",\"unknown\",\"mem-a\",\"mem-b\"]}\n```",
766            &candidates,
767        )
768        .expect("reranked ids should parse");
769
770        assert_eq!(parsed, vec!["mem-b".to_string(), "mem-a".to_string()]);
771    }
772
773    #[test]
774    fn reorder_candidates_by_ids_appends_remaining_lexical_candidates() {
775        let lexical = vec![
776            MemoryRecallCandidate {
777                id: "mem-a".to_string(),
778                title: "A".to_string(),
779                score: 10.0,
780                scope: MemoryScope::Project,
781                project_key: Some("proj-1".to_string()),
782                status: DurableMemoryStatus::Active,
783                updated_at: "2026-04-09T00:00:00Z".to_string(),
784                summary: "summary a".to_string(),
785                granularity: None,
786            },
787            MemoryRecallCandidate {
788                id: "mem-b".to_string(),
789                title: "B".to_string(),
790                score: 9.0,
791                scope: MemoryScope::Project,
792                project_key: Some("proj-1".to_string()),
793                status: DurableMemoryStatus::Active,
794                updated_at: "2026-04-09T00:00:00Z".to_string(),
795                summary: "summary b".to_string(),
796                granularity: None,
797            },
798            MemoryRecallCandidate {
799                id: "mem-c".to_string(),
800                title: "C".to_string(),
801                score: 8.0,
802                scope: MemoryScope::Project,
803                project_key: Some("proj-1".to_string()),
804                status: DurableMemoryStatus::Active,
805                updated_at: "2026-04-09T00:00:00Z".to_string(),
806                summary: "summary c".to_string(),
807                granularity: None,
808            },
809        ];
810
811        let reordered =
812            reorder_candidates_by_ids(&lexical, &["mem-c".to_string(), "mem-a".to_string()], 3);
813
814        assert_eq!(reordered[0].id, "mem-c");
815        assert_eq!(reordered[1].id, "mem-a");
816        assert_eq!(reordered[2].id, "mem-b");
817    }
818
819    #[tokio::test]
820    async fn project_scope_shortlist_excludes_global_when_project_hits_exist() {
821        let dir = tempdir().unwrap();
822        let store = MemoryStore::new(dir.path());
823
824        store
825            .write_memory(
826                MemoryScope::Project,
827                Some("proj-1"),
828                DurableMemoryType::Project,
829                "Release freeze decision",
830                "Project-specific release freeze note.",
831                &["release".to_string()],
832                Some("session-1"),
833                "main-model",
834                false,
835                None,
836            )
837            .await
838            .unwrap();
839        store
840            .write_memory(
841                MemoryScope::Global,
842                None,
843                DurableMemoryType::Reference,
844                "Global release guidance",
845                "Global note that should not be used when project hits exist.",
846                &["release".to_string()],
847                Some("session-1"),
848                "main-model",
849                false,
850                None,
851            )
852            .await
853            .unwrap();
854
855        let candidates = shortlist_relevant_memories(
856            &store,
857            Some("proj-1"),
858            "release freeze",
859            &MemoryRecallOptions::default(),
860        )
861        .await
862        .unwrap();
863
864        assert!(!candidates.is_empty());
865        assert!(candidates
866            .iter()
867            .all(|candidate| candidate.scope == MemoryScope::Project));
868    }
869
870    #[tokio::test]
871    async fn global_fallback_triggers_only_when_project_hits_are_absent() {
872        let dir = tempdir().unwrap();
873        let store = MemoryStore::new(dir.path());
874
875        store
876            .write_memory(
877                MemoryScope::Global,
878                None,
879                DurableMemoryType::Reference,
880                "Global release guidance",
881                "Fallback note for release work.",
882                &["release".to_string()],
883                Some("session-1"),
884                "main-model",
885                false,
886                None,
887            )
888            .await
889            .unwrap();
890
891        let candidates = shortlist_relevant_memories(
892            &store,
893            Some("proj-missing"),
894            "release guidance",
895            &MemoryRecallOptions::default(),
896        )
897        .await
898        .unwrap();
899
900        assert!(!candidates.is_empty());
901        assert!(candidates
902            .iter()
903            .all(|candidate| candidate.scope == MemoryScope::Global));
904    }
905
906    #[tokio::test]
907    async fn model_rerank_reorders_lexical_shortlist_when_enabled() {
908        let dir = tempdir().unwrap();
909        let store = MemoryStore::new(dir.path());
910
911        let lexical_first = store
912            .write_memory(
913                MemoryScope::Project,
914                Some("proj-1"),
915                DurableMemoryType::Project,
916                "Release freeze checklist",
917                "Generic release freeze checklist for shipping work.",
918                &["release".to_string(), "freeze".to_string()],
919                Some("session-1"),
920                "main-model",
921                false,
922                None,
923            )
924            .await
925            .unwrap();
926        let reranked_first = store
927            .write_memory(
928                MemoryScope::Project,
929                Some("proj-1"),
930                DurableMemoryType::Project,
931                "Mobile launch blocker",
932                "This durable note captures the release freeze decision for the mobile app and should be preferred for mobile freeze requests.",
933                &["mobile".to_string(), "launch".to_string()],
934                Some("session-1"),
935                "main-model",
936                false,
937                None,
938            )
939            .await
940            .unwrap();
941
942        let provider = StaticResponseProvider::new(format!(
943            "{{\"ids\":[\"{}\",\"{}\"]}}",
944            reranked_first.frontmatter.id, lexical_first.frontmatter.id
945        ));
946        let requested_models = provider.requested_models.clone();
947        let selection = select_relevant_memories(
948            &store,
949            Some("proj-1"),
950            "release freeze for mobile",
951            &MemoryRecallOptions {
952                shortlist_limit: 2,
953                include_global_fallback: false,
954                max_candidates_per_scope: 12,
955            },
956            Some(&MemoryRecallRerankContext {
957                llm: Arc::new(provider),
958                model: "rerank-fast-model".to_string(),
959                session_id: Some("session-1".to_string()),
960            }),
961        )
962        .await
963        .unwrap();
964
965        assert_eq!(selection.strategy, MemoryRecallStrategy::Reranked);
966        assert_eq!(selection.candidates.len(), 2);
967        assert_eq!(selection.candidates[0].id, reranked_first.frontmatter.id);
968        assert_eq!(selection.candidates[1].id, lexical_first.frontmatter.id);
969        assert_eq!(
970            requested_models.lock().expect("lock poisoned").as_slice(),
971            ["rerank-fast-model"]
972        );
973    }
974
975    #[tokio::test]
976    async fn invalid_model_rerank_response_falls_back_to_lexical_order() {
977        let dir = tempdir().unwrap();
978        let store = MemoryStore::new(dir.path());
979
980        let lexical_first = store
981            .write_memory(
982                MemoryScope::Project,
983                Some("proj-1"),
984                DurableMemoryType::Project,
985                "Release freeze checklist",
986                "Generic release freeze checklist for shipping work.",
987                &["release".to_string(), "freeze".to_string()],
988                Some("session-1"),
989                "main-model",
990                false,
991                None,
992            )
993            .await
994            .unwrap();
995        let lexical_second = store
996            .write_memory(
997                MemoryScope::Project,
998                Some("proj-1"),
999                DurableMemoryType::Project,
1000                "Mobile launch blocker",
1001                "This durable note captures the release freeze decision for the mobile app.",
1002                &["mobile".to_string(), "launch".to_string()],
1003                Some("session-1"),
1004                "main-model",
1005                false,
1006                None,
1007            )
1008            .await
1009            .unwrap();
1010
1011        let selection = select_relevant_memories(
1012            &store,
1013            Some("proj-1"),
1014            "release freeze for mobile",
1015            &MemoryRecallOptions {
1016                shortlist_limit: 2,
1017                include_global_fallback: false,
1018                max_candidates_per_scope: 12,
1019            },
1020            Some(&MemoryRecallRerankContext {
1021                llm: Arc::new(StaticResponseProvider::new("not valid json")),
1022                model: "rerank-fast-model".to_string(),
1023                session_id: Some("session-1".to_string()),
1024            }),
1025        )
1026        .await
1027        .unwrap();
1028
1029        assert_eq!(selection.strategy, MemoryRecallStrategy::RerankFallback);
1030        assert_eq!(selection.candidates.len(), 2);
1031        assert_eq!(selection.candidates[0].id, lexical_first.frontmatter.id);
1032        assert_eq!(selection.candidates[1].id, lexical_second.frontmatter.id);
1033    }
1034
1035    /// Provider that captures `max_output_tokens` and `reasoning_effort` from `chat_stream_with_options`.
1036    #[derive(Default)]
1037    struct RequestOptionsCaptureProvider {
1038        captured_max_tokens: Mutex<Vec<Option<u32>>>,
1039        captured_reasoning: Mutex<Vec<Option<ReasoningEffort>>>,
1040    }
1041
1042    #[async_trait]
1043    impl LLMProvider for RequestOptionsCaptureProvider {
1044        async fn chat_stream(
1045            &self,
1046            _messages: &[Message],
1047            _tools: &[bamboo_agent_core::ToolSchema],
1048            _max_output_tokens: Option<u32>,
1049            _model: &str,
1050        ) -> Result<LLMStream, LLMError> {
1051            Ok(Box::pin(stream::iter(vec![
1052                Ok(LLMChunk::Token("{\"ids\":[]}".to_string())),
1053                Ok(LLMChunk::Done),
1054            ])))
1055        }
1056
1057        async fn chat_stream_with_options(
1058            &self,
1059            messages: &[Message],
1060            tools: &[bamboo_agent_core::ToolSchema],
1061            max_output_tokens: Option<u32>,
1062            model: &str,
1063            options: Option<&LLMRequestOptions>,
1064        ) -> Result<LLMStream, LLMError> {
1065            self.captured_max_tokens
1066                .lock()
1067                .expect("lock should not be poisoned")
1068                .push(max_output_tokens);
1069            self.captured_reasoning
1070                .lock()
1071                .expect("lock should not be poisoned")
1072                .push(options.and_then(|o| o.reasoning_effort));
1073            self.chat_stream(messages, tools, max_output_tokens, model)
1074                .await
1075        }
1076    }
1077
1078    #[tokio::test]
1079    async fn rerank_sufficient_max_tokens_for_high_reasoning() {
1080        let provider = Arc::new(RequestOptionsCaptureProvider::default());
1081        let candidates = vec![MemoryRecallCandidate {
1082            id: "mem-1".to_string(),
1083            score: 0.9,
1084            title: "Test memory".to_string(),
1085            scope: MemoryScope::Project,
1086            project_key: Some("proj-1".to_string()),
1087            status: DurableMemoryStatus::Active,
1088            updated_at: "2026-05-08T00:00:00Z".to_string(),
1089            summary: "A test durable memory entry".to_string(),
1090            granularity: None,
1091        }];
1092        let context = MemoryRecallRerankContext {
1093            llm: provider.clone(),
1094            model: "deepseek-v4-pro".to_string(),
1095            session_id: Some("test-session".to_string()),
1096        };
1097
1098        let _ = rerank_candidate_ids("test query", &candidates, 5, &context).await;
1099
1100        let captured_reasoning = provider
1101            .captured_reasoning
1102            .lock()
1103            .expect("lock should not be poisoned");
1104        let captured_max_tokens = provider
1105            .captured_max_tokens
1106            .lock()
1107            .expect("lock should not be poisoned");
1108        assert_eq!(
1109            captured_reasoning.as_slice(),
1110            [Some(ReasoningEffort::High)],
1111            "rerank should request High reasoning"
1112        );
1113        let max_tokens = captured_max_tokens[0].expect("max_output_tokens should be set");
1114        assert!(
1115            max_tokens > 4096,
1116            "max_output_tokens ({}) must exceed thinking budget (4096) to avoid truncation",
1117            max_tokens
1118        );
1119    }
1120}