1#![cfg_attr(
2 test,
3 allow(
4 clippy::expect_used,
5 clippy::indexing_slicing,
6 clippy::panic,
7 clippy::unwrap_used,
8 clippy::unreachable
9 )
10)]
11use std::{
43 collections::HashMap,
44 sync::{Arc, Mutex as StdMutex},
45};
46
47pub use rig_core::memory::{
50 Compactor, ConversationMemory, DemotionHook, InMemoryConversationMemory, MemoryError,
51 NoopDemotionHook,
52};
53
54use rig_core::completion::Message;
55use rig_core::message::UserContent;
56use rig_core::wasm_compat::{WasmBoxedFuture, WasmCompatSend, WasmCompatSync};
57
58pub trait MemoryPolicy: WasmCompatSend + WasmCompatSync {
64 fn apply(&self, messages: Vec<Message>) -> Result<Vec<Message>, MemoryError>;
67
68 fn apply_with_demoted(
82 &self,
83 messages: Vec<Message>,
84 ) -> Result<(Vec<Message>, Vec<Message>), MemoryError> {
85 Ok((self.apply(messages)?, Vec::new()))
86 }
87}
88
89impl<P> MemoryPolicy for Arc<P>
90where
91 P: MemoryPolicy + ?Sized,
92{
93 fn apply(&self, messages: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
94 (**self).apply(messages)
95 }
96
97 fn apply_with_demoted(
98 &self,
99 messages: Vec<Message>,
100 ) -> Result<(Vec<Message>, Vec<Message>), MemoryError> {
101 (**self).apply_with_demoted(messages)
102 }
103}
104
105impl<P> MemoryPolicy for Box<P>
106where
107 P: MemoryPolicy + ?Sized,
108{
109 fn apply(&self, messages: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
110 (**self).apply(messages)
111 }
112
113 fn apply_with_demoted(
114 &self,
115 messages: Vec<Message>,
116 ) -> Result<(Vec<Message>, Vec<Message>), MemoryError> {
117 (**self).apply_with_demoted(messages)
118 }
119}
120
121pub trait IntoFilter: MemoryPolicy + Sized + 'static {
128 #[cfg(not(target_family = "wasm"))]
135 fn into_filter(self) -> Box<dyn Fn(Vec<Message>) -> Vec<Message> + Send + Sync> {
136 let policy = Arc::new(self);
137 Box::new(move |msgs| {
138 let fallback = msgs.clone();
139 match policy.apply(msgs) {
140 Ok(out) => out,
141 Err(err) => {
142 tracing::warn!(error = %err, "memory policy failed; returning unfiltered history");
143 fallback
144 }
145 }
146 })
147 }
148
149 #[cfg(target_family = "wasm")]
156 fn into_filter(self) -> Box<dyn Fn(Vec<Message>) -> Vec<Message>> {
157 let policy = Arc::new(self);
158 Box::new(move |msgs| {
159 let fallback = msgs.clone();
160 match policy.apply(msgs) {
161 Ok(out) => out,
162 Err(err) => {
163 tracing::warn!(error = %err, "memory policy failed; returning unfiltered history");
164 fallback
165 }
166 }
167 })
168 }
169}
170
171impl<P> IntoFilter for P where P: MemoryPolicy + 'static {}
172
173#[derive(Debug, Default, Clone, Copy)]
175pub struct NoopMemoryPolicy;
176
177impl MemoryPolicy for NoopMemoryPolicy {
178 fn apply(&self, messages: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
179 Ok(messages)
180 }
181}
182
183#[derive(Debug, Clone, Copy)]
190pub struct SlidingWindowMemory {
191 max_messages: usize,
192}
193
194impl SlidingWindowMemory {
195 pub fn last_messages(n: usize) -> Self {
197 Self { max_messages: n }
198 }
199}
200
201impl MemoryPolicy for SlidingWindowMemory {
202 fn apply(&self, messages: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
203 Ok(self.apply_with_demoted(messages)?.0)
204 }
205
206 fn apply_with_demoted(
207 &self,
208 messages: Vec<Message>,
209 ) -> Result<(Vec<Message>, Vec<Message>), MemoryError> {
210 if messages.len() <= self.max_messages {
211 return Ok((messages, Vec::new()));
212 }
213
214 let start = messages.len() - self.max_messages;
215 let mut iter = messages.into_iter();
216 let mut demoted: Vec<Message> = (&mut iter).take(start).collect();
217 let mut window: Vec<Message> = iter.collect();
218
219 if let Some(Message::User { content }) = window.first()
223 && matches!(content.first_ref(), UserContent::ToolResult(_))
224 {
225 demoted.push(window.remove(0));
226 }
227
228 Ok((window, demoted))
229 }
230}
231
232pub trait TokenCounter: WasmCompatSend + WasmCompatSync {
238 fn count(&self, message: &Message) -> usize;
240}
241
242impl<F> TokenCounter for F
243where
244 F: Fn(&Message) -> usize + WasmCompatSend + WasmCompatSync,
245{
246 fn count(&self, message: &Message) -> usize {
247 (self)(message)
248 }
249}
250
251impl<C> TokenCounter for Arc<C>
252where
253 C: TokenCounter + ?Sized,
254{
255 fn count(&self, message: &Message) -> usize {
256 (**self).count(message)
257 }
258}
259
260impl TokenCounter for Box<dyn TokenCounter> {
261 fn count(&self, message: &Message) -> usize {
262 (**self).count(message)
263 }
264}
265
266#[derive(Debug, Clone, Copy)]
308pub struct HeuristicTokenCounter {
309 bytes_per_token: f32,
310 per_message_overhead: usize,
311 per_attachment_tokens: usize,
312}
313
314impl HeuristicTokenCounter {
315 pub fn new(
320 bytes_per_token: f32,
321 per_message_overhead: usize,
322 per_attachment_tokens: usize,
323 ) -> Self {
324 let bytes_per_token = if bytes_per_token.is_finite() && bytes_per_token >= 1.0 {
325 bytes_per_token
326 } else {
327 1.0
328 };
329 Self {
330 bytes_per_token,
331 per_message_overhead,
332 per_attachment_tokens,
333 }
334 }
335
336 pub fn openai() -> Self {
340 Self::new(4.0, 4, 256)
341 }
342
343 pub fn anthropic() -> Self {
345 Self::new(3.5, 4, 256)
346 }
347
348 pub fn gemini() -> Self {
350 Self::new(4.0, 4, 256)
351 }
352
353 fn bytes_to_tokens(&self, bytes: usize) -> usize {
354 let tokens = (bytes as f32) / self.bytes_per_token;
358 tokens.ceil() as usize
359 }
360
361 fn count_user(&self, content: &rig_core::message::UserContent) -> usize {
362 use rig_core::message::UserContent;
363 match content {
364 UserContent::Text(text) => self.bytes_to_tokens(text.text.len()),
365 UserContent::ToolResult(result) => result
366 .content
367 .iter()
368 .map(|c| match c {
369 rig_core::message::ToolResultContent::Text(t) => {
370 self.bytes_to_tokens(t.text.len())
371 }
372 rig_core::message::ToolResultContent::Image(_) => self.per_attachment_tokens,
373 })
374 .sum(),
375 UserContent::Image(_)
376 | UserContent::Audio(_)
377 | UserContent::Video(_)
378 | UserContent::Document(_) => self.per_attachment_tokens,
379 }
380 }
381
382 fn count_assistant(&self, content: &rig_core::message::AssistantContent) -> usize {
383 use rig_core::message::AssistantContent;
384 match content {
385 AssistantContent::Text(text) => self.bytes_to_tokens(text.text.len()),
386 AssistantContent::Reasoning(reasoning) => {
387 self.bytes_to_tokens(reasoning.display_text().len())
388 }
389 AssistantContent::ToolCall(call) => {
390 let name_bytes = call.function.name.len();
391 let args_bytes = call.function.arguments.to_string().len();
396 self.bytes_to_tokens(name_bytes + args_bytes)
397 }
398 AssistantContent::Image(_) => self.per_attachment_tokens,
399 }
400 }
401}
402
403impl Default for HeuristicTokenCounter {
404 fn default() -> Self {
405 Self::openai()
406 }
407}
408
409impl TokenCounter for HeuristicTokenCounter {
410 fn count(&self, message: &Message) -> usize {
411 let content_tokens: usize = match message {
412 Message::User { content } => content.iter().map(|c| self.count_user(c)).sum(),
413 Message::Assistant { content, .. } => {
414 content.iter().map(|c| self.count_assistant(c)).sum()
415 }
416 Message::System { content } => self.bytes_to_tokens(content.len()),
417 };
418 content_tokens.saturating_add(self.per_message_overhead)
419 }
420}
421
422pub struct TokenWindowMemory {
431 max_tokens: usize,
432 counter: Arc<dyn TokenCounter>,
433}
434
435impl TokenWindowMemory {
436 pub fn new<C>(max_tokens: usize, counter: C) -> Self
438 where
439 C: TokenCounter + 'static,
440 {
441 Self {
442 max_tokens,
443 counter: Arc::new(counter),
444 }
445 }
446}
447
448impl std::fmt::Debug for TokenWindowMemory {
449 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
450 f.debug_struct("TokenWindowMemory")
451 .field("max_tokens", &self.max_tokens)
452 .field("counter", &"<counter>")
453 .finish()
454 }
455}
456
457impl MemoryPolicy for TokenWindowMemory {
458 fn apply(&self, messages: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
459 Ok(self.apply_with_demoted(messages)?.0)
460 }
461
462 fn apply_with_demoted(
463 &self,
464 messages: Vec<Message>,
465 ) -> Result<(Vec<Message>, Vec<Message>), MemoryError> {
466 let mut budget = self.max_tokens;
467 let mut keep_from = messages.len();
468
469 for (idx, msg) in messages.iter().enumerate().rev() {
470 let cost = self.counter.count(msg);
471 if cost > budget {
472 break;
473 }
474 budget -= cost;
475 keep_from = idx;
476 }
477
478 let mut iter = messages.into_iter();
479 let mut demoted: Vec<Message> = (&mut iter).take(keep_from).collect();
480 let mut window: Vec<Message> = iter.collect();
481
482 if let Some(Message::User { content }) = window.first()
483 && matches!(content.first_ref(), UserContent::ToolResult(_))
484 {
485 demoted.push(window.remove(0));
486 }
487
488 Ok((window, demoted))
489 }
490}
491
492#[derive(Debug, Clone, Copy)]
511pub struct PolicyMemory<M, P> {
512 inner: M,
513 policy: P,
514}
515
516impl<M, P> PolicyMemory<M, P> {
517 pub fn new(inner: M, policy: P) -> Self {
519 Self { inner, policy }
520 }
521
522 pub fn inner(&self) -> &M {
524 &self.inner
525 }
526
527 pub fn policy(&self) -> &P {
529 &self.policy
530 }
531
532 pub fn into_inner(self) -> (M, P) {
534 (self.inner, self.policy)
535 }
536}
537
538impl<M, P> ConversationMemory for PolicyMemory<M, P>
539where
540 M: ConversationMemory,
541 P: MemoryPolicy,
542{
543 fn load<'a>(
544 &'a self,
545 conversation_id: &'a str,
546 ) -> WasmBoxedFuture<'a, Result<Vec<Message>, MemoryError>> {
547 Box::pin(async move {
548 let messages = self.inner.load(conversation_id).await?;
549 self.policy.apply(messages)
550 })
551 }
552
553 fn append<'a>(
554 &'a self,
555 conversation_id: &'a str,
556 messages: Vec<Message>,
557 ) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
558 self.inner.append(conversation_id, messages)
559 }
560
561 fn clear<'a>(
562 &'a self,
563 conversation_id: &'a str,
564 ) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
565 self.inner.clear(conversation_id)
566 }
567}
568
569pub struct DemotingPolicyMemory<M, P, H> {
621 inner: M,
622 policy: P,
623 hook: H,
624 state: StdMutex<HashMap<String, ConversationDemotionState>>,
625}
626
627type InFlightReservation = Arc<()>;
628
629#[derive(Debug, Default, Clone)]
630struct ConversationDemotionState {
631 delivered: usize,
634 in_flight: Option<InFlightReservation>,
638}
639
640impl<M, P, H> DemotingPolicyMemory<M, P, H> {
641 pub fn new(inner: M, policy: P, hook: H) -> Self {
644 Self {
645 inner,
646 policy,
647 hook,
648 state: StdMutex::new(HashMap::new()),
649 }
650 }
651
652 pub fn inner(&self) -> &M {
654 &self.inner
655 }
656
657 pub fn policy(&self) -> &P {
659 &self.policy
660 }
661
662 pub fn hook(&self) -> &H {
664 &self.hook
665 }
666
667 pub fn into_inner(self) -> (M, P, H) {
669 (self.inner, self.policy, self.hook)
670 }
671
672 pub fn forget(&self, conversation_id: &str) {
682 if let Ok(mut guard) = self.state.lock() {
683 guard.remove(conversation_id);
684 }
685 }
686
687 pub fn tracked_conversations(&self) -> usize {
691 self.state.lock().map(|g| g.len()).unwrap_or(0)
692 }
693}
694
695impl<M, P, H> std::fmt::Debug for DemotingPolicyMemory<M, P, H>
696where
697 M: std::fmt::Debug,
698 P: std::fmt::Debug,
699{
700 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
701 f.debug_struct("DemotingPolicyMemory")
702 .field("inner", &self.inner)
703 .field("policy", &self.policy)
704 .field("hook", &"<hook>")
705 .finish()
706 }
707}
708
709impl<M, P, H> ConversationMemory for DemotingPolicyMemory<M, P, H>
710where
711 M: ConversationMemory,
712 P: MemoryPolicy,
713 H: DemotionHook,
714{
715 fn load<'a>(
716 &'a self,
717 conversation_id: &'a str,
718 ) -> WasmBoxedFuture<'a, Result<Vec<Message>, MemoryError>> {
719 Box::pin(async move {
720 let messages = self.inner.load(conversation_id).await?;
721 let (kept, mut demoted) = self.policy.apply_with_demoted(messages)?;
722 let demoted_count = demoted.len();
723
724 let (pending, reservation) = {
734 let mut guard = self.state.lock().map_err(poisoned)?;
735 if let Some(entry) = guard.get_mut(conversation_id) {
736 if entry.in_flight.is_some() {
737 return Ok(kept);
741 }
742 if entry.delivered >= demoted_count {
743 (Vec::new(), None)
744 } else {
745 let split = entry.delivered;
746 let reservation = Arc::new(());
747 entry.in_flight = Some(reservation.clone());
748 (demoted.split_off(split), Some(reservation))
749 }
750 } else if demoted_count == 0 {
751 (Vec::new(), None)
754 } else {
755 let reservation = Arc::new(());
756 guard.insert(
757 conversation_id.to_string(),
758 ConversationDemotionState {
759 delivered: 0,
760 in_flight: Some(reservation.clone()),
761 },
762 );
763 (std::mem::take(&mut demoted), Some(reservation))
764 }
765 };
766
767 let Some(reservation) = reservation else {
768 return Ok(kept);
769 };
770
771 let in_flight_guard =
777 DemotionInFlightGuard::new(&self.state, conversation_id, reservation.clone());
778
779 let result = self.hook.on_demote(conversation_id, pending).await;
780
781 {
791 let mut guard = self.state.lock().map_err(poisoned)?;
792 if let Some(entry) = guard.get_mut(conversation_id)
793 && entry
794 .in_flight
795 .as_ref()
796 .is_some_and(|current| Arc::ptr_eq(current, &reservation))
797 {
798 entry.in_flight = None;
799 if result.is_ok() {
800 entry.delivered = demoted_count;
801 }
802 }
803 }
804 in_flight_guard.disarm();
805 result?;
806 Ok(kept)
807 })
808 }
809
810 fn append<'a>(
811 &'a self,
812 conversation_id: &'a str,
813 messages: Vec<Message>,
814 ) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
815 self.inner.append(conversation_id, messages)
816 }
817
818 fn clear<'a>(
819 &'a self,
820 conversation_id: &'a str,
821 ) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
822 Box::pin(async move {
823 self.inner.clear(conversation_id).await?;
824 self.forget(conversation_id);
825 Ok(())
826 })
827 }
828}
829
830fn poisoned<E: std::fmt::Display>(err: E) -> MemoryError {
831 MemoryError::Internal(err.to_string())
832}
833
834struct DemotionInFlightGuard<'a> {
844 state: &'a StdMutex<HashMap<String, ConversationDemotionState>>,
845 key: &'a str,
846 reservation: InFlightReservation,
847 armed: bool,
848}
849
850impl<'a> DemotionInFlightGuard<'a> {
851 fn new(
852 state: &'a StdMutex<HashMap<String, ConversationDemotionState>>,
853 key: &'a str,
854 reservation: InFlightReservation,
855 ) -> Self {
856 Self {
857 state,
858 key,
859 reservation,
860 armed: true,
861 }
862 }
863
864 fn disarm(mut self) {
867 self.armed = false;
868 }
869}
870
871impl Drop for DemotionInFlightGuard<'_> {
872 fn drop(&mut self) {
873 if !self.armed {
874 return;
875 }
876 if let Ok(mut guard) = self.state.lock()
877 && let Some(entry) = guard.get_mut(self.key)
878 && entry
879 .in_flight
880 .as_ref()
881 .is_some_and(|current| Arc::ptr_eq(current, &self.reservation))
882 {
883 entry.in_flight = None;
884 }
885 }
886}
887
888struct InFlightGuard<'a, A> {
899 state: &'a StdMutex<HashMap<String, ConversationCompactionState<A>>>,
900 key: &'a str,
901 reservation: InFlightReservation,
902 armed: bool,
903}
904
905impl<'a, A> InFlightGuard<'a, A> {
906 fn new(
907 state: &'a StdMutex<HashMap<String, ConversationCompactionState<A>>>,
908 key: &'a str,
909 reservation: InFlightReservation,
910 ) -> Self {
911 Self {
912 state,
913 key,
914 reservation,
915 armed: true,
916 }
917 }
918
919 fn disarm(mut self) {
922 self.armed = false;
923 }
924}
925
926impl<A> Drop for InFlightGuard<'_, A> {
927 fn drop(&mut self) {
928 if !self.armed {
929 return;
930 }
931 if let Ok(mut guard) = self.state.lock()
932 && let Some(entry) = guard.get_mut(self.key)
933 && entry
934 .in_flight
935 .as_ref()
936 .is_some_and(|current| Arc::ptr_eq(current, &self.reservation))
937 {
938 entry.in_flight = None;
939 }
940 }
941}
942
943pub struct CompactingMemory<M, P, C: Compactor> {
1013 inner: M,
1014 policy: P,
1015 compactor: C,
1016 state: StdMutex<HashMap<String, ConversationCompactionState<C::Artifact>>>,
1017}
1018
1019struct ConversationCompactionState<A> {
1020 summary: Option<A>,
1023 absorbed: usize,
1026 in_flight: Option<InFlightReservation>,
1030}
1031
1032impl<M, P, C: Compactor> CompactingMemory<M, P, C> {
1033 pub fn new(inner: M, policy: P, compactor: C) -> Self {
1036 Self {
1037 inner,
1038 policy,
1039 compactor,
1040 state: StdMutex::new(HashMap::new()),
1041 }
1042 }
1043
1044 pub fn inner(&self) -> &M {
1046 &self.inner
1047 }
1048
1049 pub fn policy(&self) -> &P {
1051 &self.policy
1052 }
1053
1054 pub fn compactor(&self) -> &C {
1056 &self.compactor
1057 }
1058
1059 pub fn into_inner(self) -> (M, P, C) {
1061 (self.inner, self.policy, self.compactor)
1062 }
1063
1064 pub fn forget(&self, conversation_id: &str) {
1070 if let Ok(mut guard) = self.state.lock() {
1071 guard.remove(conversation_id);
1072 }
1073 }
1074
1075 pub fn tracked_conversations(&self) -> usize {
1079 self.state.lock().map(|g| g.len()).unwrap_or(0)
1080 }
1081}
1082
1083impl<M, P, C> std::fmt::Debug for CompactingMemory<M, P, C>
1084where
1085 M: std::fmt::Debug,
1086 P: std::fmt::Debug,
1087 C: Compactor,
1088{
1089 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1090 f.debug_struct("CompactingMemory")
1091 .field("inner", &self.inner)
1092 .field("policy", &self.policy)
1093 .field("compactor", &"<compactor>")
1094 .finish()
1095 }
1096}
1097
1098impl<M, P, C> ConversationMemory for CompactingMemory<M, P, C>
1099where
1100 M: ConversationMemory,
1101 P: MemoryPolicy,
1102 C: Compactor,
1103{
1104 fn load<'a>(
1105 &'a self,
1106 conversation_id: &'a str,
1107 ) -> WasmBoxedFuture<'a, Result<Vec<Message>, MemoryError>> {
1108 Box::pin(async move {
1109 let messages = self.inner.load(conversation_id).await?;
1110 let (kept, demoted) = self.policy.apply_with_demoted(messages)?;
1111 let demoted_count = demoted.len();
1112
1113 let plan = {
1122 let mut guard = self.state.lock().map_err(poisoned)?;
1123 if let Some(entry) = guard.get_mut(conversation_id) {
1124 if entry.in_flight.is_some() {
1125 return Ok(splice(entry.summary.clone(), kept));
1129 }
1130 if demoted_count <= entry.absorbed {
1131 return Ok(splice(entry.summary.clone(), kept));
1134 }
1135 let reservation = Arc::new(());
1136 entry.in_flight = Some(reservation.clone());
1137 CompactionPlan {
1138 carry_over: entry.summary.clone(),
1139 skip: entry.absorbed,
1140 reservation,
1141 }
1142 } else if demoted_count == 0 {
1143 return Ok(kept);
1146 } else {
1147 let reservation = Arc::new(());
1148 guard.insert(
1149 conversation_id.to_string(),
1150 ConversationCompactionState {
1151 summary: None,
1152 absorbed: 0,
1153 in_flight: Some(reservation.clone()),
1154 },
1155 );
1156 CompactionPlan {
1157 carry_over: None,
1158 skip: 0,
1159 reservation,
1160 }
1161 }
1162 };
1163
1164 let CompactionPlan {
1169 carry_over,
1170 skip,
1171 reservation,
1172 } = plan;
1173
1174 let in_flight_guard =
1180 InFlightGuard::new(&self.state, conversation_id, reservation.clone());
1181
1182 let new_slice = match demoted.get(skip..) {
1183 Some(s) => s,
1184 None => {
1185 drop(in_flight_guard);
1188 return Err(MemoryError::Internal(
1189 "compaction watermark exceeds demoted slice length".into(),
1190 ));
1191 }
1192 };
1193
1194 let result = self
1195 .compactor
1196 .compact(conversation_id, new_slice, carry_over.as_ref())
1197 .await;
1198
1199 let summary_for_splice = match result {
1209 Ok(artifact) => {
1210 let mut guard = self.state.lock().map_err(poisoned)?;
1211 if let Some(entry) = guard.get_mut(conversation_id) {
1212 if entry
1213 .in_flight
1214 .as_ref()
1215 .is_some_and(|current| Arc::ptr_eq(current, &reservation))
1216 {
1217 entry.in_flight = None;
1218 entry.absorbed = demoted_count;
1219 entry.summary = Some(artifact.clone());
1220 Some(artifact)
1221 } else {
1222 None
1223 }
1224 } else {
1225 None
1228 }
1229 }
1230 Err(err) => {
1231 let mut guard = self.state.lock().map_err(poisoned)?;
1232 if let Some(entry) = guard.get_mut(conversation_id)
1233 && entry
1234 .in_flight
1235 .as_ref()
1236 .is_some_and(|current| Arc::ptr_eq(current, &reservation))
1237 {
1238 entry.in_flight = None;
1239 }
1240 return Err(err);
1241 }
1242 };
1243
1244 in_flight_guard.disarm();
1248
1249 Ok(splice(summary_for_splice, kept))
1250 })
1251 }
1252
1253 fn append<'a>(
1254 &'a self,
1255 conversation_id: &'a str,
1256 messages: Vec<Message>,
1257 ) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
1258 self.inner.append(conversation_id, messages)
1259 }
1260
1261 fn clear<'a>(
1262 &'a self,
1263 conversation_id: &'a str,
1264 ) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
1265 Box::pin(async move {
1266 self.inner.clear(conversation_id).await?;
1267 self.forget(conversation_id);
1268 Ok(())
1269 })
1270 }
1271}
1272
1273struct CompactionPlan<A> {
1274 carry_over: Option<A>,
1275 skip: usize,
1276 reservation: InFlightReservation,
1277}
1278
1279fn splice<A>(summary: Option<A>, kept: Vec<Message>) -> Vec<Message>
1280where
1281 A: Into<Message>,
1282{
1283 match summary {
1284 Some(artifact) => {
1285 let mut out = Vec::with_capacity(kept.len() + 1);
1286 out.push(artifact.into());
1287 out.extend(kept);
1288 out
1289 }
1290 None => kept,
1291 }
1292}
1293
1294#[derive(Debug, Clone)]
1326pub struct TemplateCompactor {
1327 header: String,
1328 max_bytes: Option<usize>,
1329}
1330
1331impl TemplateCompactor {
1332 pub fn new() -> Self {
1335 Self::with_header("[Conversation summary so far]")
1336 }
1337
1338 pub fn with_header(header: impl Into<String>) -> Self {
1341 Self {
1342 header: header.into(),
1343 max_bytes: None,
1344 }
1345 }
1346
1347 pub fn with_max_bytes(mut self, max_bytes: usize) -> Self {
1356 self.max_bytes = if max_bytes == 0 {
1357 None
1358 } else {
1359 Some(max_bytes)
1360 };
1361 self
1362 }
1363}
1364
1365impl Default for TemplateCompactor {
1366 fn default() -> Self {
1367 Self::new()
1368 }
1369}
1370
1371#[derive(Debug, Clone)]
1378pub struct TextSummary(String);
1379
1380impl TextSummary {
1381 pub fn as_str(&self) -> &str {
1383 &self.0
1384 }
1385
1386 pub fn into_string(self) -> String {
1388 self.0
1389 }
1390}
1391
1392impl From<TextSummary> for Message {
1393 fn from(value: TextSummary) -> Self {
1394 Message::System { content: value.0 }
1395 }
1396}
1397
1398impl Compactor for TemplateCompactor {
1399 type Artifact = TextSummary;
1400
1401 fn compact<'a>(
1402 &'a self,
1403 _conversation_id: &'a str,
1404 evicted: &'a [Message],
1405 carry_over: Option<&'a Self::Artifact>,
1406 ) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
1407 Box::pin(async move {
1408 let mut buf = String::new();
1409 buf.push_str(&self.header);
1410 buf.push('\n');
1411 if let Some(prev) = carry_over {
1412 buf.push_str(prev.as_str());
1413 buf.push('\n');
1414 }
1415 for msg in evicted {
1416 let line = render_message_line(msg);
1417 if !line.is_empty() {
1418 buf.push_str(&line);
1419 buf.push('\n');
1420 }
1421 }
1422 if let Some(cap) = self.max_bytes
1423 && buf.len() > cap
1424 {
1425 buf = truncate_summary(&buf, cap);
1426 }
1427 Ok(TextSummary(buf))
1428 })
1429 }
1430}
1431
1432fn truncate_summary(buf: &str, cap: usize) -> String {
1440 const MARKER: &str = "[\u{2026}truncated\u{2026}]\n";
1441 let header_prefix_len = match buf.find('\n') {
1444 Some(i) => i + 1,
1445 None => return buf.to_string(),
1446 };
1447 if buf.len() <= header_prefix_len {
1448 return buf.to_string();
1449 }
1450 let preserved = header_prefix_len + MARKER.len();
1451 let keep_bytes = cap.saturating_sub(preserved);
1453 let body_start = header_prefix_len;
1454 let body = match buf.get(body_start..) {
1455 Some(b) => b,
1456 None => return buf.to_string(),
1457 };
1458 let mut cut = body.len().saturating_sub(keep_bytes);
1461 while cut < body.len() && !body.is_char_boundary(cut) {
1462 cut += 1;
1463 }
1464 let suffix: &str = body.get(cut..).unwrap_or_default();
1465 let header_with_nl = match buf.get(..header_prefix_len) {
1466 Some(h) => h,
1467 None => return buf.to_string(),
1468 };
1469 let mut out = String::with_capacity(header_prefix_len + MARKER.len() + suffix.len());
1470 out.push_str(header_with_nl);
1471 out.push_str(MARKER);
1472 out.push_str(suffix);
1473 out
1474}
1475
1476fn render_message_line(msg: &Message) -> String {
1482 use rig_core::message::AssistantContent;
1483
1484 match msg {
1485 Message::System { content } => {
1486 if content.is_empty() {
1487 String::new()
1488 } else {
1489 format!("system: {content}")
1490 }
1491 }
1492 Message::User { content } => {
1493 let mut text = String::new();
1494 for c in content.iter() {
1495 match c {
1496 UserContent::Text(t) => {
1497 if !text.is_empty() {
1498 text.push(' ');
1499 }
1500 text.push_str(&t.text);
1501 }
1502 UserContent::ToolResult(_) => {
1503 if !text.is_empty() {
1504 text.push(' ');
1505 }
1506 text.push_str("[tool result]");
1507 }
1508 _ => {
1509 if !text.is_empty() {
1510 text.push(' ');
1511 }
1512 text.push_str("[attachment]");
1513 }
1514 }
1515 }
1516 if text.is_empty() {
1517 String::new()
1518 } else {
1519 format!("user: {text}")
1520 }
1521 }
1522 Message::Assistant { content, .. } => {
1523 let mut text = String::new();
1524 for c in content.iter() {
1525 match c {
1526 AssistantContent::Text(t) => {
1527 if !text.is_empty() {
1528 text.push(' ');
1529 }
1530 text.push_str(&t.text);
1531 }
1532 AssistantContent::ToolCall(call) => {
1533 if !text.is_empty() {
1534 text.push(' ');
1535 }
1536 text.push_str(&format!("[tool call: {}]", call.function.name));
1537 }
1538 _ => {
1539 if !text.is_empty() {
1540 text.push(' ');
1541 }
1542 text.push_str("[reasoning]");
1543 }
1544 }
1545 }
1546 if text.is_empty() {
1547 String::new()
1548 } else {
1549 format!("assistant: {text}")
1550 }
1551 }
1552 }
1553}
1554
1555#[cfg(test)]
1556mod tests {
1557 use super::*;
1558 use rig_core::OneOrMany;
1559 use rig_core::message::{
1560 AssistantContent, ToolCall, ToolFunction, ToolResult, ToolResultContent, UserContent,
1561 };
1562 use std::sync::Mutex;
1563
1564 fn user(text: &str) -> Message {
1565 Message::user(text)
1566 }
1567
1568 fn assistant(text: &str) -> Message {
1569 Message::assistant(text)
1570 }
1571
1572 fn tool_call_msg() -> Message {
1573 Message::Assistant {
1574 id: None,
1575 content: OneOrMany::one(AssistantContent::ToolCall(ToolCall::new(
1576 "call_1".into(),
1577 ToolFunction::new("t".into(), serde_json::json!({})),
1578 ))),
1579 }
1580 }
1581
1582 fn tool_result_msg() -> Message {
1583 Message::User {
1584 content: OneOrMany::one(UserContent::ToolResult(ToolResult {
1585 id: "call_1".into(),
1586 call_id: None,
1587 content: OneOrMany::one(ToolResultContent::text("ok")),
1588 })),
1589 }
1590 }
1591
1592 #[test]
1593 fn noop_policy_is_identity() {
1594 let msgs = vec![user("a"), assistant("b")];
1595 let out = NoopMemoryPolicy.apply(msgs).unwrap();
1596 assert_eq!(out.len(), 2);
1597 }
1598
1599 #[test]
1600 fn sliding_window_passthrough_when_under_limit() {
1601 let policy = SlidingWindowMemory::last_messages(5);
1602 let out = policy.apply(vec![user("1"), assistant("2")]).unwrap();
1603 assert_eq!(out.len(), 2);
1604 }
1605
1606 #[tokio::test]
1607 async fn sliding_window_truncates_via_filter() {
1608 let mem = InMemoryConversationMemory::new()
1609 .with_filter(SlidingWindowMemory::last_messages(2).into_filter());
1610
1611 mem.append(
1612 "c",
1613 vec![user("1"), assistant("2"), user("3"), assistant("4")],
1614 )
1615 .await
1616 .unwrap();
1617
1618 let loaded = mem.load("c").await.unwrap();
1619 assert_eq!(loaded.len(), 2);
1620 }
1621
1622 #[test]
1623 fn sliding_window_drops_leading_orphan_tool_result() {
1624 let policy = SlidingWindowMemory::last_messages(3);
1625 let out = policy
1626 .apply(vec![
1627 tool_call_msg(),
1628 tool_result_msg(),
1629 user("after"),
1630 assistant("done"),
1631 ])
1632 .unwrap();
1633
1634 assert_eq!(out.len(), 2);
1635 assert!(matches!(out.first(), Some(Message::User { content })
1636 if matches!(content.first(), UserContent::Text(_))));
1637 }
1638
1639 #[test]
1640 fn token_window_keeps_within_budget() {
1641 let msgs = vec![
1642 user("aaaa"),
1643 assistant("bbbb"),
1644 user("cccc"),
1645 assistant("dddd"),
1646 ];
1647 let policy = TokenWindowMemory::new(2, |_: &Message| 1);
1648 let out = policy.apply(msgs).unwrap();
1649 assert_eq!(out.len(), 2);
1650 }
1651
1652 #[test]
1653 fn token_window_passes_through_when_under_budget() {
1654 let msgs = vec![user("a"), assistant("b")];
1655 let policy = TokenWindowMemory::new(usize::MAX, |_: &Message| 1);
1656 let out = policy.apply(msgs).unwrap();
1657 assert_eq!(out.len(), 2);
1658 }
1659
1660 #[test]
1661 fn token_window_drops_leading_orphan_tool_result() {
1662 let policy = TokenWindowMemory::new(25, |_: &Message| 10);
1663 let out = policy
1664 .apply(vec![tool_call_msg(), tool_result_msg(), user("after")])
1665 .unwrap();
1666 assert_eq!(out.len(), 1);
1667 assert!(matches!(out.first(), Some(Message::User { content })
1668 if matches!(content.first(), UserContent::Text(_))));
1669 }
1670
1671 #[test]
1672 fn token_window_skips_message_larger_than_budget() {
1673 let policy = TokenWindowMemory::new(5, |_: &Message| 10);
1674 let out = policy.apply(vec![user("anything")]).unwrap();
1675 assert!(out.is_empty());
1676 }
1677
1678 #[test]
1679 fn heuristic_counter_charges_overhead_per_message() {
1680 let counter = HeuristicTokenCounter::default();
1681 let empty = counter.count(&user(""));
1682 assert!(
1683 empty >= 4,
1684 "default per-message overhead is at least 4 tokens"
1685 );
1686 }
1687
1688 #[test]
1689 fn heuristic_counter_is_monotonic_in_text_length() {
1690 let counter = HeuristicTokenCounter::default();
1691 let small = counter.count(&user("hi"));
1692 let big = counter.count(&user(&"x".repeat(400)));
1693 assert!(big > small);
1694 }
1695
1696 #[test]
1697 fn heuristic_counter_handles_tool_calls() {
1698 let counter = HeuristicTokenCounter::default();
1699 let cost = counter.count(&tool_call_msg());
1700 assert!(cost > 0);
1701 }
1702
1703 #[test]
1704 fn heuristic_counter_handles_system_messages() {
1705 let counter = HeuristicTokenCounter::default();
1706 let cost = counter.count(&Message::System {
1707 content: "you are helpful".into(),
1708 });
1709 assert!(cost > 0);
1710 }
1711
1712 #[test]
1713 fn heuristic_counter_clamps_invalid_bytes_per_token() {
1714 let counter = HeuristicTokenCounter::new(0.0, 0, 0);
1716 assert!(counter.count(&user("abcd")) >= 4);
1717 let nan = HeuristicTokenCounter::new(f32::NAN, 0, 0);
1718 assert!(nan.count(&user("abcd")) >= 4);
1719 }
1720
1721 #[test]
1722 fn heuristic_counter_drives_token_window() {
1723 let policy = TokenWindowMemory::new(100, HeuristicTokenCounter::default());
1724 let msgs = vec![user(&"a".repeat(2_000)), user("short")];
1725 let out = policy.apply(msgs).unwrap();
1726 assert_eq!(out.len(), 1);
1728 }
1729
1730 #[test]
1731 fn arc_token_counter_can_drive_token_window() {
1732 let counter: Arc<dyn TokenCounter> = Arc::new(|_: &Message| 1);
1733 let policy = TokenWindowMemory::new(2, counter);
1734 let out = policy
1735 .apply(vec![user("a"), assistant("b"), user("c")])
1736 .unwrap();
1737
1738 assert_eq!(out.len(), 2);
1739 }
1740
1741 #[test]
1742 fn boxed_token_counter_forwards_count() {
1743 let counter: Box<dyn TokenCounter> = Box::new(|_: &Message| 7);
1744 assert_eq!(counter.count(&user("a")), 7);
1745 }
1746
1747 #[test]
1748 fn into_filter_returns_input_on_policy_error() {
1749 struct FailingPolicy;
1750 impl MemoryPolicy for FailingPolicy {
1751 fn apply(&self, _: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
1752 Err(MemoryError::Policy("intentional failure".into()))
1753 }
1754 }
1755
1756 let filter = FailingPolicy.into_filter();
1757 let input = vec![user("a"), assistant("b"), user("c")];
1758 let out = filter(input.clone());
1759 assert_eq!(
1760 out.len(),
1761 input.len(),
1762 "history must be preserved on policy error"
1763 );
1764 }
1765
1766 #[tokio::test]
1767 async fn policy_memory_truncates_loaded_history() {
1768 let mem = PolicyMemory::new(
1769 InMemoryConversationMemory::new(),
1770 SlidingWindowMemory::last_messages(2),
1771 );
1772
1773 mem.append(
1774 "c",
1775 vec![user("1"), assistant("2"), user("3"), assistant("4")],
1776 )
1777 .await
1778 .unwrap();
1779
1780 let loaded = mem.load("c").await.unwrap();
1781 assert_eq!(loaded.len(), 2);
1782 }
1783
1784 #[tokio::test]
1785 async fn policy_memory_propagates_policy_errors() {
1786 struct FailingPolicy;
1787 impl MemoryPolicy for FailingPolicy {
1788 fn apply(&self, _: Vec<Message>) -> Result<Vec<Message>, MemoryError> {
1789 Err(MemoryError::Policy("intentional failure".into()))
1790 }
1791 }
1792
1793 let mem = PolicyMemory::new(InMemoryConversationMemory::new(), FailingPolicy);
1794 mem.append("c", vec![user("1"), assistant("2")])
1795 .await
1796 .unwrap();
1797
1798 let result = mem.load("c").await;
1799 assert!(matches!(result, Err(MemoryError::Policy(_))));
1800 }
1801
1802 #[tokio::test]
1803 async fn policy_memory_append_and_clear_delegate_to_inner() {
1804 let mem = PolicyMemory::new(InMemoryConversationMemory::new(), NoopMemoryPolicy);
1805 mem.append("c", vec![user("hi"), assistant("ok")])
1806 .await
1807 .unwrap();
1808 assert_eq!(mem.load("c").await.unwrap().len(), 2);
1809
1810 mem.clear("c").await.unwrap();
1811 assert!(mem.load("c").await.unwrap().is_empty());
1812 }
1813
1814 #[test]
1815 fn sliding_window_reports_demoted_prefix() {
1816 let policy = SlidingWindowMemory::last_messages(2);
1817 let (kept, demoted) = policy
1818 .apply_with_demoted(vec![
1819 user("oldest"),
1820 assistant("old"),
1821 user("recent"),
1822 assistant("latest"),
1823 ])
1824 .unwrap();
1825 assert_eq!(kept.len(), 2);
1826 assert_eq!(demoted.len(), 2);
1827 }
1828
1829 #[test]
1830 fn token_window_reports_demoted_prefix() {
1831 let policy = TokenWindowMemory::new(2, |_: &Message| 1);
1832 let (kept, demoted) = policy
1833 .apply_with_demoted(vec![user("a"), assistant("b"), user("c"), assistant("d")])
1834 .unwrap();
1835 assert_eq!(kept.len(), 2);
1836 assert_eq!(demoted.len(), 2);
1837 }
1838
1839 #[test]
1840 fn noop_policy_demotes_nothing() {
1841 let (kept, demoted) = NoopMemoryPolicy
1842 .apply_with_demoted(vec![user("a"), assistant("b")])
1843 .unwrap();
1844 assert_eq!(kept.len(), 2);
1845 assert!(demoted.is_empty());
1846 }
1847
1848 #[test]
1849 fn arc_memory_policy_preserves_demoted_metadata() {
1850 let policy: Arc<dyn MemoryPolicy> = Arc::new(SlidingWindowMemory::last_messages(1));
1851 let (kept, demoted) = policy
1852 .apply_with_demoted(vec![user("old"), assistant("new")])
1853 .unwrap();
1854
1855 assert_eq!(kept.len(), 1);
1856 assert_eq!(demoted.len(), 1);
1857 }
1858
1859 #[test]
1860 fn boxed_memory_policy_preserves_demoted_metadata() {
1861 let policy: Box<dyn MemoryPolicy> = Box::new(SlidingWindowMemory::last_messages(1));
1862 let (kept, demoted) = policy
1863 .apply_with_demoted(vec![user("old"), assistant("new")])
1864 .unwrap();
1865
1866 assert_eq!(kept.len(), 1);
1867 assert_eq!(demoted.len(), 1);
1868 }
1869
1870 #[test]
1871 fn sliding_window_demotes_orphan_tool_result_with_prefix() {
1872 let policy = SlidingWindowMemory::last_messages(2);
1876 let (kept, demoted) = policy
1877 .apply_with_demoted(vec![
1878 tool_call_msg(),
1879 tool_result_msg(),
1880 user("after"),
1881 assistant("done"),
1882 ])
1883 .unwrap();
1884 assert_eq!(kept.len(), 2);
1885 assert!(matches!(kept.first(), Some(Message::User { content })
1886 if matches!(content.first(), UserContent::Text(_))));
1887 assert_eq!(demoted.len(), 2);
1888 }
1889
1890 #[derive(Default)]
1891 struct CountingHook {
1892 seen: Mutex<Vec<(String, Vec<Message>)>>,
1893 }
1894
1895 impl CountingHook {
1896 fn calls(&self) -> usize {
1897 self.seen.lock().unwrap().len()
1898 }
1899 fn last_demoted_count(&self) -> usize {
1900 self.seen
1901 .lock()
1902 .unwrap()
1903 .last()
1904 .map(|(_, m)| m.len())
1905 .unwrap_or(0)
1906 }
1907 }
1908
1909 impl DemotionHook for CountingHook {
1910 fn on_demote<'a>(
1911 &'a self,
1912 conversation_id: &'a str,
1913 messages: Vec<Message>,
1914 ) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
1915 Box::pin(async move {
1916 self.seen
1917 .lock()
1918 .unwrap()
1919 .push((conversation_id.to_string(), messages));
1920 Ok(())
1921 })
1922 }
1923 }
1924
1925 #[tokio::test]
1926 async fn demoting_policy_memory_invokes_hook_on_truncation() {
1927 let hook = Arc::new(CountingHook::default());
1928 let mem = DemotingPolicyMemory::new(
1929 InMemoryConversationMemory::new(),
1930 SlidingWindowMemory::last_messages(2),
1931 hook.clone(),
1932 );
1933
1934 mem.append(
1935 "c",
1936 vec![user("1"), assistant("2"), user("3"), assistant("4")],
1937 )
1938 .await
1939 .unwrap();
1940
1941 let kept = mem.load("c").await.unwrap();
1942 assert_eq!(kept.len(), 2);
1943 assert_eq!(hook.calls(), 1);
1944 assert_eq!(hook.last_demoted_count(), 2);
1945 }
1946
1947 #[tokio::test]
1948 async fn demoting_policy_memory_does_not_replay_demotions() {
1949 let hook = Arc::new(CountingHook::default());
1950 let mem = DemotingPolicyMemory::new(
1951 InMemoryConversationMemory::new(),
1952 SlidingWindowMemory::last_messages(2),
1953 hook.clone(),
1954 );
1955
1956 mem.append(
1957 "c",
1958 vec![user("1"), assistant("2"), user("3"), assistant("4")],
1959 )
1960 .await
1961 .unwrap();
1962
1963 mem.load("c").await.unwrap();
1964 mem.load("c").await.unwrap();
1965 assert_eq!(hook.calls(), 1);
1966 assert_eq!(hook.last_demoted_count(), 2);
1967 }
1968
1969 #[tokio::test]
1970 async fn demoting_policy_memory_only_reports_newly_demoted_messages() {
1971 let hook = Arc::new(CountingHook::default());
1972 let mem = DemotingPolicyMemory::new(
1973 InMemoryConversationMemory::new(),
1974 SlidingWindowMemory::last_messages(2),
1975 hook.clone(),
1976 );
1977
1978 mem.append(
1979 "c",
1980 vec![user("1"), assistant("2"), user("3"), assistant("4")],
1981 )
1982 .await
1983 .unwrap();
1984 mem.load("c").await.unwrap();
1985
1986 mem.append("c", vec![user("5")]).await.unwrap();
1987 mem.load("c").await.unwrap();
1988
1989 assert_eq!(hook.calls(), 2);
1990 assert_eq!(hook.last_demoted_count(), 1);
1991 }
1992
1993 #[derive(Default)]
1994 struct FailingHook {
1995 calls: Mutex<usize>,
1996 }
1997
1998 impl DemotionHook for FailingHook {
1999 fn on_demote<'a>(
2000 &'a self,
2001 _conversation_id: &'a str,
2002 _messages: Vec<Message>,
2003 ) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
2004 Box::pin(async move {
2005 *self.calls.lock().unwrap() += 1;
2006 Err(MemoryError::backend(std::io::Error::other("hook failed")))
2007 })
2008 }
2009 }
2010
2011 #[tokio::test]
2012 async fn demoting_policy_memory_does_not_advance_watermark_on_hook_failure() {
2013 let hook = Arc::new(FailingHook::default());
2014 let mem = DemotingPolicyMemory::new(
2015 InMemoryConversationMemory::new(),
2016 SlidingWindowMemory::last_messages(1),
2017 hook.clone(),
2018 );
2019 mem.append("c", vec![user("1"), assistant("2")])
2020 .await
2021 .unwrap();
2022
2023 assert!(mem.load("c").await.is_err());
2024 assert!(mem.load("c").await.is_err());
2025 assert_eq!(*hook.calls.lock().unwrap(), 2);
2026 }
2027
2028 #[tokio::test]
2029 async fn demoting_policy_memory_clear_resets_watermark() {
2030 let hook = Arc::new(CountingHook::default());
2031 let mem = DemotingPolicyMemory::new(
2032 InMemoryConversationMemory::new(),
2033 SlidingWindowMemory::last_messages(1),
2034 hook.clone(),
2035 );
2036
2037 mem.append("c", vec![user("1"), assistant("2")])
2038 .await
2039 .unwrap();
2040 mem.load("c").await.unwrap();
2041 mem.clear("c").await.unwrap();
2042 mem.append("c", vec![user("3"), assistant("4")])
2043 .await
2044 .unwrap();
2045 mem.load("c").await.unwrap();
2046
2047 assert_eq!(hook.calls(), 2);
2048 assert_eq!(hook.last_demoted_count(), 1);
2049 }
2050
2051 #[tokio::test]
2052 async fn demoting_policy_memory_skips_hook_when_nothing_evicted() {
2053 let hook = Arc::new(CountingHook::default());
2054 let mem = DemotingPolicyMemory::new(
2055 InMemoryConversationMemory::new(),
2056 SlidingWindowMemory::last_messages(10),
2057 hook.clone(),
2058 );
2059
2060 mem.append("c", vec![user("1"), assistant("2")])
2061 .await
2062 .unwrap();
2063 mem.load("c").await.unwrap();
2064 assert_eq!(hook.calls(), 0);
2065 }
2066
2067 #[tokio::test]
2068 async fn demoting_policy_memory_with_noop_hook_behaves_like_policy_memory() {
2069 let mem = DemotingPolicyMemory::new(
2070 InMemoryConversationMemory::new(),
2071 SlidingWindowMemory::last_messages(1),
2072 NoopDemotionHook,
2073 );
2074 mem.append("c", vec![user("a"), assistant("b"), user("c")])
2075 .await
2076 .unwrap();
2077 assert_eq!(mem.load("c").await.unwrap().len(), 1);
2078 }
2079
2080 struct GatedHook {
2083 calls: Arc<std::sync::atomic::AtomicUsize>,
2084 rendezvous: Arc<tokio::sync::Notify>,
2085 release: Arc<tokio::sync::Notify>,
2086 }
2087
2088 impl DemotionHook for GatedHook {
2089 fn on_demote<'a>(
2090 &'a self,
2091 _conversation_id: &'a str,
2092 _messages: Vec<Message>,
2093 ) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
2094 let calls = self.calls.clone();
2095 let rendezvous = self.rendezvous.clone();
2096 let release = self.release.clone();
2097 Box::pin(async move {
2098 calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2099 rendezvous.notify_one();
2100 release.notified().await;
2101 Ok(())
2102 })
2103 }
2104 }
2105
2106 #[tokio::test]
2107 async fn demoting_policy_memory_serialises_concurrent_loads() {
2108 use std::sync::atomic::{AtomicUsize, Ordering};
2109
2110 let calls = Arc::new(AtomicUsize::new(0));
2111 let rendezvous = Arc::new(tokio::sync::Notify::new());
2112 let release = Arc::new(tokio::sync::Notify::new());
2113 let hook = GatedHook {
2114 calls: calls.clone(),
2115 rendezvous: rendezvous.clone(),
2116 release: release.clone(),
2117 };
2118
2119 let mem = Arc::new(DemotingPolicyMemory::new(
2120 InMemoryConversationMemory::new(),
2121 SlidingWindowMemory::last_messages(1),
2122 hook,
2123 ));
2124
2125 mem.append("c", vec![user("1"), assistant("2"), user("3")])
2126 .await
2127 .unwrap();
2128
2129 let m1 = mem.clone();
2130 let first = tokio::spawn(async move { m1.load("c").await });
2131
2132 rendezvous.notified().await;
2134 assert_eq!(calls.load(Ordering::SeqCst), 1);
2135
2136 let kept = mem.load("c").await.unwrap();
2139 assert_eq!(kept.len(), 1);
2140 assert_eq!(calls.load(Ordering::SeqCst), 1, "hook must not double-fire");
2141
2142 release.notify_one();
2144 let kept_first = first.await.unwrap().unwrap();
2145 assert_eq!(kept_first.len(), 1);
2146 assert_eq!(calls.load(Ordering::SeqCst), 1);
2147
2148 mem.load("c").await.unwrap();
2150 assert_eq!(calls.load(Ordering::SeqCst), 1);
2151 }
2152
2153 #[tokio::test]
2154 async fn demoting_policy_memory_dropped_load_releases_in_flight_gate() {
2155 use std::sync::atomic::{AtomicUsize, Ordering};
2159
2160 let calls = Arc::new(AtomicUsize::new(0));
2161 let rendezvous = Arc::new(tokio::sync::Notify::new());
2162 let release = Arc::new(tokio::sync::Notify::new());
2163 let hook = GatedHook {
2164 calls: calls.clone(),
2165 rendezvous,
2166 release: release.clone(),
2167 };
2168
2169 let mem = Arc::new(DemotingPolicyMemory::new(
2170 InMemoryConversationMemory::new(),
2171 SlidingWindowMemory::last_messages(1),
2172 hook,
2173 ));
2174
2175 mem.append("c", vec![user("1"), assistant("2"), user("3")])
2176 .await
2177 .unwrap();
2178
2179 let mem_load = mem.clone();
2183 let handle = tokio::spawn(async move { mem_load.load("c").await });
2184 while calls.load(Ordering::SeqCst) == 0 {
2185 tokio::task::yield_now().await;
2186 }
2187 handle.abort();
2188 let _ = handle.await;
2189
2190 let mem_load = mem.clone();
2195 let retry = tokio::spawn(async move { mem_load.load("c").await });
2196 for _ in 0..1_000 {
2197 if calls.load(Ordering::SeqCst) >= 2 {
2198 break;
2199 }
2200 tokio::task::yield_now().await;
2201 }
2202 assert_eq!(
2203 calls.load(Ordering::SeqCst),
2204 2,
2205 "retry must re-enter the hook after cancellation"
2206 );
2207
2208 release.notify_one();
2209 let kept = retry.await.unwrap().unwrap();
2210 assert_eq!(kept.len(), 1);
2211
2212 mem.load("c").await.unwrap();
2215 assert_eq!(calls.load(Ordering::SeqCst), 2);
2216 }
2217
2218 #[tokio::test]
2219 async fn demoting_stale_cancelled_load_does_not_clear_new_reservation() {
2220 use std::sync::atomic::{AtomicUsize, Ordering};
2221
2222 let calls = Arc::new(AtomicUsize::new(0));
2223 let rendezvous = Arc::new(tokio::sync::Notify::new());
2224 let release = Arc::new(tokio::sync::Notify::new());
2225 let hook = GatedHook {
2226 calls: calls.clone(),
2227 rendezvous: rendezvous.clone(),
2228 release: release.clone(),
2229 };
2230
2231 let mem = Arc::new(DemotingPolicyMemory::new(
2232 InMemoryConversationMemory::new(),
2233 SlidingWindowMemory::last_messages(1),
2234 hook,
2235 ));
2236
2237 mem.append("c", vec![user("old 1"), assistant("old 2"), user("old 3")])
2238 .await
2239 .unwrap();
2240
2241 let mem_load = mem.clone();
2242 let stale = tokio::spawn(async move { mem_load.load("c").await });
2243 rendezvous.notified().await;
2244 assert_eq!(calls.load(Ordering::SeqCst), 1);
2245
2246 mem.clear("c").await.unwrap();
2247 mem.append(
2248 "c",
2249 vec![user("fresh 1"), assistant("fresh 2"), user("fresh 3")],
2250 )
2251 .await
2252 .unwrap();
2253
2254 let mem_load = mem.clone();
2255 let fresh = tokio::spawn(async move { mem_load.load("c").await });
2256 rendezvous.notified().await;
2257 assert_eq!(calls.load(Ordering::SeqCst), 2);
2258
2259 stale.abort();
2260 let _ = stale.await;
2261
2262 let mem_load = mem.clone();
2263 let mut concurrent = tokio::spawn(async move { mem_load.load("c").await });
2264 let concurrent_kept = tokio::select! {
2265 result = &mut concurrent => result.unwrap().unwrap(),
2266 _ = rendezvous.notified() => {
2267 panic!("stale guard must not clear the fresh in-flight reservation")
2268 }
2269 };
2270 assert_eq!(
2271 calls.load(Ordering::SeqCst),
2272 2,
2273 "stale guard must not clear the fresh in-flight reservation"
2274 );
2275
2276 release.notify_one();
2277 assert_eq!(fresh.await.unwrap().unwrap().len(), 1);
2278 assert_eq!(concurrent_kept.len(), 1);
2279 assert_eq!(calls.load(Ordering::SeqCst), 2);
2280 }
2281
2282 #[tokio::test]
2283 async fn demoting_stale_successful_load_does_not_clear_new_reservation() {
2284 #[derive(Default)]
2285 struct IndividuallyGatedHook {
2286 releases: Mutex<Vec<Arc<tokio::sync::Notify>>>,
2287 }
2288
2289 impl IndividuallyGatedHook {
2290 fn call_count(&self) -> usize {
2291 self.releases.lock().unwrap().len()
2292 }
2293
2294 async fn wait_for_call_count(&self, expected: usize) {
2295 while self.call_count() < expected {
2296 tokio::task::yield_now().await;
2297 }
2298 }
2299
2300 fn release_call(&self, index: usize) {
2301 let release = self.releases.lock().unwrap()[index].clone();
2302 release.notify_one();
2303 }
2304 }
2305
2306 impl DemotionHook for IndividuallyGatedHook {
2307 fn on_demote<'a>(
2308 &'a self,
2309 _conversation_id: &'a str,
2310 _messages: Vec<Message>,
2311 ) -> WasmBoxedFuture<'a, Result<(), MemoryError>> {
2312 let release = Arc::new(tokio::sync::Notify::new());
2313 self.releases.lock().unwrap().push(release.clone());
2314 Box::pin(async move {
2315 release.notified().await;
2316 Ok(())
2317 })
2318 }
2319 }
2320
2321 let hook = Arc::new(IndividuallyGatedHook::default());
2322 let mem = Arc::new(DemotingPolicyMemory::new(
2323 InMemoryConversationMemory::new(),
2324 SlidingWindowMemory::last_messages(1),
2325 hook.clone(),
2326 ));
2327
2328 mem.append("c", vec![user("old 1"), assistant("old 2"), user("old 3")])
2329 .await
2330 .unwrap();
2331
2332 let mem_load = mem.clone();
2333 let stale = tokio::spawn(async move { mem_load.load("c").await });
2334 hook.wait_for_call_count(1).await;
2335
2336 mem.clear("c").await.unwrap();
2337 mem.append(
2338 "c",
2339 vec![user("fresh 1"), assistant("fresh 2"), user("fresh 3")],
2340 )
2341 .await
2342 .unwrap();
2343
2344 let mem_load = mem.clone();
2345 let fresh = tokio::spawn(async move { mem_load.load("c").await });
2346 hook.wait_for_call_count(2).await;
2347
2348 hook.release_call(0);
2352 assert_eq!(stale.await.unwrap().unwrap().len(), 1);
2353 assert_eq!(hook.call_count(), 2);
2354
2355 let mem_load = mem.clone();
2356 let mut concurrent = tokio::spawn(async move { mem_load.load("c").await });
2357 let hook_wait = hook.clone();
2358 let concurrent_kept = tokio::select! {
2359 result = &mut concurrent => result.unwrap().unwrap(),
2360 _ = hook_wait.wait_for_call_count(3) => {
2361 panic!("stale successful load must not clear the fresh in-flight reservation")
2362 }
2363 };
2364 assert_eq!(
2365 hook.call_count(),
2366 2,
2367 "stale successful load must not clear the fresh in-flight reservation"
2368 );
2369
2370 hook.release_call(1);
2371 assert_eq!(fresh.await.unwrap().unwrap().len(), 1);
2372 assert_eq!(concurrent_kept.len(), 1);
2373
2374 mem.load("c").await.unwrap();
2375 assert_eq!(hook.call_count(), 2);
2376 }
2377
2378 #[tokio::test]
2379 async fn forget_drops_in_process_watermark() {
2380 let hook = Arc::new(CountingHook::default());
2381 let mem = DemotingPolicyMemory::new(
2382 InMemoryConversationMemory::new(),
2383 SlidingWindowMemory::last_messages(1),
2384 hook.clone(),
2385 );
2386
2387 mem.append("c", vec![user("1"), assistant("2")])
2388 .await
2389 .unwrap();
2390 mem.load("c").await.unwrap();
2391 assert_eq!(mem.tracked_conversations(), 1);
2392 assert_eq!(hook.calls(), 1);
2393
2394 mem.forget("c");
2399 assert_eq!(mem.tracked_conversations(), 0);
2400 mem.load("c").await.unwrap();
2401 assert_eq!(hook.calls(), 2);
2402 }
2403
2404 #[tokio::test]
2409 async fn compacting_no_demotion_returns_kept_only() {
2410 let mem = CompactingMemory::new(
2411 InMemoryConversationMemory::new(),
2412 SlidingWindowMemory::last_messages(10),
2413 TemplateCompactor::new(),
2414 );
2415
2416 mem.append("c", vec![user("hi"), assistant("hello")])
2417 .await
2418 .unwrap();
2419 let loaded = mem.load("c").await.unwrap();
2420 assert_eq!(loaded.len(), 2);
2421 assert!(matches!(&loaded[0], Message::User { .. }));
2425 }
2426
2427 #[tokio::test]
2428 async fn compacting_splices_summary_when_demoted() {
2429 let mem = CompactingMemory::new(
2430 InMemoryConversationMemory::new(),
2431 SlidingWindowMemory::last_messages(2),
2432 TemplateCompactor::new(),
2433 );
2434
2435 mem.append(
2436 "c",
2437 vec![
2438 user("first"),
2439 assistant("second"),
2440 user("third"),
2441 assistant("fourth"),
2442 ],
2443 )
2444 .await
2445 .unwrap();
2446
2447 let loaded = mem.load("c").await.unwrap();
2448 assert_eq!(loaded.len(), 3);
2450 let Message::System { content } = &loaded[0] else {
2451 panic!("expected summary as system message");
2452 };
2453 assert!(content.contains("[Conversation summary so far]"));
2454 assert!(content.contains("user: first"));
2455 assert!(content.contains("assistant: second"));
2456 let Message::User { content } = &loaded[1] else {
2458 panic!("expected kept user message");
2459 };
2460 let UserContent::Text(t) = content.first_ref() else {
2461 panic!("expected text");
2462 };
2463 assert_eq!(t.text, "third");
2464 }
2465
2466 #[tokio::test]
2467 async fn compacting_rolls_summary_forward() {
2468 let mem = CompactingMemory::new(
2469 InMemoryConversationMemory::new(),
2470 SlidingWindowMemory::last_messages(2),
2471 TemplateCompactor::new(),
2472 );
2473
2474 mem.append(
2475 "c",
2476 vec![user("a"), assistant("b"), user("c"), assistant("d")],
2477 )
2478 .await
2479 .unwrap();
2480
2481 let first = mem.load("c").await.unwrap();
2482 let Message::System { content } = &first[0] else {
2483 panic!("summary missing");
2484 };
2485 let first_summary = content.clone();
2486 assert!(first_summary.contains("user: a"));
2487 assert!(first_summary.contains("assistant: b"));
2488
2489 mem.append("c", vec![user("e"), assistant("f")])
2492 .await
2493 .unwrap();
2494 let second = mem.load("c").await.unwrap();
2495 let Message::System { content } = &second[0] else {
2496 panic!("summary missing");
2497 };
2498 assert!(content.contains(&first_summary));
2501 assert!(content.contains("user: c"));
2502 assert!(content.contains("assistant: d"));
2503 }
2504
2505 #[tokio::test]
2506 async fn compacting_idempotent_within_process() {
2507 let mem = CompactingMemory::new(
2512 InMemoryConversationMemory::new(),
2513 SlidingWindowMemory::last_messages(1),
2514 TemplateCompactor::new(),
2515 );
2516 mem.append("c", vec![user("a"), assistant("b"), user("c")])
2517 .await
2518 .unwrap();
2519
2520 let first = mem.load("c").await.unwrap();
2521 let second = mem.load("c").await.unwrap();
2522 assert_eq!(first.len(), second.len());
2523 let Message::System { content: c1 } = &first[0] else {
2524 panic!()
2525 };
2526 let Message::System { content: c2 } = &second[0] else {
2527 panic!()
2528 };
2529 assert_eq!(c1, c2);
2530 }
2531
2532 #[tokio::test]
2533 async fn compacting_clear_drops_summary() {
2534 let mem = CompactingMemory::new(
2535 InMemoryConversationMemory::new(),
2536 SlidingWindowMemory::last_messages(1),
2537 TemplateCompactor::new(),
2538 );
2539 mem.append("c", vec![user("a"), assistant("b"), user("c")])
2540 .await
2541 .unwrap();
2542 mem.load("c").await.unwrap();
2543 assert_eq!(mem.tracked_conversations(), 1);
2544
2545 mem.clear("c").await.unwrap();
2546 assert_eq!(mem.tracked_conversations(), 0);
2547 assert!(mem.load("c").await.unwrap().is_empty());
2548 }
2549
2550 #[derive(Default)]
2553 struct FlakyCompactor {
2554 calls: std::sync::atomic::AtomicUsize,
2555 }
2556
2557 impl Compactor for FlakyCompactor {
2558 type Artifact = TextSummary;
2559
2560 fn compact<'a>(
2561 &'a self,
2562 _conversation_id: &'a str,
2563 evicted: &'a [Message],
2564 _carry_over: Option<&'a Self::Artifact>,
2565 ) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
2566 Box::pin(async move {
2567 let n = self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2568 if n == 0 {
2569 Err(MemoryError::Policy("flaky".into()))
2570 } else {
2571 Ok(TextSummary(format!("compacted {} messages", evicted.len())))
2572 }
2573 })
2574 }
2575 }
2576
2577 #[tokio::test]
2578 async fn compacting_failure_does_not_advance_watermark() {
2579 let mem = CompactingMemory::new(
2580 InMemoryConversationMemory::new(),
2581 SlidingWindowMemory::last_messages(1),
2582 FlakyCompactor::default(),
2583 );
2584 mem.append("c", vec![user("a"), assistant("b"), user("c")])
2585 .await
2586 .unwrap();
2587
2588 let err = mem.load("c").await.unwrap_err();
2589 assert!(matches!(err, MemoryError::Policy(_)));
2590
2591 let loaded = mem.load("c").await.unwrap();
2593 assert_eq!(loaded.len(), 2);
2594 let Message::System { content } = &loaded[0] else {
2595 panic!("expected summary")
2596 };
2597 assert!(content.contains("compacted"));
2598 }
2599
2600 #[derive(Default)]
2603 struct CountingCompactor {
2604 log: Mutex<Vec<(usize, bool)>>,
2605 }
2606
2607 impl CountingCompactor {
2608 fn calls(&self) -> Vec<(usize, bool)> {
2609 self.log.lock().unwrap().clone()
2610 }
2611 }
2612
2613 impl Compactor for CountingCompactor {
2614 type Artifact = TextSummary;
2615
2616 fn compact<'a>(
2617 &'a self,
2618 _conversation_id: &'a str,
2619 evicted: &'a [Message],
2620 carry_over: Option<&'a Self::Artifact>,
2621 ) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
2622 Box::pin(async move {
2623 self.log
2624 .lock()
2625 .unwrap()
2626 .push((evicted.len(), carry_over.is_some()));
2627 let prev = carry_over.map(|s| s.as_str()).unwrap_or("");
2628 Ok(TextSummary(format!("{prev}|{}", evicted.len())))
2629 })
2630 }
2631 }
2632
2633 #[tokio::test]
2634 async fn compacting_no_demotion_does_not_invoke_compactor() {
2635 let compactor = Arc::new(CountingCompactor::default());
2636 let mem = CompactingMemory::new(
2637 InMemoryConversationMemory::new(),
2638 SlidingWindowMemory::last_messages(10),
2639 compactor.clone(),
2640 );
2641
2642 mem.append("c", vec![user("a"), assistant("b")])
2643 .await
2644 .unwrap();
2645 mem.load("c").await.unwrap();
2646 mem.load("c").await.unwrap();
2647 mem.load("c").await.unwrap();
2648 assert!(compactor.calls().is_empty());
2649 assert_eq!(mem.tracked_conversations(), 0);
2651 }
2652
2653 #[tokio::test]
2654 async fn compacting_invokes_compactor_only_on_new_demotions() {
2655 let compactor = Arc::new(CountingCompactor::default());
2656 let mem = CompactingMemory::new(
2657 InMemoryConversationMemory::new(),
2658 SlidingWindowMemory::last_messages(2),
2659 compactor.clone(),
2660 );
2661
2662 mem.append(
2664 "c",
2665 vec![user("a"), assistant("b"), user("c"), assistant("d")],
2666 )
2667 .await
2668 .unwrap();
2669 mem.load("c").await.unwrap();
2670 mem.load("c").await.unwrap();
2672 mem.load("c").await.unwrap();
2673 let calls = compactor.calls();
2674 assert_eq!(
2675 calls.len(),
2676 1,
2677 "compactor invoked more than once: {calls:?}"
2678 );
2679 assert_eq!(calls[0], (2, false));
2680
2681 mem.append("c", vec![user("e"), assistant("f")])
2684 .await
2685 .unwrap();
2686 mem.load("c").await.unwrap();
2687 mem.load("c").await.unwrap();
2688 let calls = compactor.calls();
2689 assert_eq!(calls.len(), 2, "expected exactly one new call: {calls:?}");
2690 assert_eq!(calls[1], (2, true));
2693 }
2694
2695 #[tokio::test]
2696 async fn compacting_serialises_concurrent_loads() {
2697 let compactor = Arc::new(CountingCompactor::default());
2700 let mem = Arc::new(CompactingMemory::new(
2701 InMemoryConversationMemory::new(),
2702 SlidingWindowMemory::last_messages(2),
2703 compactor.clone(),
2704 ));
2705 mem.append(
2706 "c",
2707 vec![user("a"), assistant("b"), user("c"), assistant("d")],
2708 )
2709 .await
2710 .unwrap();
2711
2712 let mut handles = Vec::new();
2713 for _ in 0..32 {
2714 let mem = mem.clone();
2715 handles.push(tokio::spawn(async move {
2716 mem.load("c").await.unwrap();
2717 }));
2718 }
2719 for h in handles {
2720 h.await.unwrap();
2721 }
2722
2723 let calls = compactor.calls();
2726 assert_eq!(calls.len(), 1, "expected exactly 1 call: {calls:?}");
2727 }
2728
2729 #[tokio::test]
2730 async fn compacting_clear_drops_summary_carry_over() {
2731 let compactor = Arc::new(CountingCompactor::default());
2735 let mem = CompactingMemory::new(
2736 InMemoryConversationMemory::new(),
2737 SlidingWindowMemory::last_messages(1),
2738 compactor.clone(),
2739 );
2740 mem.append("c", vec![user("a"), assistant("b"), user("c")])
2741 .await
2742 .unwrap();
2743 mem.load("c").await.unwrap();
2744 assert_eq!(compactor.calls()[0], (2, false));
2745
2746 mem.clear("c").await.unwrap();
2747 assert_eq!(mem.tracked_conversations(), 0);
2748
2749 mem.append("c", vec![user("x"), assistant("y"), user("z")])
2750 .await
2751 .unwrap();
2752 mem.load("c").await.unwrap();
2753 let calls = compactor.calls();
2754 assert_eq!(calls.len(), 2);
2755 assert_eq!(calls[1], (2, false));
2757 }
2758
2759 #[tokio::test]
2760 async fn compacting_forget_drops_summary() {
2761 let compactor = Arc::new(CountingCompactor::default());
2762 let mem = CompactingMemory::new(
2763 InMemoryConversationMemory::new(),
2764 SlidingWindowMemory::last_messages(1),
2765 compactor.clone(),
2766 );
2767 mem.append("c", vec![user("a"), assistant("b"), user("c")])
2768 .await
2769 .unwrap();
2770 mem.load("c").await.unwrap();
2771 assert_eq!(mem.tracked_conversations(), 1);
2772 mem.forget("c");
2773 assert_eq!(mem.tracked_conversations(), 0);
2774
2775 mem.load("c").await.unwrap();
2778 let calls = compactor.calls();
2779 assert_eq!(calls.len(), 2);
2780 assert_eq!(calls[1], (2, false));
2781 }
2782
2783 #[tokio::test]
2784 async fn compacting_arc_compactor_works() {
2785 let compactor: Arc<dyn Compactor<Artifact = TextSummary>> =
2788 Arc::new(TemplateCompactor::new());
2789 let mem = CompactingMemory::new(
2790 InMemoryConversationMemory::new(),
2791 SlidingWindowMemory::last_messages(1),
2792 compactor,
2793 );
2794 mem.append("c", vec![user("a"), assistant("b"), user("c")])
2795 .await
2796 .unwrap();
2797 let loaded = mem.load("c").await.unwrap();
2798 assert_eq!(loaded.len(), 2);
2799 assert!(matches!(&loaded[0], Message::System { .. }));
2800 }
2801
2802 #[tokio::test]
2803 async fn compacting_into_inner_returns_components() {
2804 let mem = CompactingMemory::new(
2805 InMemoryConversationMemory::new(),
2806 SlidingWindowMemory::last_messages(1),
2807 TemplateCompactor::new(),
2808 );
2809 let (_inner, _policy, _compactor) = mem.into_inner();
2810 }
2811
2812 #[tokio::test]
2813 async fn compacting_isolates_conversations() {
2814 let compactor = Arc::new(CountingCompactor::default());
2815 let mem = CompactingMemory::new(
2816 InMemoryConversationMemory::new(),
2817 SlidingWindowMemory::last_messages(1),
2818 compactor.clone(),
2819 );
2820 mem.append("a", vec![user("a1"), assistant("a2"), user("a3")])
2821 .await
2822 .unwrap();
2823 mem.append("b", vec![user("b1"), assistant("b2"), user("b3")])
2824 .await
2825 .unwrap();
2826
2827 let a = mem.load("a").await.unwrap();
2828 let b = mem.load("b").await.unwrap();
2829 assert_eq!(a.len(), 2);
2831 assert_eq!(b.len(), 2);
2832 assert_eq!(compactor.calls().len(), 2);
2833 assert_eq!(mem.tracked_conversations(), 2);
2834 }
2835
2836 #[tokio::test]
2837 async fn compacting_composes_with_token_window() {
2838 let mem = CompactingMemory::new(
2841 InMemoryConversationMemory::new(),
2842 TokenWindowMemory::new(30, HeuristicTokenCounter::openai()),
2843 TemplateCompactor::new(),
2844 );
2845 mem.append(
2846 "c",
2847 vec![
2848 user("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"),
2849 assistant("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"),
2850 user("cccccccccccccccccccc"),
2851 assistant("d"),
2852 ],
2853 )
2854 .await
2855 .unwrap();
2856 let loaded = mem.load("c").await.unwrap();
2857 assert!(loaded.len() >= 2);
2859 assert!(matches!(&loaded[0], Message::System { .. }));
2860 }
2861
2862 #[tokio::test]
2863 async fn template_compactor_renders_system_messages() {
2864 let compactor = TemplateCompactor::new();
2865 let evicted = vec![
2866 Message::System {
2867 content: "you are helpful".into(),
2868 },
2869 user("hi"),
2870 assistant("hello"),
2871 ];
2872 let summary = compactor.compact("c", &evicted, None).await.unwrap();
2873 let s = summary.as_str();
2874 assert!(s.contains("system: you are helpful"), "got: {s}");
2875 assert!(s.contains("user: hi"));
2876 assert!(s.contains("assistant: hello"));
2877 }
2878
2879 #[tokio::test]
2880 async fn template_compactor_renders_tool_call_marker() {
2881 let compactor = TemplateCompactor::new();
2882 let evicted = vec![tool_call_msg(), tool_result_msg()];
2883 let summary = compactor.compact("c", &evicted, None).await.unwrap();
2884 let s = summary.as_str();
2885 assert!(s.contains("[tool call: t]"), "got: {s}");
2886 assert!(s.contains("[tool result]"), "got: {s}");
2887 }
2888
2889 #[tokio::test]
2890 async fn template_compactor_carry_over_threaded() {
2891 let compactor = TemplateCompactor::new();
2892 let first = compactor
2893 .compact("c", &[user("hello")], None)
2894 .await
2895 .unwrap();
2896 assert!(!first.as_str().is_empty());
2897
2898 let second = compactor
2899 .compact("c", &[assistant("world")], Some(&first))
2900 .await
2901 .unwrap();
2902 assert!(second.as_str().contains(first.as_str()));
2904 assert!(second.as_str().contains("assistant: world"));
2905 }
2906
2907 #[tokio::test]
2908 async fn template_compactor_artifact_into_message() {
2909 let s = TextSummary("rolled-up text".into());
2910 let msg: Message = s.into();
2911 let Message::System { content } = msg else {
2912 panic!("expected system message");
2913 };
2914 assert_eq!(content, "rolled-up text");
2915 }
2916
2917 #[tokio::test]
2918 async fn template_compactor_caps_summary_at_max_bytes() {
2919 let cap = 256;
2920 let compactor = TemplateCompactor::new().with_max_bytes(cap);
2921 let mut evicted = Vec::new();
2923 for i in 0..50 {
2924 evicted.push(user(&format!("message number {i} with some filler")));
2925 }
2926 let summary = compactor.compact("c", &evicted, None).await.unwrap();
2927 assert!(
2928 summary.as_str().len()
2929 <= cap + "[Conversation summary so far]\n[\u{2026}truncated\u{2026}]\n".len(),
2930 "summary len {} exceeds cap {} (plus header+marker)",
2931 summary.as_str().len(),
2932 cap,
2933 );
2934 assert!(
2936 summary
2937 .as_str()
2938 .starts_with("[Conversation summary so far]\n")
2939 );
2940 assert!(summary.as_str().contains("[\u{2026}truncated\u{2026}]"));
2942 assert!(summary.as_str().contains("message number 49"));
2944 }
2945
2946 #[tokio::test]
2947 async fn template_compactor_unbounded_by_default() {
2948 let compactor = TemplateCompactor::new();
2949 let mut evicted = Vec::new();
2950 for i in 0..200 {
2951 evicted.push(user(&format!("msg {i}")));
2952 }
2953 let summary = compactor.compact("c", &evicted, None).await.unwrap();
2954 assert!(!summary.as_str().contains("[\u{2026}truncated\u{2026}]"));
2956 assert!(summary.as_str().contains("msg 0"));
2958 assert!(summary.as_str().contains("msg 199"));
2959 }
2960
2961 #[tokio::test]
2962 async fn template_compactor_with_max_bytes_zero_is_unbounded() {
2963 let compactor = TemplateCompactor::new().with_max_bytes(0);
2964 let mut evicted = Vec::new();
2965 for i in 0..200 {
2966 evicted.push(user(&format!("msg {i}")));
2967 }
2968 let summary = compactor.compact("c", &evicted, None).await.unwrap();
2969 assert!(!summary.as_str().contains("[\u{2026}truncated\u{2026}]"));
2970 }
2971
2972 #[tokio::test]
2973 async fn compacting_summary_stays_bounded_across_rolls() {
2974 let cap = 512;
2977 let mem = CompactingMemory::new(
2978 InMemoryConversationMemory::new(),
2979 SlidingWindowMemory::last_messages(2),
2980 TemplateCompactor::new().with_max_bytes(cap),
2981 );
2982 mem.append("c", vec![user("seed-a"), assistant("seed-b")])
2983 .await
2984 .unwrap();
2985 for i in 0..30 {
2986 mem.append(
2987 "c",
2988 vec![
2989 user(&format!("user line {i} ----- padding padding padding")),
2990 assistant(&format!("assistant line {i} ----- padding padding")),
2991 ],
2992 )
2993 .await
2994 .unwrap();
2995 mem.load("c").await.unwrap();
2996 }
2997 let loaded = mem.load("c").await.unwrap();
2998 let Message::System { content } = &loaded[0] else {
2999 panic!("expected summary");
3000 };
3001 let slack = "[Conversation summary so far]\n[\u{2026}truncated\u{2026}]\n".len();
3003 assert!(
3004 content.len() <= cap + slack,
3005 "summary grew to {} bytes (cap {}, slack {})",
3006 content.len(),
3007 cap,
3008 slack,
3009 );
3010 }
3011
3012 #[tokio::test]
3013 async fn compacting_concurrent_with_clear_does_not_resurrect_state() {
3014 use std::sync::atomic::{AtomicBool, Ordering};
3017
3018 struct GatedCompactor {
3019 release: tokio::sync::Notify,
3020 entered: AtomicBool,
3021 }
3022
3023 impl Compactor for GatedCompactor {
3024 type Artifact = TextSummary;
3025
3026 fn compact<'a>(
3027 &'a self,
3028 _conversation_id: &'a str,
3029 _evicted: &'a [Message],
3030 _carry_over: Option<&'a Self::Artifact>,
3031 ) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
3032 Box::pin(async move {
3033 self.entered.store(true, Ordering::SeqCst);
3034 self.release.notified().await;
3035 Ok(TextSummary("late summary".into()))
3036 })
3037 }
3038 }
3039
3040 let compactor = Arc::new(GatedCompactor {
3041 release: tokio::sync::Notify::new(),
3042 entered: AtomicBool::new(false),
3043 });
3044 let mem = Arc::new(CompactingMemory::new(
3045 InMemoryConversationMemory::new(),
3046 SlidingWindowMemory::last_messages(1),
3047 compactor.clone(),
3048 ));
3049 mem.append("c", vec![user("a"), assistant("b"), user("c")])
3050 .await
3051 .unwrap();
3052
3053 let mem_load = mem.clone();
3055 let load_handle = tokio::spawn(async move { mem_load.load("c").await });
3056
3057 while !compactor.entered.load(Ordering::SeqCst) {
3059 tokio::task::yield_now().await;
3060 }
3061
3062 mem.clear("c").await.unwrap();
3064
3065 compactor.release.notify_one();
3068 let _ = load_handle.await.unwrap();
3069
3070 assert_eq!(mem.tracked_conversations(), 0);
3071 assert!(mem.load("c").await.unwrap().is_empty());
3073 }
3074
3075 #[tokio::test]
3076 async fn compacting_dropped_load_releases_in_flight_gate() {
3077 use std::sync::atomic::{AtomicUsize, Ordering};
3081
3082 struct GatedCompactor {
3083 release: tokio::sync::Notify,
3084 entered: AtomicUsize,
3085 }
3086
3087 impl Compactor for GatedCompactor {
3088 type Artifact = TextSummary;
3089
3090 fn compact<'a>(
3091 &'a self,
3092 _conversation_id: &'a str,
3093 _evicted: &'a [Message],
3094 _carry_over: Option<&'a Self::Artifact>,
3095 ) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
3096 Box::pin(async move {
3097 self.entered.fetch_add(1, Ordering::SeqCst);
3098 self.release.notified().await;
3099 Ok(TextSummary("ran".into()))
3100 })
3101 }
3102 }
3103
3104 let compactor = Arc::new(GatedCompactor {
3105 release: tokio::sync::Notify::new(),
3106 entered: AtomicUsize::new(0),
3107 });
3108 let mem = Arc::new(CompactingMemory::new(
3109 InMemoryConversationMemory::new(),
3110 SlidingWindowMemory::last_messages(1),
3111 compactor.clone(),
3112 ));
3113 mem.append("c", vec![user("a"), assistant("b"), user("c")])
3114 .await
3115 .unwrap();
3116
3117 let mem_load = mem.clone();
3121 let handle = tokio::spawn(async move { mem_load.load("c").await });
3122 while compactor.entered.load(Ordering::SeqCst) == 0 {
3123 tokio::task::yield_now().await;
3124 }
3125 handle.abort();
3126 let _ = handle.await;
3127
3128 let mem_load = mem.clone();
3133 let retry = tokio::spawn(async move { mem_load.load("c").await });
3134 while compactor.entered.load(Ordering::SeqCst) < 2 {
3138 tokio::task::yield_now().await;
3139 }
3140 compactor.release.notify_one();
3141 let loaded = retry.await.unwrap().unwrap();
3142 assert_eq!(loaded.len(), 2);
3143 let Message::System { content } = &loaded[0] else {
3144 panic!("expected summary")
3145 };
3146 assert_eq!(content, "ran");
3147 }
3148
3149 #[tokio::test]
3150 async fn compacting_stale_cancelled_load_does_not_clear_new_reservation() {
3151 use std::sync::atomic::{AtomicUsize, Ordering};
3152
3153 struct GatedCompactor {
3154 release: tokio::sync::Notify,
3155 rendezvous: tokio::sync::Notify,
3156 entered: AtomicUsize,
3157 }
3158
3159 impl Compactor for GatedCompactor {
3160 type Artifact = TextSummary;
3161
3162 fn compact<'a>(
3163 &'a self,
3164 _conversation_id: &'a str,
3165 _evicted: &'a [Message],
3166 _carry_over: Option<&'a Self::Artifact>,
3167 ) -> WasmBoxedFuture<'a, Result<Self::Artifact, MemoryError>> {
3168 Box::pin(async move {
3169 self.entered.fetch_add(1, Ordering::SeqCst);
3170 self.rendezvous.notify_one();
3171 self.release.notified().await;
3172 Ok(TextSummary("ran".into()))
3173 })
3174 }
3175 }
3176
3177 let compactor = Arc::new(GatedCompactor {
3178 release: tokio::sync::Notify::new(),
3179 rendezvous: tokio::sync::Notify::new(),
3180 entered: AtomicUsize::new(0),
3181 });
3182 let mem = Arc::new(CompactingMemory::new(
3183 InMemoryConversationMemory::new(),
3184 SlidingWindowMemory::last_messages(1),
3185 compactor.clone(),
3186 ));
3187
3188 mem.append("c", vec![user("old 1"), assistant("old 2"), user("old 3")])
3189 .await
3190 .unwrap();
3191
3192 let mem_load = mem.clone();
3193 let stale = tokio::spawn(async move { mem_load.load("c").await });
3194 compactor.rendezvous.notified().await;
3195 assert_eq!(compactor.entered.load(Ordering::SeqCst), 1);
3196
3197 mem.clear("c").await.unwrap();
3198 mem.append(
3199 "c",
3200 vec![user("fresh 1"), assistant("fresh 2"), user("fresh 3")],
3201 )
3202 .await
3203 .unwrap();
3204
3205 let mem_load = mem.clone();
3206 let fresh = tokio::spawn(async move { mem_load.load("c").await });
3207 compactor.rendezvous.notified().await;
3208 assert_eq!(compactor.entered.load(Ordering::SeqCst), 2);
3209
3210 stale.abort();
3211 let _ = stale.await;
3212
3213 let mem_load = mem.clone();
3214 let mut concurrent = tokio::spawn(async move { mem_load.load("c").await });
3215 let concurrent_kept = tokio::select! {
3216 result = &mut concurrent => result.unwrap().unwrap(),
3217 _ = compactor.rendezvous.notified() => {
3218 panic!("stale guard must not clear the fresh in-flight reservation")
3219 }
3220 };
3221 assert_eq!(
3222 compactor.entered.load(Ordering::SeqCst),
3223 2,
3224 "stale guard must not clear the fresh in-flight reservation"
3225 );
3226
3227 compactor.release.notify_one();
3228 assert_eq!(fresh.await.unwrap().unwrap().len(), 2);
3229 assert_eq!(concurrent_kept.len(), 1);
3230 assert_eq!(compactor.entered.load(Ordering::SeqCst), 2);
3231 }
3232
3233 #[tokio::test]
3234 async fn template_compactor_caps_summary_with_multiline_header() {
3235 let cap = 256;
3240 let compactor = TemplateCompactor::with_header("line one\nline two").with_max_bytes(cap);
3241 let mut evicted = Vec::new();
3242 for i in 0..50 {
3243 evicted.push(user(&format!("message number {i} with some filler")));
3244 }
3245 let summary = compactor.compact("c", &evicted, None).await.unwrap();
3246 let text = summary.as_str();
3247
3248 assert!(text.starts_with("line one\n"));
3250 assert!(text.contains("[\u{2026}truncated\u{2026}]"));
3252 assert!(text.contains("message number 49"));
3253 let overhead = "line one\n".len() + "[\u{2026}truncated\u{2026}]\n".len();
3255 assert!(
3256 text.len() <= cap + overhead,
3257 "summary len {} exceeds cap {} plus overhead {}",
3258 text.len(),
3259 cap,
3260 overhead,
3261 );
3262 }
3263}