use std::sync::Arc;
use async_trait::async_trait;
use cognis_core::{Message, Result};
use cognis_llm::chat::ChatResponse;
use super::{Middleware, MiddlewareCtx, Next};
pub trait EditPolicy: Send + Sync {
fn edit(&self, messages: Vec<Message>) -> Vec<Message>;
}
impl<F> EditPolicy for F
where
F: Fn(Vec<Message>) -> Vec<Message> + Send + Sync,
{
fn edit(&self, messages: Vec<Message>) -> Vec<Message> {
(self)(messages)
}
}
pub struct CapMessageLength {
pub max_chars: usize,
}
impl EditPolicy for CapMessageLength {
fn edit(&self, messages: Vec<Message>) -> Vec<Message> {
messages
.into_iter()
.map(|m| match m {
Message::Human(mut h) => {
if h.content.chars().count() > self.max_chars {
h.content = h.content.chars().take(self.max_chars).collect();
}
Message::Human(h)
}
Message::Ai(mut a) => {
if a.content.chars().count() > self.max_chars {
a.content = a.content.chars().take(self.max_chars).collect();
}
Message::Ai(a)
}
Message::System(mut s) => {
if s.content.chars().count() > self.max_chars {
s.content = s.content.chars().take(self.max_chars).collect();
}
Message::System(s)
}
Message::Tool(mut t) => {
if t.content.chars().count() > self.max_chars {
t.content = t.content.chars().take(self.max_chars).collect();
}
Message::Tool(t)
}
})
.collect()
}
}
pub struct DropMatching {
predicate: Arc<dyn Fn(&Message) -> bool + Send + Sync>,
}
impl DropMatching {
pub fn new<F>(predicate: F) -> Self
where
F: Fn(&Message) -> bool + Send + Sync + 'static,
{
Self {
predicate: Arc::new(predicate),
}
}
}
impl EditPolicy for DropMatching {
fn edit(&self, messages: Vec<Message>) -> Vec<Message> {
messages
.into_iter()
.filter(|m| !(self.predicate)(m))
.collect()
}
}
pub struct ContextEditing {
policy: Arc<dyn EditPolicy>,
}
impl ContextEditing {
pub fn new<P: EditPolicy + 'static>(policy: P) -> Self {
Self {
policy: Arc::new(policy),
}
}
}
#[async_trait]
impl Middleware for ContextEditing {
async fn call(&self, mut ctx: MiddlewareCtx, next: Arc<dyn Next>) -> Result<ChatResponse> {
let original = std::mem::take(&mut ctx.messages);
ctx.messages = self.policy.edit(original);
next.invoke(ctx).await
}
fn name(&self) -> &str {
"ContextEditing"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::middleware::tests_util::{ok_resp, RecordingNext};
#[tokio::test]
async fn cap_message_length_truncates() {
let mw = ContextEditing::new(CapMessageLength { max_chars: 5 });
let recorder = Arc::new(RecordingNext::new(ok_resp("ok")));
let next: Arc<dyn Next> = recorder.clone();
let _ = mw
.call(
MiddlewareCtx::new(
vec![Message::human("longer than five chars")],
vec![],
Default::default(),
),
next,
)
.await;
let seen = recorder.seen.lock().unwrap();
assert_eq!(seen[0].messages[0].content().chars().count(), 5);
}
#[tokio::test]
async fn drop_matching_removes() {
let mw = ContextEditing::new(DropMatching::new(|m| m.content().contains("REDACT")));
let recorder = Arc::new(RecordingNext::new(ok_resp("ok")));
let next: Arc<dyn Next> = recorder.clone();
let _ = mw
.call(
MiddlewareCtx::new(
vec![
Message::human("keep me"),
Message::human("REDACT THIS"),
Message::human("also keep"),
],
vec![],
Default::default(),
),
next,
)
.await;
let seen = recorder.seen.lock().unwrap();
assert_eq!(seen[0].messages.len(), 2);
assert!(seen[0]
.messages
.iter()
.all(|m| !m.content().contains("REDACT")));
}
#[tokio::test]
async fn closure_policy_works() {
let mw = ContextEditing::new(|msgs: Vec<Message>| {
let mut out = vec![Message::system("policy injected")];
out.extend(msgs);
out
});
let recorder = Arc::new(RecordingNext::new(ok_resp("ok")));
let next: Arc<dyn Next> = recorder.clone();
let _ = mw
.call(
MiddlewareCtx::new(vec![Message::human("hi")], vec![], Default::default()),
next,
)
.await;
let seen = recorder.seen.lock().unwrap();
assert!(matches!(seen[0].messages[0], Message::System(_)));
}
}