1use std::cmp::Ordering;
2use std::collections::HashSet;
3use std::io;
4use std::sync::Arc;
5
6use futures::StreamExt;
7use serde::Deserialize;
8
9use bamboo_agent_core::Message;
10use bamboo_domain::ReasoningEffort;
11use bamboo_llm::{LLMChunk, LLMProvider, LLMRequestOptions};
12
13use super::{
14 extract_keywords, parse_rfc3339, DurableMemoryStatus, LexicalIndexItem, MemoryScope,
15 MemoryStore, TemporalGranularity,
16};
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct MemoryRecallCandidate {
20 pub id: String,
21 pub title: String,
22 pub score: f64,
23 pub scope: MemoryScope,
24 pub project_key: Option<String>,
25 pub status: DurableMemoryStatus,
26 pub updated_at: String,
27 pub summary: String,
28 pub granularity: Option<TemporalGranularity>,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35pub struct MemoryRecallOptions {
36 pub shortlist_limit: usize,
37 pub include_global_fallback: bool,
38 pub max_candidates_per_scope: usize,
39}
40
41impl Default for MemoryRecallOptions {
42 fn default() -> Self {
43 Self {
44 shortlist_limit: 3,
45 include_global_fallback: true,
46 max_candidates_per_scope: 20,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum MemoryRecallStrategy {
53 Lexical,
54 Reranked,
55 RerankFallback,
56}
57
58impl MemoryRecallStrategy {
59 pub fn as_str(self) -> &'static str {
60 match self {
61 Self::Lexical => "lexical",
62 Self::Reranked => "reranked",
63 Self::RerankFallback => "rerank_fallback",
64 }
65 }
66}
67
68#[derive(Debug, Clone, PartialEq)]
69pub struct MemoryRecallSelection {
70 pub candidates: Vec<MemoryRecallCandidate>,
71 pub strategy: MemoryRecallStrategy,
72}
73
74#[derive(Clone)]
75pub struct MemoryRecallRerankContext {
76 pub llm: Arc<dyn LLMProvider>,
77 pub model: String,
78 pub session_id: Option<String>,
79}
80
81#[derive(Debug, Deserialize)]
82struct MemoryRecallRerankEnvelope {
83 #[serde(default)]
84 ids: Vec<String>,
85}
86
87pub async fn shortlist_relevant_memories(
88 store: &MemoryStore,
89 project_key: Option<&str>,
90 query: &str,
91 options: &MemoryRecallOptions,
92) -> io::Result<Vec<MemoryRecallCandidate>> {
93 let limit = options.shortlist_limit.max(1);
94 let mut candidates =
95 lexical_shortlist_relevant_memories(store, project_key, query, options).await?;
96 candidates.truncate(limit);
97 Ok(candidates)
98}
99
100pub async fn select_relevant_memories(
101 store: &MemoryStore,
102 project_key: Option<&str>,
103 query: &str,
104 options: &MemoryRecallOptions,
105 rerank_context: Option<&MemoryRecallRerankContext>,
106) -> io::Result<MemoryRecallSelection> {
107 let query = query.trim();
108 if query.is_empty() {
109 return Ok(MemoryRecallSelection {
110 candidates: Vec::new(),
111 strategy: MemoryRecallStrategy::Lexical,
112 });
113 }
114
115 let limit = options.shortlist_limit.max(1);
116 let mut shortlist =
117 lexical_shortlist_relevant_memories(store, project_key, query, options).await?;
118 if shortlist.is_empty() {
119 return Ok(MemoryRecallSelection {
120 candidates: shortlist,
121 strategy: MemoryRecallStrategy::Lexical,
122 });
123 }
124
125 let Some(rerank_context) = rerank_context else {
126 shortlist.truncate(limit);
127 return Ok(MemoryRecallSelection {
128 candidates: shortlist,
129 strategy: MemoryRecallStrategy::Lexical,
130 });
131 };
132
133 if shortlist.len() <= 1 {
134 shortlist.truncate(limit);
135 return Ok(MemoryRecallSelection {
136 candidates: shortlist,
137 strategy: MemoryRecallStrategy::Lexical,
138 });
139 }
140
141 match rerank_candidate_ids(query, &shortlist, limit, rerank_context).await {
142 Ok(ids) => {
143 let reranked = reorder_candidates_by_ids(&shortlist, &ids, limit);
144 if reranked.is_empty() {
145 let mut lexical = shortlist;
146 lexical.truncate(limit);
147 return Ok(MemoryRecallSelection {
148 candidates: lexical,
149 strategy: MemoryRecallStrategy::RerankFallback,
150 });
151 }
152 Ok(MemoryRecallSelection {
153 candidates: reranked,
154 strategy: MemoryRecallStrategy::Reranked,
155 })
156 }
157 Err(error) => {
158 tracing::warn!(
159 "Relevant memory rerank failed for model '{}': {}. Falling back to lexical shortlist.",
160 rerank_context.model,
161 error
162 );
163 shortlist.truncate(limit);
164 Ok(MemoryRecallSelection {
165 candidates: shortlist,
166 strategy: MemoryRecallStrategy::RerankFallback,
167 })
168 }
169 }
170}
171
172async fn lexical_shortlist_relevant_memories(
173 store: &MemoryStore,
174 project_key: Option<&str>,
175 query: &str,
176 options: &MemoryRecallOptions,
177) -> io::Result<Vec<MemoryRecallCandidate>> {
178 let query = query.trim();
179 if query.is_empty() {
180 return Ok(Vec::new());
181 }
182
183 let limit = options.shortlist_limit.max(1);
184 let per_scope_limit = options.max_candidates_per_scope.max(limit);
185
186 if let Some(project_key) = project_key.map(str::trim).filter(|value| !value.is_empty()) {
187 let mut project_hits =
188 shortlist_scope(store, MemoryScope::Project, Some(project_key), query).await?;
189 project_hits.truncate(per_scope_limit);
190 if !project_hits.is_empty() {
191 return Ok(project_hits);
192 }
193 }
194
195 if options.include_global_fallback {
196 let mut global_hits = shortlist_scope(store, MemoryScope::Global, None, query).await?;
197 global_hits.truncate(per_scope_limit);
198 return Ok(global_hits);
199 }
200
201 Ok(Vec::new())
202}
203
204async fn shortlist_scope(
205 store: &MemoryStore,
206 scope: MemoryScope,
207 project_key: Option<&str>,
208 query: &str,
209) -> io::Result<Vec<MemoryRecallCandidate>> {
210 let Some(index) = store.read_lexical_index(scope, project_key).await? else {
211 return Ok(Vec::new());
212 };
213
214 let query_tokens = extract_keywords(query, "", &[]);
215 if query_tokens.is_empty() {
216 return Ok(Vec::new());
217 }
218
219 let mut candidates = index
220 .items
221 .iter()
222 .filter_map(|item| score_lexical_index_item(item, &query_tokens).map(|score| (item, score)))
223 .map(|(item, score)| MemoryRecallCandidate {
224 id: item.id.clone(),
225 title: item.title.clone(),
226 score,
227 scope: item.scope,
228 project_key: item.project_key.clone(),
229 status: item.status,
230 updated_at: item.updated_at.clone(),
231 summary: item.summary.clone(),
232 granularity: item.granularity,
233 })
234 .collect::<Vec<_>>();
235
236 sort_recall_candidates(&mut candidates);
237 Ok(candidates)
238}
239
240fn score_lexical_index_item(item: &LexicalIndexItem, query_tokens: &[String]) -> Option<f64> {
241 match item.status {
242 DurableMemoryStatus::Superseded
243 | DurableMemoryStatus::Contradicted
244 | DurableMemoryStatus::Archived => return None,
245 DurableMemoryStatus::Active | DurableMemoryStatus::Stale => {}
246 }
247
248 let title = item.title.to_ascii_lowercase();
249 let summary = item.summary.to_ascii_lowercase();
250
251 let mut score = 0.0;
252 let mut matched_any = false;
253
254 for token in query_tokens {
255 let mut token_score = 0.0;
256 if title.contains(token) {
257 token_score += 3.0;
258 }
259 if item
260 .keywords
261 .iter()
262 .any(|value| value.eq_ignore_ascii_case(token))
263 {
264 token_score += 2.5;
265 }
266 if item
267 .tags
268 .iter()
269 .any(|value| value.eq_ignore_ascii_case(token))
270 {
271 token_score += 2.0;
272 }
273 if item
274 .entities
275 .iter()
276 .any(|value| value.eq_ignore_ascii_case(token))
277 {
278 token_score += 1.5;
279 }
280 if summary.contains(token) {
281 token_score += 1.0;
282 }
283 if token_score > 0.0 {
284 matched_any = true;
285 score += token_score;
286 }
287 }
288
289 if !matched_any {
290 return None;
291 }
292
293 score += lexical_status_adjustment(item.status);
294 Some((score / query_tokens.len() as f64 * 100.0).round() / 100.0)
295}
296
297fn lexical_status_adjustment(status: DurableMemoryStatus) -> f64 {
298 match status {
299 DurableMemoryStatus::Active => 0.0,
300 DurableMemoryStatus::Stale => -0.75,
301 DurableMemoryStatus::Superseded
302 | DurableMemoryStatus::Contradicted
303 | DurableMemoryStatus::Archived => -10.0,
304 }
305}
306
307fn sort_recall_candidates(candidates: &mut [MemoryRecallCandidate]) {
308 candidates.sort_by(|left, right| {
309 right
310 .score
311 .partial_cmp(&left.score)
312 .unwrap_or(Ordering::Equal)
313 .then_with(|| {
318 TemporalGranularity::cache_stability_rank(left.granularity).cmp(
319 &TemporalGranularity::cache_stability_rank(right.granularity),
320 )
321 })
322 .then_with(|| {
323 let left_dt = parse_rfc3339(&left.updated_at)
324 .unwrap_or(chrono::DateTime::<chrono::Utc>::MIN_UTC);
325 let right_dt = parse_rfc3339(&right.updated_at)
326 .unwrap_or(chrono::DateTime::<chrono::Utc>::MIN_UTC);
327 right_dt.cmp(&left_dt)
328 })
329 .then_with(|| left.title.cmp(&right.title))
330 });
331}
332
333fn build_rerank_prompt(query: &str, candidates: &[MemoryRecallCandidate], limit: usize) -> String {
334 let mut prompt = String::from("# Bamboo Relevant Memory Recall Rerank\n\n");
335 prompt.push_str(
336 "Select the durable memory candidates that are most relevant to the user query.\n",
337 );
338 prompt.push_str("Return JSON only in the form {\"ids\":[\"candidate-id\", ...]}.\n");
339 prompt
340 .push_str("Do not include commentary, markdown fences, explanations, or unknown ids.\n\n");
341 prompt.push_str("## User query\n");
342 prompt.push_str(query.trim());
343 prompt.push_str("\n\n## Candidate memories\n");
344
345 for (index, candidate) in candidates.iter().enumerate() {
346 prompt.push_str(&format!(
347 "{}. id={}\n title: {}\n scope: {}\n status: {}\n updated_at: {}\n lexical_score: {:.2}\n summary: {}\n",
348 index + 1,
349 candidate.id,
350 candidate.title,
351 candidate.scope.as_str(),
352 candidate.status.as_str(),
353 candidate.updated_at,
354 candidate.score,
355 candidate.summary.replace('\n', " "),
356 ));
357 }
358
359 prompt.push_str(&format!(
360 "\n## Selection rules\n- Return at most {limit} ids.\n- Use only ids from the candidate list above.\n- Prefer candidates that best answer the user query or encode active preferences/constraints relevant to it.\n- Prefer active memories over stale ones when relevance is otherwise similar.\n- Keep the ids ordered best-to-worst.\n"
361 ));
362 prompt
363}
364
365async fn rerank_candidate_ids(
366 query: &str,
367 candidates: &[MemoryRecallCandidate],
368 limit: usize,
369 context: &MemoryRecallRerankContext,
370) -> Result<Vec<String>, String> {
371 let model = context.model.trim();
372 if model.is_empty() {
373 return Err("rerank model is empty".to_string());
374 }
375
376 let messages = vec![
377 Message::system(
378 "You rerank Bamboo durable-memory recall candidates. Return strict JSON only in the form {\"ids\":[...]} using only candidate ids from the prompt.",
379 ),
380 Message::user(build_rerank_prompt(query, candidates, limit)),
381 ];
382 let options = LLMRequestOptions {
383 session_id: context.session_id.clone(),
384 reasoning_effort: Some(ReasoningEffort::High),
385 parallel_tool_calls: None,
386 responses: None,
387 request_purpose: Some("memory_rerank".to_string()),
388 cache: None,
389 };
390
391 let mut stream = context
392 .llm
393 .chat_stream_with_options(&messages, &[], Some(8192), model, Some(&options))
394 .await
395 .map_err(|error| format!("rerank provider call failed: {error}"))?;
396
397 let content = tokio::time::timeout(std::time::Duration::from_secs(30), async {
398 let mut content = String::new();
399 while let Some(chunk_result) = stream.next().await {
400 match chunk_result {
401 Ok(LLMChunk::Token(text)) => content.push_str(&text),
402 Ok(LLMChunk::Done) => break,
403 Ok(_) => {}
404 Err(error) => {
405 if !content.trim().is_empty() {
406 break;
407 }
408 return Err(format!("rerank stream failed: {error}"));
409 }
410 }
411 }
412 Ok(content)
413 })
414 .await
415 .unwrap_or_else(|_| Err("rerank timed out after 30s".to_string()))?;
416
417 parse_reranked_ids(&content, candidates)
418 .ok_or_else(|| format!("failed to parse rerank response: {}", content.trim()))
419}
420
421fn reorder_candidates_by_ids(
422 lexical_candidates: &[MemoryRecallCandidate],
423 preferred_ids: &[String],
424 limit: usize,
425) -> Vec<MemoryRecallCandidate> {
426 if lexical_candidates.is_empty() || limit == 0 {
427 return Vec::new();
428 }
429
430 let allowed = lexical_candidates
431 .iter()
432 .map(|candidate| candidate.id.as_str())
433 .collect::<HashSet<_>>();
434 let mut seen = HashSet::new();
435 let mut ordered = Vec::new();
436
437 for id in preferred_ids {
438 let trimmed = id.trim();
439 if trimmed.is_empty() || !allowed.contains(trimmed) || !seen.insert(trimmed.to_string()) {
440 continue;
441 }
442 if let Some(candidate) = lexical_candidates
443 .iter()
444 .find(|candidate| candidate.id == trimmed)
445 .cloned()
446 {
447 ordered.push(candidate);
448 if ordered.len() >= limit {
449 return ordered;
450 }
451 }
452 }
453
454 for candidate in lexical_candidates {
455 if seen.insert(candidate.id.clone()) {
456 ordered.push(candidate.clone());
457 if ordered.len() >= limit {
458 break;
459 }
460 }
461 }
462
463 ordered
464}
465
466fn parse_reranked_ids(raw: &str, candidates: &[MemoryRecallCandidate]) -> Option<Vec<String>> {
467 let stripped = strip_markdown_fence(raw);
468 let fragment = extract_json_fragment(&stripped).unwrap_or(stripped.trim());
469 let ids = serde_json::from_str::<MemoryRecallRerankEnvelope>(fragment)
470 .map(|value| value.ids)
471 .or_else(|_| serde_json::from_str::<Vec<String>>(fragment))
472 .ok()?;
473
474 let allowed = candidates
475 .iter()
476 .map(|candidate| candidate.id.as_str())
477 .collect::<HashSet<_>>();
478 let mut seen = HashSet::new();
479 let mut out = Vec::new();
480
481 for id in ids {
482 let trimmed = id.trim();
483 if trimmed.is_empty() || !allowed.contains(trimmed) || !seen.insert(trimmed.to_string()) {
484 continue;
485 }
486 out.push(trimmed.to_string());
487 }
488
489 (!out.is_empty()).then_some(out)
490}
491
492fn strip_markdown_fence(raw: &str) -> String {
493 let trimmed = raw.trim();
494 for fence in ["````", "```"] {
495 if let Some(after_fence) = trimmed.strip_prefix(fence) {
496 let Some(first_newline) = after_fence.find('\n') else {
497 continue;
498 };
499 let body = &after_fence[first_newline + 1..];
500 if let Some(end_idx) = body.rfind(fence) {
501 return body[..end_idx].trim().to_string();
502 }
503 }
504 }
505 trimmed.to_string()
506}
507
508fn extract_json_fragment(raw: &str) -> Option<&str> {
509 let trimmed = raw.trim();
510 if trimmed.is_empty() {
511 return None;
512 }
513
514 if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}')) {
515 if start <= end {
516 return Some(trimmed[start..=end].trim());
517 }
518 }
519
520 if let (Some(start), Some(end)) = (trimmed.find('['), trimmed.rfind(']')) {
521 if start <= end {
522 return Some(trimmed[start..=end].trim());
523 }
524 }
525
526 None
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532 use crate::memory_store::DurableMemoryType;
533 use async_trait::async_trait;
534 use bamboo_domain::ReasoningEffort;
535 use bamboo_llm::provider::LLMRequestOptions;
536 use bamboo_llm::{LLMChunk, LLMError, LLMProvider, LLMStream};
537 use futures::stream;
538 use std::sync::Mutex;
539 use tempfile::tempdir;
540
541 #[allow(clippy::too_many_arguments)]
542 fn item(
543 id: &str,
544 title: &str,
545 status: DurableMemoryStatus,
546 updated_at: &str,
547 keywords: &[&str],
548 tags: &[&str],
549 entities: &[&str],
550 summary: &str,
551 ) -> LexicalIndexItem {
552 LexicalIndexItem {
553 id: id.to_string(),
554 title: title.to_string(),
555 scope: MemoryScope::Project,
556 project_key: Some("proj-1".to_string()),
557 r#type: DurableMemoryType::Project,
558 status,
559 tags: tags.iter().map(|v| v.to_string()).collect(),
560 keywords: keywords.iter().map(|v| v.to_string()).collect(),
561 entities: entities.iter().map(|v| v.to_string()).collect(),
562 updated_at: updated_at.to_string(),
563 created_at: updated_at.to_string(),
564 summary: summary.to_string(),
565 granularity: None,
566 }
567 }
568
569 #[derive(Clone)]
570 struct StaticResponseProvider {
571 response: String,
572 requested_models: Arc<Mutex<Vec<String>>>,
573 }
574
575 impl StaticResponseProvider {
576 fn new(response: impl Into<String>) -> Self {
577 Self {
578 response: response.into(),
579 requested_models: Arc::new(Mutex::new(Vec::new())),
580 }
581 }
582 }
583
584 #[async_trait]
585 impl LLMProvider for StaticResponseProvider {
586 async fn chat_stream(
587 &self,
588 _messages: &[Message],
589 _tools: &[bamboo_agent_core::ToolSchema],
590 _max_output_tokens: Option<u32>,
591 model: &str,
592 ) -> Result<LLMStream, LLMError> {
593 self.requested_models
594 .lock()
595 .expect("lock poisoned")
596 .push(model.to_string());
597 Ok(Box::pin(stream::iter(vec![
598 Ok(LLMChunk::Token(self.response.clone())),
599 Ok(LLMChunk::Done),
600 ])))
601 }
602 }
603
604 fn candidate(
605 id: &str,
606 score: f64,
607 granularity: Option<TemporalGranularity>,
608 ) -> MemoryRecallCandidate {
609 MemoryRecallCandidate {
610 id: id.to_string(),
611 title: id.to_string(),
612 score,
613 scope: MemoryScope::Project,
614 project_key: Some("proj-1".to_string()),
615 status: DurableMemoryStatus::Active,
616 updated_at: "2026-04-09T00:00:00Z".to_string(),
618 summary: "summary".to_string(),
619 granularity,
620 }
621 }
622
623 #[test]
624 fn equal_score_candidates_sort_coarse_granularity_first_for_cache_stability() {
625 let mut candidates = vec![
628 candidate("day", 5.0, Some(TemporalGranularity::Day)),
629 candidate("year", 5.0, Some(TemporalGranularity::Year)),
630 candidate("none", 5.0, None),
631 candidate("month", 5.0, Some(TemporalGranularity::Month)),
632 ];
633 sort_recall_candidates(&mut candidates);
634 let order: Vec<&str> = candidates.iter().map(|c| c.id.as_str()).collect();
635 assert_eq!(order, vec!["none", "year", "month", "day"]);
637 }
638
639 #[test]
640 fn higher_score_still_wins_over_cache_stable_granularity() {
641 let mut candidates = vec![
644 candidate("year-low", 1.0, Some(TemporalGranularity::Year)),
645 candidate("day-high", 9.0, Some(TemporalGranularity::Day)),
646 ];
647 sort_recall_candidates(&mut candidates);
648 assert_eq!(candidates[0].id, "day-high");
649 }
650
651 #[test]
652 fn title_matches_outrank_keyword_only_matches() {
653 let query_tokens = vec!["release".to_string(), "freeze".to_string()];
654 let title_item = item(
655 "a",
656 "Release freeze decision",
657 DurableMemoryStatus::Active,
658 "2026-04-09T00:00:00Z",
659 &[],
660 &[],
661 &[],
662 "summary",
663 );
664 let keyword_item = item(
665 "b",
666 "Deployment decision",
667 DurableMemoryStatus::Active,
668 "2026-04-09T00:00:00Z",
669 &["release", "freeze"],
670 &[],
671 &[],
672 "summary",
673 );
674
675 let title_score = score_lexical_index_item(&title_item, &query_tokens).unwrap();
676 let keyword_score = score_lexical_index_item(&keyword_item, &query_tokens).unwrap();
677 assert!(title_score > keyword_score);
678 }
679
680 #[test]
681 fn active_items_outrank_stale_items() {
682 let query_tokens = vec!["release".to_string()];
683 let active = item(
684 "a",
685 "Release freeze decision",
686 DurableMemoryStatus::Active,
687 "2026-04-09T00:00:00Z",
688 &[],
689 &[],
690 &[],
691 "summary",
692 );
693 let stale = item(
694 "b",
695 "Release freeze decision",
696 DurableMemoryStatus::Stale,
697 "2026-04-10T00:00:00Z",
698 &[],
699 &[],
700 &[],
701 "summary",
702 );
703
704 let active_score = score_lexical_index_item(&active, &query_tokens).unwrap();
705 let stale_score = score_lexical_index_item(&stale, &query_tokens).unwrap();
706 assert!(active_score > stale_score);
707 }
708
709 #[test]
710 fn contradicted_and_archived_items_are_filtered_out() {
711 let query_tokens = vec!["release".to_string()];
712 let contradicted = item(
713 "a",
714 "Release freeze decision",
715 DurableMemoryStatus::Contradicted,
716 "2026-04-09T00:00:00Z",
717 &[],
718 &[],
719 &[],
720 "summary",
721 );
722 let archived = item(
723 "b",
724 "Release freeze decision",
725 DurableMemoryStatus::Archived,
726 "2026-04-09T00:00:00Z",
727 &[],
728 &[],
729 &[],
730 "summary",
731 );
732
733 assert!(score_lexical_index_item(&contradicted, &query_tokens).is_none());
734 assert!(score_lexical_index_item(&archived, &query_tokens).is_none());
735 }
736
737 #[test]
738 fn parse_reranked_ids_accepts_fenced_json_and_filters_unknown_ids() {
739 let candidates = vec![
740 MemoryRecallCandidate {
741 id: "mem-a".to_string(),
742 title: "A".to_string(),
743 score: 10.0,
744 scope: MemoryScope::Project,
745 project_key: Some("proj-1".to_string()),
746 status: DurableMemoryStatus::Active,
747 updated_at: "2026-04-09T00:00:00Z".to_string(),
748 summary: "summary a".to_string(),
749 granularity: None,
750 },
751 MemoryRecallCandidate {
752 id: "mem-b".to_string(),
753 title: "B".to_string(),
754 score: 9.0,
755 scope: MemoryScope::Project,
756 project_key: Some("proj-1".to_string()),
757 status: DurableMemoryStatus::Active,
758 updated_at: "2026-04-09T00:00:00Z".to_string(),
759 summary: "summary b".to_string(),
760 granularity: None,
761 },
762 ];
763
764 let parsed = parse_reranked_ids(
765 "```json\n{\"ids\":[\"mem-b\",\"unknown\",\"mem-a\",\"mem-b\"]}\n```",
766 &candidates,
767 )
768 .expect("reranked ids should parse");
769
770 assert_eq!(parsed, vec!["mem-b".to_string(), "mem-a".to_string()]);
771 }
772
773 #[test]
774 fn reorder_candidates_by_ids_appends_remaining_lexical_candidates() {
775 let lexical = vec![
776 MemoryRecallCandidate {
777 id: "mem-a".to_string(),
778 title: "A".to_string(),
779 score: 10.0,
780 scope: MemoryScope::Project,
781 project_key: Some("proj-1".to_string()),
782 status: DurableMemoryStatus::Active,
783 updated_at: "2026-04-09T00:00:00Z".to_string(),
784 summary: "summary a".to_string(),
785 granularity: None,
786 },
787 MemoryRecallCandidate {
788 id: "mem-b".to_string(),
789 title: "B".to_string(),
790 score: 9.0,
791 scope: MemoryScope::Project,
792 project_key: Some("proj-1".to_string()),
793 status: DurableMemoryStatus::Active,
794 updated_at: "2026-04-09T00:00:00Z".to_string(),
795 summary: "summary b".to_string(),
796 granularity: None,
797 },
798 MemoryRecallCandidate {
799 id: "mem-c".to_string(),
800 title: "C".to_string(),
801 score: 8.0,
802 scope: MemoryScope::Project,
803 project_key: Some("proj-1".to_string()),
804 status: DurableMemoryStatus::Active,
805 updated_at: "2026-04-09T00:00:00Z".to_string(),
806 summary: "summary c".to_string(),
807 granularity: None,
808 },
809 ];
810
811 let reordered =
812 reorder_candidates_by_ids(&lexical, &["mem-c".to_string(), "mem-a".to_string()], 3);
813
814 assert_eq!(reordered[0].id, "mem-c");
815 assert_eq!(reordered[1].id, "mem-a");
816 assert_eq!(reordered[2].id, "mem-b");
817 }
818
819 #[tokio::test]
820 async fn project_scope_shortlist_excludes_global_when_project_hits_exist() {
821 let dir = tempdir().unwrap();
822 let store = MemoryStore::new(dir.path());
823
824 store
825 .write_memory(
826 MemoryScope::Project,
827 Some("proj-1"),
828 DurableMemoryType::Project,
829 "Release freeze decision",
830 "Project-specific release freeze note.",
831 &["release".to_string()],
832 Some("session-1"),
833 "main-model",
834 false,
835 None,
836 )
837 .await
838 .unwrap();
839 store
840 .write_memory(
841 MemoryScope::Global,
842 None,
843 DurableMemoryType::Reference,
844 "Global release guidance",
845 "Global note that should not be used when project hits exist.",
846 &["release".to_string()],
847 Some("session-1"),
848 "main-model",
849 false,
850 None,
851 )
852 .await
853 .unwrap();
854
855 let candidates = shortlist_relevant_memories(
856 &store,
857 Some("proj-1"),
858 "release freeze",
859 &MemoryRecallOptions::default(),
860 )
861 .await
862 .unwrap();
863
864 assert!(!candidates.is_empty());
865 assert!(candidates
866 .iter()
867 .all(|candidate| candidate.scope == MemoryScope::Project));
868 }
869
870 #[tokio::test]
871 async fn global_fallback_triggers_only_when_project_hits_are_absent() {
872 let dir = tempdir().unwrap();
873 let store = MemoryStore::new(dir.path());
874
875 store
876 .write_memory(
877 MemoryScope::Global,
878 None,
879 DurableMemoryType::Reference,
880 "Global release guidance",
881 "Fallback note for release work.",
882 &["release".to_string()],
883 Some("session-1"),
884 "main-model",
885 false,
886 None,
887 )
888 .await
889 .unwrap();
890
891 let candidates = shortlist_relevant_memories(
892 &store,
893 Some("proj-missing"),
894 "release guidance",
895 &MemoryRecallOptions::default(),
896 )
897 .await
898 .unwrap();
899
900 assert!(!candidates.is_empty());
901 assert!(candidates
902 .iter()
903 .all(|candidate| candidate.scope == MemoryScope::Global));
904 }
905
906 #[tokio::test]
907 async fn model_rerank_reorders_lexical_shortlist_when_enabled() {
908 let dir = tempdir().unwrap();
909 let store = MemoryStore::new(dir.path());
910
911 let lexical_first = store
912 .write_memory(
913 MemoryScope::Project,
914 Some("proj-1"),
915 DurableMemoryType::Project,
916 "Release freeze checklist",
917 "Generic release freeze checklist for shipping work.",
918 &["release".to_string(), "freeze".to_string()],
919 Some("session-1"),
920 "main-model",
921 false,
922 None,
923 )
924 .await
925 .unwrap();
926 let reranked_first = store
927 .write_memory(
928 MemoryScope::Project,
929 Some("proj-1"),
930 DurableMemoryType::Project,
931 "Mobile launch blocker",
932 "This durable note captures the release freeze decision for the mobile app and should be preferred for mobile freeze requests.",
933 &["mobile".to_string(), "launch".to_string()],
934 Some("session-1"),
935 "main-model",
936 false,
937 None,
938 )
939 .await
940 .unwrap();
941
942 let provider = StaticResponseProvider::new(format!(
943 "{{\"ids\":[\"{}\",\"{}\"]}}",
944 reranked_first.frontmatter.id, lexical_first.frontmatter.id
945 ));
946 let requested_models = provider.requested_models.clone();
947 let selection = select_relevant_memories(
948 &store,
949 Some("proj-1"),
950 "release freeze for mobile",
951 &MemoryRecallOptions {
952 shortlist_limit: 2,
953 include_global_fallback: false,
954 max_candidates_per_scope: 12,
955 },
956 Some(&MemoryRecallRerankContext {
957 llm: Arc::new(provider),
958 model: "rerank-fast-model".to_string(),
959 session_id: Some("session-1".to_string()),
960 }),
961 )
962 .await
963 .unwrap();
964
965 assert_eq!(selection.strategy, MemoryRecallStrategy::Reranked);
966 assert_eq!(selection.candidates.len(), 2);
967 assert_eq!(selection.candidates[0].id, reranked_first.frontmatter.id);
968 assert_eq!(selection.candidates[1].id, lexical_first.frontmatter.id);
969 assert_eq!(
970 requested_models.lock().expect("lock poisoned").as_slice(),
971 ["rerank-fast-model"]
972 );
973 }
974
975 #[tokio::test]
976 async fn invalid_model_rerank_response_falls_back_to_lexical_order() {
977 let dir = tempdir().unwrap();
978 let store = MemoryStore::new(dir.path());
979
980 let lexical_first = store
981 .write_memory(
982 MemoryScope::Project,
983 Some("proj-1"),
984 DurableMemoryType::Project,
985 "Release freeze checklist",
986 "Generic release freeze checklist for shipping work.",
987 &["release".to_string(), "freeze".to_string()],
988 Some("session-1"),
989 "main-model",
990 false,
991 None,
992 )
993 .await
994 .unwrap();
995 let lexical_second = store
996 .write_memory(
997 MemoryScope::Project,
998 Some("proj-1"),
999 DurableMemoryType::Project,
1000 "Mobile launch blocker",
1001 "This durable note captures the release freeze decision for the mobile app.",
1002 &["mobile".to_string(), "launch".to_string()],
1003 Some("session-1"),
1004 "main-model",
1005 false,
1006 None,
1007 )
1008 .await
1009 .unwrap();
1010
1011 let selection = select_relevant_memories(
1012 &store,
1013 Some("proj-1"),
1014 "release freeze for mobile",
1015 &MemoryRecallOptions {
1016 shortlist_limit: 2,
1017 include_global_fallback: false,
1018 max_candidates_per_scope: 12,
1019 },
1020 Some(&MemoryRecallRerankContext {
1021 llm: Arc::new(StaticResponseProvider::new("not valid json")),
1022 model: "rerank-fast-model".to_string(),
1023 session_id: Some("session-1".to_string()),
1024 }),
1025 )
1026 .await
1027 .unwrap();
1028
1029 assert_eq!(selection.strategy, MemoryRecallStrategy::RerankFallback);
1030 assert_eq!(selection.candidates.len(), 2);
1031 assert_eq!(selection.candidates[0].id, lexical_first.frontmatter.id);
1032 assert_eq!(selection.candidates[1].id, lexical_second.frontmatter.id);
1033 }
1034
1035 #[derive(Default)]
1037 struct RequestOptionsCaptureProvider {
1038 captured_max_tokens: Mutex<Vec<Option<u32>>>,
1039 captured_reasoning: Mutex<Vec<Option<ReasoningEffort>>>,
1040 }
1041
1042 #[async_trait]
1043 impl LLMProvider for RequestOptionsCaptureProvider {
1044 async fn chat_stream(
1045 &self,
1046 _messages: &[Message],
1047 _tools: &[bamboo_agent_core::ToolSchema],
1048 _max_output_tokens: Option<u32>,
1049 _model: &str,
1050 ) -> Result<LLMStream, LLMError> {
1051 Ok(Box::pin(stream::iter(vec![
1052 Ok(LLMChunk::Token("{\"ids\":[]}".to_string())),
1053 Ok(LLMChunk::Done),
1054 ])))
1055 }
1056
1057 async fn chat_stream_with_options(
1058 &self,
1059 messages: &[Message],
1060 tools: &[bamboo_agent_core::ToolSchema],
1061 max_output_tokens: Option<u32>,
1062 model: &str,
1063 options: Option<&LLMRequestOptions>,
1064 ) -> Result<LLMStream, LLMError> {
1065 self.captured_max_tokens
1066 .lock()
1067 .expect("lock should not be poisoned")
1068 .push(max_output_tokens);
1069 self.captured_reasoning
1070 .lock()
1071 .expect("lock should not be poisoned")
1072 .push(options.and_then(|o| o.reasoning_effort));
1073 self.chat_stream(messages, tools, max_output_tokens, model)
1074 .await
1075 }
1076 }
1077
1078 #[tokio::test]
1079 async fn rerank_sufficient_max_tokens_for_high_reasoning() {
1080 let provider = Arc::new(RequestOptionsCaptureProvider::default());
1081 let candidates = vec![MemoryRecallCandidate {
1082 id: "mem-1".to_string(),
1083 score: 0.9,
1084 title: "Test memory".to_string(),
1085 scope: MemoryScope::Project,
1086 project_key: Some("proj-1".to_string()),
1087 status: DurableMemoryStatus::Active,
1088 updated_at: "2026-05-08T00:00:00Z".to_string(),
1089 summary: "A test durable memory entry".to_string(),
1090 granularity: None,
1091 }];
1092 let context = MemoryRecallRerankContext {
1093 llm: provider.clone(),
1094 model: "deepseek-v4-pro".to_string(),
1095 session_id: Some("test-session".to_string()),
1096 };
1097
1098 let _ = rerank_candidate_ids("test query", &candidates, 5, &context).await;
1099
1100 let captured_reasoning = provider
1101 .captured_reasoning
1102 .lock()
1103 .expect("lock should not be poisoned");
1104 let captured_max_tokens = provider
1105 .captured_max_tokens
1106 .lock()
1107 .expect("lock should not be poisoned");
1108 assert_eq!(
1109 captured_reasoning.as_slice(),
1110 [Some(ReasoningEffort::High)],
1111 "rerank should request High reasoning"
1112 );
1113 let max_tokens = captured_max_tokens[0].expect("max_output_tokens should be set");
1114 assert!(
1115 max_tokens > 4096,
1116 "max_output_tokens ({}) must exceed thinking budget (4096) to avoid truncation",
1117 max_tokens
1118 );
1119 }
1120}