Skip to main content

rig_memvid/
hook.rs

1//! [`MemvidPersistHook`]: a [`PromptHook`] that persists every turn of an
2//! agent conversation into a [`MemvidStore`].
3
4use std::marker::PhantomData;
5use std::sync::Arc;
6
7use memvid_core::{MemoryCard, MemoryCardBuilder, PutOptions};
8use rig::{
9    agent::{HookAction, PromptHook},
10    completion::{CompletionModel, CompletionResponse, Message},
11};
12
13use crate::store::MemvidStore;
14
15/// A function that decides what (if anything) to persist for a single
16/// message. Returning `None` skips the message.
17///
18/// Returning `Some("")` is treated identically to `None`: empty payloads
19/// are never written to the archive.
20pub type WriteTransform = Arc<dyn Fn(&Message) -> Option<String> + Send + Sync + 'static>;
21
22/// Strategy for what to write into the memvid archive on each turn.
23#[derive(Clone, Default)]
24pub enum WritePolicy {
25    /// Do not persist anything. The hook becomes a no-op (useful for toggling
26    /// memory at runtime without removing the hook).
27    Disabled,
28    /// Persist the verbatim text of every user prompt and assistant response.
29    #[default]
30    Raw,
31    /// Apply the supplied transform to each message and persist its result
32    /// (or nothing, if the transform returns `None`).
33    ///
34    /// This is the extension point for caller-defined summarisation, PII
35    /// redaction, or selective filtering.
36    Custom(WriteTransform),
37}
38
39impl std::fmt::Debug for WritePolicy {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            Self::Disabled => f.write_str("WritePolicy::Disabled"),
43            Self::Raw => f.write_str("WritePolicy::Raw"),
44            Self::Custom(_) => f.write_str("WritePolicy::Custom(<fn>)"),
45        }
46    }
47}
48
49/// Callback type for [`WriteFailure::Custom`]. Receives the failing
50/// phase plus the underlying [`crate::MemvidError`] and returns a
51/// [`WriteFailureAction`] telling the hook whether to keep going or
52/// halt the agent.
53pub type WriteFailureCallback =
54    Arc<dyn Fn(WriteFailurePhase, &crate::MemvidError) -> WriteFailureAction + Send + Sync>;
55
56/// What the [`MemvidPersistHook`] does when a frame fails to write.
57///
58/// Defaults to [`WriteFailure::Warn`] so existing behaviour is preserved:
59/// the failure is logged via `tracing::warn!` and the turn continues. Operators
60/// who would rather halt the agent than silently lose memory writes can
61/// switch to [`WriteFailure::Halt`]; advanced callers can install a
62/// [`WriteFailure::Custom`] callback (e.g. to update a counter, page on
63/// an SLO breach, or fall back to a sidecar log).
64#[derive(Clone, Default)]
65pub enum WriteFailure {
66    /// Log the failure at `WARN` and continue the turn. Default — matches
67    /// pre-0.2 behaviour.
68    #[default]
69    Warn,
70    /// Log the failure at `ERROR`, emit a `tracing` event, and signal the
71    /// agent to stop via `HookAction::Terminate` on the next return path.
72    /// Useful when the agent's value is bounded by durable memory.
73    Halt,
74    /// Run a caller-provided callback with the failure phase and error.
75    /// Returning [`WriteFailureAction::Continue`] keeps the turn alive;
76    /// [`WriteFailureAction::Halt`] stops the agent.
77    Custom(WriteFailureCallback),
78}
79
80impl std::fmt::Debug for WriteFailure {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        match self {
83            Self::Warn => f.write_str("WriteFailure::Warn"),
84            Self::Halt => f.write_str("WriteFailure::Halt"),
85            Self::Custom(_) => f.write_str("WriteFailure::Custom(<fn>)"),
86        }
87    }
88}
89
90/// Which stage of frame persistence raised the error.
91#[derive(Clone, Copy, Debug, PartialEq, Eq)]
92#[non_exhaustive]
93pub enum WriteFailurePhase {
94    /// `put_text_uncommitted` failed; the frame text was not appended.
95    Put,
96    /// `put_memory_card` failed; the structured card was not appended.
97    PutCard,
98    /// `commit` failed; previously appended frames may or may not be flushed
99    /// depending on the underlying memvid error.
100    Commit,
101}
102
103/// Decision returned by a [`WriteFailure::Custom`] callback.
104#[derive(Clone, Copy, Debug, PartialEq, Eq)]
105#[non_exhaustive]
106pub enum WriteFailureAction {
107    /// Keep processing the turn (matches the default warn-and-continue path).
108    Continue,
109    /// Halt the agent on the next return point.
110    Halt,
111}
112
113/// Configuration for [`MemvidPersistHook`].
114///
115/// `MemoryConfig` is `#[non_exhaustive]` so new fields can be added without a
116/// SemVer-major bump; build instances through [`MemoryConfig::default`] +
117/// field updates, or through [`MemoryConfig::builder`] for a fluent shape.
118#[derive(Clone, Debug)]
119#[non_exhaustive]
120pub struct MemoryConfig {
121    /// What to persist on each turn.
122    pub policy: WritePolicy,
123    /// If `true`, call `commit()` after every turn so the new frames are
124    /// immediately searchable. If `false`, the caller is responsible for
125    /// committing periodically.
126    pub commit_each_turn: bool,
127    /// Tags applied to every persisted frame, useful for later filtering.
128    pub default_tags: Vec<String>,
129    /// Logical scope written into the frame's URI prefix. When set, every
130    /// frame produced by this hook is stored with `PutOptions.uri = Some(scope)`,
131    /// which makes `MemvidFilter::eq("scope", scope)` match those
132    /// frames at query time (memvid's `scope` is a URI prefix filter).
133    pub scope: Option<String>,
134    /// Stable identity for the human side of the conversation.
135    ///
136    /// When set, user turns are lightly rewritten before memvid sees
137    /// them so first-person pronouns resolve to this principal. For
138    /// example, with `principal = Some("Alice".into())`, `I like
139    /// espresso` is persisted as `Alice likes espresso`. This improves
140    /// memvid's entity / slot / value extraction without requiring an
141    /// LLM or a new runtime dependency.
142    pub principal: Option<String>,
143    /// If `true`, persist assistant responses as well as user turns.
144    ///
145    /// Defaults to `true` to preserve the full conversation transcript.
146    /// Set to `false` when the archive is primarily used for user profile
147    /// memory and assistant paraphrases would add noisy duplicate cards.
148    pub persist_assistant: bool,
149    /// Add small deterministic cards for principal-bound user turns when
150    /// memvid's built-in triplet extractor misses common user-profile or
151    /// relationship facts.
152    ///
153    /// Currently covers allergy / avoidance statements and simple
154    /// manager / reporting statements after [`Self::principal`] has
155    /// bound first-person pronouns to the stable entity. Defaults to
156    /// `true`; it is a no-op when `principal` is `None`.
157    pub supplemental_profile_cards: bool,
158    /// Run memvid's auto-tagger over each persisted frame to attach
159    /// extracted entity / topic tags. Defaults to `true`, mirroring
160    /// [`memvid_core::PutOptions::default`].
161    pub auto_tag: bool,
162    /// Run memvid's date extractor over each persisted frame so dates
163    /// mentioned in conversation become queryable. Defaults to `true`.
164    pub extract_dates: bool,
165    /// Extract Subject-Predicate-Object triplets from each persisted
166    /// frame and store them as [`memvid_core::MemoryCard`]s on the
167    /// memories track. Cards become queryable through
168    /// [`crate::MemvidStore::entity_memories`],
169    /// [`crate::MemvidStore::current_memory`],
170    /// [`crate::MemvidStore::entity_preferences`], and the rest of the
171    /// memory-card surface. Defaults to `true`.
172    pub extract_triplets: bool,
173    /// Conversation ID stamped on `rig_tap` events emitted by this
174    /// hook (`memory.frame_written`). When `None`, the hook falls back to
175    /// [`Self::scope`] and finally to `"default"` so existing consumers
176    /// keep working, but explicitly setting this field is preferred: it
177    /// decouples telemetry correlation from memvid's URI-prefix scope.
178    /// No effect when the `observe` feature is off.
179    pub observe_conversation_id: Option<String>,
180    /// What to do when a frame fails to write. Defaults to
181    /// [`WriteFailure::Warn`] (log + continue) to preserve pre-0.2
182    /// behaviour. Switch to [`WriteFailure::Halt`] when durable memory is a
183    /// hard requirement, or supply a [`WriteFailure::Custom`] callback to
184    /// route the failure into your own metrics / alerting.
185    pub on_write_failure: WriteFailure,
186    /// Whether to rewrite first-person pronouns in user turns into the
187    /// configured [`Self::principal`] (English-only heuristic). Defaults
188    /// to `true`, but the rewrite is also a no-op when `principal` is
189    /// `None`. Set to `false` to disable the heuristic entirely (e.g.
190    /// for non-English transcripts or when callers already canonicalise
191    /// text upstream).
192    ///
193    /// The rewrite also short-circuits on a per-text basis when the
194    /// turn contains a triple-backtick code fence or a balanced
195    /// double-quoted span containing the standalone token `I`; both are
196    /// strong signals that the literal `I` is quoted speech or code and
197    /// must not be rewritten.
198    pub rewrite_principal_pronouns: bool,
199}
200
201impl Default for MemoryConfig {
202    fn default() -> Self {
203        Self {
204            policy: WritePolicy::default(),
205            commit_each_turn: true,
206            default_tags: Vec::new(),
207            scope: None,
208            principal: None,
209            persist_assistant: true,
210            supplemental_profile_cards: true,
211            auto_tag: true,
212            extract_dates: true,
213            extract_triplets: true,
214            observe_conversation_id: None,
215            on_write_failure: WriteFailure::default(),
216            rewrite_principal_pronouns: true,
217        }
218    }
219}
220
221impl MemoryConfig {
222    /// Start a [`MemoryConfigBuilder`] for fluent construction. Equivalent
223    /// to `MemoryConfigBuilder::default()` but easier to discover.
224    pub fn builder() -> MemoryConfigBuilder {
225        MemoryConfigBuilder::default()
226    }
227}
228
229/// Fluent builder for [`MemoryConfig`].
230///
231/// `MemoryConfig` is `#[non_exhaustive]` to keep adding fields
232/// SemVer-additive. Prefer this builder over struct-literal construction
233/// so future fields land transparently.
234///
235/// ```rust,no_run
236/// use rig_memvid::{MemoryConfig, WritePolicy};
237/// let config = MemoryConfig::builder()
238///     .policy(WritePolicy::Raw)
239///     .commit_each_turn(false)
240///     .principal(Some("Alice".into()))
241///     .persist_assistant(false)
242///     .build();
243/// # let _ = config;
244/// ```
245#[derive(Clone, Debug, Default)]
246pub struct MemoryConfigBuilder {
247    config: MemoryConfig,
248}
249
250impl MemoryConfigBuilder {
251    /// Set the [`WritePolicy`].
252    pub fn policy(mut self, policy: WritePolicy) -> Self {
253        self.config.policy = policy;
254        self
255    }
256    /// Whether to `commit()` after every turn.
257    pub fn commit_each_turn(mut self, commit_each_turn: bool) -> Self {
258        self.config.commit_each_turn = commit_each_turn;
259        self
260    }
261    /// Tags applied to every persisted frame.
262    pub fn default_tags(mut self, tags: Vec<String>) -> Self {
263        self.config.default_tags = tags;
264        self
265    }
266    /// Logical scope (URI prefix) for every persisted frame.
267    pub fn scope(mut self, scope: Option<String>) -> Self {
268        self.config.scope = scope;
269        self
270    }
271    /// Stable identity for the human side of the conversation.
272    pub fn principal(mut self, principal: Option<String>) -> Self {
273        self.config.principal = principal;
274        self
275    }
276    /// Whether to persist assistant turns.
277    pub fn persist_assistant(mut self, persist_assistant: bool) -> Self {
278        self.config.persist_assistant = persist_assistant;
279        self
280    }
281    /// Whether to add deterministic supplemental profile / relationship cards.
282    pub fn supplemental_profile_cards(mut self, on: bool) -> Self {
283        self.config.supplemental_profile_cards = on;
284        self
285    }
286    /// Run memvid's auto-tagger over each persisted frame.
287    pub fn auto_tag(mut self, on: bool) -> Self {
288        self.config.auto_tag = on;
289        self
290    }
291    /// Run memvid's date extractor over each persisted frame.
292    pub fn extract_dates(mut self, on: bool) -> Self {
293        self.config.extract_dates = on;
294        self
295    }
296    /// Run memvid's triplet extractor over each persisted frame.
297    pub fn extract_triplets(mut self, on: bool) -> Self {
298        self.config.extract_triplets = on;
299        self
300    }
301    /// Telemetry conversation ID for the `observe` feature.
302    pub fn observe_conversation_id(mut self, id: Option<String>) -> Self {
303        self.config.observe_conversation_id = id;
304        self
305    }
306    /// Policy for frame-write failures.
307    pub fn on_write_failure(mut self, policy: WriteFailure) -> Self {
308        self.config.on_write_failure = policy;
309        self
310    }
311    /// Enable or disable the English-only principal-pronoun rewrite.
312    pub fn rewrite_principal_pronouns(mut self, on: bool) -> Self {
313        self.config.rewrite_principal_pronouns = on;
314        self
315    }
316    /// Finalise the builder and return a [`MemoryConfig`].
317    pub fn build(self) -> MemoryConfig {
318        self.config
319    }
320}
321
322/// Hook that records every user prompt and assistant response into a
323/// [`MemvidStore`].
324///
325/// The hook is generic over the [`CompletionModel`] so the same store can be
326/// shared between agents that use different providers.
327pub struct MemvidPersistHook<M> {
328    store: MemvidStore,
329    config: MemoryConfig,
330    /// Set by [`MemvidPersistHook::write`] when [`WriteFailure::Halt`]
331    /// (or a [`WriteFailure::Custom`] callback returning
332    /// [`WriteFailureAction::Halt`]) fires. The next call into a
333    /// `PromptHook` method observes this flag and returns
334    /// [`HookAction::terminate`] so the agent loop terminates with the
335    /// failure surfaced through `tracing::error!`.
336    halt: Arc<std::sync::atomic::AtomicBool>,
337    _model: PhantomData<fn() -> M>,
338}
339
340impl<M> Clone for MemvidPersistHook<M> {
341    fn clone(&self) -> Self {
342        Self {
343            store: self.store.clone(),
344            config: self.config.clone(),
345            halt: self.halt.clone(),
346            _model: PhantomData,
347        }
348    }
349}
350
351impl<M> std::fmt::Debug for MemvidPersistHook<M> {
352    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353        f.debug_struct("MemvidPersistHook")
354            .field("config", &self.config)
355            .finish_non_exhaustive()
356    }
357}
358
359impl<M> MemvidPersistHook<M> {
360    /// Create a new hook persisting into `store` according to `config`.
361    pub fn new(store: MemvidStore, config: MemoryConfig) -> Self {
362        Self {
363            store,
364            config,
365            halt: Arc::new(std::sync::atomic::AtomicBool::new(false)),
366            _model: PhantomData,
367        }
368    }
369
370    /// Convenience: build a hook with the default [`MemoryConfig`]
371    /// ([`WritePolicy::Raw`], `commit_each_turn = true`).
372    pub fn with_defaults(store: MemvidStore) -> Self {
373        Self::new(store, MemoryConfig::default())
374    }
375
376    fn render(&self, msg: &Message) -> Option<String> {
377        match &self.config.policy {
378            WritePolicy::Disabled => None,
379            WritePolicy::Raw => render_message_text(msg),
380            WritePolicy::Custom(f) => f(msg),
381        }
382    }
383
384    fn put_options(&self, chat_role: &str) -> PutOptions {
385        let mut opts = PutOptions {
386            tags: self.config.default_tags.clone(),
387            auto_tag: self.config.auto_tag,
388            extract_dates: self.config.extract_dates,
389            extract_triplets: self.config.extract_triplets,
390            ..PutOptions::default()
391        };
392        opts.extra_metadata
393            .insert("chat_role".into(), chat_role.into());
394        if let Some(scope) = self.config.scope.as_deref() {
395            // Memvid's `scope` search filter matches against frame URIs by
396            // prefix, so attach the scope as the URI. Also stash it under
397            // `extra_metadata["scope"]` for ergonomic introspection by
398            // tools that walk frames directly.
399            opts.uri = Some(scope.to_string());
400            opts.extra_metadata.insert("scope".into(), scope.into());
401        }
402        opts
403    }
404
405    fn write(&self, text: &str, chat_role: &str) {
406        if text.is_empty() {
407            return;
408        }
409        let text = if chat_role == "user" && self.config.rewrite_principal_pronouns {
410            self.config
411                .principal
412                .as_deref()
413                .map(|principal| bind_principal(text, principal))
414                .unwrap_or_else(|| text.to_string())
415        } else {
416            text.to_string()
417        };
418        let opts = self.put_options(chat_role);
419        let scope = self.config.scope.clone();
420        let frame_id = match self.store.put_text_uncommitted(&text, opts) {
421            Ok(frame_id) => frame_id,
422            Err(err) => {
423                self.handle_write_failure(WriteFailurePhase::Put, chat_role, &err);
424                return;
425            }
426        };
427        #[cfg(feature = "observe")]
428        rig_tap::emit_kind(
429            self.config
430                .observe_conversation_id
431                .as_deref()
432                .or(scope.as_deref())
433                .unwrap_or("default"),
434            rig_tap::EventKind::MemoryFrameWritten {
435                frame_kind: "turn".to_string(),
436                // memvid does not expose a cheap cumulative frame count
437                // from this hot path.
438                frame_count_after: None,
439                bytes_written: text.len(),
440            },
441        );
442
443        if chat_role == "user"
444            && self.config.supplemental_profile_cards
445            && let Some(principal) = self.config.principal.as_deref()
446        {
447            for card in supplemental_memory_cards(&text, principal, frame_id, scope.clone()) {
448                if let Err(err) = self.store.put_memory_card(card) {
449                    self.handle_write_failure(WriteFailurePhase::PutCard, chat_role, &err);
450                }
451            }
452        }
453
454        if self.config.commit_each_turn
455            && let Err(err) = self.store.commit()
456        {
457            self.handle_write_failure(WriteFailurePhase::Commit, chat_role, &err);
458        }
459    }
460
461    /// Apply the configured [`WriteFailure`] policy to a single failure.
462    /// Always logs through `tracing`; may also flip `self.halt` so the next
463    /// trait method returns [`HookAction::terminate`].
464    fn handle_write_failure(
465        &self,
466        phase: WriteFailurePhase,
467        chat_role: &str,
468        err: &crate::MemvidError,
469    ) {
470        let phase_str = match phase {
471            WriteFailurePhase::Put => "put",
472            WriteFailurePhase::PutCard => "put_card",
473            WriteFailurePhase::Commit => "commit",
474        };
475        match &self.config.on_write_failure {
476            WriteFailure::Warn => {
477                tracing::warn!(
478                    target: "rig_memvid::hook",
479                    error = %err,
480                    role = chat_role,
481                    phase = phase_str,
482                    "failed to persist into memvid",
483                );
484            }
485            WriteFailure::Halt => {
486                tracing::error!(
487                    target: "rig_memvid::hook",
488                    error = %err,
489                    role = chat_role,
490                    phase = phase_str,
491                    "failed to persist into memvid; halting agent per WriteFailure::Halt",
492                );
493                self.halt.store(true, std::sync::atomic::Ordering::SeqCst);
494            }
495            WriteFailure::Custom(callback) => {
496                let action = (callback)(phase, err);
497                if matches!(action, WriteFailureAction::Halt) {
498                    tracing::error!(
499                        target: "rig_memvid::hook",
500                        error = %err,
501                        role = chat_role,
502                        phase = phase_str,
503                        "failed to persist into memvid; halting agent per WriteFailure::Custom",
504                    );
505                    self.halt.store(true, std::sync::atomic::Ordering::SeqCst);
506                } else {
507                    tracing::warn!(
508                        target: "rig_memvid::hook",
509                        error = %err,
510                        role = chat_role,
511                        phase = phase_str,
512                        "failed to persist into memvid (Custom policy: continue)",
513                    );
514                }
515            }
516        }
517    }
518
519    /// Whether the hook has been asked to halt the agent (latched after a
520    /// write failure under [`WriteFailure::Halt`] or a Custom callback
521    /// returning [`WriteFailureAction::Halt`]). Visible to tests and to the
522    /// `PromptHook` impl below.
523    fn should_halt(&self) -> bool {
524        self.halt.load(std::sync::atomic::Ordering::SeqCst)
525    }
526}
527
528fn supplemental_memory_cards(
529    text: &str,
530    principal: &str,
531    frame_id: u64,
532    source_uri: Option<String>,
533) -> Vec<MemoryCard> {
534    let mut cards = Vec::new();
535    if let Some(value) = allergy_value(text)
536        && let Some(card) = profile_card(
537            &principal.to_lowercase(),
538            "allergy",
539            &value,
540            frame_id,
541            source_uri.clone(),
542        )
543    {
544        cards.push(card);
545    }
546    cards.extend(relationship_cards(text, principal, frame_id, source_uri));
547    cards
548}
549
550fn profile_card(
551    entity: &str,
552    slot: &str,
553    value: &str,
554    frame_id: u64,
555    source_uri: Option<String>,
556) -> Option<MemoryCard> {
557    MemoryCardBuilder::new()
558        .profile()
559        .entity(normalize_entity(entity))
560        .slot(slot)
561        .value(value.trim())
562        .source(frame_id, source_uri)
563        .engine("rig-memvid:principal-rules", "2")
564        .confidence(1.0)
565        .build(0)
566        .ok()
567}
568
569fn relationship_card(
570    entity: &str,
571    slot: &str,
572    value: &str,
573    frame_id: u64,
574    source_uri: Option<String>,
575) -> Option<MemoryCard> {
576    MemoryCardBuilder::new()
577        .relationship()
578        .entity(normalize_entity(entity))
579        .slot(slot)
580        .value(value.trim())
581        .source(frame_id, source_uri)
582        .engine("rig-memvid:principal-rules", "2")
583        .confidence(1.0)
584        .build(0)
585        .ok()
586}
587
588fn fact_card(
589    entity: &str,
590    slot: &str,
591    value: &str,
592    frame_id: u64,
593    source_uri: Option<String>,
594) -> Option<MemoryCard> {
595    MemoryCardBuilder::new()
596        .fact()
597        .entity(normalize_entity(entity))
598        .slot(slot)
599        .value(value.trim())
600        .source(frame_id, source_uri)
601        .engine("rig-memvid:principal-rules", "2")
602        .confidence(1.0)
603        .build(0)
604        .ok()
605}
606
607fn relationship_cards(
608    text: &str,
609    principal: &str,
610    frame_id: u64,
611    source_uri: Option<String>,
612) -> Vec<MemoryCard> {
613    let mut cards = Vec::new();
614    let Some(manager) = manager_subject(text, principal) else {
615        return cards;
616    };
617
618    if let Some(card) =
619        relationship_card(principal, "manager", &manager, frame_id, source_uri.clone())
620    {
621        cards.push(card);
622    }
623
624    if let Some(employer) = manager_employer(text, principal)
625        && let Some(card) = fact_card(
626            &manager,
627            "employer",
628            &employer,
629            frame_id,
630            source_uri.clone(),
631        )
632    {
633        cards.push(card);
634    }
635
636    if let Some(report) = reports_to(text, &manager) {
637        if let Some(card) = relationship_card(
638            &manager,
639            "reports_to",
640            &report.manager,
641            frame_id,
642            source_uri.clone(),
643        ) {
644            cards.push(card);
645        }
646        if let Some(title) = report.manager_title
647            && let Some(card) = profile_card(
648                &report.manager,
649                "title",
650                &title,
651                frame_id,
652                source_uri.clone(),
653            )
654        {
655            cards.push(card);
656        }
657    }
658
659    cards
660}
661
662fn manager_subject(text: &str, principal: &str) -> Option<String> {
663    let lower = text.to_lowercase();
664    let marker = format!(" is {}'s manager", principal.to_lowercase());
665    let idx = lower.find(&marker)?;
666    let before = text.get(..idx)?.trim();
667    last_name(before)
668}
669
670fn manager_employer(text: &str, principal: &str) -> Option<String> {
671    let lower = text.to_lowercase();
672    let marker = format!(" is {}'s manager at ", principal.to_lowercase());
673    let idx = lower.find(&marker)? + marker.len();
674    let raw = text.get(idx..)?;
675    clean_clause(raw, &['.', '!', '?', ';', ',', '\n'])
676}
677
678struct ReportsTo {
679    manager: String,
680    manager_title: Option<String>,
681}
682
683fn reports_to(text: &str, subject: &str) -> Option<ReportsTo> {
684    let lower = text.to_lowercase();
685    let subject_marker = format!("{} reports to ", subject.to_lowercase());
686    let start = if let Some(idx) = lower.find(&subject_marker) {
687        idx + subject_marker.len()
688    } else if let Some(idx) = lower.find(" he reports to ") {
689        idx + " he reports to ".len()
690    } else if let Some(idx) = lower.find(" she reports to ") {
691        idx + " she reports to ".len()
692    } else {
693        return None;
694    };
695    let raw = text.get(start..)?;
696    let sentence = clean_clause(raw, &['.', '!', '?', ';', '\n'])?;
697    let mut parts = sentence.splitn(2, ',');
698    let manager = clean_name(parts.next()?)?;
699    let manager_title = parts.next().and_then(clean_title);
700    Some(ReportsTo {
701        manager,
702        manager_title,
703    })
704}
705
706fn last_name(text: &str) -> Option<String> {
707    text.split_whitespace().rev().find_map(clean_name)
708}
709
710fn clean_name(text: &str) -> Option<String> {
711    let value = text
712        .trim()
713        .trim_matches(|c: char| !c.is_alphanumeric() && c != '_' && c != '-' && c != '\'')
714        .trim();
715    (!value.is_empty()).then(|| value.to_string())
716}
717
718fn clean_title(text: &str) -> Option<String> {
719    let value = text
720        .trim()
721        .strip_prefix("the ")
722        .unwrap_or_else(|| text.trim())
723        .trim()
724        .trim_matches(|c: char| !c.is_alphanumeric() && c != ' ' && c != '_' && c != '-')
725        .trim();
726    (!value.is_empty()).then(|| value.to_string())
727}
728
729fn clean_clause(text: &str, delimiters: &[char]) -> Option<String> {
730    let value = text
731        .split(|c| delimiters.contains(&c))
732        .next()?
733        .trim()
734        .trim_matches(|c: char| !c.is_alphanumeric() && c != ' ' && c != '_' && c != '-')
735        .trim();
736    // L2: strip common corporate-entity suffixes so an extractor seeing
737    // `Acme Corp.` or `Initech, LLC` materialises a card whose value is
738    // just `Acme` / `Initech`, matching what downstream relationship
739    // queries (and humans) actually look for. Order matters: longest
740    // suffix first so `Corporation` strips before `Corp`.
741    const CORP_SUFFIXES: &[&str] = &[
742        " incorporated",
743        " corporation",
744        " company",
745        " limited",
746        " inc",
747        " corp",
748        " llc",
749        " ltd",
750        " co",
751    ];
752    let lowered = value.to_lowercase();
753    let stripped = CORP_SUFFIXES
754        .iter()
755        .find_map(|suffix| lowered.strip_suffix(suffix).map(|head| head.len()))
756        .and_then(|head_len| value.get(..head_len))
757        .map(str::trim)
758        .unwrap_or(value);
759    (!stripped.is_empty()).then(|| stripped.to_string())
760}
761
762fn normalize_entity(entity: &str) -> String {
763    entity.trim().to_lowercase()
764}
765
766fn allergy_value(text: &str) -> Option<String> {
767    let lower = text.to_lowercase();
768    let start = if let Some(idx) = lower.find(" allergic to ") {
769        idx + " allergic to ".len()
770    } else if let Some(idx) = lower.find(" allergy to ") {
771        idx + " allergy to ".len()
772    } else if let Some(idx) = lower.find(" cannot have ") {
773        idx + " cannot have ".len()
774    } else if let Some(idx) = lower.find(" can't have ") {
775        idx + " can't have ".len()
776    } else {
777        return None;
778    };
779    let raw = text.get(start..)?;
780    let value = raw
781        .split(['.', '!', '?', ';', ',', '\n'])
782        .next()?
783        .trim()
784        .trim_matches(|c: char| matches!(c, '.' | '!' | '?' | ';' | ',' | ':' | ' '));
785    (!value.is_empty()).then(|| value.to_string())
786}
787
788/// Returns `true` if `text` contains a balanced ASCII double-quoted span
789/// (`"..."`) whose contents include the standalone token `I`. Used by
790/// [`bind_principal`] to refuse rewriting quoted speech.
791fn quoted_span_contains_first_person(text: &str) -> bool {
792    let mut in_quote = false;
793    let mut span_start: usize = 0;
794    for (idx, ch) in text.char_indices() {
795        if ch != '"' {
796            continue;
797        }
798        if !in_quote {
799            in_quote = true;
800            span_start = idx + ch.len_utf8();
801        } else {
802            in_quote = false;
803            if let Some(span) = text.get(span_start..idx) {
804                for tok in span.split_whitespace() {
805                    let core = tok.trim_matches(|c: char| !c.is_alphanumeric() && c != '\'');
806                    if core == "I" {
807                        return true;
808                    }
809                }
810            }
811        }
812    }
813    false
814}
815
816fn bind_principal(text: &str, principal: &str) -> String {
817    let principal = principal.trim();
818    if principal.is_empty() {
819        return text.to_string();
820    }
821
822    // Safety short-circuits: if the turn contains a triple-backtick code
823    // fence, or a balanced double-quoted span containing the standalone
824    // token `I`, refuse to rewrite. Quoted speech / code routinely
825    // contains literal `I` that does not refer to the speaker.
826    if text.contains("```") || quoted_span_contains_first_person(text) {
827        return text.to_string();
828    }
829
830    let lower = text.to_lowercase();
831    let name_prefix = format!("my name is {} and i ", principal.to_lowercase());
832    if lower.starts_with(&name_prefix)
833        && let Some(rest) = text.get(name_prefix.len() - "i ".len()..)
834    {
835        return bind_principal(rest, principal);
836    }
837
838    let mut output = Vec::new();
839    let mut tokens = text.split_whitespace().peekable();
840    while let Some(token) = tokens.next() {
841        let core = token_core_lower(token);
842        if core != "i" {
843            output.push(bind_token(token, principal));
844            continue;
845        }
846
847        if let Some(next) = tokens.peek() {
848            let next_core = token_core_lower(next);
849            if next_core == "really" {
850                let really = tokens.next();
851                if let (Some(really_token), Some(verb_token)) = (really, tokens.peek()) {
852                    let verb_core = token_core_lower(verb_token);
853                    if let Some(verb) = principal_verb(&verb_core) {
854                        let suffix = token_suffix(verb_token);
855                        let _ = tokens.next();
856                        output.push(format!("{principal} {really_token} {verb}{suffix}"));
857                        continue;
858                    }
859                }
860                output.push(principal.to_string());
861                if let Some(really_token) = really {
862                    output.push(really_token.to_string());
863                }
864                continue;
865            }
866            if let Some(verb) = principal_verb(&next_core) {
867                let suffix = token_suffix(next);
868                let _ = tokens.next();
869                output.push(format!("{principal} {verb}{suffix}"));
870                continue;
871            }
872        }
873        // Bare standalone `I` with no recognised verb follower: leave it
874        // alone. This avoids misrewriting Roman numerals (`World War I`),
875        // section headers, or any other context where `I` is not a
876        // first-person pronoun. Recognised pronoun forms (`my`, `me`,
877        // `I'm`, etc.) are handled by the `core != "i"` branch above.
878        output.push(token.to_string());
879    }
880    output.join(" ")
881}
882
883fn token_core_lower(token: &str) -> String {
884    token
885        .trim_matches(|c: char| !c.is_alphanumeric() && c != '\'')
886        .to_lowercase()
887}
888
889fn token_suffix(token: &str) -> String {
890    token
891        .chars()
892        .rev()
893        .take_while(|c| !c.is_alphanumeric() && *c != '\'')
894        .collect::<Vec<_>>()
895        .into_iter()
896        .rev()
897        .collect()
898}
899
900fn principal_verb(core: &str) -> Option<&'static str> {
901    match core {
902        "like" => Some("likes"),
903        "dislike" => Some("dislikes"),
904        "live" => Some("lives"),
905        "work" => Some("works"),
906        "grew" => Some("grew"),
907        "prefer" => Some("prefers"),
908        "love" => Some("loves"),
909        "hate" => Some("hates"),
910        "want" => Some("wants"),
911        "need" => Some("needs"),
912        "am" => Some("is"),
913        "have" => Some("has"),
914        _ => None,
915    }
916}
917
918fn bind_token(token: &str, principal: &str) -> String {
919    let suffix = token_suffix(token);
920    let core = token_core_lower(token);
921    let replacement = match core.as_str() {
922        "i" => Some(principal.to_string()),
923        "me" | "myself" => Some(principal.to_string()),
924        "my" | "mine" => Some(format!("{principal}'s")),
925        "i'm" | "im" => Some(format!("{principal} is")),
926        "i've" | "ive" => Some(format!("{principal} has")),
927        "i'd" | "id" => Some(format!("{principal} would")),
928        "i'll" | "ill" => Some(format!("{principal} will")),
929        _ => None,
930    };
931    match replacement {
932        Some(mut value) => {
933            value.push_str(&suffix);
934            value
935        }
936        None => token.to_string(),
937    }
938}
939
940/// Pull a textual representation out of a [`Message`].
941///
942/// `Message::rag_text` is `pub(crate)` in rig-core, so we re-implement the
943/// equivalent walk here over the public content enums.
944pub(crate) fn render_message_text(msg: &Message) -> Option<String> {
945    use rig::completion::message::{
946        AssistantContent, Message as Msg, ReasoningContent, UserContent,
947    };
948
949    match msg {
950        Msg::System { content } => Some(content.clone()),
951        Msg::User { content } => {
952            let mut buf = String::new();
953            for item in content.iter() {
954                if let UserContent::Text(text) = item {
955                    if !buf.is_empty() {
956                        buf.push('\n');
957                    }
958                    buf.push_str(&text.text);
959                }
960            }
961            (!buf.is_empty()).then_some(buf)
962        }
963        Msg::Assistant { content, .. } => {
964            let mut buf = String::new();
965            for item in content.iter() {
966                match item {
967                    AssistantContent::Text(text) => {
968                        if !buf.is_empty() {
969                            buf.push('\n');
970                        }
971                        buf.push_str(&text.text);
972                    }
973                    AssistantContent::Reasoning(reasoning) => {
974                        for entry in reasoning.content.iter() {
975                            if let ReasoningContent::Text { text, .. } = entry {
976                                if !buf.is_empty() {
977                                    buf.push('\n');
978                                }
979                                buf.push_str(text);
980                            }
981                        }
982                    }
983                    AssistantContent::ToolCall(_) | AssistantContent::Image(_) => {}
984                }
985            }
986            (!buf.is_empty()).then_some(buf)
987        }
988    }
989}
990
991impl<M> PromptHook<M> for MemvidPersistHook<M>
992where
993    M: CompletionModel,
994{
995    async fn on_completion_call(&self, prompt: &Message, _history: &[Message]) -> HookAction {
996        if let Some(text) = self.render(prompt) {
997            self.write(&text, "user");
998        }
999        if self.should_halt() {
1000            return HookAction::terminate(
1001                "rig-memvid: persistence failed under WriteFailure::Halt",
1002            );
1003        }
1004        HookAction::cont()
1005    }
1006
1007    async fn on_completion_response(
1008        &self,
1009        _prompt: &Message,
1010        response: &CompletionResponse<M::Response>,
1011    ) -> HookAction {
1012        if !self.config.persist_assistant {
1013            return HookAction::cont();
1014        }
1015        for content in response.choice.iter() {
1016            let synthetic = Message::Assistant {
1017                id: None,
1018                content: rig::OneOrMany::one(content.clone()),
1019            };
1020            if let Some(text) = self.render(&synthetic) {
1021                self.write(&text, "assistant");
1022            }
1023        }
1024        if self.should_halt() {
1025            return HookAction::terminate(
1026                "rig-memvid: persistence failed under WriteFailure::Halt",
1027            );
1028        }
1029        HookAction::cont()
1030    }
1031}
1032
1033#[cfg(test)]
1034mod tests {
1035    use super::{allergy_value, bind_principal, supplemental_memory_cards};
1036
1037    #[test]
1038    fn bind_principal_rewrites_first_person_tokens() {
1039        let rewritten = bind_principal(
1040            "My name is Alice. I'm allergic to peanuts, and I like espresso.",
1041            "Alice",
1042        );
1043        assert_eq!(
1044            rewritten,
1045            "Alice's name is Alice. Alice is allergic to peanuts, and Alice likes espresso."
1046        );
1047    }
1048
1049    #[test]
1050    fn bind_principal_rewrites_common_verbs_after_adverbs() {
1051        assert_eq!(
1052            bind_principal("I really dislike instant coffee.", "Alice"),
1053            "Alice really dislikes instant coffee."
1054        );
1055    }
1056
1057    #[test]
1058    fn bind_principal_collapses_name_intro_before_verbs() {
1059        assert_eq!(
1060            bind_principal(
1061                "My name is Alice and I work at Acme as a staff engineer.",
1062                "Alice",
1063            ),
1064            "Alice works at Acme as a staff engineer."
1065        );
1066    }
1067
1068    #[test]
1069    fn bind_principal_ignores_empty_principal() {
1070        assert_eq!(bind_principal("I like rust", "  "), "I like rust");
1071    }
1072
1073    #[test]
1074    fn bind_principal_is_idempotent() {
1075        let once = bind_principal("I like espresso and I dislike tea.", "Alice");
1076        let twice = bind_principal(&once, "Alice");
1077        assert_eq!(once, twice);
1078    }
1079
1080    #[test]
1081    fn bind_principal_skips_quoted_speech() {
1082        // Quoted `I` belongs to the speaker being quoted, not the principal.
1083        let input = "Bob said \"I love hiking\" yesterday.";
1084        assert_eq!(bind_principal(input, "Alice"), input);
1085    }
1086
1087    #[test]
1088    fn bind_principal_skips_code_fences() {
1089        // Code fences may contain literal `I` identifiers that must not be
1090        // rewritten.
1091        let input = "Try this:\n```\nlet I = 1;\n```\nthen rerun.";
1092        assert_eq!(bind_principal(input, "Alice"), input);
1093    }
1094
1095    #[test]
1096    fn bind_principal_leaves_roman_numeral_alone() {
1097        // `I` here is a Roman numeral, not a pronoun; no verb in
1098        // `principal_verb` follows, so the rewrite must be a no-op.
1099        let input = "World War I ended in 1918.";
1100        assert_eq!(bind_principal(input, "Alice"), input);
1101    }
1102
1103    #[test]
1104    fn allergy_value_extracts_common_forms() {
1105        assert_eq!(
1106            allergy_value("Alice is allergic to peanuts."),
1107            Some("peanuts".to_string())
1108        );
1109        assert_eq!(
1110            allergy_value("Alice cannot have shellfish, thanks"),
1111            Some("shellfish".to_string())
1112        );
1113    }
1114
1115    #[test]
1116    fn supplemental_cards_build_allergy_profile() {
1117        let cards = supplemental_memory_cards(
1118            "Alice is allergic to peanuts.",
1119            "Alice",
1120            42,
1121            Some("scope".to_string()),
1122        );
1123        assert_eq!(cards.len(), 1);
1124        for card in &cards {
1125            assert_eq!(card.kind, memvid_core::MemoryKind::Profile);
1126            assert_eq!(card.entity, "alice");
1127            assert_eq!(card.slot, "allergy");
1128            assert_eq!(card.value, "peanuts");
1129            assert_eq!(card.source_frame_id, 42);
1130        }
1131    }
1132
1133    #[test]
1134    fn supplemental_cards_build_manager_relationships() {
1135        let cards = supplemental_memory_cards(
1136            "Bob is Alice's manager at Acme. He reports to Carol, the VP.",
1137            "Alice",
1138            42,
1139            Some("scope".to_string()),
1140        );
1141        assert!(cards.iter().any(|card| {
1142            card.kind == memvid_core::MemoryKind::Relationship
1143                && card.entity == "alice"
1144                && card.slot == "manager"
1145                && card.value == "Bob"
1146        }));
1147        assert!(cards.iter().any(|card| {
1148            card.kind == memvid_core::MemoryKind::Relationship
1149                && card.entity == "bob"
1150                && card.slot == "reports_to"
1151                && card.value == "Carol"
1152        }));
1153        assert!(cards.iter().any(|card| {
1154            card.kind == memvid_core::MemoryKind::Fact
1155                && card.entity == "bob"
1156                && card.slot == "employer"
1157                && card.value == "Acme"
1158        }));
1159        assert!(cards.iter().any(|card| {
1160            card.kind == memvid_core::MemoryKind::Profile
1161                && card.entity == "carol"
1162                && card.slot == "title"
1163                && card.value == "VP"
1164        }));
1165    }
1166
1167    #[test]
1168    fn builder_matches_default() {
1169        let from_default = super::MemoryConfig::default();
1170        let from_builder = super::MemoryConfig::builder().build();
1171        // Spot-check every field the builder is supposed to default.
1172        assert_eq!(from_builder.commit_each_turn, from_default.commit_each_turn);
1173        assert_eq!(from_builder.default_tags, from_default.default_tags);
1174        assert_eq!(from_builder.scope, from_default.scope);
1175        assert_eq!(from_builder.principal, from_default.principal);
1176        assert_eq!(
1177            from_builder.persist_assistant,
1178            from_default.persist_assistant
1179        );
1180        assert_eq!(
1181            from_builder.supplemental_profile_cards,
1182            from_default.supplemental_profile_cards
1183        );
1184        assert_eq!(from_builder.auto_tag, from_default.auto_tag);
1185        assert_eq!(from_builder.extract_dates, from_default.extract_dates);
1186        assert_eq!(from_builder.extract_triplets, from_default.extract_triplets);
1187        assert_eq!(
1188            from_builder.observe_conversation_id,
1189            from_default.observe_conversation_id
1190        );
1191        assert!(matches!(
1192            from_builder.on_write_failure,
1193            super::WriteFailure::Warn
1194        ));
1195        assert_eq!(
1196            from_builder.rewrite_principal_pronouns,
1197            from_default.rewrite_principal_pronouns
1198        );
1199        assert!(from_default.rewrite_principal_pronouns);
1200    }
1201
1202    #[test]
1203    fn builder_overrides_each_field() {
1204        let cfg = super::MemoryConfig::builder()
1205            .commit_each_turn(false)
1206            .default_tags(vec!["t1".into()])
1207            .scope(Some("scope".into()))
1208            .principal(Some("Alice".into()))
1209            .persist_assistant(false)
1210            .supplemental_profile_cards(false)
1211            .auto_tag(false)
1212            .extract_dates(false)
1213            .extract_triplets(false)
1214            .observe_conversation_id(Some("conv-1".into()))
1215            .on_write_failure(super::WriteFailure::Halt)
1216            .rewrite_principal_pronouns(false)
1217            .build();
1218        assert!(!cfg.commit_each_turn);
1219        assert_eq!(cfg.default_tags, vec!["t1".to_string()]);
1220        assert_eq!(cfg.scope.as_deref(), Some("scope"));
1221        assert_eq!(cfg.principal.as_deref(), Some("Alice"));
1222        assert!(!cfg.persist_assistant);
1223        assert!(!cfg.supplemental_profile_cards);
1224        assert!(!cfg.auto_tag);
1225        assert!(!cfg.extract_dates);
1226        assert!(!cfg.extract_triplets);
1227        assert_eq!(cfg.observe_conversation_id.as_deref(), Some("conv-1"));
1228        assert!(matches!(cfg.on_write_failure, super::WriteFailure::Halt));
1229        assert!(!cfg.rewrite_principal_pronouns);
1230    }
1231}