use std::sync::Arc;
use async_trait::async_trait;
use regex::Regex;
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use crate::{chat::ChatMessage, error::LLMError};
#[derive(Debug, Clone)]
pub struct MessageEvent {
pub role: String,
pub msg: ChatMessage,
}
#[derive(Clone)]
pub enum MessageCondition {
Any,
Eq(String),
Contains(String),
NotContains(String),
RoleIs(String),
RoleNot(String),
LenGt(usize),
Custom(Arc<dyn Fn(&ChatMessage) -> bool + Send + Sync>),
Empty,
All(Vec<MessageCondition>),
AnyOf(Vec<MessageCondition>),
Regex(String),
HasAudio,
}
impl MessageCondition {
pub fn matches(&self, event: &MessageEvent) -> bool {
match self {
MessageCondition::Any => true,
MessageCondition::Eq(text) => event.msg.content == *text,
MessageCondition::Contains(text) => event.msg.content.contains(text),
MessageCondition::NotContains(text) => !event.msg.content.contains(text),
MessageCondition::RoleIs(role) => event.role == *role,
MessageCondition::RoleNot(role) => event.role != *role,
MessageCondition::LenGt(len) => event.msg.content.len() > *len,
MessageCondition::Custom(func) => func(&event.msg),
MessageCondition::Empty => event.msg.content.is_empty(),
MessageCondition::All(inner) => inner.iter().all(|c| c.matches(event)),
MessageCondition::AnyOf(inner) => inner.iter().any(|c| c.matches(event)),
MessageCondition::Regex(regex) => Regex::new(regex)
.map(|re| re.is_match(&event.msg.content))
.unwrap_or(false),
MessageCondition::HasAudio => event.msg.has_audio(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chat::ChatMessage;
fn event_for(msg: ChatMessage) -> MessageEvent {
MessageEvent {
role: "user".to_string(),
msg,
}
}
#[test]
fn has_audio_condition_matches_audio_message() {
let msg = ChatMessage::user().audio(vec![1]).build();
let event = event_for(msg);
assert!(MessageCondition::HasAudio.matches(&event));
}
#[test]
fn has_audio_condition_rejects_text_message() {
let msg = ChatMessage::user().content("hello").build();
let event = event_for(msg);
assert!(!MessageCondition::HasAudio.matches(&event));
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemoryType {
SlidingWindow,
}
#[async_trait]
pub trait MemoryProvider: Send + Sync {
async fn remember(&mut self, message: &ChatMessage) -> Result<(), LLMError>;
async fn recall(&self, query: &str, limit: Option<usize>)
-> Result<Vec<ChatMessage>, LLMError>;
async fn clear(&mut self) -> Result<(), LLMError>;
fn memory_type(&self) -> MemoryType;
fn size(&self) -> usize;
fn is_empty(&self) -> bool {
self.size() == 0
}
fn needs_summary(&self) -> bool {
false
}
fn mark_for_summary(&mut self) {}
fn replace_with_summary(&mut self, _summary: String) {}
fn get_event_receiver(&self) -> Option<broadcast::Receiver<MessageEvent>> {
None
}
async fn remember_with_role(
&mut self,
message: &ChatMessage,
_role: String,
) -> Result<(), LLMError> {
self.remember(message).await
}
}