1use memvid_core::{MemoryCard, MemoryKind, Polarity};
29use rig::vector_store::{VectorSearchRequest, VectorStoreError, VectorStoreIndex};
30use rig::wasm_compat::WasmCompatSend;
31use serde::{Deserialize, Serialize};
32
33use crate::error::MemvidError;
34use crate::memory_graph::MemoryGraph;
35use crate::store::{MemvidFilter, MemvidStore};
36
37#[derive(Debug, Clone, Default)]
43pub enum CardSelection {
44 #[default]
49 EntityMentions,
50 RecentCards,
54 ForPrincipal(String),
59 PreferencesFor(Vec<String>),
63}
64
65#[derive(Debug, Clone)]
75pub struct MemoryCardContext<G = MemvidStore>
76where
77 G: MemoryGraph,
78{
79 graph: G,
80 strategy: CardSelection,
81 max_cards: usize,
82}
83
84impl<G> MemoryCardContext<G>
85where
86 G: MemoryGraph,
87{
88 pub const DEFAULT_MAX_CARDS: usize = 8;
92
93 pub fn new(graph: G, strategy: CardSelection) -> Self {
95 Self {
96 graph,
97 strategy,
98 max_cards: Self::DEFAULT_MAX_CARDS,
99 }
100 }
101
102 #[must_use]
105 pub fn with_max_cards(mut self, max_cards: usize) -> Self {
106 self.max_cards = max_cards;
107 self
108 }
109
110 #[must_use]
112 pub fn graph(&self) -> &G {
113 &self.graph
114 }
115
116 #[must_use]
118 pub fn strategy(&self) -> &CardSelection {
119 &self.strategy
120 }
121
122 pub fn select(&self, query: &str) -> Result<Vec<MemoryCard>, G::Error> {
126 match &self.strategy {
138 CardSelection::EntityMentions => self.select_entity_mentions(query),
139 CardSelection::RecentCards => self.select_recent(),
140 CardSelection::ForPrincipal(principal) => self.select_for_principal(principal),
141 CardSelection::PreferencesFor(entities) => self.select_preferences(entities),
142 }
143 }
144
145 fn select_entity_mentions(&self, query: &str) -> Result<Vec<MemoryCard>, G::Error> {
146 let mut hits = self.graph.cards_for_query(query)?;
151 hits.sort_by_key(|c| std::cmp::Reverse(c.created_at));
152 Ok(hits)
153 }
154
155 fn select_recent(&self) -> Result<Vec<MemoryCard>, G::Error> {
156 let mut all = self.graph.all_memory_cards()?;
161 all.sort_by_key(|c| std::cmp::Reverse(c.created_at));
162 Ok(all)
163 }
164
165 fn select_for_principal(&self, principal: &str) -> Result<Vec<MemoryCard>, G::Error> {
166 let mut hits = self.graph.entity_memories(principal)?;
167 let lower = principal.to_lowercase();
168 if hits.is_empty() && lower != principal {
169 hits = self.graph.entity_memories(&lower)?;
170 }
171 for card in self.graph.all_memory_cards()? {
172 if hits.iter().any(|existing| same_card(existing, &card)) {
173 continue;
174 }
175 let entity = card.entity.to_lowercase();
176 if contains_word(&entity, &lower) {
177 hits.push(card);
178 }
179 }
180
181 let related_entities: Vec<String> = hits
182 .iter()
183 .filter(|card| card.kind == MemoryKind::Relationship)
184 .map(|card| card.value.clone())
185 .collect();
186 for entity in related_entities {
187 for card in self.related_entity_memories(&entity)? {
188 if hits.iter().any(|existing| same_card(existing, &card)) {
189 continue;
190 }
191 hits.push(card);
192 }
193 }
194
195 hits.sort_by_key(|c| std::cmp::Reverse(c.created_at));
196 Ok(hits)
197 }
198
199 fn related_entity_memories(&self, entity: &str) -> Result<Vec<MemoryCard>, G::Error> {
200 let mut hits = self.graph.entity_memories(entity)?;
201 let lower = entity.to_lowercase();
202 if hits.is_empty() && lower != entity {
203 hits = self.graph.entity_memories(&lower)?;
204 }
205 Ok(hits)
206 }
207
208 fn select_preferences(&self, entities: &[String]) -> Result<Vec<MemoryCard>, G::Error> {
209 let mut hits = Vec::new();
210 for ent in entities {
211 hits.extend(self.graph.entity_preferences(ent)?);
212 }
213 hits.sort_by_key(|c| std::cmp::Reverse(c.created_at));
214 Ok(hits)
215 }
216}
217fn same_card(left: &MemoryCard, right: &MemoryCard) -> bool {
218 left.entity == right.entity
219 && left.slot == right.slot
220 && left.value == right.value
221 && left.source_frame_id == right.source_frame_id
222}
223
224pub(crate) fn contains_word(haystack: &str, needle: &str) -> bool {
229 if needle.is_empty() || haystack.len() < needle.len() {
230 return false;
231 }
232 let bytes = haystack.as_bytes();
233 let n = needle.as_bytes();
234 let mut i = 0usize;
235 while i + n.len() <= bytes.len() {
236 let Some(window) = bytes.get(i..i + n.len()) else {
237 break;
238 };
239 if window == n {
240 let before_ok = match i.checked_sub(1).and_then(|j| bytes.get(j)) {
241 None => true,
242 Some(b) => !is_word_byte(*b),
243 };
244 let after_ok = match bytes.get(i + n.len()) {
245 None => true,
246 Some(b) => !is_word_byte(*b),
247 };
248 if before_ok && after_ok {
249 return true;
250 }
251 }
252 i += 1;
253 }
254 false
255}
256
257fn is_word_byte(b: u8) -> bool {
258 b.is_ascii_alphanumeric() || b == b'_'
259}
260
261pub(crate) fn format_card(card: &MemoryCard) -> String {
267 let polarity = match card.polarity {
268 Some(Polarity::Positive) => " (+)",
269 Some(Polarity::Negative) => " (-)",
270 Some(Polarity::Neutral) | None => "",
271 };
272 if card.kind == MemoryKind::Relationship {
273 if card.slot == "reports_to" {
274 return format!(
275 "rel {entity}'s manager = {value}",
276 entity = card.entity,
277 value = card.value
278 );
279 }
280 if card.slot == "manager" {
281 return format!(
282 "rel {entity}'s manager = {value}",
283 entity = card.entity,
284 value = card.value
285 );
286 }
287 }
288 if card.kind == MemoryKind::Fact && card.slot == "location" {
289 return format!(
290 "fact {entity} lives in {value}",
291 entity = card.entity,
292 value = card.value,
293 );
294 }
295 if card.kind == MemoryKind::Fact && card.slot == "employer" {
296 return format!(
297 "fact {entity} works at {value}",
298 entity = card.entity,
299 value = card.value,
300 );
301 }
302 if card.kind == MemoryKind::Profile && card.slot == "allergy" {
303 return format!(
304 "profile {entity} is allergic to {value}",
305 entity = card.entity,
306 value = card.value,
307 );
308 }
309 if card.kind == MemoryKind::Preference {
310 if card.polarity == Some(Polarity::Negative) {
311 return format!(
312 "pref {entity} dislikes {value}",
313 entity = card.entity,
314 value = card.value,
315 );
316 }
317 if card.polarity == Some(Polarity::Positive) {
318 return format!(
319 "pref {entity} likes {value}",
320 entity = card.entity,
321 value = card.value,
322 );
323 }
324 }
325 if matches!(card.kind, MemoryKind::Fact | MemoryKind::Profile) {
326 return format!(
327 "{kind} {entity}'s {slot} = {value}",
328 kind = kind_str(card.kind),
329 entity = card.entity,
330 slot = card.slot,
331 value = card.value,
332 );
333 }
334 format!(
335 "{kind} {entity}/{slot} = {value}{polarity}",
336 kind = kind_str(card.kind),
337 entity = card.entity,
338 slot = card.slot,
339 value = card.value,
340 polarity = polarity,
341 )
342}
343
344pub(crate) fn kind_str(kind: MemoryKind) -> &'static str {
345 match kind {
346 MemoryKind::Fact => "fact",
347 MemoryKind::Preference => "pref",
348 MemoryKind::Event => "event",
349 MemoryKind::Profile => "profile",
350 MemoryKind::Relationship => "rel",
351 MemoryKind::Goal => "goal",
352 MemoryKind::Other => "other",
353 }
354}
355
356fn recency_scores(cards: &[MemoryCard]) -> Vec<f64> {
360 let n = cards.len();
361 if n <= 1 {
362 return vec![1.0; n];
363 }
364 let denom = (n - 1) as f64;
365 (0..n).map(|i| 1.0 - (i as f64 / denom)).collect()
366}
367
368fn rank_cards(query: &str, cards: Vec<MemoryCard>) -> Vec<(f64, MemoryCard)> {
369 let query = query.to_lowercase();
370 let recency = recency_scores(&cards);
371 let mut ranked: Vec<(f64, MemoryCard)> = cards
372 .into_iter()
373 .zip(recency)
374 .map(|(card, recency_score)| {
375 let score = card_relevance_score(&query, &card) + recency_score * 0.01;
376 (score, card)
377 })
378 .collect();
379 ranked.sort_by(|left, right| {
380 right
381 .0
382 .total_cmp(&left.0)
383 .then_with(|| right.1.created_at.cmp(&left.1.created_at))
384 });
385 ranked
386}
387
388fn card_relevance_score(query: &str, card: &MemoryCard) -> f64 {
389 let mut score = 0.0;
390 let entity = card.entity.to_lowercase();
391 let slot = card.slot.to_lowercase();
392 let value = card.value.to_lowercase();
393
394 let entity_matches = !entity.is_empty() && contains_word(query, &entity);
395 let slot_query_match = !slot.is_empty() && contains_word(query, &slot);
396 let value_query_match = !value.is_empty() && contains_word(query, &value);
397 if entity_matches {
398 score += 5.0;
399 if card.kind == MemoryKind::Relationship && query_matches(query, RELATIONSHIP_INTENT_TERMS)
400 {
401 score += 4.0;
402 }
403 }
404 if slot_query_match {
405 score += 4.0;
406 }
407 if value_query_match {
408 score += 2.0;
409 }
410
411 score += slot_intent_score(query, &slot);
412 score += kind_intent_score(
413 query,
414 card.kind,
415 entity_matches || slot_query_match || value_query_match,
416 );
417
418 if query_terms_match(query, &slot) {
419 score += 1.0;
420 }
421 if query_terms_match(query, &value) {
422 score += 1.0;
423 }
424
425 score
426}
427
428fn slot_intent_score(query: &str, slot: &str) -> f64 {
429 if slot_matches(slot, &["location", "city", "home", "address"])
430 && query_matches(
431 query,
432 &[
433 "where", "live", "lives", "located", "location", "city", "reside", "resides",
434 "from", "grew",
435 ],
436 )
437 {
438 return 6.0;
439 }
440 if slot_matches(slot, &["allergy", "allergic", "avoidance"])
441 && query_matches(
442 query,
443 &[
444 "avoid", "serve", "food", "allergic", "allergy", "eat", "cannot", "can't", "safe",
445 ],
446 )
447 {
448 return 6.0;
449 }
450 if slot_matches(slot, &["preference", "drink", "food", "coffee"])
451 && query_matches(
452 query,
453 &[
454 "like",
455 "likes",
456 "prefer",
457 "prefers",
458 "preference",
459 "preferences",
460 "drink",
461 "coffee",
462 "dislike",
463 "dislikes",
464 ],
465 )
466 {
467 return 6.0;
468 }
469 if slot_matches(slot, &["manager", "reports_to", "reports", "boss"])
470 && query_matches(
471 query,
472 &["manager", "boss", "reports", "report", "supervisor"],
473 )
474 {
475 return 6.0;
476 }
477 if slot_matches(slot, &["employer", "company", "work"])
478 && query_matches(
479 query,
480 &["work", "works", "employer", "company", "job", "role"],
481 )
482 {
483 return 6.0;
484 }
485 0.0
486}
487
488const PREFERENCE_INTENT_TERMS: &[&str] = &[
491 "like",
492 "likes",
493 "prefer",
494 "prefers",
495 "preference",
496 "preferences",
497 "dislike",
498 "dislikes",
499];
500
501const PROFILE_INTENT_TERMS: &[&str] = &[
504 "allergic", "allergy", "avoid", "serve", "food", "profile", "about",
505];
506
507const RELATIONSHIP_INTENT_TERMS: &[&str] =
510 &["manager", "boss", "reports", "report", "relationship"];
511
512fn kind_intent_score(query: &str, kind: MemoryKind, card_matched_any: bool) -> f64 {
523 if !card_matched_any {
524 return match kind {
525 MemoryKind::Fact => 0.5,
526 _ => 0.0,
527 };
528 }
529 match kind {
530 MemoryKind::Preference => {
531 if query_matches(query, PREFERENCE_INTENT_TERMS) {
532 2.0
533 } else {
534 0.0
535 }
536 }
537 MemoryKind::Profile => {
538 if query_matches(query, PROFILE_INTENT_TERMS) {
539 2.0
540 } else {
541 0.0
542 }
543 }
544 MemoryKind::Relationship => {
545 if query_matches(query, RELATIONSHIP_INTENT_TERMS) {
546 2.0
547 } else {
548 0.0
549 }
550 }
551 MemoryKind::Fact => 0.5,
552 MemoryKind::Event | MemoryKind::Goal | MemoryKind::Other => 0.0,
553 }
554}
555
556fn slot_matches(slot: &str, needles: &[&str]) -> bool {
557 needles.iter().any(|needle| contains_word(slot, needle))
558}
559
560fn query_matches(query: &str, needles: &[&str]) -> bool {
561 needles.iter().any(|needle| contains_word(query, needle))
562}
563
564fn query_terms_match(query: &str, text: &str) -> bool {
565 text.split(|c: char| !c.is_alphanumeric() && c != '_')
566 .filter(|term| term.len() > 2)
567 .any(|term| contains_word(query, term))
568}
569
570impl<G> VectorStoreIndex for MemoryCardContext<G>
571where
572 G: MemoryGraph + WasmCompatSend + Sync,
573{
574 type Filter = MemvidFilter;
579
580 async fn top_n<T>(
581 &self,
582 req: VectorSearchRequest<Self::Filter>,
583 ) -> Result<Vec<(f64, String, T)>, VectorStoreError>
584 where
585 T: for<'a> Deserialize<'a> + WasmCompatSend,
586 {
587 let query = req.query().to_owned();
588 let limit = std::cmp::min(self.max_cards, req.samples() as usize);
589
590 let mut ranked = rank_cards(&query, self.select(&query).map_err(Into::into)?);
591 if ranked.len() > limit {
592 ranked.truncate(limit);
593 }
594
595 let mut out = Vec::with_capacity(ranked.len());
596 let mut byte_size = 0usize;
597 for (score, card) in ranked {
598 let id = card.id.to_string();
599 let text = format_card(&card);
600 byte_size = byte_size.saturating_add(text.len());
601 let payload = CardDoc {
602 text,
603 kind: kind_str(card.kind).to_string(),
604 entity: card.entity,
605 slot: card.slot,
606 value: card.value,
607 polarity: card.polarity.map(polarity_str).map(str::to_owned),
608 source_frame_id: card.source_frame_id,
609 confidence: card.confidence,
610 };
611 let value = serde_json::to_value(&payload).map_err(MemvidError::from)?;
612 let doc: T = serde_json::from_value(value).map_err(MemvidError::from)?;
613 out.push((score, id, doc));
614 }
615 emit_card_context_sample(out.len(), byte_size);
616 Ok(out)
617 }
618
619 async fn top_n_ids(
620 &self,
621 req: VectorSearchRequest<Self::Filter>,
622 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
623 let query = req.query().to_owned();
624 let limit = std::cmp::min(self.max_cards, req.samples() as usize);
625
626 let mut ranked = rank_cards(&query, self.select(&query).map_err(Into::into)?);
627 if ranked.len() > limit {
628 ranked.truncate(limit);
629 }
630 let byte_size = ranked
631 .iter()
632 .map(|(_, card)| card.entity.len() + card.slot.len() + card.value.len())
633 .sum();
634 emit_card_context_sample(ranked.len(), byte_size);
635 Ok(ranked
636 .into_iter()
637 .map(|(score, card)| (score, card.id.to_string()))
638 .collect())
639 }
640}
641
642fn emit_card_context_sample(_message_count: usize, _byte_size: usize) {
643 #[cfg(feature = "observe")]
644 rig_tap::emit_kind(
645 "memory-card-context",
646 rig_tap::EventKind::ContextSampled {
647 message_count: _message_count,
648 byte_size: _byte_size,
649 token_estimate: None,
650 },
651 );
652}
653
654pub(crate) fn polarity_str(p: Polarity) -> &'static str {
655 match p {
656 Polarity::Positive => "positive",
657 Polarity::Negative => "negative",
658 Polarity::Neutral => "neutral",
659 }
660}
661
662#[derive(Debug, Clone, Serialize, Deserialize)]
669pub struct CardDoc {
670 pub text: String,
672 pub kind: String,
674 pub entity: String,
676 pub slot: String,
678 pub value: String,
680 pub polarity: Option<String>,
682 pub source_frame_id: u64,
684 pub confidence: Option<f32>,
686}
687
688#[cfg(test)]
689#[allow(
690 clippy::unwrap_used,
691 clippy::expect_used,
692 clippy::panic,
693 clippy::indexing_slicing
694)]
695mod tests {
696 use super::*;
697
698 #[test]
699 fn word_boundary_matches() {
700 assert!(contains_word("alice loves rust", "alice"));
701 assert!(contains_word("hi alice!", "alice"));
702 assert!(contains_word("alice", "alice"));
703 assert!(!contains_word("smart cookie", "art"));
704 assert!(!contains_word("alicemarie", "alice"));
705 assert!(!contains_word("", "alice"));
706 assert!(!contains_word("alice", ""));
707 }
708
709 #[test]
710 fn recency_scores_handles_edge_cases() {
711 assert_eq!(recency_scores(&[]), Vec::<f64>::new());
712 assert_eq!(recency_scores(&[stub_card("a")]), vec![1.0]);
713 let two = recency_scores(&[stub_card("a"), stub_card("b")]);
714 assert_eq!(two, vec![1.0, 0.0]);
715 }
716
717 fn stub_card(entity: &str) -> MemoryCard {
718 MemoryCard {
719 id: 0,
720 kind: MemoryKind::Fact,
721 entity: entity.into(),
722 slot: "s".into(),
723 value: "v".into(),
724 polarity: None,
725 event_date: None,
726 document_date: None,
727 version_key: None,
728 version_relation: memvid_core::VersionRelation::default(),
729 source_frame_id: 0,
730 source_uri: None,
731 source_offset: None,
732 engine: "t".into(),
733 engine_version: "0".into(),
734 confidence: None,
735 created_at: 0,
736 }
737 }
738
739 fn pref_card(entity: &str) -> MemoryCard {
740 let mut card = stub_card(entity);
741 card.kind = MemoryKind::Preference;
742 card.slot = "drink".into();
743 card.value = "espresso".into();
744 card
745 }
746
747 #[test]
748 fn kind_intent_score_requires_card_match() {
749 let query = "what does alice prefer?";
754 let alice_card = pref_card("alice");
755 let mut unrelated = pref_card("bob");
757 unrelated.slot = "music_genre".into();
758 unrelated.value = "jazz".into();
759 let alice_score = card_relevance_score(query, &alice_card);
760 let unrelated_score = card_relevance_score(query, &unrelated);
761 assert!(
762 alice_score > unrelated_score,
763 "matched alice card {alice_score} must beat unrelated card {unrelated_score}"
764 );
765 assert_eq!(
768 super::kind_intent_score(query, MemoryKind::Preference, false),
769 0.0
770 );
771 assert_eq!(
772 super::kind_intent_score(query, MemoryKind::Preference, true),
773 2.0
774 );
775 assert_eq!(
778 super::kind_intent_score(query, MemoryKind::Fact, false),
779 0.5
780 );
781 }
782
783 #[test]
784 fn t2_old_relevant_card_beats_recent_noise() {
785 let mut relevant = pref_card("alice");
790 relevant.value = "espresso".into();
791 relevant.created_at = 0; let mut cards = vec![relevant.clone()];
793 for i in 1..=10 {
794 let mut noise = pref_card("bob");
795 noise.value = format!("noise-{i}");
796 noise.created_at = i; cards.push(noise);
798 }
799 let ranked = rank_cards("what does alice prefer?", cards);
800 let top = ranked.first().expect("at least one ranked card");
801 assert_eq!(
802 top.1.entity, "alice",
803 "expected alice card on top, got {:?}",
804 top.1
805 );
806 }
807
808 #[test]
809 fn t3_rank_cards_with_no_match_returns_low_scores() {
810 let mut bob = pref_card("bob");
814 bob.slot = "music_genre".into();
815 bob.value = "jazz".into();
816 let mut carol = pref_card("carol");
817 carol.slot = "music_genre".into();
818 carol.value = "rock".into();
819 let cards = vec![bob, carol];
820 let ranked = rank_cards("how is the weather today?", cards);
823 assert_eq!(ranked.len(), 2);
824 for (score, card) in &ranked {
825 assert!(
826 *score <= 1.0 + f64::EPSILON,
827 "unmatched {entity} scored {score}; expected <= 1.0",
828 entity = card.entity
829 );
830 }
831 }
832}