use crate::{
error::{AoError, Result},
session_manager::SessionManager,
traits::{Agent, Runtime, Workspace},
types::{Session, SessionStatus},
};
#[derive(Debug, Clone)]
pub struct RestoreOutcome {
pub session: Session,
pub launch_command: String,
pub runtime_handle: String,
pub prompt_sent: bool,
}
pub async fn restore_session(
id_or_prefix: &str,
sessions: &SessionManager,
runtime: &dyn Runtime,
agent: &dyn Agent,
workspace: &dyn Workspace,
) -> Result<RestoreOutcome> {
let mut session = sessions.find_by_prefix(id_or_prefix).await?;
if let Some(handle) = session.runtime_handle.as_deref() {
let alive = runtime.is_alive(handle).await.unwrap_or(false);
if !alive && !session.status.is_terminal() {
session.status = SessionStatus::Terminated;
}
} else if !session.status.is_terminal() {
session.status = SessionStatus::Terminated;
}
if !session.is_restorable() {
return Err(AoError::Runtime(format!(
"session {} is not restorable (status={})",
session.id,
session.status.as_str()
)));
}
let workspace_path = session
.workspace_path
.clone()
.ok_or_else(|| AoError::Workspace("session has no workspace_path".into()))?;
if !workspace.exists(&workspace_path).await? {
return Err(AoError::Workspace(format!(
"workspace missing: {}",
workspace_path.display()
)));
}
if let Some(handle) = session.runtime_handle.as_deref() {
let _ = runtime.destroy(handle).await;
}
let new_name = session
.runtime_handle
.clone()
.unwrap_or_else(|| session.id.0.chars().take(8).collect());
let launch_command = agent.launch_command(&session);
let env = agent.environment(&session);
let new_handle = runtime
.create(&new_name, &workspace_path, &launch_command, &env)
.await?;
session.runtime_handle = Some(new_handle.clone());
session.status = SessionStatus::Spawning;
session.activity = None;
sessions.save(&session).await?;
let prompt = agent.initial_prompt(&session);
let prompt_sent = if prompt.trim().is_empty() {
false
} else {
runtime.send_message(&new_handle, &prompt).await.is_ok()
};
Ok(RestoreOutcome {
session,
launch_command,
runtime_handle: new_handle,
prompt_sent,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{now_ms, ActivityState, SessionId, WorkspaceCreateConfig};
use async_trait::async_trait;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
fn unique_temp_dir(label: &str) -> PathBuf {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
std::env::temp_dir().join(format!("ao-rs-restore-{label}-{nanos}-{n}"))
}
#[derive(Default)]
struct RecorderRuntime {
alive: AtomicBool,
calls: Mutex<Vec<String>>,
messages: Mutex<Vec<String>>,
}
impl RecorderRuntime {
fn new(alive: bool) -> Self {
Self {
alive: AtomicBool::new(alive),
calls: Mutex::new(Vec::new()),
messages: Mutex::new(Vec::new()),
}
}
fn calls(&self) -> Vec<String> {
self.calls.lock().unwrap().clone()
}
fn messages(&self) -> Vec<String> {
self.messages.lock().unwrap().clone()
}
}
#[async_trait]
impl Runtime for RecorderRuntime {
async fn create(
&self,
session_id: &str,
_cwd: &Path,
_launch_command: &str,
_env: &[(String, String)],
) -> Result<String> {
self.calls
.lock()
.unwrap()
.push(format!("create:{session_id}"));
Ok(session_id.to_string())
}
async fn send_message(&self, handle: &str, _msg: &str) -> Result<()> {
self.calls.lock().unwrap().push(format!("send:{handle}"));
self.messages.lock().unwrap().push(_msg.to_string());
Ok(())
}
async fn is_alive(&self, _handle: &str) -> Result<bool> {
Ok(self.alive.load(Ordering::SeqCst))
}
async fn destroy(&self, handle: &str) -> Result<()> {
self.calls.lock().unwrap().push(format!("destroy:{handle}"));
Ok(())
}
}
struct StubAgent;
#[async_trait]
impl Agent for StubAgent {
fn launch_command(&self, _s: &Session) -> String {
"mock-launch".into()
}
fn environment(&self, _s: &Session) -> Vec<(String, String)> {
vec![]
}
fn initial_prompt(&self, _s: &Session) -> String {
"hello from restore".into()
}
async fn detect_activity(&self, _s: &Session) -> Result<ActivityState> {
Ok(ActivityState::Ready)
}
}
struct StubWorkspace;
#[async_trait]
impl Workspace for StubWorkspace {
async fn create(&self, _cfg: &WorkspaceCreateConfig) -> Result<PathBuf> {
Ok(PathBuf::from("/tmp/ws"))
}
async fn destroy(&self, _workspace_path: &Path) -> Result<()> {
Ok(())
}
}
struct ExistsWorkspace {
reports_exists: bool,
}
#[async_trait]
impl Workspace for ExistsWorkspace {
async fn create(&self, _cfg: &WorkspaceCreateConfig) -> Result<PathBuf> {
Ok(PathBuf::from("/tmp/ws"))
}
async fn destroy(&self, _workspace_path: &Path) -> Result<()> {
Ok(())
}
async fn exists(&self, _workspace_path: &Path) -> Result<bool> {
Ok(self.reports_exists)
}
}
async fn persist_session(
manager: &SessionManager,
id: &str,
status: SessionStatus,
workspace: &Path,
) -> Session {
let session = Session {
id: SessionId(id.into()),
project_id: "demo".into(),
status,
agent: "claude-code".into(),
agent_config: None,
branch: format!("ao-{id}"),
task: "restored task".into(),
workspace_path: Some(workspace.to_path_buf()),
runtime_handle: Some("old-handle".into()),
runtime: "tmux".into(),
activity: None,
created_at: now_ms(),
cost: None,
issue_id: None,
issue_url: None,
claimed_pr_number: None,
claimed_pr_url: None,
initial_prompt_override: None,
spawned_by: None,
last_merge_conflict_dispatched: None,
last_review_backlog_fingerprint: None,
};
manager.save(&session).await.unwrap();
session
}
#[tokio::test]
async fn restore_terminal_session_respawns_runtime_and_persists_spawning() {
let base = unique_temp_dir("ok");
let ws = base.join("ws");
std::fs::create_dir_all(&ws).unwrap();
let manager = SessionManager::new(base.clone());
persist_session(&manager, "sess-ok", SessionStatus::Terminated, &ws).await;
let rt = RecorderRuntime::new(false);
let agent = StubAgent;
let out = restore_session("sess-ok", &manager, &rt, &agent, &StubWorkspace)
.await
.unwrap();
let calls = rt.calls();
let destroy_idx = calls.iter().position(|c| c == "destroy:old-handle");
let create_idx = calls.iter().position(|c| c == "create:old-handle");
let send_idx = calls.iter().position(|c| c == "send:old-handle");
assert!(destroy_idx.is_some(), "destroy not called: {calls:?}");
assert!(create_idx.is_some(), "create not called: {calls:?}");
assert!(destroy_idx < create_idx, "destroy must come before create");
assert!(send_idx.is_some(), "send not called: {calls:?}");
assert!(create_idx < send_idx, "create must come before send");
assert_eq!(out.session.status, SessionStatus::Spawning);
assert_eq!(out.session.activity, None);
assert_eq!(out.runtime_handle, "old-handle");
assert_eq!(out.launch_command, "mock-launch");
assert!(out.prompt_sent, "expected prompt_sent=true");
let msgs = rt.messages();
assert_eq!(msgs.len(), 1, "expected exactly one message: {msgs:?}");
assert!(
!msgs[0].trim().is_empty(),
"expected non-empty prompt, got: {:?}",
msgs[0]
);
let reread = manager.list().await.unwrap();
assert_eq!(reread.len(), 1);
assert_eq!(reread[0].status, SessionStatus::Spawning);
let _ = std::fs::remove_dir_all(&base);
}
#[tokio::test]
async fn restore_missing_runtime_handle_creates_new_handle_without_destroy() {
let base = unique_temp_dir("no-handle");
let ws = base.join("ws");
std::fs::create_dir_all(&ws).unwrap();
let manager = SessionManager::new(base.clone());
let mut s =
persist_session(&manager, "sess-nohandle", SessionStatus::Terminated, &ws).await;
s.runtime_handle = None;
manager.save(&s).await.unwrap();
let rt = RecorderRuntime::new(false);
let out = restore_session("sess-nohandle", &manager, &rt, &StubAgent, &StubWorkspace)
.await
.unwrap();
let calls = rt.calls();
assert!(
!calls.iter().any(|c| c.starts_with("destroy:")),
"unexpected destroy call(s): {calls:?}"
);
assert!(
calls.iter().any(|c| c == "create:sess-noh"),
"expected create with short id (sess-noh), got calls: {calls:?}"
);
assert_eq!(out.runtime_handle, "sess-noh");
assert_eq!(out.session.status, SessionStatus::Spawning);
assert!(out.prompt_sent, "expected prompt_sent=true");
let reread = manager.find_by_prefix("sess-nohandle").await.unwrap();
assert_eq!(reread.runtime_handle.as_deref(), Some("sess-noh"));
let _ = std::fs::remove_dir_all(&base);
}
#[tokio::test]
async fn crashed_working_session_is_enriched_to_terminated_then_restored() {
let base = unique_temp_dir("enrich");
let ws = base.join("ws");
std::fs::create_dir_all(&ws).unwrap();
let manager = SessionManager::new(base.clone());
persist_session(&manager, "sess-crash", SessionStatus::Working, &ws).await;
let rt = RecorderRuntime::new(false); let out = restore_session("sess-crash", &manager, &rt, &StubAgent, &StubWorkspace)
.await
.unwrap();
assert_eq!(out.session.status, SessionStatus::Spawning);
assert!(out.prompt_sent, "expected prompt_sent=true");
let _ = std::fs::remove_dir_all(&base);
}
#[tokio::test]
async fn merged_session_is_not_restorable() {
let base = unique_temp_dir("merged");
let ws = base.join("ws");
std::fs::create_dir_all(&ws).unwrap();
let manager = SessionManager::new(base.clone());
persist_session(&manager, "sess-merged", SessionStatus::Merged, &ws).await;
let rt = RecorderRuntime::new(false);
let err = restore_session("sess-merged", &manager, &rt, &StubAgent, &StubWorkspace)
.await
.unwrap_err();
assert!(
format!("{err}").contains("not restorable"),
"unexpected: {err}"
);
let reread = manager.list().await.unwrap();
assert_eq!(reread[0].status, SessionStatus::Merged);
let _ = std::fs::remove_dir_all(&base);
}
#[tokio::test]
async fn missing_workspace_errors_before_touching_runtime() {
let base = unique_temp_dir("nows");
let manager = SessionManager::new(base.clone());
persist_session(
&manager,
"sess-ghost",
SessionStatus::Terminated,
&PathBuf::from("/nonexistent/ao-rs/does-not-exist"),
)
.await;
let rt = RecorderRuntime::new(false);
let err = restore_session("sess-ghost", &manager, &rt, &StubAgent, &StubWorkspace)
.await
.unwrap_err();
assert!(format!("{err}").contains("workspace missing"), "got: {err}");
assert!(
rt.calls().is_empty(),
"runtime was called: {:?}",
rt.calls()
);
let _ = std::fs::remove_dir_all(&base);
}
#[tokio::test]
async fn corrupted_workspace_reports_missing_via_plugin_exists() {
let base = unique_temp_dir("corrupt");
let ws = base.join("ws");
std::fs::create_dir_all(&ws).unwrap();
let manager = SessionManager::new(base.clone());
persist_session(&manager, "sess-corrupt", SessionStatus::Terminated, &ws).await;
let rt = RecorderRuntime::new(false);
let workspace = ExistsWorkspace {
reports_exists: false,
};
let err = restore_session("sess-corrupt", &manager, &rt, &StubAgent, &workspace)
.await
.unwrap_err();
assert!(format!("{err}").contains("workspace missing"), "got: {err}");
assert!(
rt.calls().is_empty(),
"runtime was called: {:?}",
rt.calls()
);
let _ = std::fs::remove_dir_all(&base);
}
#[tokio::test]
async fn unknown_session_id_errors() {
let base = unique_temp_dir("missing");
let manager = SessionManager::new(base.clone());
let rt = RecorderRuntime::new(false);
let err = restore_session("nope", &manager, &rt, &StubAgent, &StubWorkspace)
.await
.unwrap_err();
assert!(matches!(err, AoError::SessionNotFound(_)), "got {err:?}");
let _ = std::fs::remove_dir_all(&base);
}
#[tokio::test]
async fn ambiguous_prefix_errors() {
let base = unique_temp_dir("ambig");
let ws = base.join("ws");
std::fs::create_dir_all(&ws).unwrap();
let manager = SessionManager::new(base.clone());
persist_session(&manager, "abcd-1111", SessionStatus::Terminated, &ws).await;
persist_session(&manager, "abcd-2222", SessionStatus::Terminated, &ws).await;
let rt = RecorderRuntime::new(false);
let err = restore_session("abcd", &manager, &rt, &StubAgent, &StubWorkspace)
.await
.unwrap_err();
assert!(format!("{err}").contains("ambiguous"), "got: {err}");
let _ = std::fs::remove_dir_all(&base);
}
#[tokio::test]
async fn prefix_match_resolves_to_unique_session() {
let base = unique_temp_dir("prefix");
let ws = base.join("ws");
std::fs::create_dir_all(&ws).unwrap();
let manager = SessionManager::new(base.clone());
persist_session(
&manager,
"deadbeef-uuid-long",
SessionStatus::Terminated,
&ws,
)
.await;
let rt = RecorderRuntime::new(false);
let out = restore_session("deadbeef", &manager, &rt, &StubAgent, &StubWorkspace)
.await
.unwrap();
assert_eq!(out.session.id.0, "deadbeef-uuid-long");
assert!(out.prompt_sent, "expected prompt_sent=true");
let _ = std::fs::remove_dir_all(&base);
}
}