#![cfg_attr(
test,
allow(
clippy::expect_used,
clippy::indexing_slicing,
clippy::panic,
clippy::unwrap_used,
clippy::unreachable
)
)]
use std::{
collections::HashMap,
sync::{Arc, Mutex as StdMutex},
};
pub use rig_core::memory::{
Compactor, ConversationMemory, DemotionHook, InMemoryConversationMemory, MemoryError,
NoopDemotionHook,
};
use rig_core::completion::Message;
use rig_core::message::UserContent;
use rig_core::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
pub trait MemoryPolicy: WasmCompatSend + WasmCompatSync {
fn apply(&self, messages: Vec<Message>) -> Result<Vec<Message>, MemoryError>;
fn apply_with_demoted(
&self,
messages: Vec<Message>,
) -> Result<(Vec<Message>, Vec<Message>), MemoryError> {
Ok((self.apply(messages)?, Vec::new()))
}
}
impl<P> MemoryPolicy for Arc<P>
where
P: MemoryPolicy + ?Sized,
{
fn apply(&self, messages: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
(**self).apply(messages)
}
fn apply_with_demoted(
&self,
messages: Vec<Message>,
) -> Result<(Vec<Message>, Vec<Message>), MemoryError> {
(**self).apply_with_demoted(messages)
}
}
impl<P> MemoryPolicy for Box<P>
where
P: MemoryPolicy + ?Sized,
{
fn apply(&self, messages: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
(**self).apply(messages)
}
fn apply_with_demoted(
&self,
messages: Vec<Message>,
) -> Result<(Vec<Message>, Vec<Message>), MemoryError> {
(**self).apply_with_demoted(messages)
}
}
pub trait IntoFilter: MemoryPolicy + Sized + 'static {
#[cfg(not(target_family = "wasm"))]
fn into_filter(self) -> Box<dyn Fn(Vec<Message>) -> Vec<Message> + Send + Sync> {
let policy = Arc::new(self);
Box::new(move |msgs| {
let fallback = msgs.clone();
match policy.apply(msgs) {
Ok(out) => out,
Err(err) => {
tracing::warn!(error = %err, "memory policy failed; returning unfiltered history");
fallback
}
}
})
}
#[cfg(target_family = "wasm")]
fn into_filter(self) -> Box<dyn Fn(Vec<Message>) -> Vec<Message>> {
let policy = Arc::new(self);
Box::new(move |msgs| {
let fallback = msgs.clone();
match policy.apply(msgs) {
Ok(out) => out,
Err(err) => {
tracing::warn!(error = %err, "memory policy failed; returning unfiltered history");
fallback
}
}
})
}
}
impl<P> IntoFilter for P where P: MemoryPolicy + 'static {}
#[derive(Debug, Default, Clone, Copy)]
pub struct NoopMemoryPolicy;
impl MemoryPolicy for NoopMemoryPolicy {
fn apply(&self, messages: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
Ok(messages)
}
}
#[derive(Debug, Clone, Copy)]
pub struct SlidingWindowMemory {
max_messages: usize,
}
impl SlidingWindowMemory {
pub fn last_messages(n: usize) -> Self {
Self { max_messages: n }
}
}
impl MemoryPolicy for SlidingWindowMemory {
fn apply(&self, messages: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
Ok(self.apply_with_demoted(messages)?.0)
}
fn apply_with_demoted(
&self,
messages: Vec<Message>,
) -> Result<(Vec<Message>, Vec<Message>), MemoryError> {
if messages.len() <= self.max_messages {
return Ok((messages, Vec::new()));
}
let start = messages.len() - self.max_messages;
let mut iter = messages.into_iter();
let mut demoted: Vec<Message> = (&mut iter).take(start).collect();
let mut window: Vec<Message> = iter.collect();
if let Some(Message::User { content }) = window.first()
&& matches!(content.first_ref(), UserContent::ToolResult(_))
{
demoted.push(window.remove(0));
}
Ok((window, demoted))
}
}
pub trait TokenCounter: WasmCompatSend + WasmCompatSync {
fn count(&self, message: &Message) -> usize;
}
impl<F> TokenCounter for F
where
F: Fn(&Message) -> usize + WasmCompatSend + WasmCompatSync,
{
fn count(&self, message: &Message) -> usize {
(self)(message)
}
}
impl<C> TokenCounter for Arc<C>
where
C: TokenCounter + ?Sized,
{
fn count(&self, message: &Message) -> usize {
(**self).count(message)
}
}
impl TokenCounter for Box<dyn TokenCounter> {
fn count(&self, message: &Message) -> usize {
(**self).count(message)
}
}
#[derive(Debug, Clone, Copy)]
pub struct HeuristicTokenCounter {
bytes_per_token: f32,
per_message_overhead: usize,
per_attachment_tokens: usize,
}
impl HeuristicTokenCounter {
pub fn new(
bytes_per_token: f32,
per_message_overhead: usize,
per_attachment_tokens: usize,
) -> Self {
let bytes_per_token = if bytes_per_token.is_finite() && bytes_per_token >= 1.0 {
bytes_per_token
} else {
1.0
};
Self {
bytes_per_token,
per_message_overhead,
per_attachment_tokens,
}
}
pub fn openai() -> Self {
Self::new(4.0, 4, 256)
}
pub fn anthropic() -> Self {
Self::new(3.5, 4, 256)
}
pub fn gemini() -> Self {
Self::new(4.0, 4, 256)
}
fn bytes_to_tokens(&self, bytes: usize) -> usize {
let tokens = (bytes as f32) / self.bytes_per_token;
tokens.ceil() as usize
}
fn count_user(&self, content: &rig_core::message::UserContent) -> usize {
use rig_core::message::UserContent;
match content {
UserContent::Text(text) => self.bytes_to_tokens(text.text.len()),
UserContent::ToolResult(result) => result
.content
.iter()
.map(|c| match c {
rig_core::message::ToolResultContent::Text(t) => {
self.bytes_to_tokens(t.text.len())
}
rig_core::message::ToolResultContent::Image(_) => self.per_attachment_tokens,
})
.sum(),
UserContent::Image(_)
| UserContent::Audio(_)
| UserContent::Video(_)
| UserContent::Document(_) => self.per_attachment_tokens,
}
}
fn count_assistant(&self, content: &rig_core::message::AssistantContent) -> usize {
use rig_core::message::AssistantContent;
match content {
AssistantContent::Text(text) => self.bytes_to_tokens(text.text.len()),
AssistantContent::Reasoning(reasoning) => {
self.bytes_to_tokens(reasoning.display_text().len())
}
AssistantContent::ToolCall(call) => {
let name_bytes = call.function.name.len();
let args_bytes = call.function.arguments.to_string().len();
self.bytes_to_tokens(name_bytes + args_bytes)
}
AssistantContent::Image(_) => self.per_attachment_tokens,
}
}
}
impl Default for HeuristicTokenCounter {
fn default() -> Self {
Self::openai()
}
}
impl TokenCounter for HeuristicTokenCounter {
fn count(&self, message: &Message) -> usize {
let content_tokens: usize = match message {
Message::User { content } => content.iter().map(|c| self.count_user(c)).sum(),
Message::Assistant { content, .. } => {
content.iter().map(|c| self.count_assistant(c)).sum()
}
Message::System { content } => self.bytes_to_tokens(content.len()),
};
content_tokens.saturating_add(self.per_message_overhead)
}
}
pub struct TokenWindowMemory {
max_tokens: usize,
counter: Arc<dyn TokenCounter>,
}
impl TokenWindowMemory {
pub fn new<C>(max_tokens: usize, counter: C) -> Self
where
C: TokenCounter + 'static,
{
Self {
max_tokens,
counter: Arc::new(counter),
}
}
}
impl std::fmt::Debug for TokenWindowMemory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenWindowMemory")
.field("max_tokens", &self.max_tokens)
.field("counter", &"<counter>")
.finish()
}
}
impl MemoryPolicy for TokenWindowMemory {
fn apply(&self, messages: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
Ok(self.apply_with_demoted(messages)?.0)
}
fn apply_with_demoted(
&self,
messages: Vec<Message>,
) -> Result<(Vec<Message>, Vec<Message>), MemoryError> {
let mut budget = self.max_tokens;
let mut keep_from = messages.len();
for (idx, msg) in messages.iter().enumerate().rev() {
let cost = self.counter.count(msg);
if cost > budget {
break;
}
budget -= cost;
keep_from = idx;
}
let mut iter = messages.into_iter();
let mut demoted: Vec<Message> = (&mut iter).take(keep_from).collect();
let mut window: Vec<Message> = iter.collect();
if let Some(Message::User { content }) = window.first()
&& matches!(content.first_ref(), UserContent::ToolResult(_))
{
demoted.push(window.remove(0));
}
Ok((window, demoted))
}
}
#[derive(Debug, Clone, Copy)]
pub struct PolicyMemory<M, P> {
inner: M,
policy: P,
}
impl<M, P> PolicyMemory<M, P> {
pub fn new(inner: M, policy: P) -> Self {
Self { inner, policy }
}
pub fn inner(&self) -> &M {
&self.inner
}
pub fn policy(&self) -> &P {
&self.policy
}
pub fn into_inner(self) -> (M, P) {
(self.inner, self.policy)
}
}
impl<M, P> ConversationMemory for PolicyMemory<M, P>
where
M: ConversationMemory,
P: MemoryPolicy,
{
fn load<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<Vec<Message>, MemoryError>> {
Box::pin(async move {
let messages = self.inner.load(conversation_id).await?;
self.policy.apply(messages)
})
}
fn append<'a>(
&'a self,
conversation_id: &'a str,
messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
self.inner.append(conversation_id, messages)
}
fn clear<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
self.inner.clear(conversation_id)
}
}
pub struct DemotingPolicyMemory<M, P, H> {
inner: M,
policy: P,
hook: H,
state: StdMutex<HashMap<String, ConversationDemotionState>>,
}
type InFlightReservation = Arc<()>;
#[derive(Debug, Default, Clone)]
struct ConversationDemotionState {
delivered: usize,
in_flight: Option<InFlightReservation>,
}
impl<M, P, H> DemotingPolicyMemory<M, P, H> {
pub fn new(inner: M, policy: P, hook: H) -> Self {
Self {
inner,
policy,
hook,
state: StdMutex::new(HashMap::new()),
}
}
pub fn inner(&self) -> &M {
&self.inner
}
pub fn policy(&self) -> &P {
&self.policy
}
pub fn hook(&self) -> &H {
&self.hook
}
pub fn into_inner(self) -> (M, P, H) {
(self.inner, self.policy, self.hook)
}
pub fn forget(&self, conversation_id: &str) {
if let Ok(mut guard) = self.state.lock() {
guard.remove(conversation_id);
}
}
pub fn tracked_conversations(&self) -> usize {
self.state.lock().map(|g| g.len()).unwrap_or(0)
}
}
impl<M, P, H> std::fmt::Debug for DemotingPolicyMemory<M, P, H>
where
M: std::fmt::Debug,
P: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DemotingPolicyMemory")
.field("inner", &self.inner)
.field("policy", &self.policy)
.field("hook", &"<hook>")
.finish()
}
}
impl<M, P, H> ConversationMemory for DemotingPolicyMemory<M, P, H>
where
M: ConversationMemory,
P: MemoryPolicy,
H: DemotionHook,
{
fn load<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<Vec<Message>, MemoryError>> {
Box::pin(async move {
let messages = self.inner.load(conversation_id).await?;
let (kept, mut demoted) = self.policy.apply_with_demoted(messages)?;
let demoted_count = demoted.len();
let (pending, reservation) = {
let mut guard = self.state.lock().map_err(poisoned)?;
if let Some(entry) = guard.get_mut(conversation_id) {
if entry.in_flight.is_some() {
return Ok(kept);
}
if entry.delivered >= demoted_count {
(Vec::new(), None)
} else {
let split = entry.delivered;
let reservation = Arc::new(());
entry.in_flight = Some(reservation.clone());
(demoted.split_off(split), Some(reservation))
}
} else if demoted_count == 0 {
(Vec::new(), None)
} else {
let reservation = Arc::new(());
guard.insert(
conversation_id.to_string(),
ConversationDemotionState {
delivered: 0,
in_flight: Some(reservation.clone()),
},
);
(std::mem::take(&mut demoted), Some(reservation))
}
};
let Some(reservation) = reservation else {
return Ok(kept);
};
let in_flight_guard =
DemotionInFlightGuard::new(&self.state, conversation_id, reservation.clone());
let result = self.hook.on_demote(conversation_id, pending).await;
{
let mut guard = self.state.lock().map_err(poisoned)?;
if let Some(entry) = guard.get_mut(conversation_id)
&& entry
.in_flight
.as_ref()
.is_some_and(|current| Arc::ptr_eq(current, &reservation))
{
entry.in_flight = None;
if result.is_ok() {
entry.delivered = demoted_count;
}
}
}
in_flight_guard.disarm();
result?;
Ok(kept)
})
}
fn append<'a>(
&'a self,
conversation_id: &'a str,
messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
self.inner.append(conversation_id, messages)
}
fn clear<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
Box::pin(async move {
self.inner.clear(conversation_id).await?;
self.forget(conversation_id);
Ok(())
})
}
}
fn poisoned<E: std::fmt::Display>(err: E) -> MemoryError {
MemoryError::Internal(err.to_string())
}
struct DemotionInFlightGuard<'a> {
state: &'a StdMutex<HashMap<String, ConversationDemotionState>>,
key: &'a str,
reservation: InFlightReservation,
armed: bool,
}
impl<'a> DemotionInFlightGuard<'a> {
fn new(
state: &'a StdMutex<HashMap<String, ConversationDemotionState>>,
key: &'a str,
reservation: InFlightReservation,
) -> Self {
Self {
state,
key,
reservation,
armed: true,
}
}
fn disarm(mut self) {
self.armed = false;
}
}
impl Drop for DemotionInFlightGuard<'_> {
fn drop(&mut self) {
if !self.armed {
return;
}
if let Ok(mut guard) = self.state.lock()
&& let Some(entry) = guard.get_mut(self.key)
&& entry
.in_flight
.as_ref()
.is_some_and(|current| Arc::ptr_eq(current, &self.reservation))
{
entry.in_flight = None;
}
}
}
struct InFlightGuard<'a, A> {
state: &'a StdMutex<HashMap<String, ConversationCompactionState<A>>>,
key: &'a str,
reservation: InFlightReservation,
armed: bool,
}
impl<'a, A> InFlightGuard<'a, A> {
fn new(
state: &'a StdMutex<HashMap<String, ConversationCompactionState<A>>>,
key: &'a str,
reservation: InFlightReservation,
) -> Self {
Self {
state,
key,
reservation,
armed: true,
}
}
fn disarm(mut self) {
self.armed = false;
}
}
impl<A> Drop for InFlightGuard<'_, A> {
fn drop(&mut self) {
if !self.armed {
return;
}
if let Ok(mut guard) = self.state.lock()
&& let Some(entry) = guard.get_mut(self.key)
&& entry
.in_flight
.as_ref()
.is_some_and(|current| Arc::ptr_eq(current, &self.reservation))
{
entry.in_flight = None;
}
}
}
pub struct CompactingMemory<M, P, C: Compactor> {
inner: M,
policy: P,
compactor: C,
state: StdMutex<HashMap<String, ConversationCompactionState<C::Artifact>>>,
}
struct ConversationCompactionState<A> {
summary: Option<A>,
absorbed: usize,
in_flight: Option<InFlightReservation>,
}
impl<M, P, C: Compactor> CompactingMemory<M, P, C> {
pub fn new(inner: M, policy: P, compactor: C) -> Self {
Self {
inner,
policy,
compactor,
state: StdMutex::new(HashMap::new()),
}
}
pub fn inner(&self) -> &M {
&self.inner
}
pub fn policy(&self) -> &P {
&self.policy
}
pub fn compactor(&self) -> &C {
&self.compactor
}
pub fn into_inner(self) -> (M, P, C) {
(self.inner, self.policy, self.compactor)
}
pub fn forget(&self, conversation_id: &str) {
if let Ok(mut guard) = self.state.lock() {
guard.remove(conversation_id);
}
}
pub fn tracked_conversations(&self) -> usize {
self.state.lock().map(|g| g.len()).unwrap_or(0)
}
}
impl<M, P, C> std::fmt::Debug for CompactingMemory<M, P, C>
where
M: std::fmt::Debug,
P: std::fmt::Debug,
C: Compactor,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompactingMemory")
.field("inner", &self.inner)
.field("policy", &self.policy)
.field("compactor", &"<compactor>")
.finish()
}
}
impl<M, P, C> ConversationMemory for CompactingMemory<M, P, C>
where
M: ConversationMemory,
P: MemoryPolicy,
C: Compactor,
{
fn load<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<Vec<Message>, MemoryError>> {
Box::pin(async move {
let messages = self.inner.load(conversation_id).await?;
let (kept, demoted) = self.policy.apply_with_demoted(messages)?;
let demoted_count = demoted.len();
let plan = {
let mut guard = self.state.lock().map_err(poisoned)?;
if let Some(entry) = guard.get_mut(conversation_id) {
if entry.in_flight.is_some() {
return Ok(splice(entry.summary.clone(), kept));
}
if demoted_count <= entry.absorbed {
return Ok(splice(entry.summary.clone(), kept));
}
let reservation = Arc::new(());
entry.in_flight = Some(reservation.clone());
CompactionPlan {
carry_over: entry.summary.clone(),
skip: entry.absorbed,
reservation,
}
} else if demoted_count == 0 {
return Ok(kept);
} else {
let reservation = Arc::new(());
guard.insert(
conversation_id.to_string(),
ConversationCompactionState {
summary: None,
absorbed: 0,
in_flight: Some(reservation.clone()),
},
);
CompactionPlan {
carry_over: None,
skip: 0,
reservation,
}
}
};
let CompactionPlan {
carry_over,
skip,
reservation,
} = plan;
let in_flight_guard =
InFlightGuard::new(&self.state, conversation_id, reservation.clone());
let new_slice = match demoted.get(skip..) {
Some(s) => s,
None => {
drop(in_flight_guard);
return Err(MemoryError::Internal(
"compaction watermark exceeds demoted slice length".into(),
));
}
};
let result = self
.compactor
.compact(conversation_id, new_slice, carry_over.as_ref())
.await;
let summary_for_splice = match result {
Ok(artifact) => {
let mut guard = self.state.lock().map_err(poisoned)?;
if let Some(entry) = guard.get_mut(conversation_id) {
if entry
.in_flight
.as_ref()
.is_some_and(|current| Arc::ptr_eq(current, &reservation))
{
entry.in_flight = None;
entry.absorbed = demoted_count;
entry.summary = Some(artifact.clone());
Some(artifact)
} else {
None
}
} else {
None
}
}
Err(err) => {
let mut guard = self.state.lock().map_err(poisoned)?;
if let Some(entry) = guard.get_mut(conversation_id)
&& entry
.in_flight
.as_ref()
.is_some_and(|current| Arc::ptr_eq(current, &reservation))
{
entry.in_flight = None;
}
return Err(err);
}
};
in_flight_guard.disarm();
Ok(splice(summary_for_splice, kept))
})
}
fn append<'a>(
&'a self,
conversation_id: &'a str,
messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
self.inner.append(conversation_id, messages)
}
fn clear<'a>(
&'a self,
conversation_id: &'a str,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
Box::pin(async move {
self.inner.clear(conversation_id).await?;
self.forget(conversation_id);
Ok(())
})
}
}
struct CompactionPlan<A> {
carry_over: Option<A>,
skip: usize,
reservation: InFlightReservation,
}
fn splice<A>(summary: Option<A>, kept: Vec<Message>) -> Vec<Message>
where
A: Into<Message>,
{
match summary {
Some(artifact) => {
let mut out = Vec::with_capacity(kept.len() + 1);
out.push(artifact.into());
out.extend(kept);
out
}
None => kept,
}
}
#[derive(Debug, Clone)]
pub struct TemplateCompactor {
header: String,
max_bytes: Option<usize>,
}
impl TemplateCompactor {
pub fn new() -> Self {
Self::with_header("[Conversation summary so far]")
}
pub fn with_header(header: impl Into<String>) -> Self {
Self {
header: header.into(),
max_bytes: None,
}
}
pub fn with_max_bytes(mut self, max_bytes: usize) -> Self {
self.max_bytes = if max_bytes == 0 {
None
} else {
Some(max_bytes)
};
self
}
}
impl Default for TemplateCompactor {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TextSummary(String);
impl TextSummary {
pub fn as_str(&self) -> &str {
&self.0
}
pub fn into_string(self) -> String {
self.0
}
}
impl From<TextSummary> for Message {
fn from(value: TextSummary) -> Self {
Message::System { content: value.0 }
}
}
impl Compactor for TemplateCompactor {
type Artifact = TextSummary;
fn compact<'a>(
&'a self,
_conversation_id: &'a str,
evicted: &'a [Message],
carry_over: Option<&'a Self::Artifact>,
) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
Box::pin(async move {
let mut buf = String::new();
buf.push_str(&self.header);
buf.push('\n');
if let Some(prev) = carry_over {
buf.push_str(prev.as_str());
buf.push('\n');
}
for msg in evicted {
let line = render_message_line(msg);
if !line.is_empty() {
buf.push_str(&line);
buf.push('\n');
}
}
if let Some(cap) = self.max_bytes
&& buf.len() > cap
{
buf = truncate_summary(&buf, cap);
}
Ok(TextSummary(buf))
})
}
}
fn truncate_summary(buf: &str, cap: usize) -> String {
const MARKER: &str = "[\u{2026}truncated\u{2026}]\n";
let header_prefix_len = match buf.find('\n') {
Some(i) => i + 1,
None => return buf.to_string(),
};
if buf.len() <= header_prefix_len {
return buf.to_string();
}
let preserved = header_prefix_len + MARKER.len();
let keep_bytes = cap.saturating_sub(preserved);
let body_start = header_prefix_len;
let body = match buf.get(body_start..) {
Some(b) => b,
None => return buf.to_string(),
};
let mut cut = body.len().saturating_sub(keep_bytes);
while cut < body.len() && !body.is_char_boundary(cut) {
cut += 1;
}
let suffix: &str = body.get(cut..).unwrap_or_default();
let header_with_nl = match buf.get(..header_prefix_len) {
Some(h) => h,
None => return buf.to_string(),
};
let mut out = String::with_capacity(header_prefix_len + MARKER.len() + suffix.len());
out.push_str(header_with_nl);
out.push_str(MARKER);
out.push_str(suffix);
out
}
fn render_message_line(msg: &Message) -> String {
use rig_core::message::AssistantContent;
match msg {
Message::System { content } => {
if content.is_empty() {
String::new()
} else {
format!("system: {content}")
}
}
Message::User { content } => {
let mut text = String::new();
for c in content.iter() {
match c {
UserContent::Text(t) => {
if !text.is_empty() {
text.push(' ');
}
text.push_str(&t.text);
}
UserContent::ToolResult(_) => {
if !text.is_empty() {
text.push(' ');
}
text.push_str("[tool result]");
}
_ => {
if !text.is_empty() {
text.push(' ');
}
text.push_str("[attachment]");
}
}
}
if text.is_empty() {
String::new()
} else {
format!("user: {text}")
}
}
Message::Assistant { content, .. } => {
let mut text = String::new();
for c in content.iter() {
match c {
AssistantContent::Text(t) => {
if !text.is_empty() {
text.push(' ');
}
text.push_str(&t.text);
}
AssistantContent::ToolCall(call) => {
if !text.is_empty() {
text.push(' ');
}
text.push_str(&format!("[tool call: {}]", call.function.name));
}
_ => {
if !text.is_empty() {
text.push(' ');
}
text.push_str("[reasoning]");
}
}
}
if text.is_empty() {
String::new()
} else {
format!("assistant: {text}")
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rig_core::OneOrMany;
use rig_core::message::{
AssistantContent, ToolCall, ToolFunction, ToolResult, ToolResultContent, UserContent,
};
use std::sync::Mutex;
fn user(text: &str) -> Message {
Message::user(text)
}
fn assistant(text: &str) -> Message {
Message::assistant(text)
}
fn tool_call_msg() -> Message {
Message::Assistant {
id: None,
content: OneOrMany::one(AssistantContent::ToolCall(ToolCall::new(
"call_1".into(),
ToolFunction::new("t".into(), serde_json::json!({})),
))),
}
}
fn tool_result_msg() -> Message {
Message::User {
content: OneOrMany::one(UserContent::ToolResult(ToolResult {
id: "call_1".into(),
call_id: None,
content: OneOrMany::one(ToolResultContent::text("ok")),
})),
}
}
#[test]
fn noop_policy_is_identity() {
let msgs = vec![user("a"), assistant("b")];
let out = NoopMemoryPolicy.apply(msgs).unwrap();
assert_eq!(out.len(), 2);
}
#[test]
fn sliding_window_passthrough_when_under_limit() {
let policy = SlidingWindowMemory::last_messages(5);
let out = policy.apply(vec![user("1"), assistant("2")]).unwrap();
assert_eq!(out.len(), 2);
}
#[tokio::test]
async fn sliding_window_truncates_via_filter() {
let mem = InMemoryConversationMemory::new()
.with_filter(SlidingWindowMemory::last_messages(2).into_filter());
mem.append(
"c",
vec![user("1"), assistant("2"), user("3"), assistant("4")],
)
.await
.unwrap();
let loaded = mem.load("c").await.unwrap();
assert_eq!(loaded.len(), 2);
}
#[test]
fn sliding_window_drops_leading_orphan_tool_result() {
let policy = SlidingWindowMemory::last_messages(3);
let out = policy
.apply(vec![
tool_call_msg(),
tool_result_msg(),
user("after"),
assistant("done"),
])
.unwrap();
assert_eq!(out.len(), 2);
assert!(matches!(out.first(), Some(Message::User { content })
if matches!(content.first(), UserContent::Text(_))));
}
#[test]
fn token_window_keeps_within_budget() {
let msgs = vec![
user("aaaa"),
assistant("bbbb"),
user("cccc"),
assistant("dddd"),
];
let policy = TokenWindowMemory::new(2, |_: &Message| 1);
let out = policy.apply(msgs).unwrap();
assert_eq!(out.len(), 2);
}
#[test]
fn token_window_passes_through_when_under_budget() {
let msgs = vec![user("a"), assistant("b")];
let policy = TokenWindowMemory::new(usize::MAX, |_: &Message| 1);
let out = policy.apply(msgs).unwrap();
assert_eq!(out.len(), 2);
}
#[test]
fn token_window_drops_leading_orphan_tool_result() {
let policy = TokenWindowMemory::new(25, |_: &Message| 10);
let out = policy
.apply(vec![tool_call_msg(), tool_result_msg(), user("after")])
.unwrap();
assert_eq!(out.len(), 1);
assert!(matches!(out.first(), Some(Message::User { content })
if matches!(content.first(), UserContent::Text(_))));
}
#[test]
fn token_window_skips_message_larger_than_budget() {
let policy = TokenWindowMemory::new(5, |_: &Message| 10);
let out = policy.apply(vec![user("anything")]).unwrap();
assert!(out.is_empty());
}
#[test]
fn heuristic_counter_charges_overhead_per_message() {
let counter = HeuristicTokenCounter::default();
let empty = counter.count(&user(""));
assert!(
empty >= 4,
"default per-message overhead is at least 4 tokens"
);
}
#[test]
fn heuristic_counter_is_monotonic_in_text_length() {
let counter = HeuristicTokenCounter::default();
let small = counter.count(&user("hi"));
let big = counter.count(&user(&"x".repeat(400)));
assert!(big > small);
}
#[test]
fn heuristic_counter_handles_tool_calls() {
let counter = HeuristicTokenCounter::default();
let cost = counter.count(&tool_call_msg());
assert!(cost > 0);
}
#[test]
fn heuristic_counter_handles_system_messages() {
let counter = HeuristicTokenCounter::default();
let cost = counter.count(&Message::System {
content: "you are helpful".into(),
});
assert!(cost > 0);
}
#[test]
fn heuristic_counter_clamps_invalid_bytes_per_token() {
let counter = HeuristicTokenCounter::new(0.0, 0, 0);
assert!(counter.count(&user("abcd")) >= 4);
let nan = HeuristicTokenCounter::new(f32::NAN, 0, 0);
assert!(nan.count(&user("abcd")) >= 4);
}
#[test]
fn heuristic_counter_drives_token_window() {
let policy = TokenWindowMemory::new(100, HeuristicTokenCounter::default());
let msgs = vec![user(&"a".repeat(2_000)), user("short")];
let out = policy.apply(msgs).unwrap();
assert_eq!(out.len(), 1);
}
#[test]
fn arc_token_counter_can_drive_token_window() {
let counter: Arc<dyn TokenCounter> = Arc::new(|_: &Message| 1);
let policy = TokenWindowMemory::new(2, counter);
let out = policy
.apply(vec![user("a"), assistant("b"), user("c")])
.unwrap();
assert_eq!(out.len(), 2);
}
#[test]
fn boxed_token_counter_forwards_count() {
let counter: Box<dyn TokenCounter> = Box::new(|_: &Message| 7);
assert_eq!(counter.count(&user("a")), 7);
}
#[test]
fn into_filter_returns_input_on_policy_error() {
struct FailingPolicy;
impl MemoryPolicy for FailingPolicy {
fn apply(&self, _: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
Err(MemoryError::Policy("intentional failure".into()))
}
}
let filter = FailingPolicy.into_filter();
let input = vec![user("a"), assistant("b"), user("c")];
let out = filter(input.clone());
assert_eq!(
out.len(),
input.len(),
"history must be preserved on policy error"
);
}
#[tokio::test]
async fn policy_memory_truncates_loaded_history() {
let mem = PolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(2),
);
mem.append(
"c",
vec![user("1"), assistant("2"), user("3"), assistant("4")],
)
.await
.unwrap();
let loaded = mem.load("c").await.unwrap();
assert_eq!(loaded.len(), 2);
}
#[tokio::test]
async fn policy_memory_propagates_policy_errors() {
struct FailingPolicy;
impl MemoryPolicy for FailingPolicy {
fn apply(&self, _: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
Err(MemoryError::Policy("intentional failure".into()))
}
}
let mem = PolicyMemory::new(InMemoryConversationMemory::new(), FailingPolicy);
mem.append("c", vec![user("1"), assistant("2")])
.await
.unwrap();
let result = mem.load("c").await;
assert!(matches!(result, Err(MemoryError::Policy(_))));
}
#[tokio::test]
async fn policy_memory_append_and_clear_delegate_to_inner() {
let mem = PolicyMemory::new(InMemoryConversationMemory::new(), NoopMemoryPolicy);
mem.append("c", vec![user("hi"), assistant("ok")])
.await
.unwrap();
assert_eq!(mem.load("c").await.unwrap().len(), 2);
mem.clear("c").await.unwrap();
assert!(mem.load("c").await.unwrap().is_empty());
}
#[test]
fn sliding_window_reports_demoted_prefix() {
let policy = SlidingWindowMemory::last_messages(2);
let (kept, demoted) = policy
.apply_with_demoted(vec![
user("oldest"),
assistant("old"),
user("recent"),
assistant("latest"),
])
.unwrap();
assert_eq!(kept.len(), 2);
assert_eq!(demoted.len(), 2);
}
#[test]
fn token_window_reports_demoted_prefix() {
let policy = TokenWindowMemory::new(2, |_: &Message| 1);
let (kept, demoted) = policy
.apply_with_demoted(vec![user("a"), assistant("b"), user("c"), assistant("d")])
.unwrap();
assert_eq!(kept.len(), 2);
assert_eq!(demoted.len(), 2);
}
#[test]
fn noop_policy_demotes_nothing() {
let (kept, demoted) = NoopMemoryPolicy
.apply_with_demoted(vec![user("a"), assistant("b")])
.unwrap();
assert_eq!(kept.len(), 2);
assert!(demoted.is_empty());
}
#[test]
fn arc_memory_policy_preserves_demoted_metadata() {
let policy: Arc<dyn MemoryPolicy> = Arc::new(SlidingWindowMemory::last_messages(1));
let (kept, demoted) = policy
.apply_with_demoted(vec![user("old"), assistant("new")])
.unwrap();
assert_eq!(kept.len(), 1);
assert_eq!(demoted.len(), 1);
}
#[test]
fn boxed_memory_policy_preserves_demoted_metadata() {
let policy: Box<dyn MemoryPolicy> = Box::new(SlidingWindowMemory::last_messages(1));
let (kept, demoted) = policy
.apply_with_demoted(vec![user("old"), assistant("new")])
.unwrap();
assert_eq!(kept.len(), 1);
assert_eq!(demoted.len(), 1);
}
#[test]
fn sliding_window_demotes_orphan_tool_result_with_prefix() {
let policy = SlidingWindowMemory::last_messages(2);
let (kept, demoted) = policy
.apply_with_demoted(vec![
tool_call_msg(),
tool_result_msg(),
user("after"),
assistant("done"),
])
.unwrap();
assert_eq!(kept.len(), 2);
assert!(matches!(kept.first(), Some(Message::User { content })
if matches!(content.first(), UserContent::Text(_))));
assert_eq!(demoted.len(), 2);
}
#[derive(Default)]
struct CountingHook {
seen: Mutex<Vec<(String, Vec<Message>)>>,
}
impl CountingHook {
fn calls(&self) -> usize {
self.seen.lock().unwrap().len()
}
fn last_demoted_count(&self) -> usize {
self.seen
.lock()
.unwrap()
.last()
.map(|(_, m)| m.len())
.unwrap_or(0)
}
}
impl DemotionHook for CountingHook {
fn on_demote<'a>(
&'a self,
conversation_id: &'a str,
messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
Box::pin(async move {
self.seen
.lock()
.unwrap()
.push((conversation_id.to_string(), messages));
Ok(())
})
}
}
#[tokio::test]
async fn demoting_policy_memory_invokes_hook_on_truncation() {
let hook = Arc::new(CountingHook::default());
let mem = DemotingPolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(2),
hook.clone(),
);
mem.append(
"c",
vec![user("1"), assistant("2"), user("3"), assistant("4")],
)
.await
.unwrap();
let kept = mem.load("c").await.unwrap();
assert_eq!(kept.len(), 2);
assert_eq!(hook.calls(), 1);
assert_eq!(hook.last_demoted_count(), 2);
}
#[tokio::test]
async fn demoting_policy_memory_does_not_replay_demotions() {
let hook = Arc::new(CountingHook::default());
let mem = DemotingPolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(2),
hook.clone(),
);
mem.append(
"c",
vec![user("1"), assistant("2"), user("3"), assistant("4")],
)
.await
.unwrap();
mem.load("c").await.unwrap();
mem.load("c").await.unwrap();
assert_eq!(hook.calls(), 1);
assert_eq!(hook.last_demoted_count(), 2);
}
#[tokio::test]
async fn demoting_policy_memory_only_reports_newly_demoted_messages() {
let hook = Arc::new(CountingHook::default());
let mem = DemotingPolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(2),
hook.clone(),
);
mem.append(
"c",
vec![user("1"), assistant("2"), user("3"), assistant("4")],
)
.await
.unwrap();
mem.load("c").await.unwrap();
mem.append("c", vec![user("5")]).await.unwrap();
mem.load("c").await.unwrap();
assert_eq!(hook.calls(), 2);
assert_eq!(hook.last_demoted_count(), 1);
}
#[derive(Default)]
struct FailingHook {
calls: Mutex<usize>,
}
impl DemotionHook for FailingHook {
fn on_demote<'a>(
&'a self,
_conversation_id: &'a str,
_messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
Box::pin(async move {
*self.calls.lock().unwrap() += 1;
Err(MemoryError::backend(std::io::Error::other("hook failed")))
})
}
}
#[tokio::test]
async fn demoting_policy_memory_does_not_advance_watermark_on_hook_failure() {
let hook = Arc::new(FailingHook::default());
let mem = DemotingPolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
hook.clone(),
);
mem.append("c", vec![user("1"), assistant("2")])
.await
.unwrap();
assert!(mem.load("c").await.is_err());
assert!(mem.load("c").await.is_err());
assert_eq!(*hook.calls.lock().unwrap(), 2);
}
#[tokio::test]
async fn demoting_policy_memory_clear_resets_watermark() {
let hook = Arc::new(CountingHook::default());
let mem = DemotingPolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
hook.clone(),
);
mem.append("c", vec![user("1"), assistant("2")])
.await
.unwrap();
mem.load("c").await.unwrap();
mem.clear("c").await.unwrap();
mem.append("c", vec![user("3"), assistant("4")])
.await
.unwrap();
mem.load("c").await.unwrap();
assert_eq!(hook.calls(), 2);
assert_eq!(hook.last_demoted_count(), 1);
}
#[tokio::test]
async fn demoting_policy_memory_skips_hook_when_nothing_evicted() {
let hook = Arc::new(CountingHook::default());
let mem = DemotingPolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(10),
hook.clone(),
);
mem.append("c", vec![user("1"), assistant("2")])
.await
.unwrap();
mem.load("c").await.unwrap();
assert_eq!(hook.calls(), 0);
}
#[tokio::test]
async fn demoting_policy_memory_with_noop_hook_behaves_like_policy_memory() {
let mem = DemotingPolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
NoopDemotionHook,
);
mem.append("c", vec![user("a"), assistant("b"), user("c")])
.await
.unwrap();
assert_eq!(mem.load("c").await.unwrap().len(), 1);
}
struct GatedHook {
calls: Arc<std::sync::atomic::AtomicUsize>,
rendezvous: Arc<tokio::sync::Notify>,
release: Arc<tokio::sync::Notify>,
}
impl DemotionHook for GatedHook {
fn on_demote<'a>(
&'a self,
_conversation_id: &'a str,
_messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
let calls = self.calls.clone();
let rendezvous = self.rendezvous.clone();
let release = self.release.clone();
Box::pin(async move {
calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
rendezvous.notify_one();
release.notified().await;
Ok(())
})
}
}
#[tokio::test]
async fn demoting_policy_memory_serialises_concurrent_loads() {
use std::sync::atomic::{AtomicUsize, Ordering};
let calls = Arc::new(AtomicUsize::new(0));
let rendezvous = Arc::new(tokio::sync::Notify::new());
let release = Arc::new(tokio::sync::Notify::new());
let hook = GatedHook {
calls: calls.clone(),
rendezvous: rendezvous.clone(),
release: release.clone(),
};
let mem = Arc::new(DemotingPolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
hook,
));
mem.append("c", vec![user("1"), assistant("2"), user("3")])
.await
.unwrap();
let m1 = mem.clone();
let first = tokio::spawn(async move { m1.load("c").await });
rendezvous.notified().await;
assert_eq!(calls.load(Ordering::SeqCst), 1);
let kept = mem.load("c").await.unwrap();
assert_eq!(kept.len(), 1);
assert_eq!(calls.load(Ordering::SeqCst), 1, "hook must not double-fire");
release.notify_one();
let kept_first = first.await.unwrap().unwrap();
assert_eq!(kept_first.len(), 1);
assert_eq!(calls.load(Ordering::SeqCst), 1);
mem.load("c").await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn demoting_policy_memory_dropped_load_releases_in_flight_gate() {
use std::sync::atomic::{AtomicUsize, Ordering};
let calls = Arc::new(AtomicUsize::new(0));
let rendezvous = Arc::new(tokio::sync::Notify::new());
let release = Arc::new(tokio::sync::Notify::new());
let hook = GatedHook {
calls: calls.clone(),
rendezvous,
release: release.clone(),
};
let mem = Arc::new(DemotingPolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
hook,
));
mem.append("c", vec![user("1"), assistant("2"), user("3")])
.await
.unwrap();
let mem_load = mem.clone();
let handle = tokio::spawn(async move { mem_load.load("c").await });
while calls.load(Ordering::SeqCst) == 0 {
tokio::task::yield_now().await;
}
handle.abort();
let _ = handle.await;
let mem_load = mem.clone();
let retry = tokio::spawn(async move { mem_load.load("c").await });
for _ in 0..1_000 {
if calls.load(Ordering::SeqCst) >= 2 {
break;
}
tokio::task::yield_now().await;
}
assert_eq!(
calls.load(Ordering::SeqCst),
2,
"retry must re-enter the hook after cancellation"
);
release.notify_one();
let kept = retry.await.unwrap().unwrap();
assert_eq!(kept.len(), 1);
mem.load("c").await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn demoting_stale_cancelled_load_does_not_clear_new_reservation() {
use std::sync::atomic::{AtomicUsize, Ordering};
let calls = Arc::new(AtomicUsize::new(0));
let rendezvous = Arc::new(tokio::sync::Notify::new());
let release = Arc::new(tokio::sync::Notify::new());
let hook = GatedHook {
calls: calls.clone(),
rendezvous: rendezvous.clone(),
release: release.clone(),
};
let mem = Arc::new(DemotingPolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
hook,
));
mem.append("c", vec![user("old 1"), assistant("old 2"), user("old 3")])
.await
.unwrap();
let mem_load = mem.clone();
let stale = tokio::spawn(async move { mem_load.load("c").await });
rendezvous.notified().await;
assert_eq!(calls.load(Ordering::SeqCst), 1);
mem.clear("c").await.unwrap();
mem.append(
"c",
vec![user("fresh 1"), assistant("fresh 2"), user("fresh 3")],
)
.await
.unwrap();
let mem_load = mem.clone();
let fresh = tokio::spawn(async move { mem_load.load("c").await });
rendezvous.notified().await;
assert_eq!(calls.load(Ordering::SeqCst), 2);
stale.abort();
let _ = stale.await;
let mem_load = mem.clone();
let mut concurrent = tokio::spawn(async move { mem_load.load("c").await });
let concurrent_kept = tokio::select! {
result = &mut concurrent => result.unwrap().unwrap(),
_ = rendezvous.notified() => {
panic!("stale guard must not clear the fresh in-flight reservation")
}
};
assert_eq!(
calls.load(Ordering::SeqCst),
2,
"stale guard must not clear the fresh in-flight reservation"
);
release.notify_one();
assert_eq!(fresh.await.unwrap().unwrap().len(), 1);
assert_eq!(concurrent_kept.len(), 1);
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn demoting_stale_successful_load_does_not_clear_new_reservation() {
#[derive(Default)]
struct IndividuallyGatedHook {
releases: Mutex<Vec<Arc<tokio::sync::Notify>>>,
}
impl IndividuallyGatedHook {
fn call_count(&self) -> usize {
self.releases.lock().unwrap().len()
}
async fn wait_for_call_count(&self, expected: usize) {
while self.call_count() < expected {
tokio::task::yield_now().await;
}
}
fn release_call(&self, index: usize) {
let release = self.releases.lock().unwrap()[index].clone();
release.notify_one();
}
}
impl DemotionHook for IndividuallyGatedHook {
fn on_demote<'a>(
&'a self,
_conversation_id: &'a str,
_messages: Vec<Message>,
) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
let release = Arc::new(tokio::sync::Notify::new());
self.releases.lock().unwrap().push(release.clone());
Box::pin(async move {
release.notified().await;
Ok(())
})
}
}
let hook = Arc::new(IndividuallyGatedHook::default());
let mem = Arc::new(DemotingPolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
hook.clone(),
));
mem.append("c", vec![user("old 1"), assistant("old 2"), user("old 3")])
.await
.unwrap();
let mem_load = mem.clone();
let stale = tokio::spawn(async move { mem_load.load("c").await });
hook.wait_for_call_count(1).await;
mem.clear("c").await.unwrap();
mem.append(
"c",
vec![user("fresh 1"), assistant("fresh 2"), user("fresh 3")],
)
.await
.unwrap();
let mem_load = mem.clone();
let fresh = tokio::spawn(async move { mem_load.load("c").await });
hook.wait_for_call_count(2).await;
hook.release_call(0);
assert_eq!(stale.await.unwrap().unwrap().len(), 1);
assert_eq!(hook.call_count(), 2);
let mem_load = mem.clone();
let mut concurrent = tokio::spawn(async move { mem_load.load("c").await });
let hook_wait = hook.clone();
let concurrent_kept = tokio::select! {
result = &mut concurrent => result.unwrap().unwrap(),
_ = hook_wait.wait_for_call_count(3) => {
panic!("stale successful load must not clear the fresh in-flight reservation")
}
};
assert_eq!(
hook.call_count(),
2,
"stale successful load must not clear the fresh in-flight reservation"
);
hook.release_call(1);
assert_eq!(fresh.await.unwrap().unwrap().len(), 1);
assert_eq!(concurrent_kept.len(), 1);
mem.load("c").await.unwrap();
assert_eq!(hook.call_count(), 2);
}
#[tokio::test]
async fn forget_drops_in_process_watermark() {
let hook = Arc::new(CountingHook::default());
let mem = DemotingPolicyMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
hook.clone(),
);
mem.append("c", vec![user("1"), assistant("2")])
.await
.unwrap();
mem.load("c").await.unwrap();
assert_eq!(mem.tracked_conversations(), 1);
assert_eq!(hook.calls(), 1);
mem.forget("c");
assert_eq!(mem.tracked_conversations(), 0);
mem.load("c").await.unwrap();
assert_eq!(hook.calls(), 2);
}
#[tokio::test]
async fn compacting_no_demotion_returns_kept_only() {
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(10),
TemplateCompactor::new(),
);
mem.append("c", vec![user("hi"), assistant("hello")])
.await
.unwrap();
let loaded = mem.load("c").await.unwrap();
assert_eq!(loaded.len(), 2);
assert!(matches!(&loaded[0], Message::User { .. }));
}
#[tokio::test]
async fn compacting_splices_summary_when_demoted() {
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(2),
TemplateCompactor::new(),
);
mem.append(
"c",
vec![
user("first"),
assistant("second"),
user("third"),
assistant("fourth"),
],
)
.await
.unwrap();
let loaded = mem.load("c").await.unwrap();
assert_eq!(loaded.len(), 3);
let Message::System { content } = &loaded[0] else {
panic!("expected summary as system message");
};
assert!(content.contains("[Conversation summary so far]"));
assert!(content.contains("user: first"));
assert!(content.contains("assistant: second"));
let Message::User { content } = &loaded[1] else {
panic!("expected kept user message");
};
let UserContent::Text(t) = content.first_ref() else {
panic!("expected text");
};
assert_eq!(t.text, "third");
}
#[tokio::test]
async fn compacting_rolls_summary_forward() {
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(2),
TemplateCompactor::new(),
);
mem.append(
"c",
vec![user("a"), assistant("b"), user("c"), assistant("d")],
)
.await
.unwrap();
let first = mem.load("c").await.unwrap();
let Message::System { content } = &first[0] else {
panic!("summary missing");
};
let first_summary = content.clone();
assert!(first_summary.contains("user: a"));
assert!(first_summary.contains("assistant: b"));
mem.append("c", vec![user("e"), assistant("f")])
.await
.unwrap();
let second = mem.load("c").await.unwrap();
let Message::System { content } = &second[0] else {
panic!("summary missing");
};
assert!(content.contains(&first_summary));
assert!(content.contains("user: c"));
assert!(content.contains("assistant: d"));
}
#[tokio::test]
async fn compacting_idempotent_within_process() {
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
TemplateCompactor::new(),
);
mem.append("c", vec![user("a"), assistant("b"), user("c")])
.await
.unwrap();
let first = mem.load("c").await.unwrap();
let second = mem.load("c").await.unwrap();
assert_eq!(first.len(), second.len());
let Message::System { content: c1 } = &first[0] else {
panic!()
};
let Message::System { content: c2 } = &second[0] else {
panic!()
};
assert_eq!(c1, c2);
}
#[tokio::test]
async fn compacting_clear_drops_summary() {
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
TemplateCompactor::new(),
);
mem.append("c", vec![user("a"), assistant("b"), user("c")])
.await
.unwrap();
mem.load("c").await.unwrap();
assert_eq!(mem.tracked_conversations(), 1);
mem.clear("c").await.unwrap();
assert_eq!(mem.tracked_conversations(), 0);
assert!(mem.load("c").await.unwrap().is_empty());
}
#[derive(Default)]
struct FlakyCompactor {
calls: std::sync::atomic::AtomicUsize,
}
impl Compactor for FlakyCompactor {
type Artifact = TextSummary;
fn compact<'a>(
&'a self,
_conversation_id: &'a str,
evicted: &'a [Message],
_carry_over: Option<&'a Self::Artifact>,
) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
Box::pin(async move {
let n = self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if n == 0 {
Err(MemoryError::Policy("flaky".into()))
} else {
Ok(TextSummary(format!("compacted {} messages", evicted.len())))
}
})
}
}
#[tokio::test]
async fn compacting_failure_does_not_advance_watermark() {
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
FlakyCompactor::default(),
);
mem.append("c", vec![user("a"), assistant("b"), user("c")])
.await
.unwrap();
let err = mem.load("c").await.unwrap_err();
assert!(matches!(err, MemoryError::Policy(_)));
let loaded = mem.load("c").await.unwrap();
assert_eq!(loaded.len(), 2);
let Message::System { content } = &loaded[0] else {
panic!("expected summary")
};
assert!(content.contains("compacted"));
}
#[derive(Default)]
struct CountingCompactor {
log: Mutex<Vec<(usize, bool)>>,
}
impl CountingCompactor {
fn calls(&self) -> Vec<(usize, bool)> {
self.log.lock().unwrap().clone()
}
}
impl Compactor for CountingCompactor {
type Artifact = TextSummary;
fn compact<'a>(
&'a self,
_conversation_id: &'a str,
evicted: &'a [Message],
carry_over: Option<&'a Self::Artifact>,
) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
Box::pin(async move {
self.log
.lock()
.unwrap()
.push((evicted.len(), carry_over.is_some()));
let prev = carry_over.map(|s| s.as_str()).unwrap_or("");
Ok(TextSummary(format!("{prev}|{}", evicted.len())))
})
}
}
#[tokio::test]
async fn compacting_no_demotion_does_not_invoke_compactor() {
let compactor = Arc::new(CountingCompactor::default());
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(10),
compactor.clone(),
);
mem.append("c", vec![user("a"), assistant("b")])
.await
.unwrap();
mem.load("c").await.unwrap();
mem.load("c").await.unwrap();
mem.load("c").await.unwrap();
assert!(compactor.calls().is_empty());
assert_eq!(mem.tracked_conversations(), 0);
}
#[tokio::test]
async fn compacting_invokes_compactor_only_on_new_demotions() {
let compactor = Arc::new(CountingCompactor::default());
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(2),
compactor.clone(),
);
mem.append(
"c",
vec![user("a"), assistant("b"), user("c"), assistant("d")],
)
.await
.unwrap();
mem.load("c").await.unwrap();
mem.load("c").await.unwrap();
mem.load("c").await.unwrap();
let calls = compactor.calls();
assert_eq!(
calls.len(),
1,
"compactor invoked more than once: {calls:?}"
);
assert_eq!(calls[0], (2, false));
mem.append("c", vec![user("e"), assistant("f")])
.await
.unwrap();
mem.load("c").await.unwrap();
mem.load("c").await.unwrap();
let calls = compactor.calls();
assert_eq!(calls.len(), 2, "expected exactly one new call: {calls:?}");
assert_eq!(calls[1], (2, true));
}
#[tokio::test]
async fn compacting_serialises_concurrent_loads() {
let compactor = Arc::new(CountingCompactor::default());
let mem = Arc::new(CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(2),
compactor.clone(),
));
mem.append(
"c",
vec![user("a"), assistant("b"), user("c"), assistant("d")],
)
.await
.unwrap();
let mut handles = Vec::new();
for _ in 0..32 {
let mem = mem.clone();
handles.push(tokio::spawn(async move {
mem.load("c").await.unwrap();
}));
}
for h in handles {
h.await.unwrap();
}
let calls = compactor.calls();
assert_eq!(calls.len(), 1, "expected exactly 1 call: {calls:?}");
}
#[tokio::test]
async fn compacting_clear_drops_summary_carry_over() {
let compactor = Arc::new(CountingCompactor::default());
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
compactor.clone(),
);
mem.append("c", vec![user("a"), assistant("b"), user("c")])
.await
.unwrap();
mem.load("c").await.unwrap();
assert_eq!(compactor.calls()[0], (2, false));
mem.clear("c").await.unwrap();
assert_eq!(mem.tracked_conversations(), 0);
mem.append("c", vec![user("x"), assistant("y"), user("z")])
.await
.unwrap();
mem.load("c").await.unwrap();
let calls = compactor.calls();
assert_eq!(calls.len(), 2);
assert_eq!(calls[1], (2, false));
}
#[tokio::test]
async fn compacting_forget_drops_summary() {
let compactor = Arc::new(CountingCompactor::default());
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
compactor.clone(),
);
mem.append("c", vec![user("a"), assistant("b"), user("c")])
.await
.unwrap();
mem.load("c").await.unwrap();
assert_eq!(mem.tracked_conversations(), 1);
mem.forget("c");
assert_eq!(mem.tracked_conversations(), 0);
mem.load("c").await.unwrap();
let calls = compactor.calls();
assert_eq!(calls.len(), 2);
assert_eq!(calls[1], (2, false));
}
#[tokio::test]
async fn compacting_arc_compactor_works() {
let compactor: Arc<dyn Compactor<Artifact = TextSummary>> =
Arc::new(TemplateCompactor::new());
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
compactor,
);
mem.append("c", vec![user("a"), assistant("b"), user("c")])
.await
.unwrap();
let loaded = mem.load("c").await.unwrap();
assert_eq!(loaded.len(), 2);
assert!(matches!(&loaded[0], Message::System { .. }));
}
#[tokio::test]
async fn compacting_into_inner_returns_components() {
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
TemplateCompactor::new(),
);
let (_inner, _policy, _compactor) = mem.into_inner();
}
#[tokio::test]
async fn compacting_isolates_conversations() {
let compactor = Arc::new(CountingCompactor::default());
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
compactor.clone(),
);
mem.append("a", vec![user("a1"), assistant("a2"), user("a3")])
.await
.unwrap();
mem.append("b", vec![user("b1"), assistant("b2"), user("b3")])
.await
.unwrap();
let a = mem.load("a").await.unwrap();
let b = mem.load("b").await.unwrap();
assert_eq!(a.len(), 2);
assert_eq!(b.len(), 2);
assert_eq!(compactor.calls().len(), 2);
assert_eq!(mem.tracked_conversations(), 2);
}
#[tokio::test]
async fn compacting_composes_with_token_window() {
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
TokenWindowMemory::new(30, HeuristicTokenCounter::openai()),
TemplateCompactor::new(),
);
mem.append(
"c",
vec![
user("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"),
assistant("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"),
user("cccccccccccccccccccc"),
assistant("d"),
],
)
.await
.unwrap();
let loaded = mem.load("c").await.unwrap();
assert!(loaded.len() >= 2);
assert!(matches!(&loaded[0], Message::System { .. }));
}
#[tokio::test]
async fn template_compactor_renders_system_messages() {
let compactor = TemplateCompactor::new();
let evicted = vec![
Message::System {
content: "you are helpful".into(),
},
user("hi"),
assistant("hello"),
];
let summary = compactor.compact("c", &evicted, None).await.unwrap();
let s = summary.as_str();
assert!(s.contains("system: you are helpful"), "got: {s}");
assert!(s.contains("user: hi"));
assert!(s.contains("assistant: hello"));
}
#[tokio::test]
async fn template_compactor_renders_tool_call_marker() {
let compactor = TemplateCompactor::new();
let evicted = vec![tool_call_msg(), tool_result_msg()];
let summary = compactor.compact("c", &evicted, None).await.unwrap();
let s = summary.as_str();
assert!(s.contains("[tool call: t]"), "got: {s}");
assert!(s.contains("[tool result]"), "got: {s}");
}
#[tokio::test]
async fn template_compactor_carry_over_threaded() {
let compactor = TemplateCompactor::new();
let first = compactor
.compact("c", &[user("hello")], None)
.await
.unwrap();
assert!(!first.as_str().is_empty());
let second = compactor
.compact("c", &[assistant("world")], Some(&first))
.await
.unwrap();
assert!(second.as_str().contains(first.as_str()));
assert!(second.as_str().contains("assistant: world"));
}
#[tokio::test]
async fn template_compactor_artifact_into_message() {
let s = TextSummary("rolled-up text".into());
let msg: Message = s.into();
let Message::System { content } = msg else {
panic!("expected system message");
};
assert_eq!(content, "rolled-up text");
}
#[tokio::test]
async fn template_compactor_caps_summary_at_max_bytes() {
let cap = 256;
let compactor = TemplateCompactor::new().with_max_bytes(cap);
let mut evicted = Vec::new();
for i in 0..50 {
evicted.push(user(&format!("message number {i} with some filler")));
}
let summary = compactor.compact("c", &evicted, None).await.unwrap();
assert!(
summary.as_str().len()
<= cap + "[Conversation summary so far]\n[\u{2026}truncated\u{2026}]\n".len(),
"summary len {} exceeds cap {} (plus header+marker)",
summary.as_str().len(),
cap,
);
assert!(
summary
.as_str()
.starts_with("[Conversation summary so far]\n")
);
assert!(summary.as_str().contains("[\u{2026}truncated\u{2026}]"));
assert!(summary.as_str().contains("message number 49"));
}
#[tokio::test]
async fn template_compactor_unbounded_by_default() {
let compactor = TemplateCompactor::new();
let mut evicted = Vec::new();
for i in 0..200 {
evicted.push(user(&format!("msg {i}")));
}
let summary = compactor.compact("c", &evicted, None).await.unwrap();
assert!(!summary.as_str().contains("[\u{2026}truncated\u{2026}]"));
assert!(summary.as_str().contains("msg 0"));
assert!(summary.as_str().contains("msg 199"));
}
#[tokio::test]
async fn template_compactor_with_max_bytes_zero_is_unbounded() {
let compactor = TemplateCompactor::new().with_max_bytes(0);
let mut evicted = Vec::new();
for i in 0..200 {
evicted.push(user(&format!("msg {i}")));
}
let summary = compactor.compact("c", &evicted, None).await.unwrap();
assert!(!summary.as_str().contains("[\u{2026}truncated\u{2026}]"));
}
#[tokio::test]
async fn compacting_summary_stays_bounded_across_rolls() {
let cap = 512;
let mem = CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(2),
TemplateCompactor::new().with_max_bytes(cap),
);
mem.append("c", vec![user("seed-a"), assistant("seed-b")])
.await
.unwrap();
for i in 0..30 {
mem.append(
"c",
vec![
user(&format!("user line {i} ----- padding padding padding")),
assistant(&format!("assistant line {i} ----- padding padding")),
],
)
.await
.unwrap();
mem.load("c").await.unwrap();
}
let loaded = mem.load("c").await.unwrap();
let Message::System { content } = &loaded[0] else {
panic!("expected summary");
};
let slack = "[Conversation summary so far]\n[\u{2026}truncated\u{2026}]\n".len();
assert!(
content.len() <= cap + slack,
"summary grew to {} bytes (cap {}, slack {})",
content.len(),
cap,
slack,
);
}
#[tokio::test]
async fn compacting_concurrent_with_clear_does_not_resurrect_state() {
use std::sync::atomic::{AtomicBool, Ordering};
struct GatedCompactor {
release: tokio::sync::Notify,
entered: AtomicBool,
}
impl Compactor for GatedCompactor {
type Artifact = TextSummary;
fn compact<'a>(
&'a self,
_conversation_id: &'a str,
_evicted: &'a [Message],
_carry_over: Option<&'a Self::Artifact>,
) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
Box::pin(async move {
self.entered.store(true, Ordering::SeqCst);
self.release.notified().await;
Ok(TextSummary("late summary".into()))
})
}
}
let compactor = Arc::new(GatedCompactor {
release: tokio::sync::Notify::new(),
entered: AtomicBool::new(false),
});
let mem = Arc::new(CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
compactor.clone(),
));
mem.append("c", vec![user("a"), assistant("b"), user("c")])
.await
.unwrap();
let mem_load = mem.clone();
let load_handle = tokio::spawn(async move { mem_load.load("c").await });
while !compactor.entered.load(Ordering::SeqCst) {
tokio::task::yield_now().await;
}
mem.clear("c").await.unwrap();
compactor.release.notify_one();
let _ = load_handle.await.unwrap();
assert_eq!(mem.tracked_conversations(), 0);
assert!(mem.load("c").await.unwrap().is_empty());
}
#[tokio::test]
async fn compacting_dropped_load_releases_in_flight_gate() {
use std::sync::atomic::{AtomicUsize, Ordering};
struct GatedCompactor {
release: tokio::sync::Notify,
entered: AtomicUsize,
}
impl Compactor for GatedCompactor {
type Artifact = TextSummary;
fn compact<'a>(
&'a self,
_conversation_id: &'a str,
_evicted: &'a [Message],
_carry_over: Option<&'a Self::Artifact>,
) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
Box::pin(async move {
self.entered.fetch_add(1, Ordering::SeqCst);
self.release.notified().await;
Ok(TextSummary("ran".into()))
})
}
}
let compactor = Arc::new(GatedCompactor {
release: tokio::sync::Notify::new(),
entered: AtomicUsize::new(0),
});
let mem = Arc::new(CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
compactor.clone(),
));
mem.append("c", vec![user("a"), assistant("b"), user("c")])
.await
.unwrap();
let mem_load = mem.clone();
let handle = tokio::spawn(async move { mem_load.load("c").await });
while compactor.entered.load(Ordering::SeqCst) == 0 {
tokio::task::yield_now().await;
}
handle.abort();
let _ = handle.await;
let mem_load = mem.clone();
let retry = tokio::spawn(async move { mem_load.load("c").await });
while compactor.entered.load(Ordering::SeqCst) < 2 {
tokio::task::yield_now().await;
}
compactor.release.notify_one();
let loaded = retry.await.unwrap().unwrap();
assert_eq!(loaded.len(), 2);
let Message::System { content } = &loaded[0] else {
panic!("expected summary")
};
assert_eq!(content, "ran");
}
#[tokio::test]
async fn compacting_stale_cancelled_load_does_not_clear_new_reservation() {
use std::sync::atomic::{AtomicUsize, Ordering};
struct GatedCompactor {
release: tokio::sync::Notify,
rendezvous: tokio::sync::Notify,
entered: AtomicUsize,
}
impl Compactor for GatedCompactor {
type Artifact = TextSummary;
fn compact<'a>(
&'a self,
_conversation_id: &'a str,
_evicted: &'a [Message],
_carry_over: Option<&'a Self::Artifact>,
) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
Box::pin(async move {
self.entered.fetch_add(1, Ordering::SeqCst);
self.rendezvous.notify_one();
self.release.notified().await;
Ok(TextSummary("ran".into()))
})
}
}
let compactor = Arc::new(GatedCompactor {
release: tokio::sync::Notify::new(),
rendezvous: tokio::sync::Notify::new(),
entered: AtomicUsize::new(0),
});
let mem = Arc::new(CompactingMemory::new(
InMemoryConversationMemory::new(),
SlidingWindowMemory::last_messages(1),
compactor.clone(),
));
mem.append("c", vec![user("old 1"), assistant("old 2"), user("old 3")])
.await
.unwrap();
let mem_load = mem.clone();
let stale = tokio::spawn(async move { mem_load.load("c").await });
compactor.rendezvous.notified().await;
assert_eq!(compactor.entered.load(Ordering::SeqCst), 1);
mem.clear("c").await.unwrap();
mem.append(
"c",
vec![user("fresh 1"), assistant("fresh 2"), user("fresh 3")],
)
.await
.unwrap();
let mem_load = mem.clone();
let fresh = tokio::spawn(async move { mem_load.load("c").await });
compactor.rendezvous.notified().await;
assert_eq!(compactor.entered.load(Ordering::SeqCst), 2);
stale.abort();
let _ = stale.await;
let mem_load = mem.clone();
let mut concurrent = tokio::spawn(async move { mem_load.load("c").await });
let concurrent_kept = tokio::select! {
result = &mut concurrent => result.unwrap().unwrap(),
_ = compactor.rendezvous.notified() => {
panic!("stale guard must not clear the fresh in-flight reservation")
}
};
assert_eq!(
compactor.entered.load(Ordering::SeqCst),
2,
"stale guard must not clear the fresh in-flight reservation"
);
compactor.release.notify_one();
assert_eq!(fresh.await.unwrap().unwrap().len(), 2);
assert_eq!(concurrent_kept.len(), 1);
assert_eq!(compactor.entered.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn template_compactor_caps_summary_with_multiline_header() {
let cap = 256;
let compactor = TemplateCompactor::with_header("line one\nline two").with_max_bytes(cap);
let mut evicted = Vec::new();
for i in 0..50 {
evicted.push(user(&format!("message number {i} with some filler")));
}
let summary = compactor.compact("c", &evicted, None).await.unwrap();
let text = summary.as_str();
assert!(text.starts_with("line one\n"));
assert!(text.contains("[\u{2026}truncated\u{2026}]"));
assert!(text.contains("message number 49"));
let overhead = "line one\n".len() + "[\u{2026}truncated\u{2026}]\n".len();
assert!(
text.len() <= cap + overhead,
"summary len {} exceeds cap {} plus overhead {}",
text.len(),
cap,
overhead,
);
}
}