use std::sync::Arc;
use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use zeph_config::MemCotConfig;
use zeph_llm::provider::LlmProvider as _;
use super::metrics;
#[derive(Default)]
pub(crate) struct SemanticState {
pub(crate) buffer: String,
pub(crate) turn_count: u64,
pub(crate) updated_at_secs: i64,
}
pub struct SemanticStateAccumulator {
cfg: Arc<MemCotConfig>,
state: Arc<RwLock<SemanticState>>,
pub(crate) last_distill_at_secs: Arc<AtomicI64>,
pub(crate) distill_count_session: Arc<AtomicU64>,
}
impl SemanticStateAccumulator {
#[must_use]
pub fn new(cfg: Arc<MemCotConfig>) -> Self {
Self {
cfg,
state: Arc::new(RwLock::new(SemanticState::default())),
last_distill_at_secs: Arc::new(AtomicI64::new(0)),
distill_count_session: Arc::new(AtomicU64::new(0)),
}
}
pub async fn current_state(&self) -> Option<String> {
let guard = self.state.read().await;
if guard.buffer.is_empty() {
None
} else {
Some(guard.buffer.clone())
}
}
pub async fn reset_session_counters(&self) {
self.distill_count_session.store(0, Ordering::Relaxed);
self.last_distill_at_secs.store(0, Ordering::Relaxed);
self.state.write().await.buffer.clear();
}
pub(crate) fn maybe_enqueue_distill(
&self,
assistant_content: &str,
provider: zeph_llm::any::AnyProvider,
supervisor_spawn: impl FnOnce(
&'static str,
std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>,
),
) {
if !self.cfg.enabled {
return;
}
let content_chars = assistant_content.chars().count();
if content_chars < self.cfg.min_assistant_chars {
return;
}
let now = unix_now_secs();
let elapsed = now.saturating_sub(self.last_distill_at_secs.load(Ordering::Relaxed));
if elapsed < i64::try_from(self.cfg.min_distill_interval_secs).unwrap_or(i64::MAX) {
metrics::distill_skipped("interval");
return;
}
if self.distill_count_session.load(Ordering::Relaxed) >= self.cfg.max_distills_per_session {
metrics::distill_skipped("session_cap");
return;
}
self.last_distill_at_secs.store(now, Ordering::Relaxed);
self.distill_count_session.fetch_add(1, Ordering::Relaxed);
metrics::distill_total();
let state_arc = Arc::clone(&self.state);
let cfg = Arc::clone(&self.cfg);
let content = assistant_content.to_owned();
let fut = Box::pin(async move {
let span = tracing::info_span!("core.memcot.distill", result = tracing::field::Empty);
let _guard = span.enter();
let prompt = build_distill_prompt(&content, &state_arc).await;
let msgs = vec![zeph_llm::provider::Message::from_legacy(
zeph_llm::provider::Role::User,
prompt,
)];
let timeout = Duration::from_secs(cfg.distill_timeout_secs);
let result = tokio::time::timeout(timeout, provider.chat(&msgs)).await;
match result {
Ok(Ok(response)) => {
tracing::Span::current().record("result", "ok");
let raw = response.trim().to_owned();
let cap = cfg.max_state_chars;
let new_buf = if raw.chars().count() > cap {
let cut = raw.floor_char_boundary(
raw.char_indices().nth(cap).map_or(raw.len(), |(i, _)| i),
);
raw[..cut].to_owned()
} else {
raw
};
let mut state = state_arc.write().await;
state.buffer = new_buf;
state.turn_count = state.turn_count.saturating_add(1);
state.updated_at_secs = unix_now_secs();
tracing::debug!(
turn = state.turn_count,
buf_chars = state.buffer.chars().count(),
"memcot: distill complete"
);
}
Ok(Err(e)) => {
tracing::Span::current().record("result", "error");
metrics::distill_error();
tracing::warn!(error = %e, "memcot: distill failed");
}
Err(_) => {
tracing::Span::current().record("result", "timeout");
metrics::distill_timeout();
tracing::warn!(
timeout_secs = cfg.distill_timeout_secs,
"memcot: distill timed out"
);
}
}
});
supervisor_spawn("memcot_distill", fut);
}
}
fn scrub_content(s: &str) -> String {
s.replace(['\n', '\r', '<', '>'], " ")
}
async fn build_distill_prompt(assistant_content: &str, state: &RwLock<SemanticState>) -> String {
let prior = {
let guard = state.read().await;
guard.buffer.clone()
};
let safe_content = scrub_content(assistant_content);
if prior.is_empty() {
format!(
"Summarize the key concepts and entities from the following assistant response \
in 1-3 short sentences. Focus on what changed or was decided.\n\n\
<turn>{safe_content}</turn>"
)
} else {
let safe_prior = scrub_content(&prior);
format!(
"Update the semantic state by integrating the new assistant response. \
Keep the most important concepts from the prior state and add new ones. \
Reply with 1-3 short sentences total.\n\n\
Prior state: {safe_prior}\n\n\
<turn>{safe_content}</turn>"
)
}
}
fn unix_now_secs() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
.try_into()
.unwrap_or(i64::MAX)
}
#[cfg(test)]
mod tests {
use super::*;
fn null_provider() -> zeph_llm::any::AnyProvider {
zeph_llm::any::AnyProvider::Mock(zeph_llm::mock::MockProvider::default())
}
#[tokio::test]
async fn accumulator_initial_state_empty() {
let cfg = Arc::new(MemCotConfig::default());
let acc = SemanticStateAccumulator::new(cfg);
assert!(acc.current_state().await.is_none());
}
#[tokio::test]
#[allow(clippy::large_futures)]
async fn reset_session_counters_clears_state() {
let cfg = MemCotConfig {
enabled: true,
..MemCotConfig::default()
};
let acc = SemanticStateAccumulator::new(Arc::new(cfg));
acc.distill_count_session.store(42, Ordering::Relaxed);
acc.last_distill_at_secs.store(9999, Ordering::Relaxed);
acc.state.write().await.buffer = "prior semantic state".to_owned();
acc.reset_session_counters().await;
assert_eq!(acc.distill_count_session.load(Ordering::Relaxed), 0);
assert_eq!(acc.last_distill_at_secs.load(Ordering::Relaxed), 0);
assert!(
acc.current_state().await.is_none(),
"buffer must be cleared on reset"
);
}
#[tokio::test]
async fn distill_skipped_when_disabled() {
let cfg = Arc::new(MemCotConfig {
enabled: false,
..MemCotConfig::default()
});
let acc = SemanticStateAccumulator::new(cfg);
let mut spawn_called = false;
acc.maybe_enqueue_distill("hello", null_provider(), |_name, _fut| {
spawn_called = true;
});
assert!(!spawn_called, "disabled accumulator must never spawn");
}
#[tokio::test]
async fn distill_skipped_when_content_too_short() {
let cfg = Arc::new(MemCotConfig {
enabled: true,
min_assistant_chars: 100,
min_distill_interval_secs: 0,
max_distills_per_session: 50,
..MemCotConfig::default()
});
let acc = SemanticStateAccumulator::new(cfg);
let mut spawn_called = false;
acc.maybe_enqueue_distill("hi", null_provider(), |_name, _fut| {
spawn_called = true;
});
assert!(!spawn_called, "should not spawn for short content");
}
#[tokio::test]
async fn distill_skipped_when_session_cap_reached() {
let cfg = Arc::new(MemCotConfig {
enabled: true,
max_distills_per_session: 2,
min_distill_interval_secs: 0,
min_assistant_chars: 1,
..MemCotConfig::default()
});
let acc = SemanticStateAccumulator::new(cfg);
acc.distill_count_session.store(2, Ordering::Relaxed);
let mut spawn_called = false;
acc.maybe_enqueue_distill("hello world!", null_provider(), |_name, _fut| {
spawn_called = true;
});
assert!(!spawn_called, "should not spawn when session cap reached");
}
#[tokio::test]
async fn distill_skipped_when_interval_not_elapsed() {
let cfg = Arc::new(MemCotConfig {
enabled: true,
max_distills_per_session: 50,
min_distill_interval_secs: 9999,
min_assistant_chars: 1,
..MemCotConfig::default()
});
let acc = SemanticStateAccumulator::new(cfg);
acc.last_distill_at_secs
.store(unix_now_secs(), Ordering::Relaxed);
let mut spawn_called = false;
acc.maybe_enqueue_distill("hello world!", null_provider(), |_name, _fut| {
spawn_called = true;
});
assert!(!spawn_called, "should not spawn before interval elapses");
}
#[tokio::test]
async fn distill_spawned_when_all_gates_pass() {
let cfg = Arc::new(MemCotConfig {
enabled: true,
max_distills_per_session: 50,
min_distill_interval_secs: 0,
min_assistant_chars: 1,
..MemCotConfig::default()
});
let acc = SemanticStateAccumulator::new(cfg);
let mut spawn_called = false;
acc.maybe_enqueue_distill("hello world!", null_provider(), |_name, _fut| {
spawn_called = true;
});
assert!(spawn_called, "should spawn when all gates pass");
assert_eq!(acc.distill_count_session.load(Ordering::Relaxed), 1);
}
}