use tracing::{debug, warn};
use crate::client::Memory;
#[derive(Clone)]
pub struct MemoryMiddleware {
memory: Memory,
max_context_chars: usize,
auto_save: bool,
min_message_length: usize,
}
#[derive(Debug, Clone)]
pub struct BeforeResult {
pub memories: Vec<crate::types::MemoryRecord>,
pub system_prompt_addition: String,
pub success: bool,
}
impl MemoryMiddleware {
pub fn new(memory: Memory) -> Self {
Self {
memory,
max_context_chars: 1500,
auto_save: false,
min_message_length: 15,
}
}
pub fn with_auto_save(mut self) -> Self {
self.auto_save = true;
self
}
pub fn with_max_context(mut self, chars: usize) -> Self {
self.max_context_chars = chars;
self
}
pub fn with_min_message_length(mut self, len: usize) -> Self {
self.min_message_length = len;
self
}
pub async fn before(&self, user_message: &str) -> BeforeResult {
if user_message.len() < self.min_message_length || is_greeting(user_message) {
return BeforeResult {
memories: vec![],
system_prompt_addition: String::new(),
success: true,
};
}
match self.memory.search(user_message).limit(5).send().await {
Ok(result) => {
if result.memories.is_empty() {
return BeforeResult {
memories: vec![],
system_prompt_addition: String::new(),
success: true,
};
}
let mut addition = String::from("[EREBYX Memory Context]\n");
let mut chars = addition.len();
for m in &result.memories {
let line = format!("- {}\n", truncate_safe(&m.content, 300));
if chars + line.len() > self.max_context_chars {
break;
}
addition.push_str(&line);
chars += line.len();
}
debug!(
memories = result.memories.len(),
chars = chars,
"erebyx middleware: injecting context"
);
BeforeResult {
memories: result.memories,
system_prompt_addition: addition,
success: true,
}
}
Err(e) => {
warn!(error = %e, "erebyx middleware: memory retrieval failed (proceeding without)");
BeforeResult {
memories: vec![],
system_prompt_addition: String::new(),
success: false,
}
}
}
}
pub async fn after(&self, user_message: &str, ai_response: &str) {
if !self.auto_save {
return;
}
if user_message.len() < self.min_message_length || ai_response.len() < 50 {
return;
}
let content = format!(
"User asked: {}\n\nResponse: {}",
truncate_safe(user_message, 200),
truncate_safe(ai_response, 500)
);
match self.memory.save(&content, "episodic").send().await {
Ok(_) => debug!("erebyx middleware: auto-saved response"),
Err(e) => warn!(error = %e, "erebyx middleware: auto-save failed (non-critical)"),
}
}
}
fn is_greeting(msg: &str) -> bool {
let lower = msg.trim().to_lowercase();
let greetings = [
"hey",
"hi",
"hello",
"thanks",
"thank you",
"bye",
"ok",
"yes",
"no",
"sure",
"cool",
"nice",
"got it",
"sounds good",
"okay",
"yep",
"nope",
"alright",
];
greetings
.iter()
.any(|g| lower.starts_with(g) && lower.len() < 30)
}
fn truncate_safe(s: &str, max: usize) -> &str {
if s.len() <= max {
return s;
}
let mut end = max;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
}