use serde::{Deserialize, Serialize};
use crate::memory::error::MemoryValidationError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryImportance {
Low,
Medium,
High,
}
impl MemoryImportance {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Low => "low",
Self::Medium => "medium",
Self::High => "high",
}
}
pub fn from_db(value: &str) -> Result<Self, crate::memory::error::MemoryValidationError> {
match value {
"low" => Ok(Self::Low),
"medium" => Ok(Self::Medium),
"high" => Ok(Self::High),
_ => Err(crate::memory::error::MemoryValidationError::InvalidImportance(
value.to_string(),
)),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MemoryEntry {
pub id: String,
pub session_id: String,
pub topic: String,
pub summary: String,
pub raw_excerpt: String,
pub keywords: Vec<String>,
pub importance: MemoryImportance,
pub embedding: Option<Vec<f32>>,
pub created_at_epoch_ms: i64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RecallQuery {
pub session_id: Option<String>,
pub text: String,
pub query_embedding: Option<Vec<f32>>,
pub limit: usize,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct RecallHit {
pub entry: MemoryEntry,
pub bm25_score: f32,
pub vector_score: Option<f32>,
pub final_score: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct HybridWeights {
pub bm25: f32,
pub vector: f32,
}
impl HybridWeights {
pub fn new(bm25: f32, vector: f32) -> Result<Self, MemoryValidationError> {
if !(0.0..=1.0).contains(&bm25) || !(0.0..=1.0).contains(&vector) {
return Err(MemoryValidationError::InvalidHybridWeights { bm25, vector });
}
let total = bm25 + vector;
if total <= f32::EPSILON {
return Err(MemoryValidationError::InvalidHybridWeights { bm25, vector });
}
Ok(Self { bm25: bm25 / total, vector: vector / total })
}
}
impl Default for HybridWeights {
fn default() -> Self {
Self { bm25: 0.3, vector: 0.7 }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn memory_importance_as_str() {
assert_eq!(MemoryImportance::Low.as_str(), "low");
assert_eq!(MemoryImportance::Medium.as_str(), "medium");
assert_eq!(MemoryImportance::High.as_str(), "high");
}
#[test]
fn memory_importance_from_db_roundtrip() {
for variant in [MemoryImportance::Low, MemoryImportance::Medium, MemoryImportance::High] {
assert_eq!(MemoryImportance::from_db(variant.as_str()), Ok(variant));
}
assert!(MemoryImportance::from_db("unknown").is_err());
}
#[test]
fn memory_entry_construction() {
let entry = MemoryEntry {
id: "id-1".to_string(),
session_id: "sess-1".to_string(),
topic: "greetings".to_string(),
summary: "hello world".to_string(),
raw_excerpt: "raw".to_string(),
keywords: vec!["hello".to_string()],
importance: MemoryImportance::High,
embedding: None,
created_at_epoch_ms: 1_000,
};
assert_eq!(entry.id, "id-1");
assert_eq!(entry.session_id, "sess-1");
assert_eq!(entry.topic, "greetings");
assert_eq!(entry.importance, MemoryImportance::High);
assert!(entry.embedding.is_none());
assert_eq!(entry.created_at_epoch_ms, 1_000);
}
#[test]
fn recall_query_with_special_chars() {
let query = RecallQuery {
session_id: Some("s-\u{1F600}".to_string()),
text: "\u{4F60}\u{597D} hello <>&\"'".to_string(),
query_embedding: None,
limit: 10,
};
assert!(query.text.contains('\u{4F60}'));
assert_eq!(query.limit, 10);
}
#[test]
fn hybrid_weights_default_sums_to_one() {
let w = HybridWeights::default();
let sum = w.bm25 + w.vector;
assert!((sum - 1.0).abs() < f32::EPSILON);
}
}