1use std::future::Future;
17use std::pin::Pin;
18
19use futures::StreamExt as _;
20use futures::stream::FuturesUnordered;
21
22use zeph_common::memory::{AsyncMemoryRouter, CompressionLevel, GraphRecallParams, TokenCounting};
23use zeph_llm::provider::{Message, MessageMetadata, MessagePart, Role};
24
25use crate::error::AssemblerError;
26use crate::input::ContextAssemblyInput;
27use crate::slot::ContextSlot;
28
29pub(crate) fn levels_to_flags(levels: &[CompressionLevel]) -> (bool, bool, bool) {
37 if levels.is_empty() {
38 return (true, true, true);
39 }
40 let episodic = levels.contains(&CompressionLevel::Episodic);
41 let procedural = levels.contains(&CompressionLevel::Procedural);
42 let declarative = levels.contains(&CompressionLevel::Declarative);
43 (episodic, procedural, declarative)
44}
45
46pub const SUMMARY_PREFIX: &str = "[conversation summaries]\n";
48pub const CROSS_SESSION_PREFIX: &str = "[cross-session context]\n";
50pub const RECALL_PREFIX: &str = "[semantic recall]\n";
52pub const CORRECTIONS_PREFIX: &str = "[past corrections]\n";
54pub const DOCUMENT_RAG_PREFIX: &str = "## Relevant documents\n";
56pub const GRAPH_FACTS_PREFIX: &str = "[known facts]\n";
58
59pub struct PreparedContext {
64 pub graph_facts: Option<Message>,
66 pub doc_rag: Option<Message>,
68 pub corrections: Option<Message>,
70 pub recall: Option<Message>,
72 pub recall_confidence: Option<f32>,
74 pub cross_session: Option<Message>,
76 pub summaries: Option<Message>,
78 pub code_context: Option<String>,
80 pub persona_facts: Option<Message>,
82 pub trajectory_hints: Option<Message>,
84 pub tree_memory: Option<Message>,
86 pub reasoning_hints: Option<Message>,
88 pub memory_first: bool,
90 pub recent_history_budget: usize,
92}
93
94pub struct ContextAssembler;
98
99type CtxFuture<'a> = Pin<Box<dyn Future<Output = Result<ContextSlot, AssemblerError>> + Send + 'a>>;
100
101fn empty_prepared_context() -> PreparedContext {
102 PreparedContext {
103 graph_facts: None,
104 doc_rag: None,
105 corrections: None,
106 recall: None,
107 recall_confidence: None,
108 cross_session: None,
109 summaries: None,
110 code_context: None,
111 persona_facts: None,
112 trajectory_hints: None,
113 tree_memory: None,
114 reasoning_hints: None,
115 memory_first: false,
116 recent_history_budget: 0,
117 }
118}
119
120fn resolve_effective_strategy(
123 memory: &crate::input::ContextMemoryView,
124 sidequest_turn_counter: u64,
125) -> zeph_config::ContextStrategy {
126 match memory.context_strategy {
127 zeph_config::ContextStrategy::FullHistory => zeph_config::ContextStrategy::FullHistory,
128 zeph_config::ContextStrategy::MemoryFirst => zeph_config::ContextStrategy::MemoryFirst,
129 zeph_config::ContextStrategy::Adaptive => {
130 if sidequest_turn_counter >= u64::from(memory.crossover_turn_threshold) {
131 zeph_config::ContextStrategy::MemoryFirst
132 } else {
133 zeph_config::ContextStrategy::FullHistory
134 }
135 }
136 }
137}
138
139fn correction_params(cfg: Option<&crate::input::CorrectionConfig>) -> (usize, f32) {
140 cfg.filter(|c| c.correction_detection)
141 .map_or((3, 0.75), |c| {
142 (
143 c.correction_recall_limit as usize,
144 c.correction_min_similarity,
145 )
146 })
147}
148
149#[allow(clippy::too_many_arguments)]
156fn schedule_context_fetchers<'r>(
157 memory: &'r crate::input::ContextMemoryView,
158 tc: &'r dyn TokenCounting,
159 query: &'r str,
160 scrub: fn(&str) -> std::borrow::Cow<'_, str>,
161 index: Option<&'r dyn crate::input::IndexAccess>,
162 router_ref: &'r dyn AsyncMemoryRouter,
163 summaries_budget: usize,
164 cross_session_budget: usize,
165 semantic_recall_budget: usize,
166 code_context_budget: usize,
167 graph_facts_budget: usize,
168 recall_limit: usize,
169 min_sim: f32,
170 active_levels: &[CompressionLevel],
171) -> FuturesUnordered<CtxFuture<'r>> {
172 let (episodic_active, procedural_active, declarative_active) = levels_to_flags(active_levels);
176
177 let fetchers: FuturesUnordered<CtxFuture<'r>> = FuturesUnordered::new();
178
179 if episodic_active && summaries_budget > 0 {
180 fetchers.push(Box::pin(async move {
181 fetch_summaries(memory, summaries_budget, tc)
182 .await
183 .map(ContextSlot::Summaries)
184 }));
185 }
186 if episodic_active && cross_session_budget > 0 {
187 fetchers.push(Box::pin(async move {
188 fetch_cross_session(memory, query, cross_session_budget, tc)
189 .await
190 .map(ContextSlot::CrossSession)
191 }));
192 }
193 if episodic_active && semantic_recall_budget > 0 {
194 fetchers.push(Box::pin(async move {
195 fetch_semantic_recall(memory, query, semantic_recall_budget, tc, Some(router_ref))
196 .await
197 .map(|(msg, score)| ContextSlot::SemanticRecall(msg, score))
198 }));
199 fetchers.push(Box::pin(async move {
200 fetch_document_rag(memory, query, semantic_recall_budget, tc)
201 .await
202 .map(ContextSlot::DocumentRag)
203 }));
204 }
205 fetchers.push(Box::pin(async move {
207 fetch_corrections(memory, query, recall_limit, min_sim, scrub)
208 .await
209 .map(ContextSlot::Corrections)
210 }));
211 if code_context_budget > 0
213 && let Some(idx) = index
214 {
215 fetchers.push(Box::pin(async move {
216 let result: Result<Option<String>, AssemblerError> =
217 idx.fetch_code_rag(query, code_context_budget).await;
218 result.map(ContextSlot::CodeContext)
219 }));
220 }
221 if declarative_active && graph_facts_budget > 0 {
222 fetchers.push(Box::pin(async move {
223 fetch_graph_facts(memory, query, graph_facts_budget, tc)
224 .await
225 .map(ContextSlot::GraphFacts)
226 }));
227 }
228 if declarative_active && memory.persona_config.context_budget_tokens > 0 {
229 fetchers.push(Box::pin(async move {
230 let persona_budget = memory.persona_config.context_budget_tokens;
231 fetch_persona_facts(memory, persona_budget, tc)
232 .await
233 .map(ContextSlot::PersonaFacts)
234 }));
235 }
236 if procedural_active && memory.trajectory_config.context_budget_tokens > 0 {
237 fetchers.push(Box::pin(async move {
238 let tbudget = memory.trajectory_config.context_budget_tokens;
239 fetch_trajectory_hints(memory, tbudget, tc)
240 .await
241 .map(ContextSlot::TrajectoryHints)
242 }));
243 }
244 if declarative_active && memory.tree_config.context_budget_tokens > 0 {
245 fetchers.push(Box::pin(async move {
246 let tbudget = memory.tree_config.context_budget_tokens;
247 fetch_tree_memory(memory, tbudget, tc)
248 .await
249 .map(ContextSlot::TreeMemory)
250 }));
251 }
252 if procedural_active
253 && memory.reasoning_config.enabled
254 && memory.reasoning_config.context_budget_tokens > 0
255 {
256 fetchers.push(Box::pin(async move {
257 let rbudget = memory.reasoning_config.context_budget_tokens;
258 let top_k = memory.reasoning_config.top_k;
259 fetch_reasoning_strategies(memory, query, rbudget, top_k, tc)
260 .await
261 .map(ContextSlot::ReasoningStrategies)
262 }));
263 }
264
265 fetchers
266}
267
268async fn drive_fetchers(
269 mut fetchers: FuturesUnordered<CtxFuture<'_>>,
270 prepared: &mut PreparedContext,
271) -> Result<(), AssemblerError> {
272 while let Some(result) = fetchers.next().await {
273 match result {
274 Ok(slot) => match slot {
275 ContextSlot::Summaries(msg) => prepared.summaries = msg,
276 ContextSlot::CrossSession(msg) => prepared.cross_session = msg,
277 ContextSlot::SemanticRecall(msg, score) => {
278 prepared.recall = msg;
279 prepared.recall_confidence = score;
280 }
281 ContextSlot::DocumentRag(msg) => prepared.doc_rag = msg,
282 ContextSlot::Corrections(msg) => prepared.corrections = msg,
283 ContextSlot::CodeContext(text) => prepared.code_context = text,
284 ContextSlot::GraphFacts(msg) => prepared.graph_facts = msg,
285 ContextSlot::PersonaFacts(msg) => prepared.persona_facts = msg,
286 ContextSlot::TrajectoryHints(msg) => prepared.trajectory_hints = msg,
287 ContextSlot::TreeMemory(msg) => prepared.tree_memory = msg,
288 ContextSlot::ReasoningStrategies(msg) => prepared.reasoning_hints = msg,
289 },
290 Err(e) => return Err(e),
291 }
292 }
293 Ok(())
294}
295
296impl ContextAssembler {
297 pub async fn gather(
305 input: &ContextAssemblyInput<'_>,
306 ) -> Result<PreparedContext, AssemblerError> {
307 let Some(ref budget) = input.context_manager.budget else {
308 return Ok(empty_prepared_context());
309 };
310
311 let memory = input.memory;
312 let tc = input.token_counter;
313
314 let effective_strategy = resolve_effective_strategy(memory, input.sidequest_turn_counter);
315 let memory_first = effective_strategy == zeph_config::ContextStrategy::MemoryFirst;
316
317 let system_prompt = input
318 .messages
319 .first()
320 .filter(|m| m.role == Role::System)
321 .map_or("", |m| m.content.as_str());
322
323 let digest_tokens = memory
324 .cached_session_digest
325 .as_ref()
326 .map_or(0, |(_, tokens)| *tokens);
327
328 let alloc = budget.allocate_with_opts(
329 system_prompt,
330 input.skills_prompt,
331 tc,
332 memory.graph_config.enabled,
333 digest_tokens,
334 memory_first,
335 );
336
337 let (recall_limit, min_sim) = correction_params(input.correction_config.as_ref());
338
339 let router_ref: &dyn AsyncMemoryRouter = input.router.as_ref();
340
341 tracing::debug!(
342 active_sources = alloc.active_sources(),
343 active_levels = ?input.active_levels,
344 "context budget allocated"
345 );
346
347 let fetchers = schedule_context_fetchers(
348 memory,
349 tc,
350 input.query,
351 input.scrub,
352 input.index,
353 router_ref,
354 alloc.summaries,
355 alloc.cross_session,
356 alloc.semantic_recall,
357 alloc.code_context,
358 alloc.graph_facts,
359 recall_limit,
360 min_sim,
361 input.active_levels,
362 );
363
364 let mut prepared = empty_prepared_context();
365 prepared.memory_first = memory_first;
366 prepared.recent_history_budget = alloc.recent_history;
367
368 drive_fetchers(fetchers, &mut prepared).await?;
369 Ok(prepared)
370 }
371}
372
373pub fn effective_recall_timeout_ms(configured: u64) -> u64 {
378 if configured == 0 {
379 tracing::warn!(
380 "recall_timeout_ms is 0, which would disable spreading activation recall; \
381 clamping to 100ms"
382 );
383 100
384 } else {
385 configured
386 }
387}
388
389use crate::input::ContextMemoryView;
390
391#[tracing::instrument(name = "context.graph_facts", skip_all)]
392#[allow(clippy::too_many_lines)] pub(crate) async fn fetch_graph_facts(
394 memory: &ContextMemoryView,
395 query: &str,
396 budget_tokens: usize,
397 tc: &dyn TokenCounting,
398) -> Result<Option<Message>, AssemblerError> {
399 use zeph_common::memory::{RecallView, SpreadingActivationParams, classify_graph_subgraph};
400
401 if budget_tokens == 0 || !memory.graph_config.enabled {
402 return Ok(None);
403 }
404 let Some(ref mem) = memory.memory else {
405 return Ok(None);
406 };
407 let recall_limit = memory.graph_config.recall_limit;
408 let temporal_decay_rate = memory.graph_config.temporal_decay_rate;
409 let sa_config = &memory.graph_config.spreading_activation;
410
411 let fused_query;
413 let effective_query = if let Some(ref state) = memory.memcot_state {
414 let max_state_chars = 2 * query.len();
415 let state_slice = if state.len() > max_state_chars {
416 let boundary = state.floor_char_boundary(max_state_chars);
417 &state[..boundary]
418 } else {
419 state.as_str()
420 };
421 fused_query = format!("[state] {state_slice}\n{query}");
422 &fused_query as &str
423 } else {
424 query
425 };
426
427 let edge_types = classify_graph_subgraph(effective_query);
428
429 let view = match memory.memcot_config.recall_view {
430 zeph_config::RecallViewConfig::Head => RecallView::Head,
431 zeph_config::RecallViewConfig::ZoomIn => RecallView::ZoomIn,
432 zeph_config::RecallViewConfig::ZoomOut => RecallView::ZoomOut,
433 };
434
435 let sa_params = if sa_config.enabled {
436 Some(SpreadingActivationParams {
437 decay_lambda: sa_config.decay_lambda,
438 max_hops: sa_config.max_hops,
439 activation_threshold: sa_config.activation_threshold,
440 inhibition_threshold: sa_config.inhibition_threshold,
441 max_activated_nodes: sa_config.max_activated_nodes,
442 temporal_decay_rate,
443 seed_structural_weight: sa_config.seed_structural_weight,
444 seed_community_cap: sa_config.seed_community_cap,
445 })
446 } else {
447 None
448 };
449
450 let timeout_ms = effective_recall_timeout_ms(sa_config.recall_timeout_ms);
451 let recall_fut = mem.recall_graph_facts(
452 effective_query,
453 GraphRecallParams {
454 limit: recall_limit,
455 view,
456 zoom_out_neighbor_cap: memory.memcot_config.zoom_out_neighbor_cap,
457 max_hops: memory.graph_config.max_hops,
458 temporal_decay_rate,
459 edge_types: &edge_types,
460 spreading_activation: sa_params,
461 },
462 );
463 let recalled = match tokio::time::timeout(
464 std::time::Duration::from_millis(timeout_ms),
465 recall_fut,
466 )
467 .await
468 {
469 Ok(Ok(facts)) => facts,
470 Ok(Err(e)) => {
471 tracing::warn!("graph recall failed: {e:#}");
472 Vec::new()
473 }
474 Err(_) => {
475 tracing::warn!("graph recall timed out ({timeout_ms}ms)");
476 Vec::new()
477 }
478 };
479
480 if recalled.is_empty() {
481 return Ok(None);
482 }
483
484 let mut body = String::from(GRAPH_FACTS_PREFIX);
485 let mut tokens_so_far = tc.count_tokens(&body);
486
487 for rf in &recalled {
488 let fact_text = rf.fact.replace(['\n', '\r', '<', '>'], " ");
489 let line = if let Some(score) = rf.activation_score {
490 format!(
491 "- {} (confidence: {:.2}, activation: {:.2})\n",
492 fact_text, rf.confidence, score
493 )
494 } else {
495 format!("- {} (confidence: {:.2})\n", fact_text, rf.confidence)
496 };
497 let line_tokens = tc.count_tokens(&line);
498 if tokens_so_far + line_tokens > budget_tokens {
499 break;
500 }
501 body.push_str(&line);
502 tokens_so_far += line_tokens;
503
504 for nb in &rf.neighbors {
506 let nb_text = nb.fact.replace(['\n', '\r', '<', '>'], " ");
507 let nb_line = format!(" ~ {} (confidence: {:.2})\n", nb_text, nb.confidence);
508 let nb_tokens = tc.count_tokens(&nb_line);
509 if tokens_so_far + nb_tokens > budget_tokens {
510 break;
511 }
512 body.push_str(&nb_line);
513 tokens_so_far += nb_tokens;
514 }
515
516 if let Some(ref snippet) = rf.provenance_snippet {
518 let snip_line = format!(
519 " [source: {}]\n",
520 snippet.replace(['\n', '\r', '<', '>'], " ")
521 );
522 let snip_tokens = tc.count_tokens(&snip_line);
523 if tokens_so_far + snip_tokens <= budget_tokens {
524 body.push_str(&snip_line);
525 tokens_so_far += snip_tokens;
526 }
527 }
528 }
529
530 if body == GRAPH_FACTS_PREFIX {
531 return Ok(None);
532 }
533
534 Ok(Some(Message::from_legacy(Role::System, body)))
535}
536
537#[tracing::instrument(name = "context.persona_facts", skip_all)]
538pub(crate) async fn fetch_persona_facts(
539 memory: &ContextMemoryView,
540 budget_tokens: usize,
541 tc: &dyn TokenCounting,
542) -> Result<Option<Message>, AssemblerError> {
543 if budget_tokens == 0 || !memory.persona_config.enabled {
544 return Ok(None);
545 }
546 let Some(ref mem) = memory.memory else {
547 return Ok(None);
548 };
549
550 let min_confidence = memory.persona_config.min_confidence;
551 let facts = mem
552 .load_persona_facts(min_confidence)
553 .await
554 .map_err(AssemblerError::Memory)?;
555
556 if facts.is_empty() {
557 return Ok(None);
558 }
559
560 let mut body = String::from(crate::slot::PERSONA_PREFIX);
561 let mut tokens_so_far = tc.count_tokens(&body);
562
563 for fact in &facts {
564 let line = format!("[{}] {}\n", fact.category, fact.content);
565 let line_tokens = tc.count_tokens(&line);
566 if tokens_so_far + line_tokens > budget_tokens {
567 break;
568 }
569 body.push_str(&line);
570 tokens_so_far += line_tokens;
571 }
572
573 if body == crate::slot::PERSONA_PREFIX {
574 return Ok(None);
575 }
576
577 Ok(Some(Message::from_legacy(Role::System, body)))
578}
579
580#[tracing::instrument(name = "context.trajectory_hints", skip_all)]
581pub(crate) async fn fetch_trajectory_hints(
582 memory: &ContextMemoryView,
583 budget_tokens: usize,
584 tc: &dyn TokenCounting,
585) -> Result<Option<Message>, AssemblerError> {
586 if budget_tokens == 0 || !memory.trajectory_config.enabled {
587 return Ok(None);
588 }
589 let Some(ref mem) = memory.memory else {
590 return Ok(None);
591 };
592
593 let top_k = memory.trajectory_config.recall_top_k;
594 let min_conf = memory.trajectory_config.min_confidence;
595 let entries = mem
599 .load_trajectory_entries(Some("procedural"), top_k)
600 .await
601 .map_err(AssemblerError::Memory)?;
602
603 if entries.is_empty() {
604 return Ok(None);
605 }
606
607 let mut body = String::from(crate::slot::TRAJECTORY_PREFIX);
608 let mut tokens_so_far = tc.count_tokens(&body);
609
610 for entry in entries
611 .iter()
612 .filter(|e| e.confidence >= min_conf)
613 .take(top_k)
614 {
615 let line = format!("- {}: {}\n", entry.intent, entry.outcome);
616 let line_tokens = tc.count_tokens(&line);
617 if tokens_so_far + line_tokens > budget_tokens {
618 break;
619 }
620 body.push_str(&line);
621 tokens_so_far += line_tokens;
622 }
623
624 if body == crate::slot::TRAJECTORY_PREFIX {
625 return Ok(None);
626 }
627
628 Ok(Some(Message::from_legacy(Role::System, body)))
629}
630
631#[tracing::instrument(name = "context.tree_memory", skip_all)]
632pub(crate) async fn fetch_tree_memory(
633 memory: &ContextMemoryView,
634 budget_tokens: usize,
635 tc: &dyn TokenCounting,
636) -> Result<Option<Message>, AssemblerError> {
637 if budget_tokens == 0 || !memory.tree_config.enabled {
638 return Ok(None);
639 }
640 let Some(ref mem) = memory.memory else {
641 return Ok(None);
642 };
643
644 let top_k = memory.tree_config.recall_top_k;
645 let nodes = mem
646 .load_tree_nodes(1, top_k)
647 .await
648 .map_err(AssemblerError::Memory)?;
649
650 if nodes.is_empty() {
651 return Ok(None);
652 }
653
654 let mut body = String::from(crate::slot::TREE_MEMORY_PREFIX);
655 let mut tokens_so_far = tc.count_tokens(&body);
656
657 for node in nodes.iter().take(top_k) {
658 let line = format!("- {}\n", node.content);
659 let line_tokens = tc.count_tokens(&line);
660 if tokens_so_far + line_tokens > budget_tokens {
661 break;
662 }
663 body.push_str(&line);
664 tokens_so_far += line_tokens;
665 }
666
667 if body == crate::slot::TREE_MEMORY_PREFIX {
668 return Ok(None);
669 }
670
671 Ok(Some(Message::from_legacy(Role::System, body)))
672}
673
674#[tracing::instrument(name = "context.reasoning_strategies", skip_all)]
675pub(crate) async fn fetch_reasoning_strategies(
676 memory: &ContextMemoryView,
677 query: &str,
678 budget_tokens: usize,
679 top_k: usize,
680 tc: &dyn TokenCounting,
681) -> Result<Option<Message>, AssemblerError> {
682 let budget_tokens = budget_tokens.min(500);
684 if budget_tokens == 0 {
685 return Ok(None);
686 }
687 let Some(ref mem) = memory.memory else {
688 return Ok(None);
689 };
690
691 let strategies = mem
692 .retrieve_reasoning_strategies(query, top_k)
693 .await
694 .map_err(AssemblerError::Memory)?;
695
696 if strategies.is_empty() {
697 return Ok(None);
698 }
699
700 let mut body = String::from(crate::slot::REASONING_PREFIX);
701 let mut tokens_so_far = tc.count_tokens(&body);
702 let mut injected_ids: Vec<String> = Vec::new();
703
704 for s in strategies.iter().take(top_k) {
705 let safe_summary = s.summary.replace(['\n', '\r', '<', '>'], " ");
708 let line = format!("- [{}] {}\n", s.outcome, safe_summary);
709 let line_tokens = tc.count_tokens(&line);
710 if tokens_so_far + line_tokens > budget_tokens {
711 break;
712 }
713 body.push_str(&line);
714 tokens_so_far += line_tokens;
715 injected_ids.push(s.id.clone());
716 }
717
718 if body == crate::slot::REASONING_PREFIX {
719 return Ok(None);
720 }
721
722 if !injected_ids.is_empty() {
725 let mem_clone = mem.clone();
726 tokio::spawn(async move {
727 if let Err(e) = mem_clone.mark_reasoning_used(&injected_ids).await {
728 tracing::warn!(error = %e, "reasoning: mark_used failed");
729 }
730 });
731 }
732
733 Ok(Some(Message::from_legacy(Role::System, body)))
734}
735
736#[tracing::instrument(name = "context.corrections", skip_all)]
737pub(crate) async fn fetch_corrections(
738 memory: &ContextMemoryView,
739 query: &str,
740 limit: usize,
741 min_score: f32,
742 scrub: fn(&str) -> std::borrow::Cow<'_, str>,
743) -> Result<Option<Message>, AssemblerError> {
744 let Some(ref mem) = memory.memory else {
745 return Ok(None);
746 };
747 let corrections = mem
748 .retrieve_corrections(query, limit, min_score)
749 .await
750 .map_err(AssemblerError::Memory)?;
751 if corrections.is_empty() {
752 return Ok(None);
753 }
754 let mut text = String::from(CORRECTIONS_PREFIX);
755 for c in &corrections {
756 text.push_str("- Past user correction: \"");
757 text.push_str(&scrub(&c.correction_text));
758 text.push_str("\"\n");
759 }
760 Ok(Some(Message::from_legacy(Role::System, text)))
761}
762
763#[tracing::instrument(name = "context.semantic_recall", skip_all)]
764pub(crate) async fn fetch_semantic_recall(
765 memory: &ContextMemoryView,
766 query: &str,
767 token_budget: usize,
768 tc: &dyn TokenCounting,
769 router: Option<&dyn AsyncMemoryRouter>,
770) -> Result<(Option<Message>, Option<f32>), AssemblerError> {
771 let Some(ref mem) = memory.memory else {
772 return Ok((None, None));
773 };
774 if memory.recall_limit == 0 || token_budget == 0 {
775 return Ok((None, None));
776 }
777
778 let recalled = mem
779 .recall(query, memory.recall_limit, router)
780 .await
781 .map_err(AssemblerError::Memory)?;
782 if recalled.is_empty() {
783 return Ok((None, None));
784 }
785
786 let top_score = recalled.first().map(|r| r.score);
787
788 let mut recall_text = String::with_capacity(token_budget * 3);
789 recall_text.push_str(RECALL_PREFIX);
790 let mut tokens_used = tc.count_tokens(&recall_text);
791
792 for item in &recalled {
793 if item.content.starts_with("[skipped]") || item.content.starts_with("[stopped]") {
794 continue;
795 }
796 let entry = format!("- [{}] {}\n", item.role, item.content);
797 let entry_tokens = tc.count_tokens(&entry);
798 if tokens_used + entry_tokens > token_budget {
799 break;
800 }
801 recall_text.push_str(&entry);
802 tokens_used += entry_tokens;
803 }
804
805 if tokens_used > tc.count_tokens(RECALL_PREFIX) {
806 Ok((
807 Some(Message::from_parts(
808 Role::System,
809 vec![MessagePart::Recall { text: recall_text }],
810 )),
811 top_score,
812 ))
813 } else {
814 Ok((None, None))
815 }
816}
817
818#[tracing::instrument(name = "context.document_rag", skip_all)]
819pub(crate) async fn fetch_document_rag(
820 memory: &ContextMemoryView,
821 query: &str,
822 token_budget: usize,
823 tc: &dyn TokenCounting,
824) -> Result<Option<Message>, AssemblerError> {
825 if !memory.document_config.rag_enabled || token_budget == 0 {
826 return Ok(None);
827 }
828 let Some(ref mem) = memory.memory else {
829 return Ok(None);
830 };
831
832 let collection = &memory.document_config.collection;
833 let top_k = memory.document_config.top_k;
834 let chunks = mem
835 .search_document_collection(collection, query, top_k)
836 .await
837 .map_err(AssemblerError::Memory)?;
838 if chunks.is_empty() {
839 return Ok(None);
840 }
841
842 let mut text = String::from(DOCUMENT_RAG_PREFIX);
843 let mut tokens_used = tc.count_tokens(&text);
844
845 for chunk in &chunks {
846 if chunk.text.is_empty() {
847 continue;
848 }
849 let entry = format!("{}\n", chunk.text);
850 let cost = tc.count_tokens(&entry);
851 if tokens_used + cost > token_budget {
852 break;
853 }
854 text.push_str(&entry);
855 tokens_used += cost;
856 }
857
858 if tokens_used > tc.count_tokens(DOCUMENT_RAG_PREFIX) {
859 Ok(Some(Message {
860 role: Role::System,
861 content: text,
862 parts: vec![],
863 metadata: MessageMetadata::default(),
864 }))
865 } else {
866 Ok(None)
867 }
868}
869
870#[tracing::instrument(name = "context.summaries", skip_all)]
871pub(crate) async fn fetch_summaries(
872 memory: &ContextMemoryView,
873 token_budget: usize,
874 tc: &dyn TokenCounting,
875) -> Result<Option<Message>, AssemblerError> {
876 let (Some(mem), Some(cid)) = (&memory.memory, memory.conversation_id) else {
877 return Ok(None);
878 };
879 if token_budget == 0 {
880 return Ok(None);
881 }
882
883 let summaries = mem
884 .load_summaries(cid)
885 .await
886 .map_err(AssemblerError::Memory)?;
887 if summaries.is_empty() {
888 return Ok(None);
889 }
890
891 let mut summary_text = String::from(SUMMARY_PREFIX);
892 let mut tokens_used = tc.count_tokens(&summary_text);
893
894 for summary in summaries.iter().rev() {
895 let first = summary.first_message_id.unwrap_or(0);
896 let last = summary.last_message_id.unwrap_or(0);
897 let entry = format!("- Messages {first}-{last}: {}\n", summary.content);
898 let cost = tc.count_tokens(&entry);
899 if tokens_used + cost > token_budget {
900 break;
901 }
902 summary_text.push_str(&entry);
903 tokens_used += cost;
904 }
905
906 if tokens_used > tc.count_tokens(SUMMARY_PREFIX) {
907 Ok(Some(Message::from_parts(
908 Role::System,
909 vec![MessagePart::Summary { text: summary_text }],
910 )))
911 } else {
912 Ok(None)
913 }
914}
915
916#[tracing::instrument(name = "context.cross_session", skip_all)]
917pub(crate) async fn fetch_cross_session(
918 memory: &ContextMemoryView,
919 query: &str,
920 token_budget: usize,
921 tc: &dyn TokenCounting,
922) -> Result<Option<Message>, AssemblerError> {
923 let (Some(mem), Some(cid)) = (&memory.memory, memory.conversation_id) else {
924 return Ok(None);
925 };
926 if token_budget == 0 {
927 return Ok(None);
928 }
929
930 let threshold = memory.cross_session_score_threshold;
931 let results: Vec<_> = mem
932 .search_session_summaries(query, 5, Some(cid))
933 .await
934 .map_err(AssemblerError::Memory)?
935 .into_iter()
936 .filter(|r| r.score >= threshold)
937 .collect();
938 if results.is_empty() {
939 return Ok(None);
940 }
941
942 let mut text = String::from(CROSS_SESSION_PREFIX);
943 let mut tokens_used = tc.count_tokens(&text);
944
945 for item in &results {
946 let entry = format!("- {}\n", item.summary_text);
947 let cost = tc.count_tokens(&entry);
948 if tokens_used + cost > token_budget {
949 break;
950 }
951 text.push_str(&entry);
952 tokens_used += cost;
953 }
954
955 if tokens_used > tc.count_tokens(CROSS_SESSION_PREFIX) {
956 Ok(Some(Message::from_parts(
957 Role::System,
958 vec![MessagePart::CrossSession { text }],
959 )))
960 } else {
961 Ok(None)
962 }
963}
964
965pub const MAX_KEEP_TAIL_SCAN: usize = 50;
968
969#[must_use]
977pub fn memory_first_keep_tail(messages: &[Message], history_start: usize) -> usize {
978 use zeph_llm::provider::MessagePart;
979
980 let mut keep_tail = 2usize;
981 let len = messages.len();
982 let max = len.saturating_sub(history_start);
983
984 while keep_tail < max {
985 let first_retained = &messages[len - keep_tail];
986 let is_tool_result = first_retained.role == Role::User
987 && first_retained
988 .parts
989 .iter()
990 .any(|p| matches!(p, MessagePart::ToolResult { .. }));
991
992 if is_tool_result {
993 keep_tail += 1;
994 } else {
995 break;
996 }
997
998 if keep_tail >= MAX_KEEP_TAIL_SCAN {
999 let preceding_idx = len.saturating_sub(keep_tail + 1);
1000 if preceding_idx >= history_start {
1001 let preceding = &messages[preceding_idx];
1002 let is_tool_use = preceding.role == Role::Assistant
1003 && preceding
1004 .parts
1005 .iter()
1006 .any(|p| matches!(p, MessagePart::ToolUse { .. }));
1007 if is_tool_use {
1008 keep_tail += 1;
1009 }
1010 }
1011 break;
1012 }
1013 }
1014
1015 keep_tail
1016}
1017
1018#[cfg(test)]
1019mod tests {
1020 use super::*;
1021 use crate::input::ContextMemoryView;
1022 use zeph_common::memory::CompressionLevel;
1023 use zeph_config::{
1024 ContextStrategy, DocumentConfig, GraphConfig, PersonaConfig, ReasoningConfig,
1025 TrajectoryConfig, TreeConfig,
1026 };
1027
1028 struct NaiveTokenCounter;
1029 impl zeph_common::memory::TokenCounting for NaiveTokenCounter {
1030 fn count_tokens(&self, text: &str) -> usize {
1031 text.split_whitespace().count()
1032 }
1033 fn count_tool_schema_tokens(&self, schema: &serde_json::Value) -> usize {
1034 schema.to_string().split_whitespace().count()
1035 }
1036 }
1037
1038 fn empty_view() -> ContextMemoryView {
1039 ContextMemoryView {
1040 memory: None,
1041 conversation_id: None,
1042 recall_limit: 10,
1043 cross_session_score_threshold: 0.5,
1044 context_strategy: ContextStrategy::default(),
1045 crossover_turn_threshold: 5,
1046 cached_session_digest: None,
1047 graph_config: GraphConfig::default(),
1048 document_config: DocumentConfig::default(),
1049 persona_config: PersonaConfig::default(),
1050 trajectory_config: TrajectoryConfig::default(),
1051 reasoning_config: ReasoningConfig::default(),
1052 memcot_config: zeph_config::MemCotConfig::default(),
1053 memcot_state: None,
1054 tree_config: TreeConfig::default(),
1055 }
1056 }
1057
1058 #[tokio::test]
1061 async fn fetch_graph_facts_returns_none_when_memory_is_none() {
1062 let view = empty_view();
1063 let tc = NaiveTokenCounter;
1064 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1065 assert!(result.is_none());
1066 }
1067
1068 #[tokio::test]
1069 async fn fetch_graph_facts_returns_none_when_budget_zero() {
1070 let mut view = empty_view();
1071 view.graph_config.enabled = true;
1072 let tc = NaiveTokenCounter;
1073 let result = fetch_graph_facts(&view, "test", 0, &tc).await.unwrap();
1074 assert!(result.is_none());
1075 }
1076
1077 #[tokio::test]
1078 async fn fetch_graph_facts_returns_none_when_graph_disabled() {
1079 let mut view = empty_view();
1080 view.graph_config.enabled = false;
1081 let tc = NaiveTokenCounter;
1082 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1083 assert!(result.is_none());
1084 }
1085
1086 #[tokio::test]
1089 async fn fetch_persona_facts_returns_none_when_memory_is_none() {
1090 let view = empty_view();
1091 let tc = NaiveTokenCounter;
1092 let result = fetch_persona_facts(&view, 1000, &tc).await.unwrap();
1093 assert!(result.is_none());
1094 }
1095
1096 #[tokio::test]
1097 async fn fetch_persona_facts_returns_none_when_budget_zero() {
1098 let mut view = empty_view();
1099 view.persona_config.enabled = true;
1100 let tc = NaiveTokenCounter;
1101 let result = fetch_persona_facts(&view, 0, &tc).await.unwrap();
1102 assert!(result.is_none());
1103 }
1104
1105 #[tokio::test]
1108 async fn fetch_trajectory_hints_returns_none_when_memory_is_none() {
1109 let view = empty_view();
1110 let tc = NaiveTokenCounter;
1111 let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
1112 assert!(result.is_none());
1113 }
1114
1115 #[tokio::test]
1116 async fn fetch_trajectory_hints_returns_none_when_budget_zero() {
1117 let mut view = empty_view();
1118 view.trajectory_config.enabled = true;
1119 let tc = NaiveTokenCounter;
1120 let result = fetch_trajectory_hints(&view, 0, &tc).await.unwrap();
1121 assert!(result.is_none());
1122 }
1123
1124 #[tokio::test]
1127 async fn fetch_tree_memory_returns_none_when_memory_is_none() {
1128 let view = empty_view();
1129 let tc = NaiveTokenCounter;
1130 let result = fetch_tree_memory(&view, 1000, &tc).await.unwrap();
1131 assert!(result.is_none());
1132 }
1133
1134 #[tokio::test]
1135 async fn fetch_tree_memory_returns_none_when_budget_zero() {
1136 let mut view = empty_view();
1137 view.tree_config.enabled = true;
1138 let tc = NaiveTokenCounter;
1139 let result = fetch_tree_memory(&view, 0, &tc).await.unwrap();
1140 assert!(result.is_none());
1141 }
1142
1143 #[tokio::test]
1146 async fn fetch_corrections_returns_none_when_memory_is_none() {
1147 let view = empty_view();
1148 let result = fetch_corrections(&view, "test", 10, 0.5, |s| s.into())
1149 .await
1150 .unwrap();
1151 assert!(result.is_none());
1152 }
1153
1154 #[tokio::test]
1157 async fn fetch_semantic_recall_returns_none_when_memory_is_none() {
1158 let view = empty_view();
1159 let tc = NaiveTokenCounter;
1160 let result = fetch_semantic_recall(&view, "test", 1000, &tc, None)
1161 .await
1162 .unwrap();
1163 assert!(result.0.is_none() && result.1.is_none());
1164 }
1165
1166 #[tokio::test]
1167 async fn fetch_semantic_recall_returns_none_when_budget_zero() {
1168 let view = empty_view();
1169 let tc = NaiveTokenCounter;
1170 let result = fetch_semantic_recall(&view, "test", 0, &tc, None)
1171 .await
1172 .unwrap();
1173 assert!(result.0.is_none() && result.1.is_none());
1174 }
1175
1176 #[tokio::test]
1179 async fn fetch_document_rag_returns_none_when_memory_is_none() {
1180 let mut view = empty_view();
1181 view.document_config.rag_enabled = true;
1182 let tc = NaiveTokenCounter;
1183 let result = fetch_document_rag(&view, "test", 1000, &tc).await.unwrap();
1184 assert!(result.is_none());
1185 }
1186
1187 #[tokio::test]
1188 async fn fetch_document_rag_returns_none_when_rag_disabled() {
1189 let view = empty_view();
1190 let tc = NaiveTokenCounter;
1191 let result = fetch_document_rag(&view, "test", 1000, &tc).await.unwrap();
1192 assert!(result.is_none());
1193 }
1194
1195 #[tokio::test]
1198 async fn fetch_summaries_returns_none_when_memory_is_none() {
1199 let view = empty_view();
1200 let tc = NaiveTokenCounter;
1201 let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
1202 assert!(result.is_none());
1203 }
1204
1205 #[tokio::test]
1208 async fn fetch_cross_session_returns_none_when_memory_is_none() {
1209 let view = empty_view();
1210 let tc = NaiveTokenCounter;
1211 let result = fetch_cross_session(&view, "test", 1000, &tc).await.unwrap();
1212 assert!(result.is_none());
1213 }
1214
1215 #[test]
1218 fn levels_to_flags_empty_slice_enables_all_tiers() {
1219 let (e, p, d) = levels_to_flags(&[]);
1220 assert!(e, "episodic should be active for empty slice");
1221 assert!(p, "procedural should be active for empty slice");
1222 assert!(d, "declarative should be active for empty slice");
1223 }
1224
1225 #[test]
1226 fn levels_to_flags_full_set_enables_all_tiers() {
1227 let all = &[
1228 CompressionLevel::Episodic,
1229 CompressionLevel::Procedural,
1230 CompressionLevel::Declarative,
1231 ];
1232 let (e, p, d) = levels_to_flags(all);
1233 assert!(e);
1234 assert!(p);
1235 assert!(d);
1236 }
1237
1238 #[test]
1239 fn levels_to_flags_episodic_only() {
1240 let (e, p, d) = levels_to_flags(&[CompressionLevel::Episodic]);
1241 assert!(e);
1242 assert!(!p, "procedural should be inactive");
1243 assert!(!d, "declarative should be inactive");
1244 }
1245
1246 #[test]
1247 fn levels_to_flags_episodic_and_procedural() {
1248 let (e, p, d) =
1249 levels_to_flags(&[CompressionLevel::Episodic, CompressionLevel::Procedural]);
1250 assert!(e);
1251 assert!(p);
1252 assert!(!d, "declarative should be inactive");
1253 }
1254
1255 #[test]
1256 fn levels_to_flags_declarative_only() {
1257 let (e, p, d) = levels_to_flags(&[CompressionLevel::Declarative]);
1258 assert!(!e, "episodic should be inactive");
1259 assert!(!p, "procedural should be inactive");
1260 assert!(d);
1261 }
1262
1263 #[tokio::test]
1266 async fn fetch_reasoning_strategies_returns_none_when_memory_is_none() {
1267 let mut view = empty_view();
1268 view.reasoning_config.enabled = true;
1269 let tc = NaiveTokenCounter;
1270 let result = fetch_reasoning_strategies(&view, "query", 1000, 3, &tc)
1271 .await
1272 .unwrap();
1273 assert!(result.is_none());
1274 }
1275
1276 #[tokio::test]
1277 async fn fetch_reasoning_strategies_returns_none_when_budget_zero() {
1278 let mut view = empty_view();
1279 view.reasoning_config.enabled = true;
1280 let tc = NaiveTokenCounter;
1281 let result = fetch_reasoning_strategies(&view, "query", 0, 3, &tc)
1282 .await
1283 .unwrap();
1284 assert!(result.is_none());
1285 }
1286
1287 use std::sync::{Arc, Mutex};
1290 use zeph_common::memory::{
1291 ContextMemoryBackend, GraphRecallParams, MemCorrection, MemDocumentChunk, MemGraphFact,
1292 MemPersonaFact, MemReasoningStrategy, MemRecalledMessage, MemSessionSummary, MemSummary,
1293 MemTrajectoryEntry, MemTreeNode,
1294 };
1295
1296 const KNOWN_FAIL_ON: &[&str] = &[
1298 "load_persona_facts",
1299 "load_trajectory_entries",
1300 "load_tree_nodes",
1301 "load_summaries",
1302 "retrieve_reasoning_strategies",
1303 "mark_reasoning_used",
1304 "retrieve_corrections",
1305 "recall",
1306 "recall_graph_facts",
1307 "search_session_summaries",
1308 "search_document_collection",
1309 ];
1310
1311 #[derive(Default)]
1312 struct MockMemoryBackend {
1313 persona_facts: Vec<MemPersonaFact>,
1314 trajectory_entries: Vec<MemTrajectoryEntry>,
1315 tree_nodes: Vec<MemTreeNode>,
1316 summaries: Vec<MemSummary>,
1317 reasoning_strategies: Vec<MemReasoningStrategy>,
1318 corrections: Vec<MemCorrection>,
1319 recalled: Vec<MemRecalledMessage>,
1320 graph_facts: Vec<MemGraphFact>,
1321 session_summaries: Vec<MemSessionSummary>,
1322 document_chunks: Vec<MemDocumentChunk>,
1323 fail_on: Option<&'static str>,
1325 marked_ids: Mutex<Vec<String>>,
1327 }
1328
1329 impl MockMemoryBackend {
1330 fn with_fail_on(method: &'static str) -> Self {
1331 debug_assert!(
1332 KNOWN_FAIL_ON.contains(&method),
1333 "unknown fail_on method name: {method}"
1334 );
1335 Self {
1336 fail_on: Some(method),
1337 ..Default::default()
1338 }
1339 }
1340
1341 fn fail_err(method: &str) -> Box<dyn std::error::Error + Send + Sync> {
1342 format!("mock error in {method}").into()
1343 }
1344 }
1345
1346 impl ContextMemoryBackend for MockMemoryBackend {
1347 fn load_persona_facts<'a>(
1348 &'a self,
1349 _min_confidence: f64,
1350 ) -> std::pin::Pin<
1351 Box<
1352 dyn std::future::Future<
1353 Output = Result<
1354 Vec<MemPersonaFact>,
1355 Box<dyn std::error::Error + Send + Sync>,
1356 >,
1357 > + Send
1358 + 'a,
1359 >,
1360 > {
1361 let result = if self.fail_on == Some("load_persona_facts") {
1362 Err(Self::fail_err("load_persona_facts"))
1363 } else {
1364 Ok(self.persona_facts.clone())
1365 };
1366 Box::pin(async move { result })
1367 }
1368
1369 fn load_trajectory_entries<'a>(
1370 &'a self,
1371 _tier: Option<&'a str>,
1372 _top_k: usize,
1373 ) -> std::pin::Pin<
1374 Box<
1375 dyn std::future::Future<
1376 Output = Result<
1377 Vec<MemTrajectoryEntry>,
1378 Box<dyn std::error::Error + Send + Sync>,
1379 >,
1380 > + Send
1381 + 'a,
1382 >,
1383 > {
1384 let result = if self.fail_on == Some("load_trajectory_entries") {
1385 Err(Self::fail_err("load_trajectory_entries"))
1386 } else {
1387 Ok(self.trajectory_entries.clone())
1388 };
1389 Box::pin(async move { result })
1390 }
1391
1392 fn load_tree_nodes<'a>(
1393 &'a self,
1394 _level: u32,
1395 _top_k: usize,
1396 ) -> std::pin::Pin<
1397 Box<
1398 dyn std::future::Future<
1399 Output = Result<Vec<MemTreeNode>, Box<dyn std::error::Error + Send + Sync>>,
1400 > + Send
1401 + 'a,
1402 >,
1403 > {
1404 let result = if self.fail_on == Some("load_tree_nodes") {
1405 Err(Self::fail_err("load_tree_nodes"))
1406 } else {
1407 Ok(self.tree_nodes.clone())
1408 };
1409 Box::pin(async move { result })
1410 }
1411
1412 fn load_summaries<'a>(
1413 &'a self,
1414 _conversation_id: i64,
1415 ) -> std::pin::Pin<
1416 Box<
1417 dyn std::future::Future<
1418 Output = Result<Vec<MemSummary>, Box<dyn std::error::Error + Send + Sync>>,
1419 > + Send
1420 + 'a,
1421 >,
1422 > {
1423 let result = if self.fail_on == Some("load_summaries") {
1424 Err(Self::fail_err("load_summaries"))
1425 } else {
1426 Ok(self.summaries.clone())
1427 };
1428 Box::pin(async move { result })
1429 }
1430
1431 fn retrieve_reasoning_strategies<'a>(
1432 &'a self,
1433 _query: &'a str,
1434 _top_k: usize,
1435 ) -> std::pin::Pin<
1436 Box<
1437 dyn std::future::Future<
1438 Output = Result<
1439 Vec<MemReasoningStrategy>,
1440 Box<dyn std::error::Error + Send + Sync>,
1441 >,
1442 > + Send
1443 + 'a,
1444 >,
1445 > {
1446 let result = if self.fail_on == Some("retrieve_reasoning_strategies") {
1447 Err(Self::fail_err("retrieve_reasoning_strategies"))
1448 } else {
1449 Ok(self.reasoning_strategies.clone())
1450 };
1451 Box::pin(async move { result })
1452 }
1453
1454 fn mark_reasoning_used<'a>(
1455 &'a self,
1456 ids: &'a [String],
1457 ) -> std::pin::Pin<
1458 Box<
1459 dyn std::future::Future<
1460 Output = Result<(), Box<dyn std::error::Error + Send + Sync>>,
1461 > + Send
1462 + 'a,
1463 >,
1464 > {
1465 if self.fail_on == Some("mark_reasoning_used") {
1466 return Box::pin(async move { Err(Self::fail_err("mark_reasoning_used")) });
1467 }
1468 let mut guard = self.marked_ids.lock().expect("marked_ids poisoned");
1469 guard.extend_from_slice(ids);
1470 Box::pin(async move { Ok(()) })
1471 }
1472
1473 fn retrieve_corrections<'a>(
1474 &'a self,
1475 _query: &'a str,
1476 _limit: usize,
1477 _min_score: f32,
1478 ) -> std::pin::Pin<
1479 Box<
1480 dyn std::future::Future<
1481 Output = Result<
1482 Vec<MemCorrection>,
1483 Box<dyn std::error::Error + Send + Sync>,
1484 >,
1485 > + Send
1486 + 'a,
1487 >,
1488 > {
1489 let result = if self.fail_on == Some("retrieve_corrections") {
1490 Err(Self::fail_err("retrieve_corrections"))
1491 } else {
1492 Ok(self.corrections.clone())
1493 };
1494 Box::pin(async move { result })
1495 }
1496
1497 fn recall<'a>(
1498 &'a self,
1499 _query: &'a str,
1500 _limit: usize,
1501 _router: Option<&'a dyn zeph_common::memory::AsyncMemoryRouter>,
1502 ) -> std::pin::Pin<
1503 Box<
1504 dyn std::future::Future<
1505 Output = Result<
1506 Vec<MemRecalledMessage>,
1507 Box<dyn std::error::Error + Send + Sync>,
1508 >,
1509 > + Send
1510 + 'a,
1511 >,
1512 > {
1513 let result = if self.fail_on == Some("recall") {
1514 Err(Self::fail_err("recall"))
1515 } else {
1516 Ok(self.recalled.clone())
1517 };
1518 Box::pin(async move { result })
1519 }
1520
1521 fn recall_graph_facts<'a>(
1522 &'a self,
1523 _query: &'a str,
1524 _params: GraphRecallParams<'a>,
1525 ) -> std::pin::Pin<
1526 Box<
1527 dyn std::future::Future<
1528 Output = Result<
1529 Vec<MemGraphFact>,
1530 Box<dyn std::error::Error + Send + Sync>,
1531 >,
1532 > + Send
1533 + 'a,
1534 >,
1535 > {
1536 let result = if self.fail_on == Some("recall_graph_facts") {
1537 Err(Self::fail_err("recall_graph_facts"))
1538 } else {
1539 Ok(self.graph_facts.clone())
1540 };
1541 Box::pin(async move { result })
1542 }
1543
1544 fn search_session_summaries<'a>(
1545 &'a self,
1546 _query: &'a str,
1547 _limit: usize,
1548 _current_conversation_id: Option<i64>,
1549 ) -> std::pin::Pin<
1550 Box<
1551 dyn std::future::Future<
1552 Output = Result<
1553 Vec<MemSessionSummary>,
1554 Box<dyn std::error::Error + Send + Sync>,
1555 >,
1556 > + Send
1557 + 'a,
1558 >,
1559 > {
1560 let result = if self.fail_on == Some("search_session_summaries") {
1561 Err(Self::fail_err("search_session_summaries"))
1562 } else {
1563 Ok(self.session_summaries.clone())
1564 };
1565 Box::pin(async move { result })
1566 }
1567
1568 fn search_document_collection<'a>(
1569 &'a self,
1570 _collection: &'a str,
1571 _query: &'a str,
1572 _top_k: usize,
1573 ) -> std::pin::Pin<
1574 Box<
1575 dyn std::future::Future<
1576 Output = Result<
1577 Vec<MemDocumentChunk>,
1578 Box<dyn std::error::Error + Send + Sync>,
1579 >,
1580 > + Send
1581 + 'a,
1582 >,
1583 > {
1584 let result = if self.fail_on == Some("search_document_collection") {
1585 Err(Self::fail_err("search_document_collection"))
1586 } else {
1587 Ok(self.document_chunks.clone())
1588 };
1589 Box::pin(async move { result })
1590 }
1591 }
1592
1593 fn mock_view(mock: MockMemoryBackend) -> ContextMemoryView {
1594 let mut v = empty_view();
1595 v.memory = Some(Arc::new(mock));
1596 v
1597 }
1598
1599 #[tokio::test]
1602 async fn fetch_graph_facts_returns_message_when_memory_present() {
1603 let mock = MockMemoryBackend {
1604 graph_facts: vec![zeph_common::memory::MemGraphFact {
1605 fact: "Rust is fast".to_string(),
1606 confidence: 0.9,
1607 activation_score: None,
1608 neighbors: vec![],
1609 provenance_snippet: None,
1610 }],
1611 ..Default::default()
1612 };
1613 let mut view = mock_view(mock);
1614 view.graph_config.enabled = true;
1615 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
1617 let tc = NaiveTokenCounter;
1618 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1619 assert!(result.is_some(), "expected Some message");
1620 let msg = result.unwrap();
1621 assert!(
1622 msg.content.contains("Rust is fast"),
1623 "expected fact text in output, got: {}",
1624 msg.content
1625 );
1626 assert!(
1627 msg.content.starts_with(GRAPH_FACTS_PREFIX),
1628 "expected GRAPH_FACTS_PREFIX"
1629 );
1630 }
1631
1632 #[tokio::test]
1633 async fn fetch_graph_facts_swallows_error_and_returns_none() {
1634 let mock = MockMemoryBackend::with_fail_on("recall_graph_facts");
1635 let mut view = mock_view(mock);
1636 view.graph_config.enabled = true;
1637 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
1638 let tc = NaiveTokenCounter;
1639 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1641 assert!(
1642 result.is_none(),
1643 "expected None when recall_graph_facts errors"
1644 );
1645 }
1646
1647 #[tokio::test]
1648 async fn fetch_graph_facts_returns_none_when_facts_empty() {
1649 let mock = MockMemoryBackend::default(); let mut view = mock_view(mock);
1651 view.graph_config.enabled = true;
1652 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
1653 let tc = NaiveTokenCounter;
1654 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
1655 assert!(result.is_none());
1656 }
1657
1658 #[tokio::test]
1661 async fn fetch_persona_facts_returns_message_when_persona_enabled() {
1662 let mock = MockMemoryBackend {
1663 persona_facts: vec![MemPersonaFact {
1664 category: "preference".to_string(),
1665 content: "prefers concise answers".to_string(),
1666 }],
1667 ..Default::default()
1668 };
1669 let mut view = mock_view(mock);
1670 view.persona_config.enabled = true;
1671 view.persona_config.context_budget_tokens = 1000;
1672 let tc = NaiveTokenCounter;
1673 let result = fetch_persona_facts(&view, 1000, &tc).await.unwrap();
1674 assert!(result.is_some());
1675 let msg = result.unwrap();
1676 assert!(msg.content.contains("preference"));
1677 assert!(msg.content.contains("prefers concise answers"));
1678 assert!(msg.content.starts_with(crate::slot::PERSONA_PREFIX));
1679 }
1680
1681 #[tokio::test]
1682 async fn fetch_persona_facts_propagates_error() {
1683 let mock = MockMemoryBackend::with_fail_on("load_persona_facts");
1684 let mut view = mock_view(mock);
1685 view.persona_config.enabled = true;
1686 let tc = NaiveTokenCounter;
1687 let result = fetch_persona_facts(&view, 1000, &tc).await;
1688 assert!(
1689 result.is_err(),
1690 "expected Err from load_persona_facts failure"
1691 );
1692 }
1693
1694 #[tokio::test]
1697 async fn fetch_trajectory_hints_returns_message_when_trajectory_enabled() {
1698 let mock = MockMemoryBackend {
1699 trajectory_entries: vec![MemTrajectoryEntry {
1700 intent: "summarize code".to_string(),
1701 outcome: "produced concise summary".to_string(),
1702 confidence: 0.9,
1703 }],
1704 ..Default::default()
1705 };
1706 let mut view = mock_view(mock);
1707 view.trajectory_config.enabled = true;
1708 view.trajectory_config.context_budget_tokens = 1000;
1709 view.trajectory_config.min_confidence = 0.5;
1710 let tc = NaiveTokenCounter;
1711 let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
1712 assert!(result.is_some());
1713 let msg = result.unwrap();
1714 assert!(msg.content.contains("summarize code"));
1715 assert!(msg.content.starts_with(crate::slot::TRAJECTORY_PREFIX));
1716 }
1717
1718 #[tokio::test]
1719 async fn fetch_trajectory_hints_passes_tier_filter() {
1720 let mock = MockMemoryBackend {
1723 trajectory_entries: vec![
1724 MemTrajectoryEntry {
1725 intent: "debug async code".to_string(),
1726 outcome: "fixed deadlock".to_string(),
1727 confidence: 0.85,
1728 },
1729 MemTrajectoryEntry {
1730 intent: "low confidence task".to_string(),
1731 outcome: "irrelevant".to_string(),
1732 confidence: 0.3,
1733 },
1734 ],
1735 ..Default::default()
1736 };
1737 let mut view = mock_view(mock);
1738 view.trajectory_config.enabled = true;
1739 view.trajectory_config.context_budget_tokens = 1000;
1740 view.trajectory_config.min_confidence = 0.5;
1741 let tc = NaiveTokenCounter;
1742 let result = fetch_trajectory_hints(&view, 1000, &tc).await.unwrap();
1743 assert!(result.is_some(), "expected Some message");
1744 let msg = result.unwrap();
1745 assert!(
1746 msg.content.contains("debug async code"),
1747 "high-confidence entry must be included"
1748 );
1749 assert!(
1750 !msg.content.contains("low confidence task"),
1751 "entry below min_confidence must be filtered out"
1752 );
1753 }
1754
1755 #[tokio::test]
1756 async fn fetch_trajectory_hints_propagates_error() {
1757 let mock = MockMemoryBackend::with_fail_on("load_trajectory_entries");
1758 let mut view = mock_view(mock);
1759 view.trajectory_config.enabled = true;
1760 let tc = NaiveTokenCounter;
1761 let result = fetch_trajectory_hints(&view, 1000, &tc).await;
1762 assert!(result.is_err());
1763 }
1764
1765 #[tokio::test]
1768 async fn fetch_tree_memory_returns_message_when_tree_enabled() {
1769 let mock = MockMemoryBackend {
1770 tree_nodes: vec![MemTreeNode {
1771 content: "Topic: async Rust patterns".to_string(),
1772 }],
1773 ..Default::default()
1774 };
1775 let mut view = mock_view(mock);
1776 view.tree_config.enabled = true;
1777 view.tree_config.context_budget_tokens = 1000;
1778 let tc = NaiveTokenCounter;
1779 let result = fetch_tree_memory(&view, 1000, &tc).await.unwrap();
1780 assert!(result.is_some());
1781 let msg = result.unwrap();
1782 assert!(msg.content.contains("async Rust patterns"));
1783 assert!(msg.content.starts_with(crate::slot::TREE_MEMORY_PREFIX));
1784 }
1785
1786 #[tokio::test]
1787 async fn fetch_tree_memory_propagates_error() {
1788 let mock = MockMemoryBackend::with_fail_on("load_tree_nodes");
1789 let mut view = mock_view(mock);
1790 view.tree_config.enabled = true;
1791 let tc = NaiveTokenCounter;
1792 let result = fetch_tree_memory(&view, 1000, &tc).await;
1793 assert!(result.is_err());
1794 }
1795
1796 #[tokio::test]
1799 async fn fetch_corrections_returns_message_when_corrections_present() {
1800 let mock = MockMemoryBackend {
1801 corrections: vec![MemCorrection {
1802 correction_text: "use snake_case not camelCase".to_string(),
1803 }],
1804 ..Default::default()
1805 };
1806 let view = mock_view(mock);
1807 let result = fetch_corrections(&view, "query", 10, 0.5, |s| s.into())
1808 .await
1809 .unwrap();
1810 assert!(result.is_some());
1811 let msg = result.unwrap();
1812 assert!(msg.content.contains("snake_case"));
1813 assert!(msg.content.starts_with(CORRECTIONS_PREFIX));
1814 }
1815
1816 #[tokio::test]
1817 async fn fetch_corrections_propagates_error() {
1818 let mock = MockMemoryBackend::with_fail_on("retrieve_corrections");
1821 let view = mock_view(mock);
1822 let result = fetch_corrections(&view, "query", 10, 0.5, |s| s.into()).await;
1823 assert!(result.is_err(), "expected Err, got {result:?}");
1824 }
1825
1826 #[tokio::test]
1829 async fn fetch_semantic_recall_returns_message_with_content() {
1830 let mock = MockMemoryBackend {
1831 recalled: vec![
1832 MemRecalledMessage {
1833 role: "user".to_string(),
1834 content: "how does tokio work".to_string(),
1835 score: 0.95,
1836 },
1837 MemRecalledMessage {
1838 role: "assistant".to_string(),
1839 content: "tokio is an async runtime".to_string(),
1840 score: 0.88,
1841 },
1842 ],
1843 ..Default::default()
1844 };
1845 let mut view = mock_view(mock);
1846 view.recall_limit = 10;
1847 let tc = NaiveTokenCounter;
1848 let (msg, score) = fetch_semantic_recall(&view, "tokio", 1000, &tc, None)
1849 .await
1850 .unwrap();
1851 assert!(msg.is_some(), "expected Some message");
1852 assert!(score.is_some_and(|s| (s - 0.95_f32).abs() < f32::EPSILON));
1854 let msg = msg.unwrap();
1855 let has_recall_part = msg.parts.iter().any(|p| {
1857 if let zeph_llm::provider::MessagePart::Recall { text } = p {
1858 text.contains("how does tokio work")
1859 } else {
1860 false
1861 }
1862 });
1863 assert!(has_recall_part, "expected recalled content in Recall part");
1864 }
1865
1866 #[tokio::test]
1867 async fn fetch_semantic_recall_returns_none_when_recalled_empty() {
1868 let mock = MockMemoryBackend::default();
1869 let mut view = mock_view(mock);
1870 view.recall_limit = 10;
1871 let tc = NaiveTokenCounter;
1872 let (msg, score) = fetch_semantic_recall(&view, "query", 1000, &tc, None)
1873 .await
1874 .unwrap();
1875 assert!(msg.is_none());
1876 assert!(score.is_none());
1877 }
1878
1879 #[tokio::test]
1880 async fn fetch_semantic_recall_propagates_error() {
1881 let mock = MockMemoryBackend::with_fail_on("recall");
1882 let mut view = mock_view(mock);
1883 view.recall_limit = 10;
1884 let tc = NaiveTokenCounter;
1885 let result = fetch_semantic_recall(&view, "query", 1000, &tc, None).await;
1886 assert!(result.is_err());
1887 }
1888
1889 #[tokio::test]
1892 async fn fetch_document_rag_returns_message_when_rag_enabled() {
1893 let mock = MockMemoryBackend {
1894 document_chunks: vec![MemDocumentChunk {
1895 text: "Rust ownership rules prevent data races".to_string(),
1896 }],
1897 ..Default::default()
1898 };
1899 let mut view = mock_view(mock);
1900 view.document_config.rag_enabled = true;
1901 let tc = NaiveTokenCounter;
1902 let result = fetch_document_rag(&view, "ownership", 1000, &tc)
1903 .await
1904 .unwrap();
1905 assert!(result.is_some());
1906 let msg = result.unwrap();
1907 assert!(msg.content.contains("ownership rules"));
1908 assert!(msg.content.starts_with(DOCUMENT_RAG_PREFIX));
1909 }
1910
1911 #[tokio::test]
1912 async fn fetch_document_rag_propagates_error() {
1913 let mock = MockMemoryBackend::with_fail_on("search_document_collection");
1914 let mut view = mock_view(mock);
1915 view.document_config.rag_enabled = true;
1916 let tc = NaiveTokenCounter;
1917 let result = fetch_document_rag(&view, "query", 1000, &tc).await;
1918 assert!(result.is_err());
1919 }
1920
1921 #[tokio::test]
1924 async fn fetch_summaries_returns_message_when_summaries_present() {
1925 let mock = MockMemoryBackend {
1926 summaries: vec![MemSummary {
1927 first_message_id: Some(1),
1928 last_message_id: Some(5),
1929 content: "User asked about async Rust".to_string(),
1930 }],
1931 ..Default::default()
1932 };
1933 let mut view = mock_view(mock);
1934 view.conversation_id = Some(42);
1935 let tc = NaiveTokenCounter;
1936 let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
1937 assert!(result.is_some());
1938 let msg = result.unwrap();
1939 let has_summary_part = msg.parts.iter().any(|p| {
1940 if let zeph_llm::provider::MessagePart::Summary { text } = p {
1941 text.contains("Messages 1-5") && text.contains("async Rust")
1942 } else {
1943 false
1944 }
1945 });
1946 assert!(
1947 has_summary_part,
1948 "expected Summary part with messages range"
1949 );
1950 }
1951
1952 #[tokio::test]
1953 async fn fetch_summaries_returns_none_without_conversation_id() {
1954 let mock = MockMemoryBackend {
1955 summaries: vec![MemSummary {
1956 first_message_id: Some(1),
1957 last_message_id: Some(5),
1958 content: "some content".to_string(),
1959 }],
1960 ..Default::default()
1961 };
1962 let mut view = mock_view(mock);
1963 view.conversation_id = None; let tc = NaiveTokenCounter;
1965 let result = fetch_summaries(&view, 1000, &tc).await.unwrap();
1966 assert!(result.is_none());
1967 }
1968
1969 #[tokio::test]
1970 async fn fetch_summaries_propagates_error() {
1971 let mock = MockMemoryBackend::with_fail_on("load_summaries");
1972 let mut view = mock_view(mock);
1973 view.conversation_id = Some(42);
1974 let tc = NaiveTokenCounter;
1975 let result = fetch_summaries(&view, 1000, &tc).await;
1976 assert!(result.is_err());
1977 }
1978
1979 #[tokio::test]
1982 async fn fetch_cross_session_returns_message_when_results_present() {
1983 let mock = MockMemoryBackend {
1984 session_summaries: vec![MemSessionSummary {
1985 summary_text: "Previous session: debugging tokio deadlock".to_string(),
1986 score: 0.9,
1987 }],
1988 ..Default::default()
1989 };
1990 let mut view = mock_view(mock);
1991 view.conversation_id = Some(1);
1992 view.cross_session_score_threshold = 0.5;
1993 let tc = NaiveTokenCounter;
1994 let result = fetch_cross_session(&view, "async", 1000, &tc)
1995 .await
1996 .unwrap();
1997 assert!(result.is_some());
1998 let msg = result.unwrap();
1999 let has_cross_session_part = msg.parts.iter().any(|p| {
2000 if let zeph_llm::provider::MessagePart::CrossSession { text } = p {
2001 text.contains("tokio deadlock")
2002 } else {
2003 false
2004 }
2005 });
2006 assert!(has_cross_session_part);
2007 }
2008
2009 #[tokio::test]
2010 async fn fetch_cross_session_propagates_error() {
2011 let mock = MockMemoryBackend::with_fail_on("search_session_summaries");
2012 let mut view = mock_view(mock);
2013 view.conversation_id = Some(1);
2014 let tc = NaiveTokenCounter;
2015 let result = fetch_cross_session(&view, "query", 1000, &tc).await;
2016 assert!(result.is_err());
2017 }
2018
2019 #[tokio::test]
2022 async fn fetch_reasoning_strategies_returns_message_and_marks_used() {
2023 let mock = Arc::new(MockMemoryBackend {
2024 reasoning_strategies: vec![
2025 MemReasoningStrategy {
2026 id: "strat-1".to_string(),
2027 outcome: "success".to_string(),
2028 summary: "break the problem into small steps".to_string(),
2029 },
2030 MemReasoningStrategy {
2031 id: "strat-2".to_string(),
2032 outcome: "success".to_string(),
2033 summary: "use tracing spans for debugging".to_string(),
2034 },
2035 ],
2036 ..Default::default()
2037 });
2038 let marked_ids = Arc::clone(&mock);
2039 let mut view = empty_view();
2040 view.memory = Some(mock);
2041 view.reasoning_config.enabled = true;
2042 view.reasoning_config.context_budget_tokens = 1000;
2043 let tc = NaiveTokenCounter;
2044 let result = fetch_reasoning_strategies(&view, "debug", 1000, 5, &tc)
2045 .await
2046 .unwrap();
2047 assert!(result.is_some());
2048 let msg = result.unwrap();
2049 assert!(msg.content.starts_with(crate::slot::REASONING_PREFIX));
2050 assert!(msg.content.contains("break the problem"));
2051
2052 tokio::task::yield_now().await;
2056 tokio::task::yield_now().await;
2057
2058 let ids = marked_ids.marked_ids.lock().expect("marked_ids poisoned");
2059 assert!(
2060 ids.contains(&"strat-1".to_string()),
2061 "expected strat-1 marked"
2062 );
2063 assert!(
2064 ids.contains(&"strat-2".to_string()),
2065 "expected strat-2 marked"
2066 );
2067 }
2068
2069 #[tokio::test]
2070 async fn fetch_reasoning_strategies_propagates_error() {
2071 let mock = MockMemoryBackend::with_fail_on("retrieve_reasoning_strategies");
2072 let mut view = mock_view(mock);
2073 view.reasoning_config.enabled = true;
2074 let tc = NaiveTokenCounter;
2075 let result = fetch_reasoning_strategies(&view, "query", 1000, 3, &tc).await;
2076 assert!(result.is_err());
2077 }
2078
2079 #[tokio::test]
2082 async fn fetch_semantic_recall_skips_skipped_and_stopped_messages() {
2083 let mock = MockMemoryBackend {
2084 recalled: vec![
2085 MemRecalledMessage {
2086 role: "user".to_string(),
2087 content: "[skipped] some content".to_string(),
2088 score: 0.95,
2089 },
2090 MemRecalledMessage {
2091 role: "user".to_string(),
2092 content: "[stopped] other content".to_string(),
2093 score: 0.90,
2094 },
2095 MemRecalledMessage {
2096 role: "user".to_string(),
2097 content: "valid content to recall".to_string(),
2098 score: 0.85,
2099 },
2100 ],
2101 ..Default::default()
2102 };
2103 let mut view = mock_view(mock);
2104 view.recall_limit = 10;
2105 let tc = NaiveTokenCounter;
2106 let (msg, _) = fetch_semantic_recall(&view, "query", 1000, &tc, None)
2107 .await
2108 .unwrap();
2109 assert!(msg.is_some());
2110 let msg = msg.unwrap();
2111 let full_text = msg.parts.iter().find_map(|p| {
2112 if let zeph_llm::provider::MessagePart::Recall { text } = p {
2113 Some(text.clone())
2114 } else {
2115 None
2116 }
2117 });
2118 let text = full_text.unwrap_or_default();
2119 assert!(
2120 !text.contains("[skipped]"),
2121 "skipped messages must be excluded"
2122 );
2123 assert!(
2124 !text.contains("[stopped]"),
2125 "stopped messages must be excluded"
2126 );
2127 assert!(
2128 text.contains("valid content to recall"),
2129 "valid messages must be included"
2130 );
2131 }
2132
2133 #[tokio::test]
2134 async fn fetch_cross_session_filters_below_threshold() {
2135 let mock = MockMemoryBackend {
2136 session_summaries: vec![
2137 MemSessionSummary {
2138 summary_text: "high relevance session".to_string(),
2139 score: 0.9,
2140 },
2141 MemSessionSummary {
2142 summary_text: "low relevance session".to_string(),
2143 score: 0.2,
2144 },
2145 ],
2146 ..Default::default()
2147 };
2148 let mut view = mock_view(mock);
2149 view.conversation_id = Some(1);
2150 view.cross_session_score_threshold = 0.5;
2151 let tc = NaiveTokenCounter;
2152 let result = fetch_cross_session(&view, "query", 1000, &tc)
2153 .await
2154 .unwrap();
2155 assert!(result.is_some());
2156 let msg = result.unwrap();
2157 let text = msg
2158 .parts
2159 .iter()
2160 .find_map(|p| {
2161 if let zeph_llm::provider::MessagePart::CrossSession { text } = p {
2162 Some(text.clone())
2163 } else {
2164 None
2165 }
2166 })
2167 .unwrap_or_default();
2168 assert!(
2169 text.contains("high relevance"),
2170 "high score must be included"
2171 );
2172 assert!(
2173 !text.contains("low relevance"),
2174 "low score must be filtered out"
2175 );
2176 }
2177
2178 #[tokio::test]
2179 async fn fetch_document_rag_skips_empty_chunks() {
2180 let mock = MockMemoryBackend {
2181 document_chunks: vec![
2182 MemDocumentChunk {
2183 text: String::new(),
2184 }, MemDocumentChunk {
2186 text: "real content here".to_string(),
2187 },
2188 ],
2189 ..Default::default()
2190 };
2191 let mut view = mock_view(mock);
2192 view.document_config.rag_enabled = true;
2193 let tc = NaiveTokenCounter;
2194 let result = fetch_document_rag(&view, "query", 1000, &tc).await.unwrap();
2195 assert!(result.is_some());
2196 let msg = result.unwrap();
2197 assert!(msg.content.contains("real content here"));
2198 assert!(!msg.content.contains("\n\n\n"));
2200 }
2201
2202 #[tokio::test]
2203 async fn fetch_graph_facts_sanitizes_injection_payloads() {
2204 let mock = MockMemoryBackend {
2206 graph_facts: vec![zeph_common::memory::MemGraphFact {
2207 fact: "fact with <script>alert(1)</script> and\nnewline".to_string(),
2208 confidence: 0.8,
2209 activation_score: None,
2210 neighbors: vec![],
2211 provenance_snippet: None,
2212 }],
2213 ..Default::default()
2214 };
2215 let mut view = mock_view(mock);
2216 view.graph_config.enabled = true;
2217 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
2218 let tc = NaiveTokenCounter;
2219 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
2220 assert!(result.is_some());
2221 let msg = result.unwrap();
2222 assert!(
2223 !msg.content.contains('<'),
2224 "angle brackets must be sanitized"
2225 );
2226 assert!(
2229 !msg.content.contains("\n\n"),
2230 "embedded newlines must be sanitized, no double-newline sequences expected"
2231 );
2232 }
2233
2234 #[tokio::test]
2235 async fn fetch_reasoning_strategies_sanitizes_injection_payloads() {
2236 let mock = MockMemoryBackend {
2238 reasoning_strategies: vec![MemReasoningStrategy {
2239 id: "s1".to_string(),
2240 outcome: "success".to_string(),
2241 summary: "strategy with <b>bold</b> and\nnewline".to_string(),
2242 }],
2243 ..Default::default()
2244 };
2245 let mut view = mock_view(mock);
2246 view.reasoning_config.enabled = true;
2247 let tc = NaiveTokenCounter;
2248 let result = fetch_reasoning_strategies(&view, "query", 1000, 3, &tc)
2249 .await
2250 .unwrap();
2251 assert!(result.is_some());
2252 let msg = result.unwrap();
2253 assert!(
2254 !msg.content.contains('<'),
2255 "angle brackets must be sanitized in strategy summaries"
2256 );
2257 }
2258
2259 #[tokio::test]
2262 async fn fetch_persona_facts_truncates_at_budget() {
2263 let tc = NaiveTokenCounter;
2264 let first_line = "[pref] brief\n";
2266 let budget = tc.count_tokens(crate::slot::PERSONA_PREFIX) + tc.count_tokens(first_line);
2267 let mock = MockMemoryBackend {
2268 persona_facts: vec![
2269 MemPersonaFact {
2270 category: "pref".to_string(),
2271 content: "brief".to_string(),
2272 },
2273 MemPersonaFact {
2274 category: "lang".to_string(),
2275 content: "english".to_string(),
2276 },
2277 ],
2278 ..Default::default()
2279 };
2280 let mut view = mock_view(mock);
2281 view.persona_config.enabled = true;
2282 let result = fetch_persona_facts(&view, budget, &tc).await.unwrap();
2283 let msg = result.unwrap();
2284 assert!(msg.content.contains("brief"), "first fact must be included");
2285 assert!(
2286 !msg.content.contains("english"),
2287 "second fact must be truncated by budget"
2288 );
2289 }
2290
2291 #[tokio::test]
2292 async fn fetch_semantic_recall_truncates_at_budget() {
2293 let tc = NaiveTokenCounter;
2294 let first_entry = "- [user] first message\n";
2296 let budget = tc.count_tokens(RECALL_PREFIX) + tc.count_tokens(first_entry);
2297 let mock = MockMemoryBackend {
2298 recalled: vec![
2299 MemRecalledMessage {
2300 role: "user".to_string(),
2301 content: "first message".to_string(),
2302 score: 0.95,
2303 },
2304 MemRecalledMessage {
2305 role: "user".to_string(),
2306 content: "second message that should be truncated".to_string(),
2307 score: 0.80,
2308 },
2309 ],
2310 ..Default::default()
2311 };
2312 let mut view = mock_view(mock);
2313 view.recall_limit = 10;
2314 let (msg, _) = fetch_semantic_recall(&view, "query", budget, &tc, None)
2315 .await
2316 .unwrap();
2317 assert!(msg.is_some());
2318 let text = msg
2319 .unwrap()
2320 .parts
2321 .iter()
2322 .find_map(|p| {
2323 if let zeph_llm::provider::MessagePart::Recall { text } = p {
2324 Some(text.clone())
2325 } else {
2326 None
2327 }
2328 })
2329 .unwrap_or_default();
2330 assert!(
2331 text.contains("first message"),
2332 "first entry must be included"
2333 );
2334 assert!(
2335 !text.contains("second message"),
2336 "second entry must be truncated by budget"
2337 );
2338 }
2339
2340 #[tokio::test]
2343 async fn fetch_graph_facts_sanitizes_provenance_snippet() {
2344 use zeph_common::memory::MemGraphNeighbor;
2345 let mock = MockMemoryBackend {
2346 graph_facts: vec![zeph_common::memory::MemGraphFact {
2347 fact: "safe fact".to_string(),
2348 confidence: 0.9,
2349 activation_score: None,
2350 neighbors: vec![MemGraphNeighbor {
2351 fact: "neighbor".to_string(),
2352 confidence: 0.7,
2353 }],
2354 provenance_snippet: Some("source with <injection>\nand newline".to_string()),
2355 }],
2356 ..Default::default()
2357 };
2358 let mut view = mock_view(mock);
2359 view.graph_config.enabled = true;
2360 view.graph_config.spreading_activation.recall_timeout_ms = 5000;
2361 let tc = NaiveTokenCounter;
2362 let result = fetch_graph_facts(&view, "test", 1000, &tc).await.unwrap();
2363 assert!(result.is_some());
2364 let msg = result.unwrap();
2365 assert!(
2366 !msg.content.contains('<'),
2367 "angle brackets in provenance_snippet must be sanitized"
2368 );
2369 assert!(
2370 !msg.content.contains("\n\n"),
2371 "newlines in provenance_snippet must be sanitized"
2372 );
2373 assert!(
2374 msg.content.contains("[source:"),
2375 "provenance snippet must be rendered"
2376 );
2377 }
2378}