use crate::backend::resolve_backend;
use crate::builder::StdioBusBuilder;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use stdiobus_core::{
Backend, BackendMode, BusMessage, BusState, BusStats, ConfigSource, DockerOptions, Error, JsonRpcRequest,
JsonRpcResponse, RequestOptions, Result, generate_client_session_id,
};
use tokio::sync::{broadcast, oneshot, Mutex};
struct PendingRequest {
tx: oneshot::Sender<AggregatedResponse>,
chunks: Vec<String>,
}
struct AggregatedResponse {
response: JsonRpcResponse,
text: String,
}
pub struct StdioBus {
backend: Box<dyn Backend>,
default_timeout: Duration,
client_session_id: String,
pending_requests: Arc<Mutex<HashMap<String, PendingRequest>>>,
notification_tx: broadcast::Sender<Value>,
}
impl StdioBus {
pub fn builder() -> StdioBusBuilder {
StdioBusBuilder::new()
}
pub(crate) fn new(
config_source: ConfigSource,
backend_mode: BackendMode,
default_timeout: Duration,
docker_options: Option<DockerOptions>,
) -> Result<Self> {
let backend = resolve_backend(backend_mode, config_source, docker_options)?;
let (notification_tx, _) = broadcast::channel(100);
Ok(Self {
backend,
default_timeout,
client_session_id: generate_client_session_id(),
pending_requests: Arc::new(Mutex::new(HashMap::new())),
notification_tx,
})
}
pub fn client_session_id(&self) -> &str {
&self.client_session_id
}
pub async fn start(&self) -> Result<()> {
self.backend.start().await?;
let pending = self.pending_requests.clone();
let notif_tx = self.notification_tx.clone();
if let Some(mut rx) = self.backend.subscribe() {
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
Self::handle_message(msg, &pending, ¬if_tx).await;
}
});
}
Ok(())
}
pub async fn stop(&self) -> Result<()> {
self.stop_with_timeout(30).await
}
pub async fn stop_with_timeout(&self, timeout_secs: u32) -> Result<()> {
self.backend.stop(timeout_secs).await
}
pub async fn request(&self, method: &str, params: Value) -> Result<Value> {
self.request_with_options(method, params, RequestOptions::default())
.await
}
pub async fn request_with_options(
&self,
method: &str,
params: Value,
options: RequestOptions,
) -> Result<Value> {
if !self.is_running() {
return Err(Error::InvalidState {
expected: "RUNNING".to_string(),
actual: self.state().to_string(),
});
}
let mut request = JsonRpcRequest::new(method, Some(params));
let session_id = options.session_id.unwrap_or_else(|| self.client_session_id.clone());
request = request.with_session_id(session_id);
if let Some(agent_id) = options.agent_id {
request = request.with_agent_id(agent_id);
}
let id = request
.id
.as_ref()
.and_then(|v| v.as_str())
.ok_or_else(|| Error::InternalError {
message: "Request ID not set".to_string(),
})?
.to_string();
let json = serde_json::to_string(&request)?;
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending_requests.lock().await;
pending.insert(id.clone(), PendingRequest { tx, chunks: Vec::new() });
}
self.backend.send(&json).await?;
let timeout = options.timeout.unwrap_or(self.default_timeout);
let aggregated = tokio::time::timeout(timeout, rx)
.await
.map_err(|_| Error::Timeout {
timeout_ms: timeout.as_millis() as u64,
})?
.map_err(|_| Error::InternalError {
message: "Response channel closed".to_string(),
})?;
if let Some(error) = aggregated.response.error {
return Err(Error::TransportError {
message: format!("{}: {}", error.code, error.message),
});
}
let mut result = aggregated.response.result.unwrap_or(Value::Object(Default::default()));
if !aggregated.text.is_empty() {
if let Value::Object(ref mut map) = result {
map.insert("text".to_string(), Value::String(aggregated.text));
}
}
Ok(result)
}
pub async fn notify(&self, method: &str, params: Value) -> Result<()> {
if !self.is_running() {
return Err(Error::InvalidState {
expected: "RUNNING".to_string(),
actual: self.state().to_string(),
});
}
let request = JsonRpcRequest::notification(method, Some(params))
.with_session_id(&self.client_session_id);
let json = serde_json::to_string(&request)?;
self.backend.send(&json).await
}
pub async fn send(&self, message: &str) -> Result<()> {
self.backend.send(message).await
}
pub fn subscribe_notifications(&self) -> broadcast::Receiver<Value> {
self.notification_tx.subscribe()
}
pub fn state(&self) -> BusState {
self.backend.state()
}
pub fn is_running(&self) -> bool {
self.state() == BusState::Running
}
pub fn stats(&self) -> BusStats {
self.backend.stats()
}
pub fn worker_count(&self) -> i32 {
self.backend.worker_count()
}
pub fn client_count(&self) -> i32 {
self.backend.client_count()
}
pub fn backend_type(&self) -> &'static str {
self.backend.backend_type()
}
async fn handle_message(
msg: BusMessage,
pending: &Arc<Mutex<HashMap<String, PendingRequest>>>,
notif_tx: &broadcast::Sender<Value>,
) {
let parsed: Value = match serde_json::from_str(&msg.json) {
Ok(v) => v,
Err(_) => return,
};
if parsed.get("method").is_some() && parsed.get("id").is_none() {
if let Some(params) = parsed.get("params") {
if let Some(update) = params.get("update") {
if update.get("sessionUpdate").and_then(|s| s.as_str()) == Some("agent_message_chunk") {
if let Some(text) = update.get("content")
.and_then(|c| c.get("text"))
.and_then(|t| t.as_str())
{
let mut guard = pending.lock().await;
for req in guard.values_mut() {
req.chunks.push(text.to_string());
}
}
}
}
}
let _ = notif_tx.send(parsed);
return;
}
if let Some(id) = parsed.get("id").and_then(|v| v.as_str()) {
let mut guard = pending.lock().await;
if let Some(req) = guard.remove(id) {
let response: JsonRpcResponse = match serde_json::from_str(&msg.json) {
Ok(r) => r,
Err(_) => return,
};
let text = req.chunks.join("");
let _ = req.tx.send(AggregatedResponse { response, text });
}
}
}
}