use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime};
use serde::{Deserialize, Serialize};
pub const DEFAULT_SESSION_ID: &str = "default";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvestigationTurn {
pub question: String,
pub answer: String,
pub at: SystemTime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InvestigationState {
pub session_id: String,
pub turns: Vec<InvestigationTurn>,
pub created_at: SystemTime,
pub last_active: SystemTime,
#[serde(skip)]
last_seq: u64,
}
impl InvestigationState {
fn new(session_id: String) -> Self {
let now = SystemTime::now();
Self {
session_id,
turns: Vec::new(),
created_at: now,
last_active: now,
last_seq: 0,
}
}
pub fn turn_count(&self) -> usize {
self.turns.len()
}
}
#[derive(Debug, Clone)]
pub struct SessionConfig {
pub max_sessions: usize,
pub max_turns_per_session: usize,
pub ttl: Duration,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
max_sessions: 1_000,
max_turns_per_session: 100,
ttl: Duration::from_secs(60 * 60 * 24),
}
}
}
#[derive(Clone)]
pub struct SessionStore {
inner: Arc<Mutex<HashMap<String, InvestigationState>>>,
seq: Arc<AtomicU64>,
config: SessionConfig,
}
impl Default for SessionStore {
fn default() -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
seq: Arc::new(AtomicU64::new(0)),
config: SessionConfig::default(),
}
}
}
impl SessionStore {
pub fn new() -> Self {
Self::default()
}
pub fn with_config(config: SessionConfig) -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
seq: Arc::new(AtomicU64::new(0)),
config,
}
}
pub fn config(&self) -> &SessionConfig {
&self.config
}
pub fn record(&self, session_id: &str, question: &str, answer: &str) -> (usize, usize) {
let mut guard = self.lock();
let now = SystemTime::now();
let seq = self.seq.fetch_add(1, Ordering::Relaxed);
self.prune_expired(&mut guard, now);
let state = guard
.entry(session_id.to_string())
.or_insert_with(|| InvestigationState::new(session_id.to_string()));
state.turns.push(InvestigationTurn {
question: question.to_string(),
answer: answer.to_string(),
at: now,
});
state.last_active = now;
state.last_seq = seq;
let cap = self.config.max_turns_per_session;
if cap > 0 && state.turns.len() > cap {
let excess = state.turns.len() - cap;
state.turns.drain(0..excess);
}
let count = state.turns.len();
self.evict_overflow(&mut guard, session_id);
(count.saturating_sub(1), count)
}
fn prune_expired(
&self,
guard: &mut HashMap<String, InvestigationState>,
now: SystemTime,
) {
let ttl = self.config.ttl;
if ttl.is_zero() {
return;
}
guard.retain(|_, state| {
now.duration_since(state.last_active)
.map(|idle| idle <= ttl)
.unwrap_or(true)
});
}
fn evict_overflow(&self, guard: &mut HashMap<String, InvestigationState>, keep: &str) {
let max = self.config.max_sessions;
if max == 0 || guard.len() <= max {
return;
}
let mut by_activity: Vec<(String, u64)> = guard
.iter()
.map(|(id, state)| (id.clone(), state.last_seq))
.collect();
by_activity.sort_by_key(|(_, seq)| *seq);
let mut overflow = guard.len() - max;
for (id, _) in by_activity {
if overflow == 0 {
break;
}
if id == keep {
continue;
}
guard.remove(&id);
overflow -= 1;
}
}
pub fn history(&self, session_id: &str) -> Vec<InvestigationTurn> {
self.lock()
.get(session_id)
.map(|state| state.turns.clone())
.unwrap_or_default()
}
pub fn recent_history(&self, session_id: &str, n: usize) -> Vec<InvestigationTurn> {
self.lock()
.get(session_id)
.map(|state| {
let start = state.turns.len().saturating_sub(n);
state.turns.iter().skip(start).cloned().collect()
})
.unwrap_or_default()
}
pub fn snapshot(&self, session_id: &str) -> Option<InvestigationState> {
self.lock().get(session_id).cloned()
}
pub fn session_count(&self) -> usize {
self.lock().len()
}
pub fn contains(&self, session_id: &str) -> bool {
self.lock().contains_key(session_id)
}
fn lock(&self) -> std::sync::MutexGuard<'_, HashMap<String, InvestigationState>> {
self.inner
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::indexing_slicing,
clippy::panic
)]
use super::*;
#[test]
fn records_and_reads_back_turns() {
let store = SessionStore::new();
let (idx0, count0) = store.record("s1", "q1", "a1");
assert_eq!((idx0, count0), (0, 1));
let (idx1, count1) = store.record("s1", "q2", "a2");
assert_eq!((idx1, count1), (1, 2));
let history = store.history("s1");
assert_eq!(history.len(), 2);
assert_eq!(history.first().unwrap().question, "q1");
assert_eq!(history.get(1).unwrap().answer, "a2");
}
#[test]
fn sessions_are_isolated() {
let store = SessionStore::new();
store.record("a", "qa", "aa");
store.record("b", "qb", "ab");
assert_eq!(store.session_count(), 2);
assert_eq!(store.history("a").len(), 1);
assert_eq!(store.history("b").len(), 1);
assert!(store.contains("a"));
assert!(!store.contains("c"));
assert!(store.history("c").is_empty());
}
#[test]
fn recent_history_returns_only_last_n_in_order() {
let store = SessionStore::new();
for i in 0..10 {
store.record("s", &format!("q{i}"), &format!("a{i}"));
}
let recent = store.recent_history("s", 3);
assert_eq!(recent.len(), 3);
assert_eq!(recent.first().unwrap().question, "q7");
assert_eq!(recent.get(2).unwrap().question, "q9");
let all = store.recent_history("s", 100);
assert_eq!(all.len(), 10);
assert!(store.recent_history("nope", 5).is_empty());
}
#[test]
fn turns_per_session_are_capped_dropping_oldest() {
let store = SessionStore::with_config(SessionConfig {
max_turns_per_session: 2,
..SessionConfig::default()
});
store.record("s", "q1", "a1");
store.record("s", "q2", "a2");
let (idx, count) = store.record("s", "q3", "a3");
assert_eq!(count, 2, "turn count must be capped");
assert_eq!(idx, 1);
let history = store.history("s");
assert_eq!(history.len(), 2);
assert_eq!(history.first().unwrap().question, "q2");
assert_eq!(history.get(1).unwrap().question, "q3");
}
#[test]
fn overflowing_sessions_evicts_least_recently_active() {
let store = SessionStore::with_config(SessionConfig {
max_sessions: 2,
..SessionConfig::default()
});
store.record("a", "qa", "aa");
store.record("b", "qb", "ab");
store.record("a", "qa2", "aa2");
store.record("c", "qc", "ac");
assert_eq!(store.session_count(), 2);
assert!(store.contains("a"), "recently active session must survive");
assert!(store.contains("c"), "newest session must survive");
assert!(!store.contains("b"), "stalest session must be evicted");
}
#[test]
fn zero_bounds_disable_eviction() {
let store = SessionStore::with_config(SessionConfig {
max_sessions: 0,
max_turns_per_session: 0,
ttl: Duration::ZERO,
});
for i in 0..50 {
store.record(&format!("s{i}"), "q", "a");
}
for _ in 0..50 {
store.record("s0", "q", "a");
}
assert_eq!(store.session_count(), 50, "max_sessions=0 disables the cap");
assert_eq!(
store.history("s0").len(),
51,
"max_turns_per_session=0 disables the cap"
);
}
}