use anyhow::{Result, anyhow};
use std::collections::HashMap;
use std::sync::{Arc, Mutex as StdMutex};
use std::time::{Duration, Instant};
use tokio::sync::{RwLock, broadcast, mpsc, oneshot};
use cap_rs::core::{ClientFrame, Content};
use cap_rs::driver::Driver;
use super::AgentKind;
use super::runtime::{
NotifTarget, run_turn, spawn_driver, spawn_driver_acp, spawn_driver_continue_last,
spawn_driver_resume,
};
const DEFAULT_MAX_SESSIONS: usize = 8;
const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(600); const PROMPT_TIMEOUT: Duration = Duration::from_secs(300);
#[derive(Clone, Debug)]
enum ResumeMode {
None,
ById(String),
ContinueLast,
}
pub struct CapLiveManager {
sessions: Arc<RwLock<HashMap<String, LiveSessionHandle>>>,
sticky: Arc<RwLock<HashMap<String, (String, AgentKind)>>>,
suspended: Arc<RwLock<HashMap<(String, AgentKind), String>>>,
bus: broadcast::Sender<rsclaw_events::AgentEvent>,
notification_tx:
Option<broadcast::Sender<rsclaw_types::OutboundMessage>>,
max_sessions: usize,
idle_timeout: Duration,
}
#[derive(Clone)]
struct LiveSessionHandle {
agent_kind: AgentKind,
tx: mpsc::Sender<LiveRequest>,
last_active: Arc<StdMutex<Instant>>,
agent_session_id: Arc<StdMutex<Option<String>>>,
pending_memory_inject: Arc<std::sync::atomic::AtomicBool>,
}
enum LiveRequest {
Prompt {
task: String,
notif: Option<NotifTarget>,
reply: oneshot::Sender<Result<String>>,
},
Shutdown,
}
pub struct LiveDispatchResult {
pub session_id: String,
pub agent_kind: AgentKind,
pub output: String,
}
impl CapLiveManager {
pub fn new(bus: broadcast::Sender<rsclaw_events::AgentEvent>) -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
sticky: Arc::new(RwLock::new(HashMap::new())),
suspended: Arc::new(RwLock::new(HashMap::new())),
bus,
notification_tx: None,
max_sessions: DEFAULT_MAX_SESSIONS,
idle_timeout: DEFAULT_IDLE_TIMEOUT,
}
}
pub fn set_notification_tx(
&mut self,
tx: broadcast::Sender<rsclaw_types::OutboundMessage>,
) {
self.notification_tx = Some(tx);
}
pub fn notification_tx(
&self,
) -> Option<broadcast::Sender<rsclaw_types::OutboundMessage>> {
self.notification_tx.clone()
}
pub async fn acquire_session(
&self,
im_session_key: &str,
kind: AgentKind,
cwd: std::path::PathBuf,
) -> Result<String> {
let resumed = {
let mut pool = self.suspended.write().await;
pool.remove(&(im_session_key.to_owned(), kind))
};
if let Some(sid) = resumed {
let alive = self.sessions.read().await.contains_key(&sid);
if alive {
tracing::info!(
target: "cap",
im_session_key,
agent = kind.as_str(),
session_id = %sid,
"cap_live resumed driver from suspended pool"
);
return Ok(sid);
}
tracing::debug!(
target: "cap",
im_session_key,
agent = kind.as_str(),
stale_sid = %sid,
"cap_live suspended pool entry expired; spawning fresh"
);
}
self.open_session(kind, cwd).await
}
pub async fn open_session(
&self,
kind: AgentKind,
cwd: std::path::PathBuf,
) -> Result<String> {
let sid = uuid::Uuid::new_v4().simple().to_string();
self.spawn_session(&sid, kind, &cwd, ResumeMode::None).await?;
Ok(sid)
}
pub async fn open_session_resume(
&self,
kind: AgentKind,
cwd: std::path::PathBuf,
agent_session_id: String,
) -> Result<String> {
let sid = uuid::Uuid::new_v4().simple().to_string();
self.spawn_session(&sid, kind, &cwd, ResumeMode::ById(agent_session_id))
.await?;
Ok(sid)
}
pub async fn open_session_continue_last(
&self,
kind: AgentKind,
cwd: std::path::PathBuf,
) -> Result<String> {
let sid = uuid::Uuid::new_v4().simple().to_string();
self.spawn_session(&sid, kind, &cwd, ResumeMode::ContinueLast)
.await?;
Ok(sid)
}
pub async fn bind_sticky(
&self,
im_session_key: String,
live_session_id: String,
kind: AgentKind,
) -> Option<(String, AgentKind)> {
let prior = {
let mut g = self.sticky.write().await;
g.insert(im_session_key.clone(), (live_session_id, kind))
};
if let Some((old_sid, old_kind)) = &prior {
if *old_kind == kind {
tracing::info!(
target: "cap",
old_session_id = %old_sid,
new_session_id = ?prior.as_ref().map(|p| &p.0),
agent = kind.as_str(),
"cap_live same-agent rebind — ending prior driver"
);
let _ = self.end_session(old_sid).await;
} else {
let park_key = (im_session_key, *old_kind);
let evicted = {
let mut pool = self.suspended.write().await;
pool.insert(park_key, old_sid.clone())
};
tracing::info!(
target: "cap",
old_session_id = %old_sid,
old_agent = old_kind.as_str(),
new_agent = kind.as_str(),
"cap_live sticky rebind — parked prior driver in suspended pool"
);
if let Some(evicted_sid) = evicted {
if evicted_sid != *old_sid {
tracing::info!(
target: "cap",
evicted_session_id = %evicted_sid,
"cap_live pool collision — ending evicted driver"
);
let _ = self.end_session(&evicted_sid).await;
}
}
}
}
prior
}
pub async fn unbind_sticky(
&self,
im_session_key: &str,
) -> Option<(String, AgentKind)> {
let mut g = self.sticky.write().await;
g.remove(im_session_key)
}
pub async fn resolve_sticky(
&self,
im_session_key: &str,
) -> Option<(String, AgentKind)> {
let entry = {
let g = self.sticky.read().await;
g.get(im_session_key).cloned()
}?;
let (sid, _kind) = &entry;
let alive = {
let g = self.sessions.read().await;
g.contains_key(sid)
};
if !alive {
let mut g = self.sticky.write().await;
g.remove(im_session_key);
return None;
}
Some(entry)
}
pub async fn dispatch_sync(
&self,
kind: AgentKind,
session_id: Option<String>,
task: String,
cwd: std::path::PathBuf,
notif: Option<NotifTarget>,
) -> Result<LiveDispatchResult> {
let sid = match session_id {
Some(s) if !s.trim().is_empty() => s,
_ => {
let new_id = uuid::Uuid::new_v4().simple().to_string();
self.spawn_session(&new_id, kind, &cwd, ResumeMode::None)
.await?;
new_id
}
};
let handle = {
let g = self.sessions.read().await;
g.get(&sid).cloned().ok_or_else(|| {
anyhow!(
"live session `{sid}` not found (expired by idle GC, ended, \
or never created — start a new one by omitting session_id)"
)
})?
};
if handle.agent_kind != kind {
return Err(anyhow!(
"live session `{sid}` is bound to `{}`, cannot route a `{}` prompt to it",
handle.agent_kind.as_str(),
kind.as_str()
));
}
if let Ok(mut g) = handle.last_active.lock() {
*g = Instant::now();
}
let (reply_tx, reply_rx) = oneshot::channel();
handle
.tx
.send(LiveRequest::Prompt {
task,
notif,
reply: reply_tx,
})
.await
.map_err(|_| anyhow!("live session `{sid}` actor closed unexpectedly"))?;
let output = tokio::time::timeout(PROMPT_TIMEOUT, reply_rx)
.await
.map_err(|_| {
anyhow!(
"live session `{sid}`: turn timed out after {}s",
PROMPT_TIMEOUT.as_secs()
)
})?
.map_err(|_| anyhow!("live session `{sid}`: actor dropped reply"))??;
Ok(LiveDispatchResult {
session_id: sid,
agent_kind: kind,
output,
})
}
pub fn snapshot_sticky_blocking(&self) -> Vec<(String, String, AgentKind)> {
match self.sticky.try_read() {
Ok(g) => g
.iter()
.map(|(im_key, (sid, kind))| (im_key.clone(), sid.clone(), *kind))
.collect(),
Err(_) => Vec::new(),
}
}
pub async fn try_take_pending_memory_inject(&self, sid: &str) -> bool {
let g = self.sessions.read().await;
match g.get(sid) {
Some(h) => h
.pending_memory_inject
.swap(false, std::sync::atomic::Ordering::SeqCst),
None => false,
}
}
pub fn agent_session_id_blocking(&self, sid: &str) -> Option<String> {
let g = self.sessions.try_read().ok()?;
let h = g.get(sid)?;
h.agent_session_id.lock().ok().and_then(|s| s.clone())
}
pub fn snapshot_suspended_blocking(&self) -> Vec<(String, AgentKind, String)> {
match self.suspended.try_read() {
Ok(g) => g
.iter()
.map(|((im_key, kind), sid)| (im_key.clone(), *kind, sid.clone()))
.collect(),
Err(_) => Vec::new(),
}
}
pub async fn end_session(&self, session_id: &str) -> Result<()> {
let handle = {
let mut g = self.sessions.write().await;
g.remove(session_id)
};
if let Some(h) = handle {
let _ = h.tx.send(LiveRequest::Shutdown).await;
}
Ok(())
}
#[allow(dead_code)]
pub async fn list(&self) -> Vec<(String, AgentKind)> {
let g = self.sessions.read().await;
g.iter().map(|(k, h)| (k.clone(), h.agent_kind)).collect()
}
async fn spawn_session(
&self,
session_id: &str,
kind: AgentKind,
cwd: &std::path::Path,
resume_mode: ResumeMode,
) -> Result<()> {
self.gc_idle().await;
{
let g = self.sessions.read().await;
if g.len() >= self.max_sessions {
return Err(anyhow!(
"live session limit reached ({} active); end one via `cap_live_end` first",
self.max_sessions
));
}
}
let driver = match &resume_mode {
ResumeMode::ById(rid) => spawn_driver_resume(kind, cwd, rid).await?,
ResumeMode::ContinueLast => spawn_driver_continue_last(kind, cwd).await?,
ResumeMode::None => spawn_driver(kind, cwd).await?,
};
let (tx, rx) = mpsc::channel::<LiveRequest>(4);
let last_active = Arc::new(StdMutex::new(Instant::now()));
let agent_session_id = Arc::new(StdMutex::new(None::<String>));
let bus = self.bus.clone();
let sessions_for_gc = Arc::clone(&self.sessions);
let sticky_for_gc = Arc::clone(&self.sticky);
let suspended_for_gc = Arc::clone(&self.suspended);
let sid_owned = session_id.to_owned();
let agent_sid_slot = Arc::clone(&agent_session_id);
tokio::spawn(actor_loop(
sid_owned,
kind,
cwd.to_path_buf(),
driver,
rx,
bus,
sessions_for_gc,
sticky_for_gc,
suspended_for_gc,
agent_sid_slot,
));
let inject = matches!(resume_mode, ResumeMode::None);
let handle = LiveSessionHandle {
agent_kind: kind,
tx,
last_active,
agent_session_id,
pending_memory_inject: Arc::new(std::sync::atomic::AtomicBool::new(inject)),
};
let mut g = self.sessions.write().await;
g.insert(session_id.to_owned(), handle);
Ok(())
}
pub async fn get_agent_session_id(&self, sid: &str) -> Option<String> {
let g = self.sessions.read().await;
g.get(sid).and_then(|h| h.agent_session_id.lock().ok().and_then(|s| s.clone()))
}
#[allow(dead_code)]
pub async fn wait_agent_session_id(
&self,
sid: &str,
timeout: std::time::Duration,
) -> Option<String> {
let deadline = tokio::time::Instant::now() + timeout;
loop {
if let Some(s) = self.get_agent_session_id(sid).await {
return Some(s);
}
if tokio::time::Instant::now() >= deadline {
return None;
}
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
}
}
async fn gc_idle(&self) {
let now = Instant::now();
let idle = self.idle_timeout;
let mut to_remove: Vec<String> = Vec::new();
{
let g = self.sessions.read().await;
for (sid, handle) in g.iter() {
if let Ok(last) = handle.last_active.lock() {
if now.duration_since(*last) > idle {
to_remove.push(sid.clone());
}
}
}
}
if to_remove.is_empty() {
return;
}
{
let mut g = self.sessions.write().await;
for sid in &to_remove {
if let Some(h) = g.remove(sid) {
tracing::info!(
target: "cap",
session_id = %sid,
"live session reaped (idle > {}s)",
idle.as_secs()
);
let _ = h.tx.send(LiveRequest::Shutdown).await;
}
}
}
let mut sg = self.sticky.write().await;
sg.retain(|_, (sid, _)| !to_remove.contains(sid));
drop(sg);
let mut pg = self.suspended.write().await;
pg.retain(|_, sid| !to_remove.contains(sid));
}
}
#[allow(dead_code)]
async fn capture_ready_session_id(
driver: &mut dyn Driver,
timeout: std::time::Duration,
) -> Option<String> {
use cap_rs::core::AgentEvent;
let deadline = tokio::time::sleep(timeout);
tokio::pin!(deadline);
loop {
tokio::select! {
_ = &mut deadline => {
tracing::warn!(
target: "cap",
"no Ready event within {}s; agent_session_id will be unknown",
timeout.as_secs()
);
return None;
}
ev = driver.next_event() => match ev {
Some(AgentEvent::Ready { session_id, .. }) => {
tracing::info!(
target: "cap",
agent_session_id = ?session_id,
"cap_live captured Ready"
);
return session_id;
}
Some(_) => continue,
None => {
tracing::warn!(
target: "cap",
"driver event stream ended before Ready"
);
return None;
}
}
}
}
}
async fn respawn_driver(
kind: &AgentKind,
cwd: &std::path::Path,
driver: &mut Box<dyn Driver>,
sid: &str,
reason: &str,
agent_sid_slot: &Arc<StdMutex<Option<String>>>,
force_acp: bool,
) -> bool {
let spawned = if force_acp {
spawn_driver_acp(*kind, cwd).await
} else {
spawn_driver(*kind, cwd).await
};
match spawned {
Ok(fresh) => {
if let Err(e) = driver.shutdown().await {
tracing::debug!(target: "cap", error = %e, "best-effort shutdown of dead driver");
}
*driver = fresh;
{
use cap_rs::core::AgentEvent;
let mut captured: Option<String> = None;
let recapture_deadline =
tokio::time::sleep(std::time::Duration::from_secs(8));
tokio::pin!(recapture_deadline);
loop {
tokio::select! {
biased;
_ = &mut recapture_deadline => break,
ev = driver.next_event() => match ev {
Some(AgentEvent::Ready { session_id, .. }) => {
captured = session_id;
break;
}
Some(_) => continue,
None => break,
}
}
}
if captured.is_none() {
tracing::warn!(
target: "cap",
session_id = %sid,
"no Ready captured after respawn; resume id unavailable until next bind"
);
}
if let Ok(mut g) = agent_sid_slot.lock() {
*g = captured;
}
}
tracing::info!(
target: "cap",
session_id = %sid,
agent = kind.as_str(),
reason,
"cap_live respawned driver after death; retrying prompt once"
);
true
}
Err(e) => {
tracing::warn!(
target: "cap",
session_id = %sid,
agent = kind.as_str(),
reason,
error = %e,
"cap_live driver respawn failed; surfacing error"
);
false
}
}
}
async fn actor_loop(
sid: String,
kind: AgentKind,
cwd: std::path::PathBuf,
mut driver: Box<dyn Driver>,
mut rx: mpsc::Receiver<LiveRequest>,
bus: broadcast::Sender<rsclaw_events::AgentEvent>,
sessions: Arc<RwLock<HashMap<String, LiveSessionHandle>>>,
sticky: Arc<RwLock<HashMap<String, (String, AgentKind)>>>,
suspended: Arc<RwLock<HashMap<(String, AgentKind), String>>>,
agent_sid_slot: Arc<StdMutex<Option<String>>>,
) {
tracing::info!(
target: "cap",
session_id = %sid,
agent = kind.as_str(),
"cap_live actor started"
);
let mut prebuf_request: Option<LiveRequest> = None;
{
use cap_rs::core::AgentEvent;
let capture_deadline = tokio::time::sleep(std::time::Duration::from_secs(30));
tokio::pin!(capture_deadline);
loop {
tokio::select! {
biased;
ev = driver.next_event() => match ev {
Some(AgentEvent::Ready { session_id, .. }) => {
if let Ok(mut g) = agent_sid_slot.lock() {
*g = session_id.clone();
}
tracing::info!(
target: "cap",
session_id = %sid,
agent_session_id = ?session_id,
"cap_live captured Ready"
);
break;
}
Some(_) => continue,
None => {
tracing::warn!(
target: "cap",
session_id = %sid,
"driver event stream ended before Ready"
);
break;
}
},
first_req = rx.recv(), if prebuf_request.is_none() => match first_req {
Some(req) => {
prebuf_request = Some(req);
capture_deadline.as_mut().reset(
tokio::time::Instant::now() + std::time::Duration::from_secs(8),
);
}
None => break, },
_ = &mut capture_deadline => {
tracing::warn!(
target: "cap",
session_id = %sid,
agent = kind.as_str(),
"no Ready captured before deadline; resume id unavailable until next bind"
);
break;
}
}
}
}
let pseudo_session_id = format!("cap-live-{}-{sid}", kind.as_str());
loop {
let req = match prebuf_request.take() {
Some(r) => r,
None => match rx.recv().await {
Some(r) => r,
None => break,
},
};
match req {
LiveRequest::Prompt { task, notif, reply } => {
let retry_acp = kind == AgentKind::Opencode;
let mut respawned = false;
let mut empty_resends = 0u8;
let outcome = loop {
let send_res = driver
.send(ClientFrame::Prompt {
content: vec![Content::text(task.clone())],
})
.await;
if let Err(e) = send_res {
if !respawned
&& respawn_driver(&kind, &cwd, &mut driver, &sid, "send failed", &agent_sid_slot, retry_acp).await
{
respawned = true;
continue;
}
break Err(anyhow!("cap_live driver send: {e}"));
}
let mut reply_buf = String::new();
let turn = match tokio::time::timeout(
super::runtime::TURN_TIMEOUT,
run_turn(
driver.as_mut(),
&bus,
&pseudo_session_id,
"cap-live",
notif.as_ref(),
&mut reply_buf,
),
)
.await
{
Ok(r) => r,
Err(_) => Err(anyhow!(
"cap_live: turn timed out after {}s (driver hang?)",
super::runtime::TURN_TIMEOUT.as_secs()
)),
};
match turn {
Ok(()) => {
if reply_buf.trim().is_empty() && empty_resends < 2 {
empty_resends += 1;
tracing::info!(
target: "cap",
session_id = %sid,
agent = kind.as_str(),
attempt = empty_resends,
"cap_live empty turn — re-sending prompt to warm driver"
);
continue;
}
break Ok(reply_buf);
}
Err(e) => {
if !respawned
&& respawn_driver(
&kind,
&cwd,
&mut driver,
&sid,
"exited mid-turn",
&agent_sid_slot,
retry_acp,
)
.await
{
respawned = true;
continue;
}
break Err(anyhow!("cap_live driver: {e}"));
}
}
};
match outcome {
Ok(reply_buf) => {
let _ = reply.send(Ok(reply_buf));
}
Err(e) => {
let _ = reply.send(Err(e));
break;
}
}
}
LiveRequest::Shutdown => break,
}
}
if let Err(e) = driver.shutdown().await {
tracing::debug!(target: "cap", error = %e, "best-effort shutdown of dead driver");
}
{
let mut g = sessions.write().await;
g.remove(&sid);
}
{
let mut sg = sticky.write().await;
sg.retain(|_, (s, _)| s != &sid);
}
{
let mut pg = suspended.write().await;
pg.retain(|_, s| s != &sid);
}
tracing::info!(
target: "cap",
session_id = %sid,
"cap_live actor exited"
);
}