1use std::cmp::Ordering;
2use std::collections::HashSet;
3use std::io;
4use std::sync::Arc;
5
6use futures::StreamExt;
7use serde::Deserialize;
8
9use bamboo_agent_core::Message;
10use bamboo_domain::ReasoningEffort;
11use bamboo_infrastructure::{LLMChunk, LLMProvider, LLMRequestOptions};
12
13use super::{
14 extract_keywords, parse_rfc3339, DurableMemoryStatus, LexicalIndexItem, MemoryScope,
15 MemoryStore,
16};
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct MemoryRecallCandidate {
20 pub id: String,
21 pub title: String,
22 pub score: f64,
23 pub scope: MemoryScope,
24 pub project_key: Option<String>,
25 pub status: DurableMemoryStatus,
26 pub updated_at: String,
27 pub summary: String,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct MemoryRecallOptions {
32 pub shortlist_limit: usize,
33 pub include_global_fallback: bool,
34 pub max_candidates_per_scope: usize,
35}
36
37impl Default for MemoryRecallOptions {
38 fn default() -> Self {
39 Self {
40 shortlist_limit: 3,
41 include_global_fallback: true,
42 max_candidates_per_scope: 20,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum MemoryRecallStrategy {
49 Lexical,
50 Reranked,
51 RerankFallback,
52}
53
54impl MemoryRecallStrategy {
55 pub fn as_str(self) -> &'static str {
56 match self {
57 Self::Lexical => "lexical",
58 Self::Reranked => "reranked",
59 Self::RerankFallback => "rerank_fallback",
60 }
61 }
62}
63
64#[derive(Debug, Clone, PartialEq)]
65pub struct MemoryRecallSelection {
66 pub candidates: Vec<MemoryRecallCandidate>,
67 pub strategy: MemoryRecallStrategy,
68}
69
70#[derive(Clone)]
71pub struct MemoryRecallRerankContext {
72 pub llm: Arc<dyn LLMProvider>,
73 pub model: String,
74 pub session_id: Option<String>,
75}
76
77#[derive(Debug, Deserialize)]
78struct MemoryRecallRerankEnvelope {
79 #[serde(default)]
80 ids: Vec<String>,
81}
82
83pub async fn shortlist_relevant_memories(
84 store: &MemoryStore,
85 project_key: Option<&str>,
86 query: &str,
87 options: &MemoryRecallOptions,
88) -> io::Result<Vec<MemoryRecallCandidate>> {
89 let limit = options.shortlist_limit.max(1);
90 let mut candidates =
91 lexical_shortlist_relevant_memories(store, project_key, query, options).await?;
92 candidates.truncate(limit);
93 Ok(candidates)
94}
95
96pub async fn select_relevant_memories(
97 store: &MemoryStore,
98 project_key: Option<&str>,
99 query: &str,
100 options: &MemoryRecallOptions,
101 rerank_context: Option<&MemoryRecallRerankContext>,
102) -> io::Result<MemoryRecallSelection> {
103 let query = query.trim();
104 if query.is_empty() {
105 return Ok(MemoryRecallSelection {
106 candidates: Vec::new(),
107 strategy: MemoryRecallStrategy::Lexical,
108 });
109 }
110
111 let limit = options.shortlist_limit.max(1);
112 let mut shortlist =
113 lexical_shortlist_relevant_memories(store, project_key, query, options).await?;
114 if shortlist.is_empty() {
115 return Ok(MemoryRecallSelection {
116 candidates: shortlist,
117 strategy: MemoryRecallStrategy::Lexical,
118 });
119 }
120
121 let Some(rerank_context) = rerank_context else {
122 shortlist.truncate(limit);
123 return Ok(MemoryRecallSelection {
124 candidates: shortlist,
125 strategy: MemoryRecallStrategy::Lexical,
126 });
127 };
128
129 if shortlist.len() <= 1 {
130 shortlist.truncate(limit);
131 return Ok(MemoryRecallSelection {
132 candidates: shortlist,
133 strategy: MemoryRecallStrategy::Lexical,
134 });
135 }
136
137 match rerank_candidate_ids(query, &shortlist, limit, rerank_context).await {
138 Ok(ids) => {
139 let reranked = reorder_candidates_by_ids(&shortlist, &ids, limit);
140 if reranked.is_empty() {
141 let mut lexical = shortlist;
142 lexical.truncate(limit);
143 return Ok(MemoryRecallSelection {
144 candidates: lexical,
145 strategy: MemoryRecallStrategy::RerankFallback,
146 });
147 }
148 Ok(MemoryRecallSelection {
149 candidates: reranked,
150 strategy: MemoryRecallStrategy::Reranked,
151 })
152 }
153 Err(error) => {
154 tracing::warn!(
155 "Relevant memory rerank failed for model '{}': {}. Falling back to lexical shortlist.",
156 rerank_context.model,
157 error
158 );
159 shortlist.truncate(limit);
160 Ok(MemoryRecallSelection {
161 candidates: shortlist,
162 strategy: MemoryRecallStrategy::RerankFallback,
163 })
164 }
165 }
166}
167
168async fn lexical_shortlist_relevant_memories(
169 store: &MemoryStore,
170 project_key: Option<&str>,
171 query: &str,
172 options: &MemoryRecallOptions,
173) -> io::Result<Vec<MemoryRecallCandidate>> {
174 let query = query.trim();
175 if query.is_empty() {
176 return Ok(Vec::new());
177 }
178
179 let limit = options.shortlist_limit.max(1);
180 let per_scope_limit = options.max_candidates_per_scope.max(limit);
181
182 if let Some(project_key) = project_key.map(str::trim).filter(|value| !value.is_empty()) {
183 let mut project_hits =
184 shortlist_scope(store, MemoryScope::Project, Some(project_key), query).await?;
185 project_hits.truncate(per_scope_limit);
186 if !project_hits.is_empty() {
187 return Ok(project_hits);
188 }
189 }
190
191 if options.include_global_fallback {
192 let mut global_hits = shortlist_scope(store, MemoryScope::Global, None, query).await?;
193 global_hits.truncate(per_scope_limit);
194 return Ok(global_hits);
195 }
196
197 Ok(Vec::new())
198}
199
200async fn shortlist_scope(
201 store: &MemoryStore,
202 scope: MemoryScope,
203 project_key: Option<&str>,
204 query: &str,
205) -> io::Result<Vec<MemoryRecallCandidate>> {
206 let Some(index) = store.read_lexical_index(scope, project_key).await? else {
207 return Ok(Vec::new());
208 };
209
210 let query_tokens = extract_keywords(query, "", &[]);
211 if query_tokens.is_empty() {
212 return Ok(Vec::new());
213 }
214
215 let mut candidates = index
216 .items
217 .iter()
218 .filter_map(|item| score_lexical_index_item(item, &query_tokens).map(|score| (item, score)))
219 .map(|(item, score)| MemoryRecallCandidate {
220 id: item.id.clone(),
221 title: item.title.clone(),
222 score,
223 scope: item.scope,
224 project_key: item.project_key.clone(),
225 status: item.status,
226 updated_at: item.updated_at.clone(),
227 summary: item.summary.clone(),
228 })
229 .collect::<Vec<_>>();
230
231 sort_recall_candidates(&mut candidates);
232 Ok(candidates)
233}
234
235fn score_lexical_index_item(item: &LexicalIndexItem, query_tokens: &[String]) -> Option<f64> {
236 match item.status {
237 DurableMemoryStatus::Superseded
238 | DurableMemoryStatus::Contradicted
239 | DurableMemoryStatus::Archived => return None,
240 DurableMemoryStatus::Active | DurableMemoryStatus::Stale => {}
241 }
242
243 let title = item.title.to_ascii_lowercase();
244 let summary = item.summary.to_ascii_lowercase();
245
246 let mut score = 0.0;
247 let mut matched_any = false;
248
249 for token in query_tokens {
250 let mut token_score = 0.0;
251 if title.contains(token) {
252 token_score += 3.0;
253 }
254 if item
255 .keywords
256 .iter()
257 .any(|value| value.eq_ignore_ascii_case(token))
258 {
259 token_score += 2.5;
260 }
261 if item
262 .tags
263 .iter()
264 .any(|value| value.eq_ignore_ascii_case(token))
265 {
266 token_score += 2.0;
267 }
268 if item
269 .entities
270 .iter()
271 .any(|value| value.eq_ignore_ascii_case(token))
272 {
273 token_score += 1.5;
274 }
275 if summary.contains(token) {
276 token_score += 1.0;
277 }
278 if token_score > 0.0 {
279 matched_any = true;
280 score += token_score;
281 }
282 }
283
284 if !matched_any {
285 return None;
286 }
287
288 score += lexical_status_adjustment(item.status);
289 Some((score / query_tokens.len() as f64 * 100.0).round() / 100.0)
290}
291
292fn lexical_status_adjustment(status: DurableMemoryStatus) -> f64 {
293 match status {
294 DurableMemoryStatus::Active => 0.0,
295 DurableMemoryStatus::Stale => -0.75,
296 DurableMemoryStatus::Superseded
297 | DurableMemoryStatus::Contradicted
298 | DurableMemoryStatus::Archived => -10.0,
299 }
300}
301
302fn sort_recall_candidates(candidates: &mut [MemoryRecallCandidate]) {
303 candidates.sort_by(|left, right| {
304 right
305 .score
306 .partial_cmp(&left.score)
307 .unwrap_or(Ordering::Equal)
308 .then_with(|| {
309 let left_dt = parse_rfc3339(&left.updated_at)
310 .unwrap_or(chrono::DateTime::<chrono::Utc>::MIN_UTC);
311 let right_dt = parse_rfc3339(&right.updated_at)
312 .unwrap_or(chrono::DateTime::<chrono::Utc>::MIN_UTC);
313 right_dt.cmp(&left_dt)
314 })
315 .then_with(|| left.title.cmp(&right.title))
316 });
317}
318
319fn build_rerank_prompt(query: &str, candidates: &[MemoryRecallCandidate], limit: usize) -> String {
320 let mut prompt = String::from("# Bamboo Relevant Memory Recall Rerank\n\n");
321 prompt.push_str(
322 "Select the durable memory candidates that are most relevant to the user query.\n",
323 );
324 prompt.push_str("Return JSON only in the form {\"ids\":[\"candidate-id\", ...]}.\n");
325 prompt
326 .push_str("Do not include commentary, markdown fences, explanations, or unknown ids.\n\n");
327 prompt.push_str("## User query\n");
328 prompt.push_str(query.trim());
329 prompt.push_str("\n\n## Candidate memories\n");
330
331 for (index, candidate) in candidates.iter().enumerate() {
332 prompt.push_str(&format!(
333 "{}. id={}\n title: {}\n scope: {}\n status: {}\n updated_at: {}\n lexical_score: {:.2}\n summary: {}\n",
334 index + 1,
335 candidate.id,
336 candidate.title,
337 candidate.scope.as_str(),
338 candidate.status.as_str(),
339 candidate.updated_at,
340 candidate.score,
341 candidate.summary.replace('\n', " "),
342 ));
343 }
344
345 prompt.push_str(&format!(
346 "\n## Selection rules\n- Return at most {limit} ids.\n- Use only ids from the candidate list above.\n- Prefer candidates that best answer the user query or encode active preferences/constraints relevant to it.\n- Prefer active memories over stale ones when relevance is otherwise similar.\n- Keep the ids ordered best-to-worst.\n"
347 ));
348 prompt
349}
350
351async fn rerank_candidate_ids(
352 query: &str,
353 candidates: &[MemoryRecallCandidate],
354 limit: usize,
355 context: &MemoryRecallRerankContext,
356) -> Result<Vec<String>, String> {
357 let model = context.model.trim();
358 if model.is_empty() {
359 return Err("rerank model is empty".to_string());
360 }
361
362 let messages = vec![
363 Message::system(
364 "You rerank Bamboo durable-memory recall candidates. Return strict JSON only in the form {\"ids\":[...]} using only candidate ids from the prompt.",
365 ),
366 Message::user(build_rerank_prompt(query, candidates, limit)),
367 ];
368 let options = LLMRequestOptions {
369 session_id: context.session_id.clone(),
370 reasoning_effort: Some(ReasoningEffort::High),
371 parallel_tool_calls: None,
372 responses: None,
373 request_purpose: Some("memory_rerank".to_string()),
374 cache: None,
375 };
376
377 let mut stream = context
378 .llm
379 .chat_stream_with_options(&messages, &[], Some(8192), model, Some(&options))
380 .await
381 .map_err(|error| format!("rerank provider call failed: {error}"))?;
382
383 let content = tokio::time::timeout(std::time::Duration::from_secs(30), async {
384 let mut content = String::new();
385 while let Some(chunk_result) = stream.next().await {
386 match chunk_result {
387 Ok(LLMChunk::Token(text)) => content.push_str(&text),
388 Ok(LLMChunk::Done) => break,
389 Ok(_) => {}
390 Err(error) => {
391 if !content.trim().is_empty() {
392 break;
393 }
394 return Err(format!("rerank stream failed: {error}"));
395 }
396 }
397 }
398 Ok(content)
399 })
400 .await
401 .unwrap_or_else(|_| Err("rerank timed out after 30s".to_string()))?;
402
403 parse_reranked_ids(&content, candidates)
404 .ok_or_else(|| format!("failed to parse rerank response: {}", content.trim()))
405}
406
407fn reorder_candidates_by_ids(
408 lexical_candidates: &[MemoryRecallCandidate],
409 preferred_ids: &[String],
410 limit: usize,
411) -> Vec<MemoryRecallCandidate> {
412 if lexical_candidates.is_empty() || limit == 0 {
413 return Vec::new();
414 }
415
416 let allowed = lexical_candidates
417 .iter()
418 .map(|candidate| candidate.id.as_str())
419 .collect::<HashSet<_>>();
420 let mut seen = HashSet::new();
421 let mut ordered = Vec::new();
422
423 for id in preferred_ids {
424 let trimmed = id.trim();
425 if trimmed.is_empty() || !allowed.contains(trimmed) || !seen.insert(trimmed.to_string()) {
426 continue;
427 }
428 if let Some(candidate) = lexical_candidates
429 .iter()
430 .find(|candidate| candidate.id == trimmed)
431 .cloned()
432 {
433 ordered.push(candidate);
434 if ordered.len() >= limit {
435 return ordered;
436 }
437 }
438 }
439
440 for candidate in lexical_candidates {
441 if seen.insert(candidate.id.clone()) {
442 ordered.push(candidate.clone());
443 if ordered.len() >= limit {
444 break;
445 }
446 }
447 }
448
449 ordered
450}
451
452fn parse_reranked_ids(raw: &str, candidates: &[MemoryRecallCandidate]) -> Option<Vec<String>> {
453 let stripped = strip_markdown_fence(raw);
454 let fragment = extract_json_fragment(&stripped).unwrap_or(stripped.trim());
455 let ids = serde_json::from_str::<MemoryRecallRerankEnvelope>(fragment)
456 .map(|value| value.ids)
457 .or_else(|_| serde_json::from_str::<Vec<String>>(fragment))
458 .ok()?;
459
460 let allowed = candidates
461 .iter()
462 .map(|candidate| candidate.id.as_str())
463 .collect::<HashSet<_>>();
464 let mut seen = HashSet::new();
465 let mut out = Vec::new();
466
467 for id in ids {
468 let trimmed = id.trim();
469 if trimmed.is_empty() || !allowed.contains(trimmed) || !seen.insert(trimmed.to_string()) {
470 continue;
471 }
472 out.push(trimmed.to_string());
473 }
474
475 (!out.is_empty()).then_some(out)
476}
477
478fn strip_markdown_fence(raw: &str) -> String {
479 let trimmed = raw.trim();
480 for fence in ["````", "```"] {
481 if let Some(after_fence) = trimmed.strip_prefix(fence) {
482 let Some(first_newline) = after_fence.find('\n') else {
483 continue;
484 };
485 let body = &after_fence[first_newline + 1..];
486 if let Some(end_idx) = body.rfind(fence) {
487 return body[..end_idx].trim().to_string();
488 }
489 }
490 }
491 trimmed.to_string()
492}
493
494fn extract_json_fragment(raw: &str) -> Option<&str> {
495 let trimmed = raw.trim();
496 if trimmed.is_empty() {
497 return None;
498 }
499
500 if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}')) {
501 if start <= end {
502 return Some(trimmed[start..=end].trim());
503 }
504 }
505
506 if let (Some(start), Some(end)) = (trimmed.find('['), trimmed.rfind(']')) {
507 if start <= end {
508 return Some(trimmed[start..=end].trim());
509 }
510 }
511
512 None
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518 use crate::memory_store::DurableMemoryType;
519 use async_trait::async_trait;
520 use bamboo_domain::ReasoningEffort;
521 use bamboo_infrastructure::llm::provider::LLMRequestOptions;
522 use bamboo_infrastructure::{LLMChunk, LLMError, LLMProvider, LLMStream};
523 use futures::stream;
524 use std::sync::Mutex;
525 use tempfile::tempdir;
526
527 #[allow(clippy::too_many_arguments)]
528 fn item(
529 id: &str,
530 title: &str,
531 status: DurableMemoryStatus,
532 updated_at: &str,
533 keywords: &[&str],
534 tags: &[&str],
535 entities: &[&str],
536 summary: &str,
537 ) -> LexicalIndexItem {
538 LexicalIndexItem {
539 id: id.to_string(),
540 title: title.to_string(),
541 scope: MemoryScope::Project,
542 project_key: Some("proj-1".to_string()),
543 r#type: DurableMemoryType::Project,
544 status,
545 tags: tags.iter().map(|v| v.to_string()).collect(),
546 keywords: keywords.iter().map(|v| v.to_string()).collect(),
547 entities: entities.iter().map(|v| v.to_string()).collect(),
548 updated_at: updated_at.to_string(),
549 created_at: updated_at.to_string(),
550 summary: summary.to_string(),
551 }
552 }
553
554 #[derive(Clone)]
555 struct StaticResponseProvider {
556 response: String,
557 requested_models: Arc<Mutex<Vec<String>>>,
558 }
559
560 impl StaticResponseProvider {
561 fn new(response: impl Into<String>) -> Self {
562 Self {
563 response: response.into(),
564 requested_models: Arc::new(Mutex::new(Vec::new())),
565 }
566 }
567 }
568
569 #[async_trait]
570 impl LLMProvider for StaticResponseProvider {
571 async fn chat_stream(
572 &self,
573 _messages: &[Message],
574 _tools: &[bamboo_agent_core::ToolSchema],
575 _max_output_tokens: Option<u32>,
576 model: &str,
577 ) -> Result<LLMStream, LLMError> {
578 self.requested_models
579 .lock()
580 .expect("lock poisoned")
581 .push(model.to_string());
582 Ok(Box::pin(stream::iter(vec![
583 Ok(LLMChunk::Token(self.response.clone())),
584 Ok(LLMChunk::Done),
585 ])))
586 }
587 }
588
589 #[test]
590 fn title_matches_outrank_keyword_only_matches() {
591 let query_tokens = vec!["release".to_string(), "freeze".to_string()];
592 let title_item = item(
593 "a",
594 "Release freeze decision",
595 DurableMemoryStatus::Active,
596 "2026-04-09T00:00:00Z",
597 &[],
598 &[],
599 &[],
600 "summary",
601 );
602 let keyword_item = item(
603 "b",
604 "Deployment decision",
605 DurableMemoryStatus::Active,
606 "2026-04-09T00:00:00Z",
607 &["release", "freeze"],
608 &[],
609 &[],
610 "summary",
611 );
612
613 let title_score = score_lexical_index_item(&title_item, &query_tokens).unwrap();
614 let keyword_score = score_lexical_index_item(&keyword_item, &query_tokens).unwrap();
615 assert!(title_score > keyword_score);
616 }
617
618 #[test]
619 fn active_items_outrank_stale_items() {
620 let query_tokens = vec!["release".to_string()];
621 let active = item(
622 "a",
623 "Release freeze decision",
624 DurableMemoryStatus::Active,
625 "2026-04-09T00:00:00Z",
626 &[],
627 &[],
628 &[],
629 "summary",
630 );
631 let stale = item(
632 "b",
633 "Release freeze decision",
634 DurableMemoryStatus::Stale,
635 "2026-04-10T00:00:00Z",
636 &[],
637 &[],
638 &[],
639 "summary",
640 );
641
642 let active_score = score_lexical_index_item(&active, &query_tokens).unwrap();
643 let stale_score = score_lexical_index_item(&stale, &query_tokens).unwrap();
644 assert!(active_score > stale_score);
645 }
646
647 #[test]
648 fn contradicted_and_archived_items_are_filtered_out() {
649 let query_tokens = vec!["release".to_string()];
650 let contradicted = item(
651 "a",
652 "Release freeze decision",
653 DurableMemoryStatus::Contradicted,
654 "2026-04-09T00:00:00Z",
655 &[],
656 &[],
657 &[],
658 "summary",
659 );
660 let archived = item(
661 "b",
662 "Release freeze decision",
663 DurableMemoryStatus::Archived,
664 "2026-04-09T00:00:00Z",
665 &[],
666 &[],
667 &[],
668 "summary",
669 );
670
671 assert!(score_lexical_index_item(&contradicted, &query_tokens).is_none());
672 assert!(score_lexical_index_item(&archived, &query_tokens).is_none());
673 }
674
675 #[test]
676 fn parse_reranked_ids_accepts_fenced_json_and_filters_unknown_ids() {
677 let candidates = vec![
678 MemoryRecallCandidate {
679 id: "mem-a".to_string(),
680 title: "A".to_string(),
681 score: 10.0,
682 scope: MemoryScope::Project,
683 project_key: Some("proj-1".to_string()),
684 status: DurableMemoryStatus::Active,
685 updated_at: "2026-04-09T00:00:00Z".to_string(),
686 summary: "summary a".to_string(),
687 },
688 MemoryRecallCandidate {
689 id: "mem-b".to_string(),
690 title: "B".to_string(),
691 score: 9.0,
692 scope: MemoryScope::Project,
693 project_key: Some("proj-1".to_string()),
694 status: DurableMemoryStatus::Active,
695 updated_at: "2026-04-09T00:00:00Z".to_string(),
696 summary: "summary b".to_string(),
697 },
698 ];
699
700 let parsed = parse_reranked_ids(
701 "```json\n{\"ids\":[\"mem-b\",\"unknown\",\"mem-a\",\"mem-b\"]}\n```",
702 &candidates,
703 )
704 .expect("reranked ids should parse");
705
706 assert_eq!(parsed, vec!["mem-b".to_string(), "mem-a".to_string()]);
707 }
708
709 #[test]
710 fn reorder_candidates_by_ids_appends_remaining_lexical_candidates() {
711 let lexical = vec![
712 MemoryRecallCandidate {
713 id: "mem-a".to_string(),
714 title: "A".to_string(),
715 score: 10.0,
716 scope: MemoryScope::Project,
717 project_key: Some("proj-1".to_string()),
718 status: DurableMemoryStatus::Active,
719 updated_at: "2026-04-09T00:00:00Z".to_string(),
720 summary: "summary a".to_string(),
721 },
722 MemoryRecallCandidate {
723 id: "mem-b".to_string(),
724 title: "B".to_string(),
725 score: 9.0,
726 scope: MemoryScope::Project,
727 project_key: Some("proj-1".to_string()),
728 status: DurableMemoryStatus::Active,
729 updated_at: "2026-04-09T00:00:00Z".to_string(),
730 summary: "summary b".to_string(),
731 },
732 MemoryRecallCandidate {
733 id: "mem-c".to_string(),
734 title: "C".to_string(),
735 score: 8.0,
736 scope: MemoryScope::Project,
737 project_key: Some("proj-1".to_string()),
738 status: DurableMemoryStatus::Active,
739 updated_at: "2026-04-09T00:00:00Z".to_string(),
740 summary: "summary c".to_string(),
741 },
742 ];
743
744 let reordered =
745 reorder_candidates_by_ids(&lexical, &["mem-c".to_string(), "mem-a".to_string()], 3);
746
747 assert_eq!(reordered[0].id, "mem-c");
748 assert_eq!(reordered[1].id, "mem-a");
749 assert_eq!(reordered[2].id, "mem-b");
750 }
751
752 #[tokio::test]
753 async fn project_scope_shortlist_excludes_global_when_project_hits_exist() {
754 let dir = tempdir().unwrap();
755 let store = MemoryStore::new(dir.path());
756
757 store
758 .write_memory(
759 MemoryScope::Project,
760 Some("proj-1"),
761 DurableMemoryType::Project,
762 "Release freeze decision",
763 "Project-specific release freeze note.",
764 &["release".to_string()],
765 Some("session-1"),
766 "main-model",
767 false,
768 )
769 .await
770 .unwrap();
771 store
772 .write_memory(
773 MemoryScope::Global,
774 None,
775 DurableMemoryType::Reference,
776 "Global release guidance",
777 "Global note that should not be used when project hits exist.",
778 &["release".to_string()],
779 Some("session-1"),
780 "main-model",
781 false,
782 )
783 .await
784 .unwrap();
785
786 let candidates = shortlist_relevant_memories(
787 &store,
788 Some("proj-1"),
789 "release freeze",
790 &MemoryRecallOptions::default(),
791 )
792 .await
793 .unwrap();
794
795 assert!(!candidates.is_empty());
796 assert!(candidates
797 .iter()
798 .all(|candidate| candidate.scope == MemoryScope::Project));
799 }
800
801 #[tokio::test]
802 async fn global_fallback_triggers_only_when_project_hits_are_absent() {
803 let dir = tempdir().unwrap();
804 let store = MemoryStore::new(dir.path());
805
806 store
807 .write_memory(
808 MemoryScope::Global,
809 None,
810 DurableMemoryType::Reference,
811 "Global release guidance",
812 "Fallback note for release work.",
813 &["release".to_string()],
814 Some("session-1"),
815 "main-model",
816 false,
817 )
818 .await
819 .unwrap();
820
821 let candidates = shortlist_relevant_memories(
822 &store,
823 Some("proj-missing"),
824 "release guidance",
825 &MemoryRecallOptions::default(),
826 )
827 .await
828 .unwrap();
829
830 assert!(!candidates.is_empty());
831 assert!(candidates
832 .iter()
833 .all(|candidate| candidate.scope == MemoryScope::Global));
834 }
835
836 #[tokio::test]
837 async fn model_rerank_reorders_lexical_shortlist_when_enabled() {
838 let dir = tempdir().unwrap();
839 let store = MemoryStore::new(dir.path());
840
841 let lexical_first = store
842 .write_memory(
843 MemoryScope::Project,
844 Some("proj-1"),
845 DurableMemoryType::Project,
846 "Release freeze checklist",
847 "Generic release freeze checklist for shipping work.",
848 &["release".to_string(), "freeze".to_string()],
849 Some("session-1"),
850 "main-model",
851 false,
852 )
853 .await
854 .unwrap();
855 let reranked_first = store
856 .write_memory(
857 MemoryScope::Project,
858 Some("proj-1"),
859 DurableMemoryType::Project,
860 "Mobile launch blocker",
861 "This durable note captures the release freeze decision for the mobile app and should be preferred for mobile freeze requests.",
862 &["mobile".to_string(), "launch".to_string()],
863 Some("session-1"),
864 "main-model",
865 false,
866 )
867 .await
868 .unwrap();
869
870 let provider = StaticResponseProvider::new(format!(
871 "{{\"ids\":[\"{}\",\"{}\"]}}",
872 reranked_first.frontmatter.id, lexical_first.frontmatter.id
873 ));
874 let requested_models = provider.requested_models.clone();
875 let selection = select_relevant_memories(
876 &store,
877 Some("proj-1"),
878 "release freeze for mobile",
879 &MemoryRecallOptions {
880 shortlist_limit: 2,
881 include_global_fallback: false,
882 max_candidates_per_scope: 12,
883 },
884 Some(&MemoryRecallRerankContext {
885 llm: Arc::new(provider),
886 model: "rerank-fast-model".to_string(),
887 session_id: Some("session-1".to_string()),
888 }),
889 )
890 .await
891 .unwrap();
892
893 assert_eq!(selection.strategy, MemoryRecallStrategy::Reranked);
894 assert_eq!(selection.candidates.len(), 2);
895 assert_eq!(selection.candidates[0].id, reranked_first.frontmatter.id);
896 assert_eq!(selection.candidates[1].id, lexical_first.frontmatter.id);
897 assert_eq!(
898 requested_models.lock().expect("lock poisoned").as_slice(),
899 ["rerank-fast-model"]
900 );
901 }
902
903 #[tokio::test]
904 async fn invalid_model_rerank_response_falls_back_to_lexical_order() {
905 let dir = tempdir().unwrap();
906 let store = MemoryStore::new(dir.path());
907
908 let lexical_first = store
909 .write_memory(
910 MemoryScope::Project,
911 Some("proj-1"),
912 DurableMemoryType::Project,
913 "Release freeze checklist",
914 "Generic release freeze checklist for shipping work.",
915 &["release".to_string(), "freeze".to_string()],
916 Some("session-1"),
917 "main-model",
918 false,
919 )
920 .await
921 .unwrap();
922 let lexical_second = store
923 .write_memory(
924 MemoryScope::Project,
925 Some("proj-1"),
926 DurableMemoryType::Project,
927 "Mobile launch blocker",
928 "This durable note captures the release freeze decision for the mobile app.",
929 &["mobile".to_string(), "launch".to_string()],
930 Some("session-1"),
931 "main-model",
932 false,
933 )
934 .await
935 .unwrap();
936
937 let selection = select_relevant_memories(
938 &store,
939 Some("proj-1"),
940 "release freeze for mobile",
941 &MemoryRecallOptions {
942 shortlist_limit: 2,
943 include_global_fallback: false,
944 max_candidates_per_scope: 12,
945 },
946 Some(&MemoryRecallRerankContext {
947 llm: Arc::new(StaticResponseProvider::new("not valid json")),
948 model: "rerank-fast-model".to_string(),
949 session_id: Some("session-1".to_string()),
950 }),
951 )
952 .await
953 .unwrap();
954
955 assert_eq!(selection.strategy, MemoryRecallStrategy::RerankFallback);
956 assert_eq!(selection.candidates.len(), 2);
957 assert_eq!(selection.candidates[0].id, lexical_first.frontmatter.id);
958 assert_eq!(selection.candidates[1].id, lexical_second.frontmatter.id);
959 }
960
961 #[derive(Default)]
963 struct RequestOptionsCaptureProvider {
964 captured_max_tokens: Mutex<Vec<Option<u32>>>,
965 captured_reasoning: Mutex<Vec<Option<ReasoningEffort>>>,
966 }
967
968 #[async_trait]
969 impl LLMProvider for RequestOptionsCaptureProvider {
970 async fn chat_stream(
971 &self,
972 _messages: &[Message],
973 _tools: &[bamboo_agent_core::ToolSchema],
974 _max_output_tokens: Option<u32>,
975 _model: &str,
976 ) -> Result<LLMStream, LLMError> {
977 Ok(Box::pin(stream::iter(vec![
978 Ok(LLMChunk::Token("{\"ids\":[]}".to_string())),
979 Ok(LLMChunk::Done),
980 ])))
981 }
982
983 async fn chat_stream_with_options(
984 &self,
985 messages: &[Message],
986 tools: &[bamboo_agent_core::ToolSchema],
987 max_output_tokens: Option<u32>,
988 model: &str,
989 options: Option<&LLMRequestOptions>,
990 ) -> Result<LLMStream, LLMError> {
991 self.captured_max_tokens
992 .lock()
993 .expect("lock should not be poisoned")
994 .push(max_output_tokens);
995 self.captured_reasoning
996 .lock()
997 .expect("lock should not be poisoned")
998 .push(options.and_then(|o| o.reasoning_effort));
999 self.chat_stream(messages, tools, max_output_tokens, model)
1000 .await
1001 }
1002 }
1003
1004 #[tokio::test]
1005 async fn rerank_sufficient_max_tokens_for_high_reasoning() {
1006 let provider = Arc::new(RequestOptionsCaptureProvider::default());
1007 let candidates = vec![MemoryRecallCandidate {
1008 id: "mem-1".to_string(),
1009 score: 0.9,
1010 title: "Test memory".to_string(),
1011 scope: MemoryScope::Project,
1012 project_key: Some("proj-1".to_string()),
1013 status: DurableMemoryStatus::Active,
1014 updated_at: "2026-05-08T00:00:00Z".to_string(),
1015 summary: "A test durable memory entry".to_string(),
1016 }];
1017 let context = MemoryRecallRerankContext {
1018 llm: provider.clone(),
1019 model: "deepseek-v4-pro".to_string(),
1020 session_id: Some("test-session".to_string()),
1021 };
1022
1023 let _ = rerank_candidate_ids("test query", &candidates, 5, &context).await;
1024
1025 let captured_reasoning = provider
1026 .captured_reasoning
1027 .lock()
1028 .expect("lock should not be poisoned");
1029 let captured_max_tokens = provider
1030 .captured_max_tokens
1031 .lock()
1032 .expect("lock should not be poisoned");
1033 assert_eq!(
1034 captured_reasoning.as_slice(),
1035 [Some(ReasoningEffort::High)],
1036 "rerank should request High reasoning"
1037 );
1038 let max_tokens = captured_max_tokens[0].expect("max_output_tokens should be set");
1039 assert!(
1040 max_tokens > 4096,
1041 "max_output_tokens ({}) must exceed thinking budget (4096) to avoid truncation",
1042 max_tokens
1043 );
1044 }
1045}