use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use serde_json::{json, Value};
use tokio::sync::{broadcast, mpsc, oneshot, Mutex, RwLock};
use crate::agent::{spawn_agent, Agent};
use crate::config::Config;
use crate::ws::handle_agent_message;
const GRACE_PERIOD: Duration = Duration::from_secs(30);
static NEXT_ATTACH_ID: AtomicU64 = AtomicU64::new(1);
const BROADCAST_CAPACITY: usize = 1024;
const COMMAND_CAPACITY: usize = 256;
#[derive(Debug)]
pub enum HubCommand {
Prompt { blocks: Vec<Value>, attach_id: u64 },
PermissionResponse { id: Value, option_id: String },
Cancel,
SetMode { mode_id: String },
SetModel { model_id: String },
}
#[derive(Clone)]
struct NegotiationSnapshot {
ready: Value,
session_info: Option<Value>,
}
pub struct SessionHub {
commands: mpsc::Sender<HubCommand>,
outbound: broadcast::Sender<Arc<Value>>,
snapshot: Arc<Mutex<NegotiationSnapshot>>,
session_id: String,
counter: Arc<Counter>,
}
struct Counter {
state: Mutex<CounterState>,
grace_tx: mpsc::Sender<GraceEvent>,
}
#[derive(Default)]
struct CounterState {
count: usize,
cancel_grace: Option<oneshot::Sender<()>>,
}
#[derive(Debug)]
enum GraceEvent {
Empty,
Refilled,
}
impl Counter {
fn new(grace_tx: mpsc::Sender<GraceEvent>) -> Self {
Self {
state: Mutex::new(CounterState::default()),
grace_tx,
}
}
async fn increment(&self) -> usize {
let mut state = self.state.lock().await;
let was_zero = state.count == 0;
state.count += 1;
if was_zero {
if let Some(tx) = state.cancel_grace.take() {
let _ = tx.send(());
}
let _ = self.grace_tx.send(GraceEvent::Refilled).await;
}
state.count
}
async fn decrement(&self) -> usize {
let mut state = self.state.lock().await;
if state.count > 0 {
state.count -= 1;
}
if state.count == 0 {
let _ = self.grace_tx.send(GraceEvent::Empty).await;
}
state.count
}
async fn install_cancel(&self) -> oneshot::Receiver<()> {
let (tx, rx) = oneshot::channel();
let mut state = self.state.lock().await;
state.cancel_grace = Some(tx);
rx
}
async fn count(&self) -> usize {
self.state.lock().await.count
}
}
pub struct AttachedHub {
pub commands: mpsc::Sender<HubCommand>,
pub outbound: broadcast::Receiver<Arc<Value>>,
pub snapshot_ready: Value,
pub snapshot_session_info: Option<Value>,
pub session_id: String,
pub attach_id: u64,
counter: Arc<Counter>,
}
impl Drop for AttachedHub {
fn drop(&mut self) {
let counter = self.counter.clone();
tokio::spawn(async move {
counter.decrement().await;
});
}
}
#[derive(Clone, Default)]
pub struct HubRegistry {
inner: Arc<RwLock<HashMap<String, Arc<SessionHub>>>>,
building: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>,
}
impl HubRegistry {
pub fn new() -> Self {
Self::default()
}
pub async fn attach_or_create(
&self,
cfg: Arc<Config>,
resume_session_id: Option<String>,
cwd_override: Option<String>,
build_id: &str,
) -> Result<AttachedHub> {
if let Some(sid) = resume_session_id.as_deref() {
let map = self.inner.read().await;
if let Some(hub) = map.get(sid).cloned() {
drop(map);
return Ok(self.subscribe(hub).await);
}
}
if let Some(sid) = resume_session_id.as_deref() {
let key_mutex = {
let mut building = self.building.lock().await;
building
.entry(sid.to_string())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
};
let _guard = key_mutex.lock().await;
{
let map = self.inner.read().await;
if let Some(hub) = map.get(sid).cloned() {
drop(map);
let attached = self.subscribe(hub).await;
self.cleanup_build_slot(sid).await;
return Ok(attached);
}
}
let result = self
.build_and_register(cfg, Some(sid.to_string()), cwd_override, build_id)
.await;
self.cleanup_build_slot(sid).await;
return result;
}
self.build_and_register(cfg, None, cwd_override, build_id)
.await
}
async fn cleanup_build_slot(&self, sid: &str) {
let mut building = self.building.lock().await;
if let Some(entry) = building.get(sid) {
if Arc::strong_count(entry) == 1 {
building.remove(sid);
}
}
}
async fn build_and_register(
&self,
cfg: Arc<Config>,
resume_session_id: Option<String>,
cwd_override: Option<String>,
build_id: &str,
) -> Result<AttachedHub> {
let hub = build_hub(cfg, resume_session_id, cwd_override, build_id, self.clone()).await?;
let session_id = hub.session_id.clone();
let mut map = self.inner.write().await;
let entry = map.entry(session_id).or_insert_with(|| Arc::new(hub));
let hub = entry.clone();
drop(map);
Ok(self.subscribe(hub).await)
}
async fn subscribe(&self, hub: Arc<SessionHub>) -> AttachedHub {
hub.counter.increment().await;
let snapshot = hub.snapshot.lock().await;
let mut snapshot_ready = snapshot.ready.clone();
if let Some(map) = snapshot_ready.as_object_mut() {
map.insert("resumed".into(), Value::Bool(true));
}
let snapshot_session_info = snapshot.session_info.clone();
drop(snapshot);
AttachedHub {
commands: hub.commands.clone(),
outbound: hub.outbound.subscribe(),
snapshot_ready,
snapshot_session_info,
session_id: hub.session_id.clone(),
attach_id: NEXT_ATTACH_ID.fetch_add(1, Ordering::Relaxed),
counter: hub.counter.clone(),
}
}
async fn remove(&self, session_id: &str) {
let mut map = self.inner.write().await;
map.remove(session_id);
}
#[doc(hidden)]
pub async fn register_for_test(
&self,
agent: Arc<Agent>,
session_id: String,
updates_rx: mpsc::UnboundedReceiver<Value>,
ready: Value,
session_info: Option<Value>,
) -> AttachedHub {
let (cmd_tx, cmd_rx) = mpsc::channel::<HubCommand>(COMMAND_CAPACITY);
let (out_tx, _) = broadcast::channel::<Arc<Value>>(BROADCAST_CAPACITY);
let (grace_tx, grace_rx) = mpsc::channel::<GraceEvent>(8);
let counter = Arc::new(Counter::new(grace_tx));
let suppress_replay = Arc::new(Mutex::new(false));
let snapshot = Arc::new(Mutex::new(NegotiationSnapshot {
ready,
session_info,
}));
let hub = SessionHub {
commands: cmd_tx,
outbound: out_tx.clone(),
snapshot: snapshot.clone(),
session_id: session_id.clone(),
counter: counter.clone(),
};
tokio::spawn(run_hub_loop(HubLoopState {
agent,
session_id: session_id.clone(),
outbound: out_tx,
commands: cmd_rx,
updates: updates_rx,
suppress_replay,
counter,
grace_rx,
registry: self.clone(),
snapshot,
}));
let mut map = self.inner.write().await;
let entry = map.entry(session_id).or_insert_with(|| Arc::new(hub));
let hub = entry.clone();
drop(map);
self.subscribe(hub).await
}
#[doc(hidden)]
pub async fn attach_existing_for_test(&self, session_id: &str) -> Option<AttachedHub> {
let map = self.inner.read().await;
let hub = map.get(session_id).cloned()?;
drop(map);
Some(self.subscribe(hub).await)
}
}
async fn build_hub(
cfg: Arc<Config>,
resume_session_id: Option<String>,
cwd_override: Option<String>,
build_id: &str,
registry: HubRegistry,
) -> Result<SessionHub> {
let (agent, updates_rx) = spawn_agent(&cfg).await?;
let agent = Arc::new(agent);
let (snapshot_tx, mut snapshot_rx) = mpsc::unbounded_channel::<axum::extract::ws::Message>();
let _outcome = crate::ws::negotiate_session(
&agent,
&snapshot_tx,
resume_session_id,
cwd_override,
build_id,
)
.await?;
drop(snapshot_tx);
let mut ready: Option<Value> = None;
let mut session_info: Option<Value> = None;
while let Some(msg) = snapshot_rx.recv().await {
let axum::extract::ws::Message::Text(text) = msg else {
continue;
};
let value: Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(_) => continue,
};
match value.get("type").and_then(Value::as_str) {
Some("ready") => ready = Some(value),
Some("session_info") => session_info = Some(value),
_ => {}
}
}
let ready = ready.ok_or_else(|| anyhow::anyhow!("Negotiation produced no `ready` event"))?;
let session_id = ready
.get("sessionId")
.and_then(Value::as_str)
.ok_or_else(|| anyhow::anyhow!("`ready` missing sessionId"))?
.to_string();
let suppress_replay = Arc::new(Mutex::new(
ready
.get("resumed")
.and_then(Value::as_bool)
.unwrap_or(false),
));
let (cmd_tx, cmd_rx) = mpsc::channel::<HubCommand>(COMMAND_CAPACITY);
let (out_tx, _) = broadcast::channel::<Arc<Value>>(BROADCAST_CAPACITY);
let (grace_tx, grace_rx) = mpsc::channel::<GraceEvent>(8);
let counter = Arc::new(Counter::new(grace_tx));
let snapshot = Arc::new(Mutex::new(NegotiationSnapshot {
ready,
session_info,
}));
let hub = SessionHub {
commands: cmd_tx,
outbound: out_tx.clone(),
snapshot: snapshot.clone(),
session_id: session_id.clone(),
counter: counter.clone(),
};
tokio::spawn(run_hub_loop(HubLoopState {
agent,
session_id,
outbound: out_tx,
commands: cmd_rx,
updates: updates_rx,
suppress_replay,
counter,
grace_rx,
registry,
snapshot,
}));
Ok(hub)
}
struct HubLoopState {
agent: Arc<Agent>,
session_id: String,
outbound: broadcast::Sender<Arc<Value>>,
commands: mpsc::Receiver<HubCommand>,
updates: mpsc::UnboundedReceiver<Value>,
suppress_replay: Arc<Mutex<bool>>,
counter: Arc<Counter>,
grace_rx: mpsc::Receiver<GraceEvent>,
registry: HubRegistry,
snapshot: Arc<Mutex<NegotiationSnapshot>>,
}
async fn run_hub_loop(state: HubLoopState) {
let HubLoopState {
agent,
session_id,
outbound,
mut commands,
mut updates,
suppress_replay,
counter,
mut grace_rx,
registry,
snapshot,
} = state;
let (relay_tx, mut relay_rx) = mpsc::unbounded_channel::<axum::extract::ws::Message>();
let mut answered_permissions: std::collections::HashSet<String> =
std::collections::HashSet::new();
let current_prompter: Arc<Mutex<Option<u64>>> = Arc::new(Mutex::new(None));
let mut grace_deadline: Option<Pin<Box<dyn Future<Output = ()> + Send>>> = None;
loop {
tokio::select! {
cmd = commands.recv() => {
match cmd {
Some(c) => handle_command(
&agent,
&session_id,
c,
&suppress_replay,
&mut answered_permissions,
&outbound,
&snapshot,
¤t_prompter,
).await,
None => break, }
}
agent_msg = updates.recv() => {
let Some(msg) = agent_msg else {
break; };
let suppress = *suppress_replay.lock().await;
handle_agent_message(&relay_tx, msg, suppress).await;
while let Ok(frame) = relay_rx.try_recv() {
if let axum::extract::ws::Message::Text(text) = frame {
if let Ok(mut value) = serde_json::from_str::<Value>(&text) {
let event_type = value
.get("type")
.and_then(Value::as_str)
.unwrap_or("")
.to_string();
let target = *current_prompter.lock().await;
if let Some(target) = target {
if matches!(
event_type.as_str(),
"permission_request" | "mcp_oauth_request"
) {
if let Some(map) = value.as_object_mut() {
map.insert(
"_target".into(),
Value::Number(target.into()),
);
}
}
}
if event_type == "prompt_done" {
*current_prompter.lock().await = None;
}
let _ = outbound.send(Arc::new(value));
}
}
}
}
grace_evt = grace_rx.recv() => {
match grace_evt {
Some(GraceEvent::Empty) => {
let cancel_rx = counter.install_cancel().await;
grace_deadline = Some(Box::pin(async move {
tokio::select! {
_ = tokio::time::sleep(GRACE_PERIOD) => {}
_ = cancel_rx => {}
}
}));
}
Some(GraceEvent::Refilled) => {
grace_deadline = None;
}
None => {} }
}
_ = async {
match grace_deadline.as_mut() {
Some(f) => f.await,
None => std::future::pending().await,
}
}, if grace_deadline.is_some() => {
if counter.count().await == 0 {
break;
}
grace_deadline = None;
}
}
}
agent.shutdown(Some(&session_id)).await;
registry.remove(&session_id).await;
}
#[allow(clippy::too_many_arguments)]
async fn handle_command(
agent: &Arc<Agent>,
session_id: &str,
cmd: HubCommand,
suppress_replay: &Mutex<bool>,
answered: &mut std::collections::HashSet<String>,
outbound: &broadcast::Sender<Arc<Value>>,
snapshot: &Arc<Mutex<NegotiationSnapshot>>,
current_prompter: &Arc<Mutex<Option<u64>>>,
) {
match cmd {
HubCommand::Prompt { blocks, attach_id } => {
if blocks.is_empty() {
return;
}
*current_prompter.lock().await = Some(attach_id);
*suppress_replay.lock().await = false;
let echo_text = extract_user_text(&blocks);
if !echo_text.is_empty() {
let _ = outbound.send(Arc::new(json!({
"type": "append",
"role": "user",
"text": format!("> {echo_text}\n")
})));
}
let agent = Arc::clone(agent);
let sid = session_id.to_string();
let outbound_clone = outbound.clone();
let prompter_clone = current_prompter.clone();
tokio::spawn(async move {
let res = agent
.request(
"session/prompt",
json!({ "sessionId": sid, "prompt": blocks }),
)
.await;
if let Err(e) = res {
let _ = outbound_clone.send(Arc::new(json!({
"type": "error",
"message": format!("{e}")
})));
}
*prompter_clone.lock().await = None;
let _ = outbound_clone.send(Arc::new(json!({ "type": "prompt_done" })));
});
}
HubCommand::PermissionResponse { id, option_id } => {
let key = id.to_string();
if !answered.insert(key) {
return;
}
let agent = Arc::clone(agent);
tokio::spawn(async move {
let _ = agent
.respond(
id,
json!({
"outcome": {
"outcome": "selected",
"optionId": option_id
}
}),
)
.await;
});
}
HubCommand::Cancel => {
let agent = Arc::clone(agent);
let sid = session_id.to_string();
tokio::spawn(async move {
let _ = agent
.notify("session/cancel", json!({ "sessionId": sid }))
.await;
});
}
HubCommand::SetMode { mode_id } => {
let agent = Arc::clone(agent);
let sid = session_id.to_string();
let outbound_clone = outbound.clone();
let snapshot = snapshot.clone();
tokio::spawn(async move {
let res = agent
.request(
"session/set_mode",
json!({ "sessionId": sid, "modeId": mode_id }),
)
.await;
if res.is_err() {
if let Err(e) = res {
let _ = outbound_clone.send(Arc::new(json!({
"type": "append",
"role": "sys",
"text": format!("\n[set_mode failed: {e}]\n")
})));
}
return;
}
let next =
update_session_info_field(snapshot.clone(), "modes", "currentModeId", &mode_id)
.await;
if let Some(info) = next {
let _ = outbound_clone.send(Arc::new(json!({
"type": "session_info",
"info": info
})));
}
});
}
HubCommand::SetModel { model_id } => {
let agent = Arc::clone(agent);
let sid = session_id.to_string();
let outbound_clone = outbound.clone();
let snapshot = snapshot.clone();
tokio::spawn(async move {
let res = agent
.request(
"session/set_model",
json!({ "sessionId": sid, "modelId": model_id }),
)
.await;
if res.is_err() {
if let Err(e) = res {
let _ = outbound_clone.send(Arc::new(json!({
"type": "append",
"role": "sys",
"text": format!("\n[set_model failed: {e}]\n")
})));
}
return;
}
let next = update_session_info_field(
snapshot.clone(),
"models",
"currentModelId",
&model_id,
)
.await;
if let Some(info) = next {
let _ = outbound_clone.send(Arc::new(json!({
"type": "session_info",
"info": info
})));
}
});
}
}
}
fn extract_user_text(blocks: &[Value]) -> String {
blocks
.iter()
.filter_map(|block| {
if block.get("type").and_then(Value::as_str)? == "text" {
block
.get("text")
.and_then(Value::as_str)
.map(str::to_string)
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n")
}
async fn update_session_info_field(
snapshot: Arc<Mutex<NegotiationSnapshot>>,
outer_field: &str,
current_field: &str,
new_id: &str,
) -> Option<Value> {
let mut snap = snapshot.lock().await;
let info_frame = snap.session_info.as_mut()?;
let info = info_frame.get_mut("info")?;
let outer = info.get_mut(outer_field)?.as_object_mut()?;
outer.insert(current_field.into(), Value::String(new_id.to_string()));
Some(info.clone())
}