1use anyhow::{bail, Result};
7use chrono::{DateTime, Utc};
8use serde::{de::DeserializeOwned, Deserialize, Deserializer, Serialize, Serializer};
9use std::path::PathBuf;
10use tokio::fs;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct SessionId(pub String);
15
16impl SessionId {
17 pub fn new() -> Self {
19 Self(uuid::Uuid::new_v4().to_string())
20 }
21}
22
23impl Default for SessionId {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl std::fmt::Display for SessionId {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "{}", self.0)
32 }
33}
34
35impl Serialize for SessionId {
36 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
37 where
38 S: Serializer,
39 {
40 serializer.serialize_str(&self.0)
41 }
42}
43
44impl<'de> Deserialize<'de> for SessionId {
45 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
46 where
47 D: Deserializer<'de>,
48 {
49 let s = String::deserialize(deserializer)?;
50 Ok(Self(s))
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct UserMessage {
57 pub content: String,
59 pub timestamp: DateTime<Utc>,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
65pub struct AgentResponse {
66 pub content: String,
68 pub session_id: Option<String>,
70 pub seed_id: Option<String>,
72 pub phase_reached: Option<String>,
74 pub evaluation_passed: Option<bool>,
76 pub timestamp: DateTime<Utc>,
78}
79
80pub type SessionMetadata = std::collections::HashMap<String, serde_json::Value>;
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct Session {
90 pub id: SessionId,
92 pub user_id: String,
94 #[serde(default)]
96 pub user_messages: Vec<UserMessage>,
97 #[serde(default)]
99 pub agent_responses: Vec<AgentResponse>,
100 #[serde(skip_serializing_if = "Option::is_none")]
102 pub active_seed_id: Option<String>,
103 #[serde(skip_serializing_if = "Option::is_none")]
105 pub active_persona_id: Option<String>,
106 pub created_at: DateTime<Utc>,
108 pub updated_at: DateTime<Utc>,
110 #[serde(default)]
112 pub metadata: SessionMetadata,
113}
114
115impl Session {
116 pub fn new(user_id: impl Into<String>) -> Self {
118 let now = Utc::now();
119 Self {
120 id: SessionId::new(),
121 user_id: user_id.into(),
122 user_messages: Vec::new(),
123 agent_responses: Vec::new(),
124 active_seed_id: None,
125 active_persona_id: None,
126 created_at: now,
127 updated_at: now,
128 metadata: SessionMetadata::new(),
129 }
130 }
131
132 pub fn with_id(user_id: impl Into<String>, session_id: SessionId) -> Self {
134 let now = Utc::now();
135 Self {
136 id: session_id,
137 user_id: user_id.into(),
138 user_messages: Vec::new(),
139 agent_responses: Vec::new(),
140 active_seed_id: None,
141 active_persona_id: None,
142 created_at: now,
143 updated_at: now,
144 metadata: SessionMetadata::new(),
145 }
146 }
147
148 pub fn add_user_message(&mut self, content: impl Into<String>) {
150 self.user_messages.push(UserMessage {
151 content: content.into(),
152 timestamp: Utc::now(),
153 });
154 self.updated_at = Utc::now();
155 }
156
157 pub fn add_agent_response(&mut self, response: AgentResponse) {
159 self.agent_responses.push(response);
160 self.updated_at = Utc::now();
161 }
162
163 pub fn set_active_seed(&mut self, seed_id: Option<String>) {
165 self.active_seed_id = seed_id;
166 self.updated_at = Utc::now();
167 }
168
169 pub fn set_active_persona(&mut self, persona_id: Option<String>) {
171 self.active_persona_id = persona_id;
172 self.updated_at = Utc::now();
173 }
174
175 pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
177 self.metadata.insert(key.into(), value);
178 self.updated_at = Utc::now();
179 }
180
181 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
183 self.metadata.get(key)
184 }
185
186 pub fn exchange_count(&self) -> usize {
188 self.user_messages.len().min(self.agent_responses.len())
189 }
190
191 pub fn is_empty(&self) -> bool {
193 self.user_messages.is_empty()
194 }
195}
196#[derive(Clone)]
201pub struct StateStore {
202 pub base_path: PathBuf,
204}
205
206impl StateStore {
207 pub fn new(base_path: PathBuf) -> Result<Self> {
218 Ok(Self { base_path })
219 }
220
221 fn validate_category(category: &str) -> Result<()> {
223 if category.contains("..") || category.contains('\\') {
224 bail!("invalid category name: '{}'", category);
225 }
226 if category.is_empty()
227 || category.starts_with('/')
228 || category.ends_with('/')
229 || category.contains("//")
230 {
231 bail!("invalid category name: '{}'", category);
232 }
233 Ok(())
234 }
235
236 fn validate_name(name: &str) -> Result<()> {
238 if name.contains("..") || name.contains('/') || name.contains('\\') {
239 bail!("invalid file name: '{}'", name);
240 }
241 Ok(())
242 }
243
244 pub async fn save_markdown(&self, category: &str, name: &str, content: &str) -> Result<()> {
246 Self::validate_category(category)?;
247 Self::validate_name(name)?;
248 let dir = self.base_path.join(category);
249 fs::create_dir_all(&dir).await?;
250 let path = dir.join(format!("{name}.md"));
251
252 let temp_path = dir.join(format!("{name}.{}.tmp", std::process::id()));
254 fs::write(&temp_path, content).await?;
255 tokio::fs::rename(&temp_path, &path).await?;
256
257 Ok(())
258 }
259
260 pub async fn load_markdown(&self, category: &str, name: &str) -> Result<Option<String>> {
262 Self::validate_category(category)?;
263 Self::validate_name(name)?;
264 let path = self.base_path.join(category).join(format!("{name}.md"));
265 match fs::read_to_string(&path).await {
266 Ok(content) => Ok(Some(content)),
267 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
268 Err(e) => Err(e.into()),
269 }
270 }
271
272 pub async fn list_category(&self, category: &str) -> Result<Vec<String>> {
274 Self::validate_category(category)?;
275 let dir = self.base_path.join(category);
276 if !dir.exists() {
277 return Ok(Vec::new());
278 }
279 let mut entries = fs::read_dir(&dir).await?;
280 let mut names = Vec::new();
281 while let Some(entry) = entries.next_entry().await? {
282 let path = entry.path();
283 if let Some(ext) = path.extension() {
284 if ext == "md" || ext == "json" {
285 if let Some(stem) = path.file_stem() {
286 names.push(stem.to_string_lossy().into_owned());
287 }
288 }
289 }
290 }
291 names.sort();
292 Ok(names)
293 }
294
295 pub async fn save_json<T: Serialize>(
297 &self,
298 category: &str,
299 name: &str,
300 data: &T,
301 ) -> Result<()> {
302 Self::validate_category(category)?;
303 Self::validate_name(name)?;
304 let dir = self.base_path.join(category);
305 fs::create_dir_all(&dir).await?;
306 let path = dir.join(format!("{name}.json"));
307
308 let content = serde_json::to_string_pretty(data)?;
309
310 let temp_path = dir.join(format!("{name}.{}.tmp", std::process::id()));
312 fs::write(&temp_path, &content).await?;
313 tokio::fs::rename(&temp_path, &path).await?;
314
315 Ok(())
316 }
317
318 pub async fn load_json<T: DeserializeOwned>(
320 &self,
321 category: &str,
322 name: &str,
323 ) -> Result<Option<T>> {
324 Self::validate_category(category)?;
325 Self::validate_name(name)?;
326 let path = self.base_path.join(category).join(format!("{name}.json"));
327 match fs::read_to_string(&path).await {
328 Ok(content) => Ok(Some(serde_json::from_str(&content)?)),
329 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
330 Err(e) => Err(e.into()),
331 }
332 }
333
334 pub async fn delete_file(&self, category: &str, name: &str) -> Result<bool> {
336 Self::validate_category(category)?;
337 Self::validate_name(name)?;
338 let path = self.base_path.join(category).join(format!("{name}.json"));
339 if path.exists() {
340 tokio::fs::remove_file(path).await?;
341 Ok(true)
342 } else {
343 let path = self.base_path.join(category).join(format!("{name}.md"));
344 if path.exists() {
345 tokio::fs::remove_file(path).await?;
346 Ok(true)
347 } else {
348 Ok(false)
349 }
350 }
351 }
352}
353
354impl std::fmt::Debug for StateStore {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 f.debug_struct("StateStore")
357 .field("base_path", &self.base_path)
358 .finish()
359 }
360}
361
362impl StateStore {
363 pub async fn save_session(&self, session: &Session) -> Result<()> {
365 self.save_json("sessions", &session.id.0, session).await
366 }
367
368 pub async fn load_session(&self, session_id: &SessionId) -> Result<Option<Session>> {
370 self.load_json("sessions", &session_id.0).await
371 }
372
373 pub async fn list_sessions(&self) -> Result<Vec<SessionSummary>> {
375 let mut sessions = Vec::new();
376
377 if let Ok(names) = self.list_category("sessions").await {
378 for name in names {
379 if let Ok(Some(session)) = self.load_json::<Session>("sessions", &name).await {
380 sessions.push(SessionSummary {
381 id: session.id.0.clone(),
382 user_id: session.user_id.clone(),
383 message_count: session.user_messages.len(),
384 active_seed_id: session.active_seed_id.clone(),
385 created_at: session.created_at,
386 updated_at: session.updated_at,
387 });
388 }
389 }
390 }
391
392 sessions.sort_by_key(|b| std::cmp::Reverse(b.updated_at));
394 Ok(sessions)
395 }
396
397 pub async fn delete_session(&self, session_id: &SessionId) -> Result<bool> {
399 let path = self
400 .base_path
401 .join("sessions")
402 .join(format!("{}.json", session_id.0));
403 match fs::remove_file(&path).await {
404 Ok(()) => {
405 tracing::info!(session_id = %session_id, "Session deleted");
406 Ok(true)
407 }
408 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(false),
409 Err(e) => Err(e.into()),
410 }
411 }
412
413 pub async fn get_or_create_session(
415 &self,
416 user_id: &str,
417 session_id: Option<&SessionId>,
418 ) -> Result<Session> {
419 if let Some(sid) = session_id {
420 if let Some(existing) = self.load_session(sid).await? {
421 return Ok(existing);
422 }
423 }
424
425 let session = match session_id {
427 Some(sid) => Session::with_id(user_id, sid.clone()),
428 None => Session::new(user_id),
429 };
430
431 self.save_session(&session).await?;
432 Ok(session)
433 }
434
435 pub async fn update_session(&self, session: &Session) -> Result<()> {
437 self.save_session(session).await
438 }
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct SessionSummary {
444 pub id: String,
446 pub user_id: String,
448 pub message_count: usize,
450 #[serde(skip_serializing_if = "Option::is_none")]
452 pub active_seed_id: Option<String>,
453 pub created_at: DateTime<Utc>,
455 pub updated_at: DateTime<Utc>,
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[tokio::test]
464 async fn test_session_creation_and_persistence() {
465 let temp_dir = tempfile::tempdir().unwrap();
466 let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
467
468 let mut session = Session::new("user-123");
470 session.add_user_message("Hello");
471
472 store.save_session(&session).await.unwrap();
474 let loaded = store.load_session(&session.id).await.unwrap();
475 assert!(loaded.is_some());
476 let loaded = loaded.unwrap();
477 assert_eq!(loaded.user_id, "user-123");
478 assert_eq!(loaded.user_messages.len(), 1);
479 }
480
481 #[tokio::test]
482 async fn test_session_list_sorts_by_updated() {
483 let temp_dir = tempfile::tempdir().unwrap();
484 let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
485
486 for i in 0..3 {
488 let mut session = Session::new(&format!("user-{}", i));
489 session.add_user_message(&format!("Message {}", i));
490 store.save_session(&session).await.unwrap();
491 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
492 }
493
494 let sessions = store.list_sessions().await.unwrap();
495 assert_eq!(sessions.len(), 3);
496 assert_eq!(sessions[0].user_id, "user-2");
498 }
499
500 #[tokio::test]
501 async fn test_delete_session() {
502 let temp_dir = tempfile::tempdir().unwrap();
503 let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
504
505 let session = Session::new("user-123");
506 store.save_session(&session).await.unwrap();
507
508 let deleted = store.delete_session(&session.id).await.unwrap();
510 assert!(deleted);
511
512 let loaded = store.load_session(&session.id).await.unwrap();
513 assert!(loaded.is_none());
514 }
515
516 #[tokio::test]
517 async fn test_get_or_create_session_existing() {
518 let temp_dir = tempfile::tempdir().unwrap();
519 let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
520
521 let mut existing = Session::new("user-123");
522 existing.add_user_message("Original message");
523 store.save_session(&existing).await.unwrap();
524
525 let retrieved = store
527 .get_or_create_session("user-123", Some(&existing.id))
528 .await
529 .unwrap();
530 assert_eq!(retrieved.id, existing.id);
531 assert_eq!(retrieved.user_messages.len(), 1);
532 }
533
534 #[tokio::test]
535 async fn test_get_or_create_session_new() {
536 let temp_dir = tempfile::tempdir().unwrap();
537 let store = StateStore::new(temp_dir.path().to_path_buf()).unwrap();
538
539 let session = store.get_or_create_session("user-456", None).await.unwrap();
541 assert_eq!(session.user_id, "user-456");
542 assert!(session.user_messages.is_empty());
543 }
544}