use crate::session::spawn_session_actor_with_state;
use crate::types::{
PersistedSessionFile, PersistenceCommand, PersistenceConfig, SessionCommand, SessionHandle,
SessionStateSnapshot, ThinkingState,
};
use std::time::{Duration, Instant};
fn sanitize_session_id(session_id: &str) -> Result<String, anyhow::Error> {
if !session_id
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == '-')
{
anyhow::bail!(
"Invalid session_id '{}': must contain only alphanumeric characters, hyphens, and underscores",
session_id
);
}
if session_id.is_empty() {
anyhow::bail!("session_id cannot be empty");
}
if session_id.len() > 255 {
anyhow::bail!("session_id too long (max 255 characters, got {})", session_id.len());
}
Ok(session_id.to_string())
}
async fn verify_path_within_base(
constructed_path: &std::path::Path,
allowed_base: &std::path::Path,
) -> Result<std::path::PathBuf, anyhow::Error> {
use anyhow::Context;
if !tokio::fs::try_exists(allowed_base).await.unwrap_or(false) {
tokio::fs::create_dir_all(allowed_base)
.await
.context("Failed to create sessions base directory")?;
}
let canonical_base = tokio::fs::canonicalize(allowed_base)
.await
.context("Failed to canonicalize base directory")?;
let canonical_path = if tokio::fs::try_exists(constructed_path).await.unwrap_or(false) {
tokio::fs::canonicalize(constructed_path)
.await
.context("Failed to canonicalize path")?
} else {
if let Some(parent) = constructed_path.parent() {
let canonical_parent = if tokio::fs::try_exists(parent).await.unwrap_or(false) {
tokio::fs::canonicalize(parent)
.await
.context("Failed to canonicalize parent")?
} else {
if !parent.starts_with(&canonical_base) {
anyhow::bail!(
"Directory traversal detected: parent '{}' escapes base '{}'",
parent.display(),
canonical_base.display()
);
}
canonical_base.clone()
};
if let Some(filename) = constructed_path.file_name() {
canonical_parent.join(filename)
} else {
anyhow::bail!("Path has no filename");
}
} else {
anyhow::bail!("Path has no parent directory");
}
};
if !canonical_path.starts_with(&canonical_base) {
anyhow::bail!(
"Directory traversal detected: '{}' escapes base '{}'",
canonical_path.display(),
canonical_base.display()
);
}
Ok(canonical_path)
}
pub fn start_persistence_processor(
mut receiver: tokio::sync::mpsc::Receiver<PersistenceCommand>,
) {
let config = PersistenceConfig::default();
tokio::spawn(async move {
if let Err(e) = tokio::fs::create_dir_all(&config.sessions_dir).await {
log::error!("Failed to create sessions directory: {e}");
}
while let Some(cmd) = receiver.recv().await {
match cmd {
PersistenceCommand::Persist {
session_id,
snapshot,
created_at,
last_activity,
} => {
if let Err(e) = persist_session_to_disk(
&config,
&session_id,
&snapshot,
created_at,
last_activity,
)
.await
{
log::error!("Failed to persist session {session_id}: {e}");
}
}
PersistenceCommand::PersistBatch { sessions, completion } => {
let batch_size = sessions.len();
log::info!("Processing batch of {} sessions", batch_size);
let mut success_count = 0usize;
let mut failure_count = 0usize;
for (session_id, snapshot, created_at, last_activity) in sessions {
match persist_session_to_disk(
&config,
&session_id,
&snapshot,
created_at,
last_activity,
)
.await
{
Ok(()) => {
success_count += 1;
}
Err(e) => {
log::error!("Failed to persist session {session_id} in batch: {e}");
failure_count += 1;
}
}
}
log::debug!(
"Batch persistence complete: {}/{} succeeded, {} failed",
success_count, batch_size, failure_count
);
if let Some(tx) = completion {
let _ = tx.send(Ok(success_count));
}
}
PersistenceCommand::Delete { session_id } => {
let safe_session_id = match sanitize_session_id(&session_id) {
Ok(id) => id,
Err(e) => {
log::error!("Invalid session_id for deletion: {e}");
continue;
}
};
let session_dir = config.sessions_dir.join(&safe_session_id);
let verified_dir = match verify_path_within_base(&session_dir, &config.sessions_dir).await {
Ok(path) => path,
Err(e) => {
log::error!("Path verification failed during deletion: {e}");
continue;
}
};
if let Err(e) = tokio::fs::remove_dir_all(&verified_dir).await {
log::debug!("Failed to delete session directory {session_id}: {e}");
} else {
log::info!("Deleted persisted session: {session_id}");
}
}
}
}
log::debug!("Persistence processor terminated");
});
}
async fn persist_session_to_disk(
config: &PersistenceConfig,
session_id: &str,
snapshot: &SessionStateSnapshot,
created_at: std::time::SystemTime,
last_activity: std::time::SystemTime,
) -> Result<(), anyhow::Error> {
use anyhow::Context;
use tokio::io::AsyncWriteExt;
let safe_session_id = sanitize_session_id(session_id)
.context("session_id validation failed")?;
let session_dir = config.sessions_dir.join(&safe_session_id);
let verified_session_dir = verify_path_within_base(&session_dir, &config.sessions_dir)
.await
.context("Path verification failed - potential directory traversal")?;
tokio::fs::create_dir_all(&verified_session_dir)
.await
.context("Failed to create session directory")?;
let session_dir = verified_session_dir;
let session_file = PersistedSessionFile::from_snapshot(
session_id.to_string(),
snapshot,
created_at,
last_activity,
);
let json = serde_json::to_string_pretty(&session_file)
.context("Failed to serialize session data")?;
let final_path = session_dir.join("session.json");
let temp_path = final_path.with_extension("json.tmp");
let mut file = tokio::fs::File::create(&temp_path)
.await
.context("Failed to create temporary session file")?;
file.write_all(json.as_bytes())
.await
.context("Failed to write session data to temp file")?;
file.sync_all()
.await
.context("Failed to sync session file to disk")?;
drop(file);
tokio::fs::rename(&temp_path, &final_path)
.await
.context("Failed to atomically commit session file")?;
log::info!(
"Persisted session {} ({} thoughts, {} branches) atomically to {:?}",
session_id,
snapshot.thought_history.len(),
snapshot.branches.len(),
final_path
);
Ok(())
}
pub async fn try_restore_session(
session_id: &str,
persistence_sender: &tokio::sync::mpsc::Sender<PersistenceCommand>,
) -> Option<SessionHandle> {
let config = PersistenceConfig::default();
let safe_session_id = match sanitize_session_id(session_id) {
Ok(id) => id,
Err(e) => {
log::warn!("Invalid session_id for restoration: {e}");
return None;
}
};
let session_dir = config.sessions_dir.join(&safe_session_id);
let session_dir = match verify_path_within_base(&session_dir, &config.sessions_dir).await {
Ok(path) => path,
Err(e) => {
log::warn!("Path verification failed during restoration: {e}");
return None;
}
};
let session_file = session_dir.join("session.json");
if !tokio::fs::try_exists(&session_file).await.unwrap_or(false) {
log::debug!("Session file not found for {session_id}");
return None;
}
log::debug!("Attempting to restore session {session_id} from {:?}", session_file);
let contents = tokio::fs::read_to_string(&session_file).await.ok()?;
let persisted: PersistedSessionFile = match serde_json::from_str(&contents) {
Ok(p) => p,
Err(e) => {
log::error!("Failed to parse session file for {session_id}: {e}");
let _ = tokio::fs::remove_dir_all(&session_dir).await;
return None;
}
};
if persisted.thought_history.is_empty() && persisted.branches.is_empty() {
log::warn!("Restored session {session_id} has no thoughts, ignoring");
return None;
}
log::info!(
"Restored session {} ({} thoughts, {} branches) from disk",
session_id,
persisted.thought_history.len(),
persisted.branches.len()
);
let snapshot = persisted.to_snapshot();
let (tx, rx) = tokio::sync::mpsc::channel(100);
let restored_state = ThinkingState {
thought_history: snapshot.thought_history,
branches: snapshot.branches,
};
spawn_session_actor_with_state(rx, restored_state);
let created_at = match persisted.created_at.elapsed() {
Ok(elapsed) => Instant::now()
.checked_sub(elapsed)
.unwrap_or_else(Instant::now),
Err(_) => {
log::warn!(
"Clock drift detected for session {}: created_at is in future, using current time",
session_id
);
Instant::now()
}
};
let handle = SessionHandle::with_created_at(tx.clone(), created_at);
let (verify_tx, verify_rx) = tokio::sync::oneshot::channel();
let verify_cmd = SessionCommand::GetState { respond_to: verify_tx };
match tokio::time::timeout(Duration::from_secs(5), async {
if handle.tx.send(verify_cmd).await.is_err() {
return None;
}
verify_rx.await.ok()
}).await {
Ok(Some(_snapshot)) => {
log::info!("Restored session {} verified, deleting disk backup", session_id);
let _ = persistence_sender.try_send(PersistenceCommand::Delete {
session_id: session_id.to_string(),
});
Some(handle)
}
_ => {
log::error!(
"Restored session {} failed verification, preserving disk backup for retry",
session_id
);
None
}
}
}
pub fn start_disk_cleanup_task(
persistence_sender: tokio::sync::mpsc::Sender<PersistenceCommand>,
) {
tokio::spawn(async move {
let config = PersistenceConfig::default();
let mut interval = tokio::time::interval(Duration::from_secs(60 * 60));
loop {
interval.tick().await;
log::debug!("Running disk cleanup task");
let Ok(mut entries) = tokio::fs::read_dir(&config.sessions_dir).await else {
continue;
};
while let Ok(Some(entry)) = entries.next_entry().await {
let Ok(file_type) = entry.file_type().await else {
continue;
};
if !file_type.is_dir() {
continue;
}
let path = entry.path();
let session_file = path.join("session.json");
let Ok(session_json) = tokio::fs::read_to_string(&session_file).await else {
continue;
};
let Ok(session) = serde_json::from_str::<PersistedSessionFile>(&session_json)
else {
if let Some(session_id) = path.file_name().and_then(|n| n.to_str()) {
log::warn!("Found corrupted session file {session_id}, scheduling cleanup");
let _ = persistence_sender.try_send(PersistenceCommand::Delete {
session_id: session_id.to_string(),
});
}
continue;
};
let age = match tokio::fs::metadata(&session_file).await {
Ok(meta) => meta
.modified()
.ok()
.and_then(|mtime| mtime.elapsed().ok())
.unwrap_or_else(|| {
log::warn!(
"Clock drift detected for session {}: file mtime is in future",
session.session_id
);
Duration::ZERO
}),
Err(e) => {
log::debug!(
"Failed to read metadata for session {}: {e}",
session.session_id
);
continue;
}
};
if age > config.cleanup_after {
match persistence_sender.try_send(PersistenceCommand::Delete {
session_id: session.session_id.clone(),
}) {
Ok(_) => {
log::info!(
"Queued old session {} for deletion (age: {:.1} hours)",
session.session_id,
age.as_secs_f64() / 3600.0
);
}
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
log::warn!(
"Persistence channel full, deferring deletion of session {}",
session.session_id
);
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
log::error!("Persistence channel closed, cannot delete session {}", session.session_id);
}
}
}
}
}
});
}