1use std::cmp::Ordering;
2use std::collections::HashSet;
3use std::io;
4use std::sync::Arc;
5
6use futures::StreamExt;
7use serde::Deserialize;
8
9use bamboo_agent_core::Message;
10use bamboo_domain::ReasoningEffort;
11use bamboo_infrastructure::{LLMChunk, LLMProvider, LLMRequestOptions};
12
13use super::{
14 extract_keywords, parse_rfc3339, DurableMemoryStatus, LexicalIndexItem, MemoryScope,
15 MemoryStore,
16};
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct MemoryRecallCandidate {
20 pub id: String,
21 pub title: String,
22 pub score: f64,
23 pub scope: MemoryScope,
24 pub project_key: Option<String>,
25 pub status: DurableMemoryStatus,
26 pub updated_at: String,
27 pub summary: String,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct MemoryRecallOptions {
32 pub shortlist_limit: usize,
33 pub include_global_fallback: bool,
34 pub max_candidates_per_scope: usize,
35}
36
37impl Default for MemoryRecallOptions {
38 fn default() -> Self {
39 Self {
40 shortlist_limit: 3,
41 include_global_fallback: true,
42 max_candidates_per_scope: 20,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum MemoryRecallStrategy {
49 Lexical,
50 Reranked,
51 RerankFallback,
52}
53
54impl MemoryRecallStrategy {
55 pub fn as_str(self) -> &'static str {
56 match self {
57 Self::Lexical => "lexical",
58 Self::Reranked => "reranked",
59 Self::RerankFallback => "rerank_fallback",
60 }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq)]
65pub struct MemoryRecallSelection {
66 pub candidates: Vec<MemoryRecallCandidate>,
67 pub strategy: MemoryRecallStrategy,
68}
69
70#[derive(Clone)]
71pub struct MemoryRecallRerankContext {
72 pub llm: Arc<dyn LLMProvider>,
73 pub model: String,
74 pub session_id: Option<String>,
75}
76
77#[derive(Debug, Deserialize)]
78struct MemoryRecallRerankEnvelope {
79 #[serde(default)]
80 ids: Vec<String>,
81}
82
83pub async fn shortlist_relevant_memories(
84 store: &MemoryStore,
85 project_key: Option<&str>,
86 query: &str,
87 options: &MemoryRecallOptions,
88) -> io::Result<Vec<MemoryRecallCandidate>> {
89 let limit = options.shortlist_limit.max(1);
90 let mut candidates =
91 lexical_shortlist_relevant_memories(store, project_key, query, options).await?;
92 candidates.truncate(limit);
93 Ok(candidates)
94}
95
96pub async fn select_relevant_memories(
97 store: &MemoryStore,
98 project_key: Option<&str>,
99 query: &str,
100 options: &MemoryRecallOptions,
101 rerank_context: Option<&MemoryRecallRerankContext>,
102) -> io::Result<MemoryRecallSelection> {
103 let query = query.trim();
104 if query.is_empty() {
105 return Ok(MemoryRecallSelection {
106 candidates: Vec::new(),
107 strategy: MemoryRecallStrategy::Lexical,
108 });
109 }
110
111 let limit = options.shortlist_limit.max(1);
112 let mut shortlist =
113 lexical_shortlist_relevant_memories(store, project_key, query, options).await?;
114 if shortlist.is_empty() {
115 return Ok(MemoryRecallSelection {
116 candidates: shortlist,
117 strategy: MemoryRecallStrategy::Lexical,
118 });
119 }
120
121 let Some(rerank_context) = rerank_context else {
122 shortlist.truncate(limit);
123 return Ok(MemoryRecallSelection {
124 candidates: shortlist,
125 strategy: MemoryRecallStrategy::Lexical,
126 });
127 };
128
129 if shortlist.len() <= 1 {
130 shortlist.truncate(limit);
131 return Ok(MemoryRecallSelection {
132 candidates: shortlist,
133 strategy: MemoryRecallStrategy::Lexical,
134 });
135 }
136
137 match rerank_candidate_ids(query, &shortlist, limit, rerank_context).await {
138 Ok(ids) => {
139 let reranked = reorder_candidates_by_ids(&shortlist, &ids, limit);
140 if reranked.is_empty() {
141 let mut lexical = shortlist;
142 lexical.truncate(limit);
143 return Ok(MemoryRecallSelection {
144 candidates: lexical,
145 strategy: MemoryRecallStrategy::RerankFallback,
146 });
147 }
148 Ok(MemoryRecallSelection {
149 candidates: reranked,
150 strategy: MemoryRecallStrategy::Reranked,
151 })
152 }
153 Err(error) => {
154 tracing::warn!(
155 "Relevant memory rerank failed for model '{}': {}. Falling back to lexical shortlist.",
156 rerank_context.model,
157 error
158 );
159 shortlist.truncate(limit);
160 Ok(MemoryRecallSelection {
161 candidates: shortlist,
162 strategy: MemoryRecallStrategy::RerankFallback,
163 })
164 }
165 }
166}
167
168async fn lexical_shortlist_relevant_memories(
169 store: &MemoryStore,
170 project_key: Option<&str>,
171 query: &str,
172 options: &MemoryRecallOptions,
173) -> io::Result<Vec<MemoryRecallCandidate>> {
174 let query = query.trim();
175 if query.is_empty() {
176 return Ok(Vec::new());
177 }
178
179 let limit = options.shortlist_limit.max(1);
180 let per_scope_limit = options.max_candidates_per_scope.max(limit);
181
182 if let Some(project_key) = project_key.map(str::trim).filter(|value| !value.is_empty()) {
183 let mut project_hits =
184 shortlist_scope(store, MemoryScope::Project, Some(project_key), query).await?;
185 project_hits.truncate(per_scope_limit);
186 if !project_hits.is_empty() {
187 return Ok(project_hits);
188 }
189 }
190
191 if options.include_global_fallback {
192 let mut global_hits = shortlist_scope(store, MemoryScope::Global, None, query).await?;
193 global_hits.truncate(per_scope_limit);
194 return Ok(global_hits);
195 }
196
197 Ok(Vec::new())
198}
199
200async fn shortlist_scope(
201 store: &MemoryStore,
202 scope: MemoryScope,
203 project_key: Option<&str>,
204 query: &str,
205) -> io::Result<Vec<MemoryRecallCandidate>> {
206 let Some(index) = store.read_lexical_index(scope, project_key).await? else {
207 return Ok(Vec::new());
208 };
209
210 let query_tokens = extract_keywords(query, "", &[]);
211 if query_tokens.is_empty() {
212 return Ok(Vec::new());
213 }
214
215 let mut candidates = index
216 .items
217 .iter()
218 .filter_map(|item| score_lexical_index_item(item, &query_tokens).map(|score| (item, score)))
219 .map(|(item, score)| MemoryRecallCandidate {
220 id: item.id.clone(),
221 title: item.title.clone(),
222 score,
223 scope: item.scope,
224 project_key: item.project_key.clone(),
225 status: item.status,
226 updated_at: item.updated_at.clone(),
227 summary: item.summary.clone(),
228 })
229 .collect::<Vec<_>>();
230
231 sort_recall_candidates(&mut candidates);
232 Ok(candidates)
233}
234
235fn score_lexical_index_item(item: &LexicalIndexItem, query_tokens: &[String]) -> Option<f64> {
236 match item.status {
237 DurableMemoryStatus::Superseded
238 | DurableMemoryStatus::Contradicted
239 | DurableMemoryStatus::Archived => return None,
240 DurableMemoryStatus::Active | DurableMemoryStatus::Stale => {}
241 }
242
243 let title = item.title.to_ascii_lowercase();
244 let summary = item.summary.to_ascii_lowercase();
245
246 let mut score = 0.0;
247 let mut matched_any = false;
248
249 for token in query_tokens {
250 let mut token_score = 0.0;
251 if title.contains(token) {
252 token_score += 3.0;
253 }
254 if item
255 .keywords
256 .iter()
257 .any(|value| value.eq_ignore_ascii_case(token))
258 {
259 token_score += 2.5;
260 }
261 if item
262 .tags
263 .iter()
264 .any(|value| value.eq_ignore_ascii_case(token))
265 {
266 token_score += 2.0;
267 }
268 if item
269 .entities
270 .iter()
271 .any(|value| value.eq_ignore_ascii_case(token))
272 {
273 token_score += 1.5;
274 }
275 if summary.contains(token) {
276 token_score += 1.0;
277 }
278 if token_score > 0.0 {
279 matched_any = true;
280 score += token_score;
281 }
282 }
283
284 if !matched_any {
285 return None;
286 }
287
288 score += lexical_status_adjustment(item.status);
289 Some((score / query_tokens.len() as f64 * 100.0).round() / 100.0)
290}
291
292fn lexical_status_adjustment(status: DurableMemoryStatus) -> f64 {
293 match status {
294 DurableMemoryStatus::Active => 0.0,
295 DurableMemoryStatus::Stale => -0.75,
296 DurableMemoryStatus::Superseded
297 | DurableMemoryStatus::Contradicted
298 | DurableMemoryStatus::Archived => -10.0,
299 }
300}
301
302fn sort_recall_candidates(candidates: &mut [MemoryRecallCandidate]) {
303 candidates.sort_by(|left, right| {
304 right
305 .score
306 .partial_cmp(&left.score)
307 .unwrap_or(Ordering::Equal)
308 .then_with(|| {
309 let left_dt = parse_rfc3339(&left.updated_at)
310 .unwrap_or(chrono::DateTime::<chrono::Utc>::MIN_UTC);
311 let right_dt = parse_rfc3339(&right.updated_at)
312 .unwrap_or(chrono::DateTime::<chrono::Utc>::MIN_UTC);
313 right_dt.cmp(&left_dt)
314 })
315 .then_with(|| left.title.cmp(&right.title))
316 });
317}
318
319fn build_rerank_prompt(query: &str, candidates: &[MemoryRecallCandidate], limit: usize) -> String {
320 let mut prompt = String::from("# Bamboo Relevant Memory Recall Rerank\n\n");
321 prompt.push_str(
322 "Select the durable memory candidates that are most relevant to the user query.\n",
323 );
324 prompt.push_str("Return JSON only in the form {\"ids\":[\"candidate-id\", ...]}.\n");
325 prompt
326 .push_str("Do not include commentary, markdown fences, explanations, or unknown ids.\n\n");
327 prompt.push_str("## User query\n");
328 prompt.push_str(query.trim());
329 prompt.push_str("\n\n## Candidate memories\n");
330
331 for (index, candidate) in candidates.iter().enumerate() {
332 prompt.push_str(&format!(
333 "{}. id={}\n title: {}\n scope: {}\n status: {}\n updated_at: {}\n lexical_score: {:.2}\n summary: {}\n",
334 index + 1,
335 candidate.id,
336 candidate.title,
337 candidate.scope.as_str(),
338 candidate.status.as_str(),
339 candidate.updated_at,
340 candidate.score,
341 candidate.summary.replace('\n', " "),
342 ));
343 }
344
345 prompt.push_str(&format!(
346 "\n## Selection rules\n- Return at most {limit} ids.\n- Use only ids from the candidate list above.\n- Prefer candidates that best answer the user query or encode active preferences/constraints relevant to it.\n- Prefer active memories over stale ones when relevance is otherwise similar.\n- Keep the ids ordered best-to-worst.\n"
347 ));
348 prompt
349}
350
351async fn rerank_candidate_ids(
352 query: &str,
353 candidates: &[MemoryRecallCandidate],
354 limit: usize,
355 context: &MemoryRecallRerankContext,
356) -> Result<Vec<String>, String> {
357 let model = context.model.trim();
358 if model.is_empty() {
359 return Err("rerank model is empty".to_string());
360 }
361
362 let messages = vec![
363 Message::system(
364 "You rerank Bamboo durable-memory recall candidates. Return strict JSON only in the form {\"ids\":[...]} using only candidate ids from the prompt.",
365 ),
366 Message::user(build_rerank_prompt(query, candidates, limit)),
367 ];
368 let options = LLMRequestOptions {
369 session_id: context.session_id.clone(),
370 reasoning_effort: Some(ReasoningEffort::High),
371 parallel_tool_calls: None,
372 responses: None,
373 request_purpose: Some("memory_rerank".to_string()),
374 };
375
376 let mut stream = context
377 .llm
378 .chat_stream_with_options(&messages, &[], Some(8192), model, Some(&options))
379 .await
380 .map_err(|error| format!("rerank provider call failed: {error}"))?;
381
382 let content = tokio::time::timeout(std::time::Duration::from_secs(30), async {
383 let mut content = String::new();
384 while let Some(chunk_result) = stream.next().await {
385 match chunk_result {
386 Ok(LLMChunk::Token(text)) => content.push_str(&text),
387 Ok(LLMChunk::Done) => break,
388 Ok(_) => {}
389 Err(error) => {
390 if !content.trim().is_empty() {
391 break;
392 }
393 return Err(format!("rerank stream failed: {error}"));
394 }
395 }
396 }
397 Ok(content)
398 })
399 .await
400 .unwrap_or_else(|_| Err("rerank timed out after 30s".to_string()))?;
401
402 parse_reranked_ids(&content, candidates)
403 .ok_or_else(|| format!("failed to parse rerank response: {}", content.trim()))
404}
405
406fn reorder_candidates_by_ids(
407 lexical_candidates: &[MemoryRecallCandidate],
408 preferred_ids: &[String],
409 limit: usize,
410) -> Vec<MemoryRecallCandidate> {
411 if lexical_candidates.is_empty() || limit == 0 {
412 return Vec::new();
413 }
414
415 let allowed = lexical_candidates
416 .iter()
417 .map(|candidate| candidate.id.as_str())
418 .collect::<HashSet<_>>();
419 let mut seen = HashSet::new();
420 let mut ordered = Vec::new();
421
422 for id in preferred_ids {
423 let trimmed = id.trim();
424 if trimmed.is_empty() || !allowed.contains(trimmed) || !seen.insert(trimmed.to_string()) {
425 continue;
426 }
427 if let Some(candidate) = lexical_candidates
428 .iter()
429 .find(|candidate| candidate.id == trimmed)
430 .cloned()
431 {
432 ordered.push(candidate);
433 if ordered.len() >= limit {
434 return ordered;
435 }
436 }
437 }
438
439 for candidate in lexical_candidates {
440 if seen.insert(candidate.id.clone()) {
441 ordered.push(candidate.clone());
442 if ordered.len() >= limit {
443 break;
444 }
445 }
446 }
447
448 ordered
449}
450
451fn parse_reranked_ids(raw: &str, candidates: &[MemoryRecallCandidate]) -> Option<Vec<String>> {
452 let stripped = strip_markdown_fence(raw);
453 let fragment = extract_json_fragment(&stripped).unwrap_or(stripped.trim());
454 let ids = serde_json::from_str::<MemoryRecallRerankEnvelope>(fragment)
455 .map(|value| value.ids)
456 .or_else(|_| serde_json::from_str::<Vec<String>>(fragment))
457 .ok()?;
458
459 let allowed = candidates
460 .iter()
461 .map(|candidate| candidate.id.as_str())
462 .collect::<HashSet<_>>();
463 let mut seen = HashSet::new();
464 let mut out = Vec::new();
465
466 for id in ids {
467 let trimmed = id.trim();
468 if trimmed.is_empty() || !allowed.contains(trimmed) || !seen.insert(trimmed.to_string()) {
469 continue;
470 }
471 out.push(trimmed.to_string());
472 }
473
474 (!out.is_empty()).then_some(out)
475}
476
477fn strip_markdown_fence(raw: &str) -> String {
478 let trimmed = raw.trim();
479 for fence in ["````", "```"] {
480 if let Some(after_fence) = trimmed.strip_prefix(fence) {
481 let Some(first_newline) = after_fence.find('\n') else {
482 continue;
483 };
484 let body = &after_fence[first_newline + 1..];
485 if let Some(end_idx) = body.rfind(fence) {
486 return body[..end_idx].trim().to_string();
487 }
488 }
489 }
490 trimmed.to_string()
491}
492
493fn extract_json_fragment(raw: &str) -> Option<&str> {
494 let trimmed = raw.trim();
495 if trimmed.is_empty() {
496 return None;
497 }
498
499 if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}')) {
500 if start <= end {
501 return Some(trimmed[start..=end].trim());
502 }
503 }
504
505 if let (Some(start), Some(end)) = (trimmed.find('['), trimmed.rfind(']')) {
506 if start <= end {
507 return Some(trimmed[start..=end].trim());
508 }
509 }
510
511 None
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517 use crate::memory_store::DurableMemoryType;
518 use async_trait::async_trait;
519 use bamboo_domain::ReasoningEffort;
520 use bamboo_infrastructure::llm::provider::LLMRequestOptions;
521 use bamboo_infrastructure::{LLMChunk, LLMError, LLMProvider, LLMStream};
522 use futures::stream;
523 use std::sync::Mutex;
524 use tempfile::tempdir;
525
526 fn item(
527 id: &str,
528 title: &str,
529 status: DurableMemoryStatus,
530 updated_at: &str,
531 keywords: &[&str],
532 tags: &[&str],
533 entities: &[&str],
534 summary: &str,
535 ) -> LexicalIndexItem {
536 LexicalIndexItem {
537 id: id.to_string(),
538 title: title.to_string(),
539 scope: MemoryScope::Project,
540 project_key: Some("proj-1".to_string()),
541 r#type: DurableMemoryType::Project,
542 status,
543 tags: tags.iter().map(|v| v.to_string()).collect(),
544 keywords: keywords.iter().map(|v| v.to_string()).collect(),
545 entities: entities.iter().map(|v| v.to_string()).collect(),
546 updated_at: updated_at.to_string(),
547 created_at: updated_at.to_string(),
548 summary: summary.to_string(),
549 }
550 }
551
552 #[derive(Clone)]
553 struct StaticResponseProvider {
554 response: String,
555 requested_models: Arc<Mutex<Vec<String>>>,
556 }
557
558 impl StaticResponseProvider {
559 fn new(response: impl Into<String>) -> Self {
560 Self {
561 response: response.into(),
562 requested_models: Arc::new(Mutex::new(Vec::new())),
563 }
564 }
565 }
566
567 #[async_trait]
568 impl LLMProvider for StaticResponseProvider {
569 async fn chat_stream(
570 &self,
571 _messages: &[Message],
572 _tools: &[bamboo_agent_core::ToolSchema],
573 _max_output_tokens: Option<u32>,
574 model: &str,
575 ) -> Result<LLMStream, LLMError> {
576 self.requested_models
577 .lock()
578 .expect("lock poisoned")
579 .push(model.to_string());
580 Ok(Box::pin(stream::iter(vec![
581 Ok(LLMChunk::Token(self.response.clone())),
582 Ok(LLMChunk::Done),
583 ])))
584 }
585 }
586
587 #[test]
588 fn title_matches_outrank_keyword_only_matches() {
589 let query_tokens = vec!["release".to_string(), "freeze".to_string()];
590 let title_item = item(
591 "a",
592 "Release freeze decision",
593 DurableMemoryStatus::Active,
594 "2026-04-09T00:00:00Z",
595 &[],
596 &[],
597 &[],
598 "summary",
599 );
600 let keyword_item = item(
601 "b",
602 "Deployment decision",
603 DurableMemoryStatus::Active,
604 "2026-04-09T00:00:00Z",
605 &["release", "freeze"],
606 &[],
607 &[],
608 "summary",
609 );
610
611 let title_score = score_lexical_index_item(&title_item, &query_tokens).unwrap();
612 let keyword_score = score_lexical_index_item(&keyword_item, &query_tokens).unwrap();
613 assert!(title_score > keyword_score);
614 }
615
616 #[test]
617 fn active_items_outrank_stale_items() {
618 let query_tokens = vec!["release".to_string()];
619 let active = item(
620 "a",
621 "Release freeze decision",
622 DurableMemoryStatus::Active,
623 "2026-04-09T00:00:00Z",
624 &[],
625 &[],
626 &[],
627 "summary",
628 );
629 let stale = item(
630 "b",
631 "Release freeze decision",
632 DurableMemoryStatus::Stale,
633 "2026-04-10T00:00:00Z",
634 &[],
635 &[],
636 &[],
637 "summary",
638 );
639
640 let active_score = score_lexical_index_item(&active, &query_tokens).unwrap();
641 let stale_score = score_lexical_index_item(&stale, &query_tokens).unwrap();
642 assert!(active_score > stale_score);
643 }
644
645 #[test]
646 fn contradicted_and_archived_items_are_filtered_out() {
647 let query_tokens = vec!["release".to_string()];
648 let contradicted = item(
649 "a",
650 "Release freeze decision",
651 DurableMemoryStatus::Contradicted,
652 "2026-04-09T00:00:00Z",
653 &[],
654 &[],
655 &[],
656 "summary",
657 );
658 let archived = item(
659 "b",
660 "Release freeze decision",
661 DurableMemoryStatus::Archived,
662 "2026-04-09T00:00:00Z",
663 &[],
664 &[],
665 &[],
666 "summary",
667 );
668
669 assert!(score_lexical_index_item(&contradicted, &query_tokens).is_none());
670 assert!(score_lexical_index_item(&archived, &query_tokens).is_none());
671 }
672
673 #[test]
674 fn parse_reranked_ids_accepts_fenced_json_and_filters_unknown_ids() {
675 let candidates = vec![
676 MemoryRecallCandidate {
677 id: "mem-a".to_string(),
678 title: "A".to_string(),
679 score: 10.0,
680 scope: MemoryScope::Project,
681 project_key: Some("proj-1".to_string()),
682 status: DurableMemoryStatus::Active,
683 updated_at: "2026-04-09T00:00:00Z".to_string(),
684 summary: "summary a".to_string(),
685 },
686 MemoryRecallCandidate {
687 id: "mem-b".to_string(),
688 title: "B".to_string(),
689 score: 9.0,
690 scope: MemoryScope::Project,
691 project_key: Some("proj-1".to_string()),
692 status: DurableMemoryStatus::Active,
693 updated_at: "2026-04-09T00:00:00Z".to_string(),
694 summary: "summary b".to_string(),
695 },
696 ];
697
698 let parsed = parse_reranked_ids(
699 "```json\n{\"ids\":[\"mem-b\",\"unknown\",\"mem-a\",\"mem-b\"]}\n```",
700 &candidates,
701 )
702 .expect("reranked ids should parse");
703
704 assert_eq!(parsed, vec!["mem-b".to_string(), "mem-a".to_string()]);
705 }
706
707 #[test]
708 fn reorder_candidates_by_ids_appends_remaining_lexical_candidates() {
709 let lexical = vec![
710 MemoryRecallCandidate {
711 id: "mem-a".to_string(),
712 title: "A".to_string(),
713 score: 10.0,
714 scope: MemoryScope::Project,
715 project_key: Some("proj-1".to_string()),
716 status: DurableMemoryStatus::Active,
717 updated_at: "2026-04-09T00:00:00Z".to_string(),
718 summary: "summary a".to_string(),
719 },
720 MemoryRecallCandidate {
721 id: "mem-b".to_string(),
722 title: "B".to_string(),
723 score: 9.0,
724 scope: MemoryScope::Project,
725 project_key: Some("proj-1".to_string()),
726 status: DurableMemoryStatus::Active,
727 updated_at: "2026-04-09T00:00:00Z".to_string(),
728 summary: "summary b".to_string(),
729 },
730 MemoryRecallCandidate {
731 id: "mem-c".to_string(),
732 title: "C".to_string(),
733 score: 8.0,
734 scope: MemoryScope::Project,
735 project_key: Some("proj-1".to_string()),
736 status: DurableMemoryStatus::Active,
737 updated_at: "2026-04-09T00:00:00Z".to_string(),
738 summary: "summary c".to_string(),
739 },
740 ];
741
742 let reordered =
743 reorder_candidates_by_ids(&lexical, &["mem-c".to_string(), "mem-a".to_string()], 3);
744
745 assert_eq!(reordered[0].id, "mem-c");
746 assert_eq!(reordered[1].id, "mem-a");
747 assert_eq!(reordered[2].id, "mem-b");
748 }
749
750 #[tokio::test]
751 async fn project_scope_shortlist_excludes_global_when_project_hits_exist() {
752 let dir = tempdir().unwrap();
753 let store = MemoryStore::new(dir.path());
754
755 store
756 .write_memory(
757 MemoryScope::Project,
758 Some("proj-1"),
759 DurableMemoryType::Project,
760 "Release freeze decision",
761 "Project-specific release freeze note.",
762 &["release".to_string()],
763 Some("session-1"),
764 "main-model",
765 false,
766 )
767 .await
768 .unwrap();
769 store
770 .write_memory(
771 MemoryScope::Global,
772 None,
773 DurableMemoryType::Reference,
774 "Global release guidance",
775 "Global note that should not be used when project hits exist.",
776 &["release".to_string()],
777 Some("session-1"),
778 "main-model",
779 false,
780 )
781 .await
782 .unwrap();
783
784 let candidates = shortlist_relevant_memories(
785 &store,
786 Some("proj-1"),
787 "release freeze",
788 &MemoryRecallOptions::default(),
789 )
790 .await
791 .unwrap();
792
793 assert!(!candidates.is_empty());
794 assert!(candidates
795 .iter()
796 .all(|candidate| candidate.scope == MemoryScope::Project));
797 }
798
799 #[tokio::test]
800 async fn global_fallback_triggers_only_when_project_hits_are_absent() {
801 let dir = tempdir().unwrap();
802 let store = MemoryStore::new(dir.path());
803
804 store
805 .write_memory(
806 MemoryScope::Global,
807 None,
808 DurableMemoryType::Reference,
809 "Global release guidance",
810 "Fallback note for release work.",
811 &["release".to_string()],
812 Some("session-1"),
813 "main-model",
814 false,
815 )
816 .await
817 .unwrap();
818
819 let candidates = shortlist_relevant_memories(
820 &store,
821 Some("proj-missing"),
822 "release guidance",
823 &MemoryRecallOptions::default(),
824 )
825 .await
826 .unwrap();
827
828 assert!(!candidates.is_empty());
829 assert!(candidates
830 .iter()
831 .all(|candidate| candidate.scope == MemoryScope::Global));
832 }
833
834 #[tokio::test]
835 async fn model_rerank_reorders_lexical_shortlist_when_enabled() {
836 let dir = tempdir().unwrap();
837 let store = MemoryStore::new(dir.path());
838
839 let lexical_first = store
840 .write_memory(
841 MemoryScope::Project,
842 Some("proj-1"),
843 DurableMemoryType::Project,
844 "Release freeze checklist",
845 "Generic release freeze checklist for shipping work.",
846 &["release".to_string(), "freeze".to_string()],
847 Some("session-1"),
848 "main-model",
849 false,
850 )
851 .await
852 .unwrap();
853 let reranked_first = store
854 .write_memory(
855 MemoryScope::Project,
856 Some("proj-1"),
857 DurableMemoryType::Project,
858 "Mobile launch blocker",
859 "This durable note captures the release freeze decision for the mobile app and should be preferred for mobile freeze requests.",
860 &["mobile".to_string(), "launch".to_string()],
861 Some("session-1"),
862 "main-model",
863 false,
864 )
865 .await
866 .unwrap();
867
868 let provider = StaticResponseProvider::new(format!(
869 "{{\"ids\":[\"{}\",\"{}\"]}}",
870 reranked_first.frontmatter.id, lexical_first.frontmatter.id
871 ));
872 let requested_models = provider.requested_models.clone();
873 let selection = select_relevant_memories(
874 &store,
875 Some("proj-1"),
876 "release freeze for mobile",
877 &MemoryRecallOptions {
878 shortlist_limit: 2,
879 include_global_fallback: false,
880 max_candidates_per_scope: 12,
881 },
882 Some(&MemoryRecallRerankContext {
883 llm: Arc::new(provider),
884 model: "rerank-fast-model".to_string(),
885 session_id: Some("session-1".to_string()),
886 }),
887 )
888 .await
889 .unwrap();
890
891 assert_eq!(selection.strategy, MemoryRecallStrategy::Reranked);
892 assert_eq!(selection.candidates.len(), 2);
893 assert_eq!(selection.candidates[0].id, reranked_first.frontmatter.id);
894 assert_eq!(selection.candidates[1].id, lexical_first.frontmatter.id);
895 assert_eq!(
896 requested_models.lock().expect("lock poisoned").as_slice(),
897 ["rerank-fast-model"]
898 );
899 }
900
901 #[tokio::test]
902 async fn invalid_model_rerank_response_falls_back_to_lexical_order() {
903 let dir = tempdir().unwrap();
904 let store = MemoryStore::new(dir.path());
905
906 let lexical_first = store
907 .write_memory(
908 MemoryScope::Project,
909 Some("proj-1"),
910 DurableMemoryType::Project,
911 "Release freeze checklist",
912 "Generic release freeze checklist for shipping work.",
913 &["release".to_string(), "freeze".to_string()],
914 Some("session-1"),
915 "main-model",
916 false,
917 )
918 .await
919 .unwrap();
920 let lexical_second = store
921 .write_memory(
922 MemoryScope::Project,
923 Some("proj-1"),
924 DurableMemoryType::Project,
925 "Mobile launch blocker",
926 "This durable note captures the release freeze decision for the mobile app.",
927 &["mobile".to_string(), "launch".to_string()],
928 Some("session-1"),
929 "main-model",
930 false,
931 )
932 .await
933 .unwrap();
934
935 let selection = select_relevant_memories(
936 &store,
937 Some("proj-1"),
938 "release freeze for mobile",
939 &MemoryRecallOptions {
940 shortlist_limit: 2,
941 include_global_fallback: false,
942 max_candidates_per_scope: 12,
943 },
944 Some(&MemoryRecallRerankContext {
945 llm: Arc::new(StaticResponseProvider::new("not valid json")),
946 model: "rerank-fast-model".to_string(),
947 session_id: Some("session-1".to_string()),
948 }),
949 )
950 .await
951 .unwrap();
952
953 assert_eq!(selection.strategy, MemoryRecallStrategy::RerankFallback);
954 assert_eq!(selection.candidates.len(), 2);
955 assert_eq!(selection.candidates[0].id, lexical_first.frontmatter.id);
956 assert_eq!(selection.candidates[1].id, lexical_second.frontmatter.id);
957 }
958
959 #[derive(Default)]
961 struct RequestOptionsCaptureProvider {
962 captured_max_tokens: Mutex<Vec<Option<u32>>>,
963 captured_reasoning: Mutex<Vec<Option<ReasoningEffort>>>,
964 }
965
966 #[async_trait]
967 impl LLMProvider for RequestOptionsCaptureProvider {
968 async fn chat_stream(
969 &self,
970 _messages: &[Message],
971 _tools: &[bamboo_agent_core::ToolSchema],
972 _max_output_tokens: Option<u32>,
973 _model: &str,
974 ) -> Result<LLMStream, LLMError> {
975 Ok(Box::pin(stream::iter(vec![
976 Ok(LLMChunk::Token("{\"ids\":[]}".to_string())),
977 Ok(LLMChunk::Done),
978 ])))
979 }
980
981 async fn chat_stream_with_options(
982 &self,
983 messages: &[Message],
984 tools: &[bamboo_agent_core::ToolSchema],
985 max_output_tokens: Option<u32>,
986 model: &str,
987 options: Option<&LLMRequestOptions>,
988 ) -> Result<LLMStream, LLMError> {
989 self.captured_max_tokens
990 .lock()
991 .expect("lock should not be poisoned")
992 .push(max_output_tokens);
993 self.captured_reasoning
994 .lock()
995 .expect("lock should not be poisoned")
996 .push(options.and_then(|o| o.reasoning_effort));
997 self.chat_stream(messages, tools, max_output_tokens, model)
998 .await
999 }
1000 }
1001
1002 #[tokio::test]
1003 async fn rerank_sufficient_max_tokens_for_high_reasoning() {
1004 let provider = Arc::new(RequestOptionsCaptureProvider::default());
1005 let candidates = vec![MemoryRecallCandidate {
1006 id: "mem-1".to_string(),
1007 score: 0.9,
1008 title: "Test memory".to_string(),
1009 scope: MemoryScope::Project,
1010 project_key: Some("proj-1".to_string()),
1011 status: DurableMemoryStatus::Active,
1012 updated_at: "2026-05-08T00:00:00Z".to_string(),
1013 summary: "A test durable memory entry".to_string(),
1014 }];
1015 let context = MemoryRecallRerankContext {
1016 llm: provider.clone(),
1017 model: "deepseek-v4-pro".to_string(),
1018 session_id: Some("test-session".to_string()),
1019 };
1020
1021 let _ = rerank_candidate_ids("test query", &candidates, 5, &context).await;
1022
1023 let captured_reasoning = provider
1024 .captured_reasoning
1025 .lock()
1026 .expect("lock should not be poisoned");
1027 let captured_max_tokens = provider
1028 .captured_max_tokens
1029 .lock()
1030 .expect("lock should not be poisoned");
1031 assert_eq!(
1032 captured_reasoning.as_slice(),
1033 [Some(ReasoningEffort::High)],
1034 "rerank should request High reasoning"
1035 );
1036 let max_tokens = captured_max_tokens[0].expect("max_output_tokens should be set");
1037 assert!(
1038 max_tokens > 4096,
1039 "max_output_tokens ({}) must exceed thinking budget (4096) to avoid truncation",
1040 max_tokens
1041 );
1042 }
1043}