Skip to main content

bamboo_memory/memory_store/
recall.rs

1use std::cmp::Ordering;
2use std::collections::HashSet;
3use std::io;
4use std::sync::Arc;
5
6use futures::StreamExt;
7use serde::Deserialize;
8
9use bamboo_agent_core::Message;
10use bamboo_domain::ReasoningEffort;
11use bamboo_infrastructure::{LLMChunk, LLMProvider, LLMRequestOptions};
12
13use super::{
14    extract_keywords, parse_rfc3339, DurableMemoryStatus, LexicalIndexItem, MemoryScope,
15    MemoryStore,
16};
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct MemoryRecallCandidate {
20    pub id: String,
21    pub title: String,
22    pub score: f64,
23    pub scope: MemoryScope,
24    pub project_key: Option<String>,
25    pub status: DurableMemoryStatus,
26    pub updated_at: String,
27    pub summary: String,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct MemoryRecallOptions {
32    pub shortlist_limit: usize,
33    pub include_global_fallback: bool,
34    pub max_candidates_per_scope: usize,
35}
36
37impl Default for MemoryRecallOptions {
38    fn default() -> Self {
39        Self {
40            shortlist_limit: 3,
41            include_global_fallback: true,
42            max_candidates_per_scope: 20,
43        }
44    }
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum MemoryRecallStrategy {
49    Lexical,
50    Reranked,
51    RerankFallback,
52}
53
54impl MemoryRecallStrategy {
55    pub fn as_str(self) -> &'static str {
56        match self {
57            Self::Lexical => "lexical",
58            Self::Reranked => "reranked",
59            Self::RerankFallback => "rerank_fallback",
60        }
61    }
62}
63
64#[derive(Debug, Clone, PartialEq)]
65pub struct MemoryRecallSelection {
66    pub candidates: Vec<MemoryRecallCandidate>,
67    pub strategy: MemoryRecallStrategy,
68}
69
70#[derive(Clone)]
71pub struct MemoryRecallRerankContext {
72    pub llm: Arc<dyn LLMProvider>,
73    pub model: String,
74    pub session_id: Option<String>,
75}
76
77#[derive(Debug, Deserialize)]
78struct MemoryRecallRerankEnvelope {
79    #[serde(default)]
80    ids: Vec<String>,
81}
82
83pub async fn shortlist_relevant_memories(
84    store: &MemoryStore,
85    project_key: Option<&str>,
86    query: &str,
87    options: &MemoryRecallOptions,
88) -> io::Result<Vec<MemoryRecallCandidate>> {
89    let limit = options.shortlist_limit.max(1);
90    let mut candidates =
91        lexical_shortlist_relevant_memories(store, project_key, query, options).await?;
92    candidates.truncate(limit);
93    Ok(candidates)
94}
95
96pub async fn select_relevant_memories(
97    store: &MemoryStore,
98    project_key: Option<&str>,
99    query: &str,
100    options: &MemoryRecallOptions,
101    rerank_context: Option<&MemoryRecallRerankContext>,
102) -> io::Result<MemoryRecallSelection> {
103    let query = query.trim();
104    if query.is_empty() {
105        return Ok(MemoryRecallSelection {
106            candidates: Vec::new(),
107            strategy: MemoryRecallStrategy::Lexical,
108        });
109    }
110
111    let limit = options.shortlist_limit.max(1);
112    let mut shortlist =
113        lexical_shortlist_relevant_memories(store, project_key, query, options).await?;
114    if shortlist.is_empty() {
115        return Ok(MemoryRecallSelection {
116            candidates: shortlist,
117            strategy: MemoryRecallStrategy::Lexical,
118        });
119    }
120
121    let Some(rerank_context) = rerank_context else {
122        shortlist.truncate(limit);
123        return Ok(MemoryRecallSelection {
124            candidates: shortlist,
125            strategy: MemoryRecallStrategy::Lexical,
126        });
127    };
128
129    if shortlist.len() <= 1 {
130        shortlist.truncate(limit);
131        return Ok(MemoryRecallSelection {
132            candidates: shortlist,
133            strategy: MemoryRecallStrategy::Lexical,
134        });
135    }
136
137    match rerank_candidate_ids(query, &shortlist, limit, rerank_context).await {
138        Ok(ids) => {
139            let reranked = reorder_candidates_by_ids(&shortlist, &ids, limit);
140            if reranked.is_empty() {
141                let mut lexical = shortlist;
142                lexical.truncate(limit);
143                return Ok(MemoryRecallSelection {
144                    candidates: lexical,
145                    strategy: MemoryRecallStrategy::RerankFallback,
146                });
147            }
148            Ok(MemoryRecallSelection {
149                candidates: reranked,
150                strategy: MemoryRecallStrategy::Reranked,
151            })
152        }
153        Err(error) => {
154            tracing::warn!(
155                "Relevant memory rerank failed for model '{}': {}. Falling back to lexical shortlist.",
156                rerank_context.model,
157                error
158            );
159            shortlist.truncate(limit);
160            Ok(MemoryRecallSelection {
161                candidates: shortlist,
162                strategy: MemoryRecallStrategy::RerankFallback,
163            })
164        }
165    }
166}
167
168async fn lexical_shortlist_relevant_memories(
169    store: &MemoryStore,
170    project_key: Option<&str>,
171    query: &str,
172    options: &MemoryRecallOptions,
173) -> io::Result<Vec<MemoryRecallCandidate>> {
174    let query = query.trim();
175    if query.is_empty() {
176        return Ok(Vec::new());
177    }
178
179    let limit = options.shortlist_limit.max(1);
180    let per_scope_limit = options.max_candidates_per_scope.max(limit);
181
182    if let Some(project_key) = project_key.map(str::trim).filter(|value| !value.is_empty()) {
183        let mut project_hits =
184            shortlist_scope(store, MemoryScope::Project, Some(project_key), query).await?;
185        project_hits.truncate(per_scope_limit);
186        if !project_hits.is_empty() {
187            return Ok(project_hits);
188        }
189    }
190
191    if options.include_global_fallback {
192        let mut global_hits = shortlist_scope(store, MemoryScope::Global, None, query).await?;
193        global_hits.truncate(per_scope_limit);
194        return Ok(global_hits);
195    }
196
197    Ok(Vec::new())
198}
199
200async fn shortlist_scope(
201    store: &MemoryStore,
202    scope: MemoryScope,
203    project_key: Option<&str>,
204    query: &str,
205) -> io::Result<Vec<MemoryRecallCandidate>> {
206    let Some(index) = store.read_lexical_index(scope, project_key).await? else {
207        return Ok(Vec::new());
208    };
209
210    let query_tokens = extract_keywords(query, "", &[]);
211    if query_tokens.is_empty() {
212        return Ok(Vec::new());
213    }
214
215    let mut candidates = index
216        .items
217        .iter()
218        .filter_map(|item| score_lexical_index_item(item, &query_tokens).map(|score| (item, score)))
219        .map(|(item, score)| MemoryRecallCandidate {
220            id: item.id.clone(),
221            title: item.title.clone(),
222            score,
223            scope: item.scope,
224            project_key: item.project_key.clone(),
225            status: item.status,
226            updated_at: item.updated_at.clone(),
227            summary: item.summary.clone(),
228        })
229        .collect::<Vec<_>>();
230
231    sort_recall_candidates(&mut candidates);
232    Ok(candidates)
233}
234
235fn score_lexical_index_item(item: &LexicalIndexItem, query_tokens: &[String]) -> Option<f64> {
236    match item.status {
237        DurableMemoryStatus::Superseded
238        | DurableMemoryStatus::Contradicted
239        | DurableMemoryStatus::Archived => return None,
240        DurableMemoryStatus::Active | DurableMemoryStatus::Stale => {}
241    }
242
243    let title = item.title.to_ascii_lowercase();
244    let summary = item.summary.to_ascii_lowercase();
245
246    let mut score = 0.0;
247    let mut matched_any = false;
248
249    for token in query_tokens {
250        let mut token_score = 0.0;
251        if title.contains(token) {
252            token_score += 3.0;
253        }
254        if item
255            .keywords
256            .iter()
257            .any(|value| value.eq_ignore_ascii_case(token))
258        {
259            token_score += 2.5;
260        }
261        if item
262            .tags
263            .iter()
264            .any(|value| value.eq_ignore_ascii_case(token))
265        {
266            token_score += 2.0;
267        }
268        if item
269            .entities
270            .iter()
271            .any(|value| value.eq_ignore_ascii_case(token))
272        {
273            token_score += 1.5;
274        }
275        if summary.contains(token) {
276            token_score += 1.0;
277        }
278        if token_score > 0.0 {
279            matched_any = true;
280            score += token_score;
281        }
282    }
283
284    if !matched_any {
285        return None;
286    }
287
288    score += lexical_status_adjustment(item.status);
289    Some((score / query_tokens.len() as f64 * 100.0).round() / 100.0)
290}
291
292fn lexical_status_adjustment(status: DurableMemoryStatus) -> f64 {
293    match status {
294        DurableMemoryStatus::Active => 0.0,
295        DurableMemoryStatus::Stale => -0.75,
296        DurableMemoryStatus::Superseded
297        | DurableMemoryStatus::Contradicted
298        | DurableMemoryStatus::Archived => -10.0,
299    }
300}
301
302fn sort_recall_candidates(candidates: &mut [MemoryRecallCandidate]) {
303    candidates.sort_by(|left, right| {
304        right
305            .score
306            .partial_cmp(&left.score)
307            .unwrap_or(Ordering::Equal)
308            .then_with(|| {
309                let left_dt = parse_rfc3339(&left.updated_at)
310                    .unwrap_or(chrono::DateTime::<chrono::Utc>::MIN_UTC);
311                let right_dt = parse_rfc3339(&right.updated_at)
312                    .unwrap_or(chrono::DateTime::<chrono::Utc>::MIN_UTC);
313                right_dt.cmp(&left_dt)
314            })
315            .then_with(|| left.title.cmp(&right.title))
316    });
317}
318
319fn build_rerank_prompt(query: &str, candidates: &[MemoryRecallCandidate], limit: usize) -> String {
320    let mut prompt = String::from("# Bamboo Relevant Memory Recall Rerank\n\n");
321    prompt.push_str(
322        "Select the durable memory candidates that are most relevant to the user query.\n",
323    );
324    prompt.push_str("Return JSON only in the form {\"ids\":[\"candidate-id\", ...]}.\n");
325    prompt
326        .push_str("Do not include commentary, markdown fences, explanations, or unknown ids.\n\n");
327    prompt.push_str("## User query\n");
328    prompt.push_str(query.trim());
329    prompt.push_str("\n\n## Candidate memories\n");
330
331    for (index, candidate) in candidates.iter().enumerate() {
332        prompt.push_str(&format!(
333            "{}. id={}\n   title: {}\n   scope: {}\n   status: {}\n   updated_at: {}\n   lexical_score: {:.2}\n   summary: {}\n",
334            index + 1,
335            candidate.id,
336            candidate.title,
337            candidate.scope.as_str(),
338            candidate.status.as_str(),
339            candidate.updated_at,
340            candidate.score,
341            candidate.summary.replace('\n', " "),
342        ));
343    }
344
345    prompt.push_str(&format!(
346        "\n## Selection rules\n- Return at most {limit} ids.\n- Use only ids from the candidate list above.\n- Prefer candidates that best answer the user query or encode active preferences/constraints relevant to it.\n- Prefer active memories over stale ones when relevance is otherwise similar.\n- Keep the ids ordered best-to-worst.\n"
347    ));
348    prompt
349}
350
351async fn rerank_candidate_ids(
352    query: &str,
353    candidates: &[MemoryRecallCandidate],
354    limit: usize,
355    context: &MemoryRecallRerankContext,
356) -> Result<Vec<String>, String> {
357    let model = context.model.trim();
358    if model.is_empty() {
359        return Err("rerank model is empty".to_string());
360    }
361
362    let messages = vec![
363        Message::system(
364            "You rerank Bamboo durable-memory recall candidates. Return strict JSON only in the form {\"ids\":[...]} using only candidate ids from the prompt.",
365        ),
366        Message::user(build_rerank_prompt(query, candidates, limit)),
367    ];
368    let options = LLMRequestOptions {
369        session_id: context.session_id.clone(),
370        reasoning_effort: Some(ReasoningEffort::High),
371        parallel_tool_calls: None,
372        responses: None,
373        request_purpose: Some("memory_rerank".to_string()),
374        cache: None,
375    };
376
377    let mut stream = context
378        .llm
379        .chat_stream_with_options(&messages, &[], Some(8192), model, Some(&options))
380        .await
381        .map_err(|error| format!("rerank provider call failed: {error}"))?;
382
383    let content = tokio::time::timeout(std::time::Duration::from_secs(30), async {
384        let mut content = String::new();
385        while let Some(chunk_result) = stream.next().await {
386            match chunk_result {
387                Ok(LLMChunk::Token(text)) => content.push_str(&text),
388                Ok(LLMChunk::Done) => break,
389                Ok(_) => {}
390                Err(error) => {
391                    if !content.trim().is_empty() {
392                        break;
393                    }
394                    return Err(format!("rerank stream failed: {error}"));
395                }
396            }
397        }
398        Ok(content)
399    })
400    .await
401    .unwrap_or_else(|_| Err("rerank timed out after 30s".to_string()))?;
402
403    parse_reranked_ids(&content, candidates)
404        .ok_or_else(|| format!("failed to parse rerank response: {}", content.trim()))
405}
406
407fn reorder_candidates_by_ids(
408    lexical_candidates: &[MemoryRecallCandidate],
409    preferred_ids: &[String],
410    limit: usize,
411) -> Vec<MemoryRecallCandidate> {
412    if lexical_candidates.is_empty() || limit == 0 {
413        return Vec::new();
414    }
415
416    let allowed = lexical_candidates
417        .iter()
418        .map(|candidate| candidate.id.as_str())
419        .collect::<HashSet<_>>();
420    let mut seen = HashSet::new();
421    let mut ordered = Vec::new();
422
423    for id in preferred_ids {
424        let trimmed = id.trim();
425        if trimmed.is_empty() || !allowed.contains(trimmed) || !seen.insert(trimmed.to_string()) {
426            continue;
427        }
428        if let Some(candidate) = lexical_candidates
429            .iter()
430            .find(|candidate| candidate.id == trimmed)
431            .cloned()
432        {
433            ordered.push(candidate);
434            if ordered.len() >= limit {
435                return ordered;
436            }
437        }
438    }
439
440    for candidate in lexical_candidates {
441        if seen.insert(candidate.id.clone()) {
442            ordered.push(candidate.clone());
443            if ordered.len() >= limit {
444                break;
445            }
446        }
447    }
448
449    ordered
450}
451
452fn parse_reranked_ids(raw: &str, candidates: &[MemoryRecallCandidate]) -> Option<Vec<String>> {
453    let stripped = strip_markdown_fence(raw);
454    let fragment = extract_json_fragment(&stripped).unwrap_or(stripped.trim());
455    let ids = serde_json::from_str::<MemoryRecallRerankEnvelope>(fragment)
456        .map(|value| value.ids)
457        .or_else(|_| serde_json::from_str::<Vec<String>>(fragment))
458        .ok()?;
459
460    let allowed = candidates
461        .iter()
462        .map(|candidate| candidate.id.as_str())
463        .collect::<HashSet<_>>();
464    let mut seen = HashSet::new();
465    let mut out = Vec::new();
466
467    for id in ids {
468        let trimmed = id.trim();
469        if trimmed.is_empty() || !allowed.contains(trimmed) || !seen.insert(trimmed.to_string()) {
470            continue;
471        }
472        out.push(trimmed.to_string());
473    }
474
475    (!out.is_empty()).then_some(out)
476}
477
478fn strip_markdown_fence(raw: &str) -> String {
479    let trimmed = raw.trim();
480    for fence in ["````", "```"] {
481        if let Some(after_fence) = trimmed.strip_prefix(fence) {
482            let Some(first_newline) = after_fence.find('\n') else {
483                continue;
484            };
485            let body = &after_fence[first_newline + 1..];
486            if let Some(end_idx) = body.rfind(fence) {
487                return body[..end_idx].trim().to_string();
488            }
489        }
490    }
491    trimmed.to_string()
492}
493
494fn extract_json_fragment(raw: &str) -> Option<&str> {
495    let trimmed = raw.trim();
496    if trimmed.is_empty() {
497        return None;
498    }
499
500    if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}')) {
501        if start <= end {
502            return Some(trimmed[start..=end].trim());
503        }
504    }
505
506    if let (Some(start), Some(end)) = (trimmed.find('['), trimmed.rfind(']')) {
507        if start <= end {
508            return Some(trimmed[start..=end].trim());
509        }
510    }
511
512    None
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518    use crate::memory_store::DurableMemoryType;
519    use async_trait::async_trait;
520    use bamboo_domain::ReasoningEffort;
521    use bamboo_infrastructure::llm::provider::LLMRequestOptions;
522    use bamboo_infrastructure::{LLMChunk, LLMError, LLMProvider, LLMStream};
523    use futures::stream;
524    use std::sync::Mutex;
525    use tempfile::tempdir;
526
527    #[allow(clippy::too_many_arguments)]
528    fn item(
529        id: &str,
530        title: &str,
531        status: DurableMemoryStatus,
532        updated_at: &str,
533        keywords: &[&str],
534        tags: &[&str],
535        entities: &[&str],
536        summary: &str,
537    ) -> LexicalIndexItem {
538        LexicalIndexItem {
539            id: id.to_string(),
540            title: title.to_string(),
541            scope: MemoryScope::Project,
542            project_key: Some("proj-1".to_string()),
543            r#type: DurableMemoryType::Project,
544            status,
545            tags: tags.iter().map(|v| v.to_string()).collect(),
546            keywords: keywords.iter().map(|v| v.to_string()).collect(),
547            entities: entities.iter().map(|v| v.to_string()).collect(),
548            updated_at: updated_at.to_string(),
549            created_at: updated_at.to_string(),
550            summary: summary.to_string(),
551        }
552    }
553
554    #[derive(Clone)]
555    struct StaticResponseProvider {
556        response: String,
557        requested_models: Arc<Mutex<Vec<String>>>,
558    }
559
560    impl StaticResponseProvider {
561        fn new(response: impl Into<String>) -> Self {
562            Self {
563                response: response.into(),
564                requested_models: Arc::new(Mutex::new(Vec::new())),
565            }
566        }
567    }
568
569    #[async_trait]
570    impl LLMProvider for StaticResponseProvider {
571        async fn chat_stream(
572            &self,
573            _messages: &[Message],
574            _tools: &[bamboo_agent_core::ToolSchema],
575            _max_output_tokens: Option<u32>,
576            model: &str,
577        ) -> Result<LLMStream, LLMError> {
578            self.requested_models
579                .lock()
580                .expect("lock poisoned")
581                .push(model.to_string());
582            Ok(Box::pin(stream::iter(vec![
583                Ok(LLMChunk::Token(self.response.clone())),
584                Ok(LLMChunk::Done),
585            ])))
586        }
587    }
588
589    #[test]
590    fn title_matches_outrank_keyword_only_matches() {
591        let query_tokens = vec!["release".to_string(), "freeze".to_string()];
592        let title_item = item(
593            "a",
594            "Release freeze decision",
595            DurableMemoryStatus::Active,
596            "2026-04-09T00:00:00Z",
597            &[],
598            &[],
599            &[],
600            "summary",
601        );
602        let keyword_item = item(
603            "b",
604            "Deployment decision",
605            DurableMemoryStatus::Active,
606            "2026-04-09T00:00:00Z",
607            &["release", "freeze"],
608            &[],
609            &[],
610            "summary",
611        );
612
613        let title_score = score_lexical_index_item(&title_item, &query_tokens).unwrap();
614        let keyword_score = score_lexical_index_item(&keyword_item, &query_tokens).unwrap();
615        assert!(title_score > keyword_score);
616    }
617
618    #[test]
619    fn active_items_outrank_stale_items() {
620        let query_tokens = vec!["release".to_string()];
621        let active = item(
622            "a",
623            "Release freeze decision",
624            DurableMemoryStatus::Active,
625            "2026-04-09T00:00:00Z",
626            &[],
627            &[],
628            &[],
629            "summary",
630        );
631        let stale = item(
632            "b",
633            "Release freeze decision",
634            DurableMemoryStatus::Stale,
635            "2026-04-10T00:00:00Z",
636            &[],
637            &[],
638            &[],
639            "summary",
640        );
641
642        let active_score = score_lexical_index_item(&active, &query_tokens).unwrap();
643        let stale_score = score_lexical_index_item(&stale, &query_tokens).unwrap();
644        assert!(active_score > stale_score);
645    }
646
647    #[test]
648    fn contradicted_and_archived_items_are_filtered_out() {
649        let query_tokens = vec!["release".to_string()];
650        let contradicted = item(
651            "a",
652            "Release freeze decision",
653            DurableMemoryStatus::Contradicted,
654            "2026-04-09T00:00:00Z",
655            &[],
656            &[],
657            &[],
658            "summary",
659        );
660        let archived = item(
661            "b",
662            "Release freeze decision",
663            DurableMemoryStatus::Archived,
664            "2026-04-09T00:00:00Z",
665            &[],
666            &[],
667            &[],
668            "summary",
669        );
670
671        assert!(score_lexical_index_item(&contradicted, &query_tokens).is_none());
672        assert!(score_lexical_index_item(&archived, &query_tokens).is_none());
673    }
674
675    #[test]
676    fn parse_reranked_ids_accepts_fenced_json_and_filters_unknown_ids() {
677        let candidates = vec![
678            MemoryRecallCandidate {
679                id: "mem-a".to_string(),
680                title: "A".to_string(),
681                score: 10.0,
682                scope: MemoryScope::Project,
683                project_key: Some("proj-1".to_string()),
684                status: DurableMemoryStatus::Active,
685                updated_at: "2026-04-09T00:00:00Z".to_string(),
686                summary: "summary a".to_string(),
687            },
688            MemoryRecallCandidate {
689                id: "mem-b".to_string(),
690                title: "B".to_string(),
691                score: 9.0,
692                scope: MemoryScope::Project,
693                project_key: Some("proj-1".to_string()),
694                status: DurableMemoryStatus::Active,
695                updated_at: "2026-04-09T00:00:00Z".to_string(),
696                summary: "summary b".to_string(),
697            },
698        ];
699
700        let parsed = parse_reranked_ids(
701            "```json\n{\"ids\":[\"mem-b\",\"unknown\",\"mem-a\",\"mem-b\"]}\n```",
702            &candidates,
703        )
704        .expect("reranked ids should parse");
705
706        assert_eq!(parsed, vec!["mem-b".to_string(), "mem-a".to_string()]);
707    }
708
709    #[test]
710    fn reorder_candidates_by_ids_appends_remaining_lexical_candidates() {
711        let lexical = vec![
712            MemoryRecallCandidate {
713                id: "mem-a".to_string(),
714                title: "A".to_string(),
715                score: 10.0,
716                scope: MemoryScope::Project,
717                project_key: Some("proj-1".to_string()),
718                status: DurableMemoryStatus::Active,
719                updated_at: "2026-04-09T00:00:00Z".to_string(),
720                summary: "summary a".to_string(),
721            },
722            MemoryRecallCandidate {
723                id: "mem-b".to_string(),
724                title: "B".to_string(),
725                score: 9.0,
726                scope: MemoryScope::Project,
727                project_key: Some("proj-1".to_string()),
728                status: DurableMemoryStatus::Active,
729                updated_at: "2026-04-09T00:00:00Z".to_string(),
730                summary: "summary b".to_string(),
731            },
732            MemoryRecallCandidate {
733                id: "mem-c".to_string(),
734                title: "C".to_string(),
735                score: 8.0,
736                scope: MemoryScope::Project,
737                project_key: Some("proj-1".to_string()),
738                status: DurableMemoryStatus::Active,
739                updated_at: "2026-04-09T00:00:00Z".to_string(),
740                summary: "summary c".to_string(),
741            },
742        ];
743
744        let reordered =
745            reorder_candidates_by_ids(&lexical, &["mem-c".to_string(), "mem-a".to_string()], 3);
746
747        assert_eq!(reordered[0].id, "mem-c");
748        assert_eq!(reordered[1].id, "mem-a");
749        assert_eq!(reordered[2].id, "mem-b");
750    }
751
752    #[tokio::test]
753    async fn project_scope_shortlist_excludes_global_when_project_hits_exist() {
754        let dir = tempdir().unwrap();
755        let store = MemoryStore::new(dir.path());
756
757        store
758            .write_memory(
759                MemoryScope::Project,
760                Some("proj-1"),
761                DurableMemoryType::Project,
762                "Release freeze decision",
763                "Project-specific release freeze note.",
764                &["release".to_string()],
765                Some("session-1"),
766                "main-model",
767                false,
768            )
769            .await
770            .unwrap();
771        store
772            .write_memory(
773                MemoryScope::Global,
774                None,
775                DurableMemoryType::Reference,
776                "Global release guidance",
777                "Global note that should not be used when project hits exist.",
778                &["release".to_string()],
779                Some("session-1"),
780                "main-model",
781                false,
782            )
783            .await
784            .unwrap();
785
786        let candidates = shortlist_relevant_memories(
787            &store,
788            Some("proj-1"),
789            "release freeze",
790            &MemoryRecallOptions::default(),
791        )
792        .await
793        .unwrap();
794
795        assert!(!candidates.is_empty());
796        assert!(candidates
797            .iter()
798            .all(|candidate| candidate.scope == MemoryScope::Project));
799    }
800
801    #[tokio::test]
802    async fn global_fallback_triggers_only_when_project_hits_are_absent() {
803        let dir = tempdir().unwrap();
804        let store = MemoryStore::new(dir.path());
805
806        store
807            .write_memory(
808                MemoryScope::Global,
809                None,
810                DurableMemoryType::Reference,
811                "Global release guidance",
812                "Fallback note for release work.",
813                &["release".to_string()],
814                Some("session-1"),
815                "main-model",
816                false,
817            )
818            .await
819            .unwrap();
820
821        let candidates = shortlist_relevant_memories(
822            &store,
823            Some("proj-missing"),
824            "release guidance",
825            &MemoryRecallOptions::default(),
826        )
827        .await
828        .unwrap();
829
830        assert!(!candidates.is_empty());
831        assert!(candidates
832            .iter()
833            .all(|candidate| candidate.scope == MemoryScope::Global));
834    }
835
836    #[tokio::test]
837    async fn model_rerank_reorders_lexical_shortlist_when_enabled() {
838        let dir = tempdir().unwrap();
839        let store = MemoryStore::new(dir.path());
840
841        let lexical_first = store
842            .write_memory(
843                MemoryScope::Project,
844                Some("proj-1"),
845                DurableMemoryType::Project,
846                "Release freeze checklist",
847                "Generic release freeze checklist for shipping work.",
848                &["release".to_string(), "freeze".to_string()],
849                Some("session-1"),
850                "main-model",
851                false,
852            )
853            .await
854            .unwrap();
855        let reranked_first = store
856            .write_memory(
857                MemoryScope::Project,
858                Some("proj-1"),
859                DurableMemoryType::Project,
860                "Mobile launch blocker",
861                "This durable note captures the release freeze decision for the mobile app and should be preferred for mobile freeze requests.",
862                &["mobile".to_string(), "launch".to_string()],
863                Some("session-1"),
864                "main-model",
865                false,
866            )
867            .await
868            .unwrap();
869
870        let provider = StaticResponseProvider::new(format!(
871            "{{\"ids\":[\"{}\",\"{}\"]}}",
872            reranked_first.frontmatter.id, lexical_first.frontmatter.id
873        ));
874        let requested_models = provider.requested_models.clone();
875        let selection = select_relevant_memories(
876            &store,
877            Some("proj-1"),
878            "release freeze for mobile",
879            &MemoryRecallOptions {
880                shortlist_limit: 2,
881                include_global_fallback: false,
882                max_candidates_per_scope: 12,
883            },
884            Some(&MemoryRecallRerankContext {
885                llm: Arc::new(provider),
886                model: "rerank-fast-model".to_string(),
887                session_id: Some("session-1".to_string()),
888            }),
889        )
890        .await
891        .unwrap();
892
893        assert_eq!(selection.strategy, MemoryRecallStrategy::Reranked);
894        assert_eq!(selection.candidates.len(), 2);
895        assert_eq!(selection.candidates[0].id, reranked_first.frontmatter.id);
896        assert_eq!(selection.candidates[1].id, lexical_first.frontmatter.id);
897        assert_eq!(
898            requested_models.lock().expect("lock poisoned").as_slice(),
899            ["rerank-fast-model"]
900        );
901    }
902
903    #[tokio::test]
904    async fn invalid_model_rerank_response_falls_back_to_lexical_order() {
905        let dir = tempdir().unwrap();
906        let store = MemoryStore::new(dir.path());
907
908        let lexical_first = store
909            .write_memory(
910                MemoryScope::Project,
911                Some("proj-1"),
912                DurableMemoryType::Project,
913                "Release freeze checklist",
914                "Generic release freeze checklist for shipping work.",
915                &["release".to_string(), "freeze".to_string()],
916                Some("session-1"),
917                "main-model",
918                false,
919            )
920            .await
921            .unwrap();
922        let lexical_second = store
923            .write_memory(
924                MemoryScope::Project,
925                Some("proj-1"),
926                DurableMemoryType::Project,
927                "Mobile launch blocker",
928                "This durable note captures the release freeze decision for the mobile app.",
929                &["mobile".to_string(), "launch".to_string()],
930                Some("session-1"),
931                "main-model",
932                false,
933            )
934            .await
935            .unwrap();
936
937        let selection = select_relevant_memories(
938            &store,
939            Some("proj-1"),
940            "release freeze for mobile",
941            &MemoryRecallOptions {
942                shortlist_limit: 2,
943                include_global_fallback: false,
944                max_candidates_per_scope: 12,
945            },
946            Some(&MemoryRecallRerankContext {
947                llm: Arc::new(StaticResponseProvider::new("not valid json")),
948                model: "rerank-fast-model".to_string(),
949                session_id: Some("session-1".to_string()),
950            }),
951        )
952        .await
953        .unwrap();
954
955        assert_eq!(selection.strategy, MemoryRecallStrategy::RerankFallback);
956        assert_eq!(selection.candidates.len(), 2);
957        assert_eq!(selection.candidates[0].id, lexical_first.frontmatter.id);
958        assert_eq!(selection.candidates[1].id, lexical_second.frontmatter.id);
959    }
960
961    /// Provider that captures `max_output_tokens` and `reasoning_effort` from `chat_stream_with_options`.
962    #[derive(Default)]
963    struct RequestOptionsCaptureProvider {
964        captured_max_tokens: Mutex<Vec<Option<u32>>>,
965        captured_reasoning: Mutex<Vec<Option<ReasoningEffort>>>,
966    }
967
968    #[async_trait]
969    impl LLMProvider for RequestOptionsCaptureProvider {
970        async fn chat_stream(
971            &self,
972            _messages: &[Message],
973            _tools: &[bamboo_agent_core::ToolSchema],
974            _max_output_tokens: Option<u32>,
975            _model: &str,
976        ) -> Result<LLMStream, LLMError> {
977            Ok(Box::pin(stream::iter(vec![
978                Ok(LLMChunk::Token("{\"ids\":[]}".to_string())),
979                Ok(LLMChunk::Done),
980            ])))
981        }
982
983        async fn chat_stream_with_options(
984            &self,
985            messages: &[Message],
986            tools: &[bamboo_agent_core::ToolSchema],
987            max_output_tokens: Option<u32>,
988            model: &str,
989            options: Option<&LLMRequestOptions>,
990        ) -> Result<LLMStream, LLMError> {
991            self.captured_max_tokens
992                .lock()
993                .expect("lock should not be poisoned")
994                .push(max_output_tokens);
995            self.captured_reasoning
996                .lock()
997                .expect("lock should not be poisoned")
998                .push(options.and_then(|o| o.reasoning_effort));
999            self.chat_stream(messages, tools, max_output_tokens, model)
1000                .await
1001        }
1002    }
1003
1004    #[tokio::test]
1005    async fn rerank_sufficient_max_tokens_for_high_reasoning() {
1006        let provider = Arc::new(RequestOptionsCaptureProvider::default());
1007        let candidates = vec![MemoryRecallCandidate {
1008            id: "mem-1".to_string(),
1009            score: 0.9,
1010            title: "Test memory".to_string(),
1011            scope: MemoryScope::Project,
1012            project_key: Some("proj-1".to_string()),
1013            status: DurableMemoryStatus::Active,
1014            updated_at: "2026-05-08T00:00:00Z".to_string(),
1015            summary: "A test durable memory entry".to_string(),
1016        }];
1017        let context = MemoryRecallRerankContext {
1018            llm: provider.clone(),
1019            model: "deepseek-v4-pro".to_string(),
1020            session_id: Some("test-session".to_string()),
1021        };
1022
1023        let _ = rerank_candidate_ids("test query", &candidates, 5, &context).await;
1024
1025        let captured_reasoning = provider
1026            .captured_reasoning
1027            .lock()
1028            .expect("lock should not be poisoned");
1029        let captured_max_tokens = provider
1030            .captured_max_tokens
1031            .lock()
1032            .expect("lock should not be poisoned");
1033        assert_eq!(
1034            captured_reasoning.as_slice(),
1035            [Some(ReasoningEffort::High)],
1036            "rerank should request High reasoning"
1037        );
1038        let max_tokens = captured_max_tokens[0].expect("max_output_tokens should be set");
1039        assert!(
1040            max_tokens > 4096,
1041            "max_output_tokens ({}) must exceed thinking budget (4096) to avoid truncation",
1042            max_tokens
1043        );
1044    }
1045}