use std::marker::PhantomData;
use std::sync::Arc;
use memvid_core::PutOptions;
use rig::{
agent::{HookAction, PromptHook},
completion::{CompletionModel, CompletionResponse, Message},
};
use crate::store::MemvidStore;
pub type WriteTransform = Arc<dyn Fn(&Message) -> Option<String> + Send + Sync + 'static>;
#[derive(Clone, Default)]
pub enum WritePolicy {
Disabled,
#[default]
Raw,
Custom(WriteTransform),
}
impl std::fmt::Debug for WritePolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Disabled => f.write_str("WritePolicy::Disabled"),
Self::Raw => f.write_str("WritePolicy::Raw"),
Self::Custom(_) => f.write_str("WritePolicy::Custom(<fn>)"),
}
}
}
#[derive(Clone, Debug)]
pub struct MemoryConfig {
pub policy: WritePolicy,
pub commit_each_turn: bool,
pub default_tags: Vec<String>,
pub scope: Option<String>,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
policy: WritePolicy::default(),
commit_each_turn: true,
default_tags: Vec::new(),
scope: None,
}
}
}
pub struct MemvidPersistHook<M> {
store: MemvidStore,
config: MemoryConfig,
_model: PhantomData<fn() -> M>,
}
impl<M> Clone for MemvidPersistHook<M> {
fn clone(&self) -> Self {
Self {
store: self.store.clone(),
config: self.config.clone(),
_model: PhantomData,
}
}
}
impl<M> std::fmt::Debug for MemvidPersistHook<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemvidPersistHook")
.field("config", &self.config)
.finish_non_exhaustive()
}
}
impl<M> MemvidPersistHook<M> {
pub fn new(store: MemvidStore, config: MemoryConfig) -> Self {
Self {
store,
config,
_model: PhantomData,
}
}
pub fn with_defaults(store: MemvidStore) -> Self {
Self::new(store, MemoryConfig::default())
}
fn render(&self, msg: &Message) -> Option<String> {
match &self.config.policy {
WritePolicy::Disabled => None,
WritePolicy::Raw => render_message_text(msg),
WritePolicy::Custom(f) => f(msg),
}
}
fn put_options(&self, chat_role: &str) -> PutOptions {
let mut opts = PutOptions {
tags: self.config.default_tags.clone(),
..PutOptions::default()
};
opts.extra_metadata
.insert("chat_role".into(), chat_role.into());
if let Some(scope) = self.config.scope.as_deref() {
opts.extra_metadata.insert("scope".into(), scope.into());
}
opts
}
fn write(&self, text: &str, chat_role: &str) {
if text.is_empty() {
return;
}
let opts = self.put_options(chat_role);
let result = if self.config.commit_each_turn {
self.store.put_text(text, opts)
} else {
self.store.put_text_uncommitted(text, opts)
};
if let Err(err) = result {
tracing::warn!(
target: "rig_memvid::hook",
error = %err,
role = chat_role,
"failed to persist message into memvid",
);
}
}
}
fn render_message_text(msg: &Message) -> Option<String> {
use rig::completion::message::{
AssistantContent, Message as Msg, ReasoningContent, UserContent,
};
match msg {
Msg::System { content } => Some(content.clone()),
Msg::User { content } => {
let mut buf = String::new();
for item in content.iter() {
if let UserContent::Text(text) = item {
if !buf.is_empty() {
buf.push('\n');
}
buf.push_str(&text.text);
}
}
(!buf.is_empty()).then_some(buf)
}
Msg::Assistant { content, .. } => {
let mut buf = String::new();
for item in content.iter() {
match item {
AssistantContent::Text(text) => {
if !buf.is_empty() {
buf.push('\n');
}
buf.push_str(&text.text);
}
AssistantContent::Reasoning(reasoning) => {
for entry in reasoning.content.iter() {
if let ReasoningContent::Text { text, .. } = entry {
if !buf.is_empty() {
buf.push('\n');
}
buf.push_str(text);
}
}
}
AssistantContent::ToolCall(_) | AssistantContent::Image(_) => {}
}
}
(!buf.is_empty()).then_some(buf)
}
}
}
impl<M> PromptHook<M> for MemvidPersistHook<M>
where
M: CompletionModel,
{
async fn on_completion_call(&self, prompt: &Message, _history: &[Message]) -> HookAction {
if let Some(text) = self.render(prompt) {
self.write(&text, "user");
}
HookAction::cont()
}
async fn on_completion_response(
&self,
_prompt: &Message,
response: &CompletionResponse<M::Response>,
) -> HookAction {
for content in response.choice.iter() {
let synthetic = Message::Assistant {
id: None,
content: rig::OneOrMany::one(content.clone()),
};
if let Some(text) = self.render(&synthetic) {
self.write(&text, "assistant");
}
}
HookAction::cont()
}
}