1use std::marker::PhantomData;
5use std::sync::Arc;
6
7use memvid_core::{MemoryCard, MemoryCardBuilder, PutOptions};
8use rig::{
9 agent::{HookAction, PromptHook},
10 completion::{CompletionModel, CompletionResponse, Message},
11};
12
13use crate::store::MemvidStore;
14
15pub type WriteTransform = Arc<dyn Fn(&Message) -> Option<String> + Send + Sync + 'static>;
21
22#[derive(Clone, Default)]
24pub enum WritePolicy {
25 Disabled,
28 #[default]
30 Raw,
31 Custom(WriteTransform),
37}
38
39impl std::fmt::Debug for WritePolicy {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 Self::Disabled => f.write_str("WritePolicy::Disabled"),
43 Self::Raw => f.write_str("WritePolicy::Raw"),
44 Self::Custom(_) => f.write_str("WritePolicy::Custom(<fn>)"),
45 }
46 }
47}
48
49pub type WriteFailureCallback =
54 Arc<dyn Fn(WriteFailurePhase, &crate::MemvidError) -> WriteFailureAction + Send + Sync>;
55
56#[derive(Clone, Default)]
65pub enum WriteFailure {
66 #[default]
69 Warn,
70 Halt,
74 Custom(WriteFailureCallback),
78}
79
80impl std::fmt::Debug for WriteFailure {
81 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82 match self {
83 Self::Warn => f.write_str("WriteFailure::Warn"),
84 Self::Halt => f.write_str("WriteFailure::Halt"),
85 Self::Custom(_) => f.write_str("WriteFailure::Custom(<fn>)"),
86 }
87 }
88}
89
90#[derive(Clone, Copy, Debug, PartialEq, Eq)]
92#[non_exhaustive]
93pub enum WriteFailurePhase {
94 Put,
96 PutCard,
98 Commit,
101}
102
103#[derive(Clone, Copy, Debug, PartialEq, Eq)]
105#[non_exhaustive]
106pub enum WriteFailureAction {
107 Continue,
109 Halt,
111}
112
113#[derive(Clone, Debug)]
119#[non_exhaustive]
120pub struct MemoryConfig {
121 pub policy: WritePolicy,
123 pub commit_each_turn: bool,
127 pub default_tags: Vec<String>,
129 pub scope: Option<String>,
134 pub principal: Option<String>,
143 pub persist_assistant: bool,
149 pub supplemental_profile_cards: bool,
158 pub auto_tag: bool,
162 pub extract_dates: bool,
165 pub extract_triplets: bool,
173 pub observe_conversation_id: Option<String>,
180 pub on_write_failure: WriteFailure,
186 pub rewrite_principal_pronouns: bool,
199}
200
201impl Default for MemoryConfig {
202 fn default() -> Self {
203 Self {
204 policy: WritePolicy::default(),
205 commit_each_turn: true,
206 default_tags: Vec::new(),
207 scope: None,
208 principal: None,
209 persist_assistant: true,
210 supplemental_profile_cards: true,
211 auto_tag: true,
212 extract_dates: true,
213 extract_triplets: true,
214 observe_conversation_id: None,
215 on_write_failure: WriteFailure::default(),
216 rewrite_principal_pronouns: true,
217 }
218 }
219}
220
221impl MemoryConfig {
222 pub fn builder() -> MemoryConfigBuilder {
225 MemoryConfigBuilder::default()
226 }
227}
228
229#[derive(Clone, Debug, Default)]
246pub struct MemoryConfigBuilder {
247 config: MemoryConfig,
248}
249
250impl MemoryConfigBuilder {
251 pub fn policy(mut self, policy: WritePolicy) -> Self {
253 self.config.policy = policy;
254 self
255 }
256 pub fn commit_each_turn(mut self, commit_each_turn: bool) -> Self {
258 self.config.commit_each_turn = commit_each_turn;
259 self
260 }
261 pub fn default_tags(mut self, tags: Vec<String>) -> Self {
263 self.config.default_tags = tags;
264 self
265 }
266 pub fn scope(mut self, scope: Option<String>) -> Self {
268 self.config.scope = scope;
269 self
270 }
271 pub fn principal(mut self, principal: Option<String>) -> Self {
273 self.config.principal = principal;
274 self
275 }
276 pub fn persist_assistant(mut self, persist_assistant: bool) -> Self {
278 self.config.persist_assistant = persist_assistant;
279 self
280 }
281 pub fn supplemental_profile_cards(mut self, on: bool) -> Self {
283 self.config.supplemental_profile_cards = on;
284 self
285 }
286 pub fn auto_tag(mut self, on: bool) -> Self {
288 self.config.auto_tag = on;
289 self
290 }
291 pub fn extract_dates(mut self, on: bool) -> Self {
293 self.config.extract_dates = on;
294 self
295 }
296 pub fn extract_triplets(mut self, on: bool) -> Self {
298 self.config.extract_triplets = on;
299 self
300 }
301 pub fn observe_conversation_id(mut self, id: Option<String>) -> Self {
303 self.config.observe_conversation_id = id;
304 self
305 }
306 pub fn on_write_failure(mut self, policy: WriteFailure) -> Self {
308 self.config.on_write_failure = policy;
309 self
310 }
311 pub fn rewrite_principal_pronouns(mut self, on: bool) -> Self {
313 self.config.rewrite_principal_pronouns = on;
314 self
315 }
316 pub fn build(self) -> MemoryConfig {
318 self.config
319 }
320}
321
322pub struct MemvidPersistHook<M> {
328 store: MemvidStore,
329 config: MemoryConfig,
330 halt: Arc<std::sync::atomic::AtomicBool>,
337 _model: PhantomData<fn() -> M>,
338}
339
340impl<M> Clone for MemvidPersistHook<M> {
341 fn clone(&self) -> Self {
342 Self {
343 store: self.store.clone(),
344 config: self.config.clone(),
345 halt: self.halt.clone(),
346 _model: PhantomData,
347 }
348 }
349}
350
351impl<M> std::fmt::Debug for MemvidPersistHook<M> {
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353 f.debug_struct("MemvidPersistHook")
354 .field("config", &self.config)
355 .finish_non_exhaustive()
356 }
357}
358
359impl<M> MemvidPersistHook<M> {
360 pub fn new(store: MemvidStore, config: MemoryConfig) -> Self {
362 Self {
363 store,
364 config,
365 halt: Arc::new(std::sync::atomic::AtomicBool::new(false)),
366 _model: PhantomData,
367 }
368 }
369
370 pub fn with_defaults(store: MemvidStore) -> Self {
373 Self::new(store, MemoryConfig::default())
374 }
375
376 fn render(&self, msg: &Message) -> Option<String> {
377 match &self.config.policy {
378 WritePolicy::Disabled => None,
379 WritePolicy::Raw => render_message_text(msg),
380 WritePolicy::Custom(f) => f(msg),
381 }
382 }
383
384 fn put_options(&self, chat_role: &str) -> PutOptions {
385 let mut opts = PutOptions {
386 tags: self.config.default_tags.clone(),
387 auto_tag: self.config.auto_tag,
388 extract_dates: self.config.extract_dates,
389 extract_triplets: self.config.extract_triplets,
390 ..PutOptions::default()
391 };
392 opts.extra_metadata
393 .insert("chat_role".into(), chat_role.into());
394 if let Some(scope) = self.config.scope.as_deref() {
395 opts.uri = Some(scope.to_string());
400 opts.extra_metadata.insert("scope".into(), scope.into());
401 }
402 opts
403 }
404
405 fn write(&self, text: &str, chat_role: &str) {
406 if text.is_empty() {
407 return;
408 }
409 let text = if chat_role == "user" && self.config.rewrite_principal_pronouns {
410 self.config
411 .principal
412 .as_deref()
413 .map(|principal| bind_principal(text, principal))
414 .unwrap_or_else(|| text.to_string())
415 } else {
416 text.to_string()
417 };
418 let opts = self.put_options(chat_role);
419 let scope = self.config.scope.clone();
420 let frame_id = match self.store.put_text_uncommitted(&text, opts) {
421 Ok(frame_id) => frame_id,
422 Err(err) => {
423 self.handle_write_failure(WriteFailurePhase::Put, chat_role, &err);
424 return;
425 }
426 };
427 #[cfg(feature = "observe")]
428 rig_tap::emit_kind(
429 self.config
430 .observe_conversation_id
431 .as_deref()
432 .or(scope.as_deref())
433 .unwrap_or("default"),
434 rig_tap::EventKind::MemoryFrameWritten {
435 frame_kind: "turn".to_string(),
436 frame_count_after: None,
439 bytes_written: text.len(),
440 },
441 );
442
443 if chat_role == "user"
444 && self.config.supplemental_profile_cards
445 && let Some(principal) = self.config.principal.as_deref()
446 {
447 for card in supplemental_memory_cards(&text, principal, frame_id, scope.clone()) {
448 if let Err(err) = self.store.put_memory_card(card) {
449 self.handle_write_failure(WriteFailurePhase::PutCard, chat_role, &err);
450 }
451 }
452 }
453
454 if self.config.commit_each_turn
455 && let Err(err) = self.store.commit()
456 {
457 self.handle_write_failure(WriteFailurePhase::Commit, chat_role, &err);
458 }
459 }
460
461 fn handle_write_failure(
465 &self,
466 phase: WriteFailurePhase,
467 chat_role: &str,
468 err: &crate::MemvidError,
469 ) {
470 let phase_str = match phase {
471 WriteFailurePhase::Put => "put",
472 WriteFailurePhase::PutCard => "put_card",
473 WriteFailurePhase::Commit => "commit",
474 };
475 match &self.config.on_write_failure {
476 WriteFailure::Warn => {
477 tracing::warn!(
478 target: "rig_memvid::hook",
479 error = %err,
480 role = chat_role,
481 phase = phase_str,
482 "failed to persist into memvid",
483 );
484 }
485 WriteFailure::Halt => {
486 tracing::error!(
487 target: "rig_memvid::hook",
488 error = %err,
489 role = chat_role,
490 phase = phase_str,
491 "failed to persist into memvid; halting agent per WriteFailure::Halt",
492 );
493 self.halt.store(true, std::sync::atomic::Ordering::SeqCst);
494 }
495 WriteFailure::Custom(callback) => {
496 let action = (callback)(phase, err);
497 if matches!(action, WriteFailureAction::Halt) {
498 tracing::error!(
499 target: "rig_memvid::hook",
500 error = %err,
501 role = chat_role,
502 phase = phase_str,
503 "failed to persist into memvid; halting agent per WriteFailure::Custom",
504 );
505 self.halt.store(true, std::sync::atomic::Ordering::SeqCst);
506 } else {
507 tracing::warn!(
508 target: "rig_memvid::hook",
509 error = %err,
510 role = chat_role,
511 phase = phase_str,
512 "failed to persist into memvid (Custom policy: continue)",
513 );
514 }
515 }
516 }
517 }
518
519 fn should_halt(&self) -> bool {
524 self.halt.load(std::sync::atomic::Ordering::SeqCst)
525 }
526}
527
528fn supplemental_memory_cards(
529 text: &str,
530 principal: &str,
531 frame_id: u64,
532 source_uri: Option<String>,
533) -> Vec<MemoryCard> {
534 let mut cards = Vec::new();
535 if let Some(value) = allergy_value(text)
536 && let Some(card) = profile_card(
537 &principal.to_lowercase(),
538 "allergy",
539 &value,
540 frame_id,
541 source_uri.clone(),
542 )
543 {
544 cards.push(card);
545 }
546 cards.extend(relationship_cards(text, principal, frame_id, source_uri));
547 cards
548}
549
550fn profile_card(
551 entity: &str,
552 slot: &str,
553 value: &str,
554 frame_id: u64,
555 source_uri: Option<String>,
556) -> Option<MemoryCard> {
557 MemoryCardBuilder::new()
558 .profile()
559 .entity(normalize_entity(entity))
560 .slot(slot)
561 .value(value.trim())
562 .source(frame_id, source_uri)
563 .engine("rig-memvid:principal-rules", "2")
564 .confidence(1.0)
565 .build(0)
566 .ok()
567}
568
569fn relationship_card(
570 entity: &str,
571 slot: &str,
572 value: &str,
573 frame_id: u64,
574 source_uri: Option<String>,
575) -> Option<MemoryCard> {
576 MemoryCardBuilder::new()
577 .relationship()
578 .entity(normalize_entity(entity))
579 .slot(slot)
580 .value(value.trim())
581 .source(frame_id, source_uri)
582 .engine("rig-memvid:principal-rules", "2")
583 .confidence(1.0)
584 .build(0)
585 .ok()
586}
587
588fn fact_card(
589 entity: &str,
590 slot: &str,
591 value: &str,
592 frame_id: u64,
593 source_uri: Option<String>,
594) -> Option<MemoryCard> {
595 MemoryCardBuilder::new()
596 .fact()
597 .entity(normalize_entity(entity))
598 .slot(slot)
599 .value(value.trim())
600 .source(frame_id, source_uri)
601 .engine("rig-memvid:principal-rules", "2")
602 .confidence(1.0)
603 .build(0)
604 .ok()
605}
606
607fn relationship_cards(
608 text: &str,
609 principal: &str,
610 frame_id: u64,
611 source_uri: Option<String>,
612) -> Vec<MemoryCard> {
613 let mut cards = Vec::new();
614 let Some(manager) = manager_subject(text, principal) else {
615 return cards;
616 };
617
618 if let Some(card) =
619 relationship_card(principal, "manager", &manager, frame_id, source_uri.clone())
620 {
621 cards.push(card);
622 }
623
624 if let Some(employer) = manager_employer(text, principal)
625 && let Some(card) = fact_card(
626 &manager,
627 "employer",
628 &employer,
629 frame_id,
630 source_uri.clone(),
631 )
632 {
633 cards.push(card);
634 }
635
636 if let Some(report) = reports_to(text, &manager) {
637 if let Some(card) = relationship_card(
638 &manager,
639 "reports_to",
640 &report.manager,
641 frame_id,
642 source_uri.clone(),
643 ) {
644 cards.push(card);
645 }
646 if let Some(title) = report.manager_title
647 && let Some(card) = profile_card(
648 &report.manager,
649 "title",
650 &title,
651 frame_id,
652 source_uri.clone(),
653 )
654 {
655 cards.push(card);
656 }
657 }
658
659 cards
660}
661
662fn manager_subject(text: &str, principal: &str) -> Option<String> {
663 let lower = text.to_lowercase();
664 let marker = format!(" is {}'s manager", principal.to_lowercase());
665 let idx = lower.find(&marker)?;
666 let before = text.get(..idx)?.trim();
667 last_name(before)
668}
669
670fn manager_employer(text: &str, principal: &str) -> Option<String> {
671 let lower = text.to_lowercase();
672 let marker = format!(" is {}'s manager at ", principal.to_lowercase());
673 let idx = lower.find(&marker)? + marker.len();
674 let raw = text.get(idx..)?;
675 clean_clause(raw, &['.', '!', '?', ';', ',', '\n'])
676}
677
678struct ReportsTo {
679 manager: String,
680 manager_title: Option<String>,
681}
682
683fn reports_to(text: &str, subject: &str) -> Option<ReportsTo> {
684 let lower = text.to_lowercase();
685 let subject_marker = format!("{} reports to ", subject.to_lowercase());
686 let start = if let Some(idx) = lower.find(&subject_marker) {
687 idx + subject_marker.len()
688 } else if let Some(idx) = lower.find(" he reports to ") {
689 idx + " he reports to ".len()
690 } else if let Some(idx) = lower.find(" she reports to ") {
691 idx + " she reports to ".len()
692 } else {
693 return None;
694 };
695 let raw = text.get(start..)?;
696 let sentence = clean_clause(raw, &['.', '!', '?', ';', '\n'])?;
697 let mut parts = sentence.splitn(2, ',');
698 let manager = clean_name(parts.next()?)?;
699 let manager_title = parts.next().and_then(clean_title);
700 Some(ReportsTo {
701 manager,
702 manager_title,
703 })
704}
705
706fn last_name(text: &str) -> Option<String> {
707 text.split_whitespace().rev().find_map(clean_name)
708}
709
710fn clean_name(text: &str) -> Option<String> {
711 let value = text
712 .trim()
713 .trim_matches(|c: char| !c.is_alphanumeric() && c != '_' && c != '-' && c != '\'')
714 .trim();
715 (!value.is_empty()).then(|| value.to_string())
716}
717
718fn clean_title(text: &str) -> Option<String> {
719 let value = text
720 .trim()
721 .strip_prefix("the ")
722 .unwrap_or_else(|| text.trim())
723 .trim()
724 .trim_matches(|c: char| !c.is_alphanumeric() && c != ' ' && c != '_' && c != '-')
725 .trim();
726 (!value.is_empty()).then(|| value.to_string())
727}
728
729fn clean_clause(text: &str, delimiters: &[char]) -> Option<String> {
730 let value = text
731 .split(|c| delimiters.contains(&c))
732 .next()?
733 .trim()
734 .trim_matches(|c: char| !c.is_alphanumeric() && c != ' ' && c != '_' && c != '-')
735 .trim();
736 const CORP_SUFFIXES: &[&str] = &[
742 " incorporated",
743 " corporation",
744 " company",
745 " limited",
746 " inc",
747 " corp",
748 " llc",
749 " ltd",
750 " co",
751 ];
752 let lowered = value.to_lowercase();
753 let stripped = CORP_SUFFIXES
754 .iter()
755 .find_map(|suffix| lowered.strip_suffix(suffix).map(|head| head.len()))
756 .and_then(|head_len| value.get(..head_len))
757 .map(str::trim)
758 .unwrap_or(value);
759 (!stripped.is_empty()).then(|| stripped.to_string())
760}
761
762fn normalize_entity(entity: &str) -> String {
763 entity.trim().to_lowercase()
764}
765
766fn allergy_value(text: &str) -> Option<String> {
767 let lower = text.to_lowercase();
768 let start = if let Some(idx) = lower.find(" allergic to ") {
769 idx + " allergic to ".len()
770 } else if let Some(idx) = lower.find(" allergy to ") {
771 idx + " allergy to ".len()
772 } else if let Some(idx) = lower.find(" cannot have ") {
773 idx + " cannot have ".len()
774 } else if let Some(idx) = lower.find(" can't have ") {
775 idx + " can't have ".len()
776 } else {
777 return None;
778 };
779 let raw = text.get(start..)?;
780 let value = raw
781 .split(['.', '!', '?', ';', ',', '\n'])
782 .next()?
783 .trim()
784 .trim_matches(|c: char| matches!(c, '.' | '!' | '?' | ';' | ',' | ':' | ' '));
785 (!value.is_empty()).then(|| value.to_string())
786}
787
788fn quoted_span_contains_first_person(text: &str) -> bool {
792 let mut in_quote = false;
793 let mut span_start: usize = 0;
794 for (idx, ch) in text.char_indices() {
795 if ch != '"' {
796 continue;
797 }
798 if !in_quote {
799 in_quote = true;
800 span_start = idx + ch.len_utf8();
801 } else {
802 in_quote = false;
803 if let Some(span) = text.get(span_start..idx) {
804 for tok in span.split_whitespace() {
805 let core = tok.trim_matches(|c: char| !c.is_alphanumeric() && c != '\'');
806 if core == "I" {
807 return true;
808 }
809 }
810 }
811 }
812 }
813 false
814}
815
816fn bind_principal(text: &str, principal: &str) -> String {
817 let principal = principal.trim();
818 if principal.is_empty() {
819 return text.to_string();
820 }
821
822 if text.contains("```") || quoted_span_contains_first_person(text) {
827 return text.to_string();
828 }
829
830 let lower = text.to_lowercase();
831 let name_prefix = format!("my name is {} and i ", principal.to_lowercase());
832 if lower.starts_with(&name_prefix)
833 && let Some(rest) = text.get(name_prefix.len() - "i ".len()..)
834 {
835 return bind_principal(rest, principal);
836 }
837
838 let mut output = Vec::new();
839 let mut tokens = text.split_whitespace().peekable();
840 while let Some(token) = tokens.next() {
841 let core = token_core_lower(token);
842 if core != "i" {
843 output.push(bind_token(token, principal));
844 continue;
845 }
846
847 if let Some(next) = tokens.peek() {
848 let next_core = token_core_lower(next);
849 if next_core == "really" {
850 let really = tokens.next();
851 if let (Some(really_token), Some(verb_token)) = (really, tokens.peek()) {
852 let verb_core = token_core_lower(verb_token);
853 if let Some(verb) = principal_verb(&verb_core) {
854 let suffix = token_suffix(verb_token);
855 let _ = tokens.next();
856 output.push(format!("{principal} {really_token} {verb}{suffix}"));
857 continue;
858 }
859 }
860 output.push(principal.to_string());
861 if let Some(really_token) = really {
862 output.push(really_token.to_string());
863 }
864 continue;
865 }
866 if let Some(verb) = principal_verb(&next_core) {
867 let suffix = token_suffix(next);
868 let _ = tokens.next();
869 output.push(format!("{principal} {verb}{suffix}"));
870 continue;
871 }
872 }
873 output.push(token.to_string());
879 }
880 output.join(" ")
881}
882
883fn token_core_lower(token: &str) -> String {
884 token
885 .trim_matches(|c: char| !c.is_alphanumeric() && c != '\'')
886 .to_lowercase()
887}
888
889fn token_suffix(token: &str) -> String {
890 token
891 .chars()
892 .rev()
893 .take_while(|c| !c.is_alphanumeric() && *c != '\'')
894 .collect::<Vec<_>>()
895 .into_iter()
896 .rev()
897 .collect()
898}
899
900fn principal_verb(core: &str) -> Option<&'static str> {
901 match core {
902 "like" => Some("likes"),
903 "dislike" => Some("dislikes"),
904 "live" => Some("lives"),
905 "work" => Some("works"),
906 "grew" => Some("grew"),
907 "prefer" => Some("prefers"),
908 "love" => Some("loves"),
909 "hate" => Some("hates"),
910 "want" => Some("wants"),
911 "need" => Some("needs"),
912 "am" => Some("is"),
913 "have" => Some("has"),
914 _ => None,
915 }
916}
917
918fn bind_token(token: &str, principal: &str) -> String {
919 let suffix = token_suffix(token);
920 let core = token_core_lower(token);
921 let replacement = match core.as_str() {
922 "i" => Some(principal.to_string()),
923 "me" | "myself" => Some(principal.to_string()),
924 "my" | "mine" => Some(format!("{principal}'s")),
925 "i'm" | "im" => Some(format!("{principal} is")),
926 "i've" | "ive" => Some(format!("{principal} has")),
927 "i'd" | "id" => Some(format!("{principal} would")),
928 "i'll" | "ill" => Some(format!("{principal} will")),
929 _ => None,
930 };
931 match replacement {
932 Some(mut value) => {
933 value.push_str(&suffix);
934 value
935 }
936 None => token.to_string(),
937 }
938}
939
940pub(crate) fn render_message_text(msg: &Message) -> Option<String> {
945 use rig::completion::message::{
946 AssistantContent, Message as Msg, ReasoningContent, UserContent,
947 };
948
949 match msg {
950 Msg::System { content } => Some(content.clone()),
951 Msg::User { content } => {
952 let mut buf = String::new();
953 for item in content.iter() {
954 if let UserContent::Text(text) = item {
955 if !buf.is_empty() {
956 buf.push('\n');
957 }
958 buf.push_str(&text.text);
959 }
960 }
961 (!buf.is_empty()).then_some(buf)
962 }
963 Msg::Assistant { content, .. } => {
964 let mut buf = String::new();
965 for item in content.iter() {
966 match item {
967 AssistantContent::Text(text) => {
968 if !buf.is_empty() {
969 buf.push('\n');
970 }
971 buf.push_str(&text.text);
972 }
973 AssistantContent::Reasoning(reasoning) => {
974 for entry in reasoning.content.iter() {
975 if let ReasoningContent::Text { text, .. } = entry {
976 if !buf.is_empty() {
977 buf.push('\n');
978 }
979 buf.push_str(text);
980 }
981 }
982 }
983 AssistantContent::ToolCall(_) | AssistantContent::Image(_) => {}
984 }
985 }
986 (!buf.is_empty()).then_some(buf)
987 }
988 }
989}
990
991impl<M> PromptHook<M> for MemvidPersistHook<M>
992where
993 M: CompletionModel,
994{
995 async fn on_completion_call(&self, prompt: &Message, _history: &[Message]) -> HookAction {
996 if let Some(text) = self.render(prompt) {
997 self.write(&text, "user");
998 }
999 if self.should_halt() {
1000 return HookAction::terminate(
1001 "rig-memvid: persistence failed under WriteFailure::Halt",
1002 );
1003 }
1004 HookAction::cont()
1005 }
1006
1007 async fn on_completion_response(
1008 &self,
1009 _prompt: &Message,
1010 response: &CompletionResponse<M::Response>,
1011 ) -> HookAction {
1012 if !self.config.persist_assistant {
1013 return HookAction::cont();
1014 }
1015 for content in response.choice.iter() {
1016 let synthetic = Message::Assistant {
1017 id: None,
1018 content: rig::OneOrMany::one(content.clone()),
1019 };
1020 if let Some(text) = self.render(&synthetic) {
1021 self.write(&text, "assistant");
1022 }
1023 }
1024 if self.should_halt() {
1025 return HookAction::terminate(
1026 "rig-memvid: persistence failed under WriteFailure::Halt",
1027 );
1028 }
1029 HookAction::cont()
1030 }
1031}
1032
1033#[cfg(test)]
1034mod tests {
1035 use super::{allergy_value, bind_principal, supplemental_memory_cards};
1036
1037 #[test]
1038 fn bind_principal_rewrites_first_person_tokens() {
1039 let rewritten = bind_principal(
1040 "My name is Alice. I'm allergic to peanuts, and I like espresso.",
1041 "Alice",
1042 );
1043 assert_eq!(
1044 rewritten,
1045 "Alice's name is Alice. Alice is allergic to peanuts, and Alice likes espresso."
1046 );
1047 }
1048
1049 #[test]
1050 fn bind_principal_rewrites_common_verbs_after_adverbs() {
1051 assert_eq!(
1052 bind_principal("I really dislike instant coffee.", "Alice"),
1053 "Alice really dislikes instant coffee."
1054 );
1055 }
1056
1057 #[test]
1058 fn bind_principal_collapses_name_intro_before_verbs() {
1059 assert_eq!(
1060 bind_principal(
1061 "My name is Alice and I work at Acme as a staff engineer.",
1062 "Alice",
1063 ),
1064 "Alice works at Acme as a staff engineer."
1065 );
1066 }
1067
1068 #[test]
1069 fn bind_principal_ignores_empty_principal() {
1070 assert_eq!(bind_principal("I like rust", " "), "I like rust");
1071 }
1072
1073 #[test]
1074 fn bind_principal_is_idempotent() {
1075 let once = bind_principal("I like espresso and I dislike tea.", "Alice");
1076 let twice = bind_principal(&once, "Alice");
1077 assert_eq!(once, twice);
1078 }
1079
1080 #[test]
1081 fn bind_principal_skips_quoted_speech() {
1082 let input = "Bob said \"I love hiking\" yesterday.";
1084 assert_eq!(bind_principal(input, "Alice"), input);
1085 }
1086
1087 #[test]
1088 fn bind_principal_skips_code_fences() {
1089 let input = "Try this:\n```\nlet I = 1;\n```\nthen rerun.";
1092 assert_eq!(bind_principal(input, "Alice"), input);
1093 }
1094
1095 #[test]
1096 fn bind_principal_leaves_roman_numeral_alone() {
1097 let input = "World War I ended in 1918.";
1100 assert_eq!(bind_principal(input, "Alice"), input);
1101 }
1102
1103 #[test]
1104 fn allergy_value_extracts_common_forms() {
1105 assert_eq!(
1106 allergy_value("Alice is allergic to peanuts."),
1107 Some("peanuts".to_string())
1108 );
1109 assert_eq!(
1110 allergy_value("Alice cannot have shellfish, thanks"),
1111 Some("shellfish".to_string())
1112 );
1113 }
1114
1115 #[test]
1116 fn supplemental_cards_build_allergy_profile() {
1117 let cards = supplemental_memory_cards(
1118 "Alice is allergic to peanuts.",
1119 "Alice",
1120 42,
1121 Some("scope".to_string()),
1122 );
1123 assert_eq!(cards.len(), 1);
1124 for card in &cards {
1125 assert_eq!(card.kind, memvid_core::MemoryKind::Profile);
1126 assert_eq!(card.entity, "alice");
1127 assert_eq!(card.slot, "allergy");
1128 assert_eq!(card.value, "peanuts");
1129 assert_eq!(card.source_frame_id, 42);
1130 }
1131 }
1132
1133 #[test]
1134 fn supplemental_cards_build_manager_relationships() {
1135 let cards = supplemental_memory_cards(
1136 "Bob is Alice's manager at Acme. He reports to Carol, the VP.",
1137 "Alice",
1138 42,
1139 Some("scope".to_string()),
1140 );
1141 assert!(cards.iter().any(|card| {
1142 card.kind == memvid_core::MemoryKind::Relationship
1143 && card.entity == "alice"
1144 && card.slot == "manager"
1145 && card.value == "Bob"
1146 }));
1147 assert!(cards.iter().any(|card| {
1148 card.kind == memvid_core::MemoryKind::Relationship
1149 && card.entity == "bob"
1150 && card.slot == "reports_to"
1151 && card.value == "Carol"
1152 }));
1153 assert!(cards.iter().any(|card| {
1154 card.kind == memvid_core::MemoryKind::Fact
1155 && card.entity == "bob"
1156 && card.slot == "employer"
1157 && card.value == "Acme"
1158 }));
1159 assert!(cards.iter().any(|card| {
1160 card.kind == memvid_core::MemoryKind::Profile
1161 && card.entity == "carol"
1162 && card.slot == "title"
1163 && card.value == "VP"
1164 }));
1165 }
1166
1167 #[test]
1168 fn builder_matches_default() {
1169 let from_default = super::MemoryConfig::default();
1170 let from_builder = super::MemoryConfig::builder().build();
1171 assert_eq!(from_builder.commit_each_turn, from_default.commit_each_turn);
1173 assert_eq!(from_builder.default_tags, from_default.default_tags);
1174 assert_eq!(from_builder.scope, from_default.scope);
1175 assert_eq!(from_builder.principal, from_default.principal);
1176 assert_eq!(
1177 from_builder.persist_assistant,
1178 from_default.persist_assistant
1179 );
1180 assert_eq!(
1181 from_builder.supplemental_profile_cards,
1182 from_default.supplemental_profile_cards
1183 );
1184 assert_eq!(from_builder.auto_tag, from_default.auto_tag);
1185 assert_eq!(from_builder.extract_dates, from_default.extract_dates);
1186 assert_eq!(from_builder.extract_triplets, from_default.extract_triplets);
1187 assert_eq!(
1188 from_builder.observe_conversation_id,
1189 from_default.observe_conversation_id
1190 );
1191 assert!(matches!(
1192 from_builder.on_write_failure,
1193 super::WriteFailure::Warn
1194 ));
1195 assert_eq!(
1196 from_builder.rewrite_principal_pronouns,
1197 from_default.rewrite_principal_pronouns
1198 );
1199 assert!(from_default.rewrite_principal_pronouns);
1200 }
1201
1202 #[test]
1203 fn builder_overrides_each_field() {
1204 let cfg = super::MemoryConfig::builder()
1205 .commit_each_turn(false)
1206 .default_tags(vec!["t1".into()])
1207 .scope(Some("scope".into()))
1208 .principal(Some("Alice".into()))
1209 .persist_assistant(false)
1210 .supplemental_profile_cards(false)
1211 .auto_tag(false)
1212 .extract_dates(false)
1213 .extract_triplets(false)
1214 .observe_conversation_id(Some("conv-1".into()))
1215 .on_write_failure(super::WriteFailure::Halt)
1216 .rewrite_principal_pronouns(false)
1217 .build();
1218 assert!(!cfg.commit_each_turn);
1219 assert_eq!(cfg.default_tags, vec!["t1".to_string()]);
1220 assert_eq!(cfg.scope.as_deref(), Some("scope"));
1221 assert_eq!(cfg.principal.as_deref(), Some("Alice"));
1222 assert!(!cfg.persist_assistant);
1223 assert!(!cfg.supplemental_profile_cards);
1224 assert!(!cfg.auto_tag);
1225 assert!(!cfg.extract_dates);
1226 assert!(!cfg.extract_triplets);
1227 assert_eq!(cfg.observe_conversation_id.as_deref(), Some("conv-1"));
1228 assert!(matches!(cfg.on_write_failure, super::WriteFailure::Halt));
1229 assert!(!cfg.rewrite_principal_pronouns);
1230 }
1231}