use async_trait::async_trait;
use oharness_core::{ConversationView, Message, NullSink, RunId, ScopedEmitter};
use oharness_memory::{MemoryContext, MemoryError, MemoryPolicy};
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
struct KeepLastN {
n: usize,
}
#[async_trait]
impl MemoryPolicy for KeepLastN {
async fn transform(
&self,
conversation: ConversationView<'_>,
_ctx: &MemoryContext,
) -> Result<Vec<Message>, MemoryError> {
let all = conversation.messages();
let (systems, rest): (Vec<&Message>, Vec<&Message>) = all
.iter()
.partition(|m| matches!(m, Message::System { .. }));
let keep_tail_start = rest.len().saturating_sub(self.n);
let kept_tail = &rest[keep_tail_start..];
Ok(systems
.into_iter()
.chain(kept_tail.iter().copied())
.cloned()
.collect())
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let messages = vec![
Message::system("You are a helpful assistant."),
Message::user_text("first user turn"),
Message::assistant_text("first assistant reply"),
Message::user_text("second user turn"),
Message::assistant_text("second assistant reply"),
Message::user_text("third user turn (the latest)"),
];
let policy = KeepLastN { n: 3 };
let view = ConversationView::new(&messages);
let ctx = MemoryContext {
events: ScopedEmitter::new(
Arc::new(NullSink) as Arc<dyn oharness_core::EventSink>,
RunId::new(),
Arc::new(AtomicU64::new(0)),
),
token_budget: 4_000,
};
let transformed = policy.transform(view, &ctx).await?;
println!("Before ({} messages):", messages.len());
for (i, m) in messages.iter().enumerate() {
println!(" {i}: {}", describe(m));
}
println!("After ({} messages):", transformed.len());
for (i, m) in transformed.iter().enumerate() {
println!(" {i}: {}", describe(m));
}
assert_eq!(transformed.len(), 4, "system + last 3 non-system");
assert!(
matches!(transformed[0], Message::System { .. }),
"leading system preserved"
);
println!("\nPolicy preserved the system prompt and kept the last 3 non-system messages ✔");
Ok(())
}
fn describe(m: &Message) -> String {
match m {
Message::System { content, .. } => format!("[system] {content:?}"),
Message::User { content, .. } => format!("[user] {}", flatten(content)),
Message::Assistant { content, .. } => format!("[assistant] {}", flatten(content)),
}
}
fn flatten(content: &[oharness_core::Content]) -> String {
content
.iter()
.filter_map(|c| match c {
oharness_core::Content::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join(" | ")
}