1use std::collections::HashMap;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::{Arc, Mutex};
16use std::time::{Duration, SystemTime};
17
18use serde::{Deserialize, Serialize};
19
20pub const DEFAULT_SESSION_ID: &str = "default";
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct InvestigationTurn {
26 pub question: String,
28 pub answer: String,
30 pub at: SystemTime,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct InvestigationState {
37 pub session_id: String,
39 pub turns: Vec<InvestigationTurn>,
41 pub created_at: SystemTime,
43 pub last_active: SystemTime,
45 #[serde(skip)]
48 last_seq: u64,
49}
50
51impl InvestigationState {
52 fn new(session_id: String) -> Self {
53 let now = SystemTime::now();
54 Self {
55 session_id,
56 turns: Vec::new(),
57 created_at: now,
58 last_active: now,
59 last_seq: 0,
60 }
61 }
62
63 pub fn turn_count(&self) -> usize {
65 self.turns.len()
66 }
67}
68
69#[derive(Debug, Clone)]
77pub struct SessionConfig {
78 pub max_sessions: usize,
81 pub max_turns_per_session: usize,
84 pub ttl: Duration,
87}
88
89impl Default for SessionConfig {
90 fn default() -> Self {
91 Self {
92 max_sessions: 1_000,
93 max_turns_per_session: 100,
94 ttl: Duration::from_secs(60 * 60 * 24),
95 }
96 }
97}
98
99#[derive(Clone)]
105pub struct SessionStore {
106 inner: Arc<Mutex<HashMap<String, InvestigationState>>>,
107 seq: Arc<AtomicU64>,
108 config: SessionConfig,
109}
110
111impl Default for SessionStore {
112 fn default() -> Self {
113 Self {
114 inner: Arc::new(Mutex::new(HashMap::new())),
115 seq: Arc::new(AtomicU64::new(0)),
116 config: SessionConfig::default(),
117 }
118 }
119}
120
121impl SessionStore {
122 pub fn new() -> Self {
124 Self::default()
125 }
126
127 pub fn with_config(config: SessionConfig) -> Self {
129 Self {
130 inner: Arc::new(Mutex::new(HashMap::new())),
131 seq: Arc::new(AtomicU64::new(0)),
132 config,
133 }
134 }
135
136 pub fn config(&self) -> &SessionConfig {
138 &self.config
139 }
140
141 pub fn record(&self, session_id: &str, question: &str, answer: &str) -> (usize, usize) {
150 let mut guard = self.lock();
151 let now = SystemTime::now();
152 let seq = self.seq.fetch_add(1, Ordering::Relaxed);
153
154 self.prune_expired(&mut guard, now);
155
156 let state = guard
157 .entry(session_id.to_string())
158 .or_insert_with(|| InvestigationState::new(session_id.to_string()));
159 state.turns.push(InvestigationTurn {
160 question: question.to_string(),
161 answer: answer.to_string(),
162 at: now,
163 });
164 state.last_active = now;
165 state.last_seq = seq;
166
167 let cap = self.config.max_turns_per_session;
169 if cap > 0 && state.turns.len() > cap {
170 let excess = state.turns.len() - cap;
171 state.turns.drain(0..excess);
172 }
173 let count = state.turns.len();
174
175 self.evict_overflow(&mut guard, session_id);
176
177 (count.saturating_sub(1), count)
178 }
179
180 fn prune_expired(
182 &self,
183 guard: &mut HashMap<String, InvestigationState>,
184 now: SystemTime,
185 ) {
186 let ttl = self.config.ttl;
187 if ttl.is_zero() {
188 return;
189 }
190 guard.retain(|_, state| {
191 now.duration_since(state.last_active)
192 .map(|idle| idle <= ttl)
193 .unwrap_or(true)
194 });
195 }
196
197 fn evict_overflow(&self, guard: &mut HashMap<String, InvestigationState>, keep: &str) {
200 let max = self.config.max_sessions;
201 if max == 0 || guard.len() <= max {
202 return;
203 }
204 let mut by_activity: Vec<(String, u64)> = guard
205 .iter()
206 .map(|(id, state)| (id.clone(), state.last_seq))
207 .collect();
208 by_activity.sort_by_key(|(_, seq)| *seq);
210 let mut overflow = guard.len() - max;
211 for (id, _) in by_activity {
212 if overflow == 0 {
213 break;
214 }
215 if id == keep {
216 continue;
217 }
218 guard.remove(&id);
219 overflow -= 1;
220 }
221 }
222
223 pub fn history(&self, session_id: &str) -> Vec<InvestigationTurn> {
225 self.lock()
226 .get(session_id)
227 .map(|state| state.turns.clone())
228 .unwrap_or_default()
229 }
230
231 pub fn recent_history(&self, session_id: &str, n: usize) -> Vec<InvestigationTurn> {
238 self.lock()
239 .get(session_id)
240 .map(|state| {
241 let start = state.turns.len().saturating_sub(n);
242 state.turns.iter().skip(start).cloned().collect()
243 })
244 .unwrap_or_default()
245 }
246
247 pub fn snapshot(&self, session_id: &str) -> Option<InvestigationState> {
249 self.lock().get(session_id).cloned()
250 }
251
252 pub fn session_count(&self) -> usize {
254 self.lock().len()
255 }
256
257 pub fn contains(&self, session_id: &str) -> bool {
259 self.lock().contains_key(session_id)
260 }
261
262 fn lock(&self) -> std::sync::MutexGuard<'_, HashMap<String, InvestigationState>> {
268 self.inner
269 .lock()
270 .unwrap_or_else(|poisoned| poisoned.into_inner())
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 #![allow(
277 clippy::unwrap_used,
278 clippy::expect_used,
279 clippy::indexing_slicing,
280 clippy::panic
281 )]
282 use super::*;
283
284 #[test]
285 fn records_and_reads_back_turns() {
286 let store = SessionStore::new();
287 let (idx0, count0) = store.record("s1", "q1", "a1");
288 assert_eq!((idx0, count0), (0, 1));
289 let (idx1, count1) = store.record("s1", "q2", "a2");
290 assert_eq!((idx1, count1), (1, 2));
291
292 let history = store.history("s1");
293 assert_eq!(history.len(), 2);
294 assert_eq!(history.first().unwrap().question, "q1");
295 assert_eq!(history.get(1).unwrap().answer, "a2");
296 }
297
298 #[test]
299 fn sessions_are_isolated() {
300 let store = SessionStore::new();
301 store.record("a", "qa", "aa");
302 store.record("b", "qb", "ab");
303
304 assert_eq!(store.session_count(), 2);
305 assert_eq!(store.history("a").len(), 1);
306 assert_eq!(store.history("b").len(), 1);
307 assert!(store.contains("a"));
308 assert!(!store.contains("c"));
309 assert!(store.history("c").is_empty());
310 }
311
312 #[test]
313 fn recent_history_returns_only_last_n_in_order() {
314 let store = SessionStore::new();
315 for i in 0..10 {
316 store.record("s", &format!("q{i}"), &format!("a{i}"));
317 }
318 let recent = store.recent_history("s", 3);
319 assert_eq!(recent.len(), 3);
320 assert_eq!(recent.first().unwrap().question, "q7");
321 assert_eq!(recent.get(2).unwrap().question, "q9");
322
323 let all = store.recent_history("s", 100);
325 assert_eq!(all.len(), 10);
326 assert!(store.recent_history("nope", 5).is_empty());
328 }
329
330 #[test]
331 fn turns_per_session_are_capped_dropping_oldest() {
332 let store = SessionStore::with_config(SessionConfig {
333 max_turns_per_session: 2,
334 ..SessionConfig::default()
335 });
336 store.record("s", "q1", "a1");
337 store.record("s", "q2", "a2");
338 let (idx, count) = store.record("s", "q3", "a3");
339
340 assert_eq!(count, 2, "turn count must be capped");
341 assert_eq!(idx, 1);
342 let history = store.history("s");
343 assert_eq!(history.len(), 2);
344 assert_eq!(history.first().unwrap().question, "q2");
346 assert_eq!(history.get(1).unwrap().question, "q3");
347 }
348
349 #[test]
350 fn overflowing_sessions_evicts_least_recently_active() {
351 let store = SessionStore::with_config(SessionConfig {
352 max_sessions: 2,
353 ..SessionConfig::default()
354 });
355 store.record("a", "qa", "aa");
356 store.record("b", "qb", "ab");
357 store.record("a", "qa2", "aa2");
359 store.record("c", "qc", "ac");
361
362 assert_eq!(store.session_count(), 2);
363 assert!(store.contains("a"), "recently active session must survive");
364 assert!(store.contains("c"), "newest session must survive");
365 assert!(!store.contains("b"), "stalest session must be evicted");
366 }
367
368 #[test]
369 fn zero_bounds_disable_eviction() {
370 let store = SessionStore::with_config(SessionConfig {
371 max_sessions: 0,
372 max_turns_per_session: 0,
373 ttl: Duration::ZERO,
374 });
375 for i in 0..50 {
376 store.record(&format!("s{i}"), "q", "a");
377 }
378 for _ in 0..50 {
379 store.record("s0", "q", "a");
380 }
381 assert_eq!(store.session_count(), 50, "max_sessions=0 disables the cap");
382 assert_eq!(
383 store.history("s0").len(),
384 51,
385 "max_turns_per_session=0 disables the cap"
386 );
387 }
388}