use anyhow::Context;
use std::collections::BTreeMap;
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::sync::{mpsc, Mutex};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use futures::StreamExt;
use synaps_cli::{
Runtime, Session, SessionEvent, StreamEvent,
core::rpc_protocol::{RpcAttachment, RpcCommand, RpcEvent, TurnUsage, RPC_PROTOCOL_VERSION},
core::rpc_dispatch::{
accumulate_usage, build_user_content, build_tools_list_body, map_stream_event, parse_frame, MAX_FRAME_BYTES,
},
engine::setup::{self, EngineOpts},
};
use synaps_cli::runtime::openai::registry::{list_models, list_providers};
const WRITER_CHAN_CAP: usize = 256;
struct InFlight {
prompt_id: String,
cancel: CancellationToken,
handle: JoinHandle<()>,
}
struct RpcState {
runtime: Runtime,
session: Session,
api_messages: Vec<serde_json::Value>,
total_input_tokens: u64,
total_output_tokens: u64,
session_cost: f64,
in_flight: Option<InFlight>,
}
impl RpcState {
async fn save_session(&mut self) {
if self.api_messages.is_empty() {
return;
}
self.session.api_messages = self.api_messages.clone();
self.session.total_input_tokens = self.total_input_tokens;
self.session.total_output_tokens = self.total_output_tokens;
self.session.session_cost = self.session_cost;
self.session.model = self.runtime.model().to_string();
self.session.system_prompt = self.runtime.system_prompt().map(|s| s.to_string());
self.session.thinking_level = self.runtime.thinking_level().to_string();
self.session.updated_at = chrono::Utc::now();
self.session.auto_title();
if let Err(e) = self.session.save().await {
tracing::error!(error = %e, "failed to save session");
}
}
fn is_streaming(&self) -> bool {
self.in_flight.is_some()
}
}
fn encode_event(ev: &RpcEvent) -> String {
serde_json::to_string(ev).unwrap_or_else(|e| {
tracing::error!(error = %e, "BUG: failed to serialise RpcEvent");
format!(r#"{{"type":"error","message":"internal serialisation error: {e}"}}"#)
})
}
fn spawn_writer(mut rx: mpsc::Receiver<RpcEvent>) -> JoinHandle<()> {
tokio::spawn(async move {
while let Some(ev) = rx.recv().await {
println!("{}", encode_event(&ev));
}
})
}
async fn spawn_prompt(
prompt_id: String,
state: Arc<Mutex<RpcState>>,
writer_tx: mpsc::Sender<RpcEvent>,
) -> InFlight {
let cancel = CancellationToken::new();
let cancel_clone = cancel.clone();
let cancel_check = cancel.clone();
let pid = prompt_id.clone();
let wtx = writer_tx.clone();
let handle = tokio::spawn(async move {
let messages = {
let st = state.lock().await;
st.api_messages.clone()
};
let mut stream = {
let st = state.lock().await;
st.runtime
.run_stream_with_messages(messages, cancel_clone, None, None, false)
.await
};
let mut usage_acc = TurnUsage {
input_tokens: 0,
output_tokens: 0,
cache_read_input_tokens: 0,
cache_creation_input_tokens: 0,
model: None,
};
while let Some(ev) = stream.next().await {
match &ev {
StreamEvent::Session(SessionEvent::MessageHistory(msgs)) => {
let mut st = state.lock().await;
st.api_messages = msgs.clone();
st.save_session().await;
continue;
}
StreamEvent::Session(se @ SessionEvent::Usage {
input_tokens,
output_tokens,
..
}) => {
accumulate_usage(&mut usage_acc, se);
let mut st = state.lock().await;
st.total_input_tokens += input_tokens;
st.total_output_tokens += output_tokens;
continue;
}
StreamEvent::Session(SessionEvent::Done) => {
let _ = wtx.send(RpcEvent::AgentEnd { usage: usage_acc.clone() }).await;
let _ = wtx
.send(RpcEvent::Response {
id: pid.clone(),
command: "prompt".to_string(),
body: serde_json::json!({ "ok": true }),
})
.await;
state.lock().await.in_flight = None;
return;
}
StreamEvent::Session(SessionEvent::Error(msg)) => {
if cancel_check.is_cancelled() {
let _ = wtx
.send(RpcEvent::AgentEnd { usage: usage_acc.clone() })
.await;
let _ = wtx
.send(RpcEvent::Response {
id: pid.clone(),
command: "prompt".to_string(),
body: serde_json::json!({ "ok": true, "cancelled": true }),
})
.await;
state.lock().await.in_flight = None;
return;
}
let _ = wtx
.send(RpcEvent::Error {
id: Some(pid.clone()),
message: msg.clone(),
})
.await;
let _ = wtx.send(RpcEvent::AgentEnd { usage: usage_acc.clone() }).await;
let _ = wtx
.send(RpcEvent::Response {
id: pid.clone(),
command: "prompt".to_string(),
body: serde_json::json!({ "ok": false, "error": msg }),
})
.await;
state.lock().await.in_flight = None;
return;
}
_ => {}
}
if let Some(rpc_ev) = map_stream_event(&ev) {
if wtx.send(rpc_ev).await.is_err() {
tracing::warn!("writer channel closed; aborting stream early");
break;
}
}
}
let cancelled = cancel_check.is_cancelled();
let _ = wtx.send(RpcEvent::AgentEnd { usage: usage_acc.clone() }).await;
let body = if cancelled {
serde_json::json!({ "ok": true, "cancelled": true })
} else {
serde_json::json!({
"ok": false,
"error": "stream ended without Done"
})
};
let _ = wtx
.send(RpcEvent::Response {
id: pid.clone(),
command: "prompt".to_string(),
body,
})
.await;
state.lock().await.in_flight = None;
});
InFlight { prompt_id, cancel, handle }
}
async fn handle_prompt(
id: String,
message: String,
attachments: Vec<RpcAttachment>,
state: Arc<Mutex<RpcState>>,
writer_tx: mpsc::Sender<RpcEvent>,
) {
{
let st = state.lock().await;
if st.is_streaming() {
tracing::warn!(id, "rejected concurrent prompt — stream already in flight");
let _ = writer_tx
.send(RpcEvent::Error {
id: Some(id),
message: "another prompt is in flight; abort first".to_string(),
})
.await;
return;
}
}
let content = build_user_content(&message, &attachments);
{
let mut st = state.lock().await;
st.api_messages
.push(serde_json::json!({"role": "user", "content": content}));
}
let in_flight = spawn_prompt(id, state.clone(), writer_tx).await;
state.lock().await.in_flight = Some(in_flight);
}
async fn handle_compact(
id: String,
state: Arc<Mutex<RpcState>>,
writer_tx: mpsc::Sender<RpcEvent>,
) {
let (msgs, runtime) = {
let st = state.lock().await;
(st.api_messages.clone(), st.runtime.clone())
};
let summary_result =
synaps_cli::core::compaction::compact_conversation(&msgs, &runtime, None).await;
match summary_result {
Ok(summary) => {
{
let mut st = state.lock().await;
st.api_messages =
vec![serde_json::json!({"role": "user", "content": summary.clone()})];
st.save_session().await;
}
let _ = writer_tx
.send(RpcEvent::Response {
id,
command: "compact".to_string(),
body: serde_json::json!({ "summary": summary }),
})
.await;
}
Err(e) => {
tracing::error!(error = %e, "compact_conversation failed");
let _ = writer_tx
.send(RpcEvent::Error {
id: Some(id),
message: e.to_string(),
})
.await;
}
}
}
async fn handle_new_session(
id: String,
state: Arc<Mutex<RpcState>>,
writer_tx: mpsc::Sender<RpcEvent>,
) {
{
let st = state.lock().await;
if st.is_streaming() {
tracing::warn!(id, "rejected new_session — stream in flight");
let _ = writer_tx
.send(RpcEvent::Error {
id: Some(id),
message: "another prompt is in flight; abort first".to_string(),
})
.await;
return;
}
}
let new_session_id = {
let mut st = state.lock().await;
st.save_session().await;
let new_sess =
Session::new(st.runtime.model(), st.runtime.thinking_level(), st.runtime.system_prompt());
let sid = new_sess.id.clone();
st.session = new_sess;
st.api_messages.clear();
st.total_input_tokens = 0;
st.total_output_tokens = 0;
st.session_cost = 0.0;
sid
};
let _ = writer_tx
.send(RpcEvent::Response {
id,
command: "new_session".to_string(),
body: serde_json::json!({ "session_id": new_session_id }),
})
.await;
}
async fn handle_get_messages(
id: String,
state: Arc<Mutex<RpcState>>,
writer_tx: mpsc::Sender<RpcEvent>,
) {
let messages = state.lock().await.api_messages.clone();
let _ = writer_tx
.send(RpcEvent::Response {
id,
command: "get_messages".to_string(),
body: serde_json::json!({ "messages": messages }),
})
.await;
}
async fn handle_set_model(
id: String,
model: String,
state: Arc<Mutex<RpcState>>,
writer_tx: mpsc::Sender<RpcEvent>,
) {
state.lock().await.runtime.set_model(model.clone());
let _ = writer_tx
.send(RpcEvent::Response {
id,
command: "set_model".to_string(),
body: serde_json::json!({ "model": model }),
})
.await;
}
async fn handle_get_available_models(id: String, writer_tx: mpsc::Sender<RpcEvent>) {
let overrides: BTreeMap<String, String> = BTreeMap::new();
let providers = list_providers(&overrides);
let mut models_list: Vec<serde_json::Value> = Vec::new();
for (provider_key, _provider_name, _has_key, _count) in &providers {
if let Some(models) = list_models(provider_key) {
for (model_id, model_name, _default_flag) in models {
models_list.push(serde_json::json!({
"provider": provider_key,
"model_id": model_id,
"model_name": model_name,
}));
}
}
}
let _ = writer_tx
.send(RpcEvent::Response {
id,
command: "get_available_models".to_string(),
body: serde_json::json!({ "models": models_list }),
})
.await;
}
async fn handle_abort(
id: String,
state: Arc<Mutex<RpcState>>,
writer_tx: mpsc::Sender<RpcEvent>,
) {
let handle_opt = {
let mut st = state.lock().await;
if let Some(inf) = st.in_flight.take() {
tracing::info!(prompt_id = %inf.prompt_id, abort_id = %id, "aborted in-flight stream");
inf.cancel.cancel();
Some(inf.handle)
} else {
None
}
};
if let Some(handle) = handle_opt {
if let Err(e) = handle.await {
tracing::warn!(error = ?e, "streaming task panicked during abort");
}
}
let _ = writer_tx
.send(RpcEvent::Response {
id,
command: "abort".to_string(),
body: serde_json::json!({ "ok": true }),
})
.await;
}
async fn handle_get_session_stats(
id: String,
state: Arc<Mutex<RpcState>>,
writer_tx: mpsc::Sender<RpcEvent>,
) {
let body = {
let st = state.lock().await;
serde_json::json!({
"input_tokens": st.total_input_tokens,
"output_tokens": st.total_output_tokens,
"cost": st.session_cost,
"message_count": st.api_messages.len(),
"model": st.runtime.model(),
"session_id": st.session.id,
})
};
let _ = writer_tx
.send(RpcEvent::Response {
id,
command: "get_session_stats".to_string(),
body,
})
.await;
}
async fn handle_get_state(
id: String,
state: Arc<Mutex<RpcState>>,
writer_tx: mpsc::Sender<RpcEvent>,
) {
let body = {
let st = state.lock().await;
serde_json::json!({
"streaming": st.is_streaming(),
"model": st.runtime.model(),
"session_id": st.session.id,
"message_count": st.api_messages.len(),
})
};
let _ = writer_tx
.send(RpcEvent::Response {
id,
command: "get_state".to_string(),
body,
})
.await;
}
async fn handle_tools_list(
id: Option<String>,
state: Arc<Mutex<RpcState>>,
writer_tx: mpsc::Sender<RpcEvent>,
) {
let schema = {
let st = state.lock().await;
let registry = st.runtime.tools_shared();
let guard = registry.read().await;
guard.tools_schema().as_ref().clone()
};
let body = build_tools_list_body(&schema);
let response_id = id.unwrap_or_default();
let _ = writer_tx
.send(RpcEvent::Response {
id: response_id,
command: "tools_list".to_string(),
body,
})
.await;
}
pub async fn run(
continue_id: Option<String>,
system: Option<String>,
model: Option<String>,
profile: Option<String>,
) -> anyhow::Result<()> {
let boot = setup::boot(EngineOpts {
continue_session: continue_id.map(Some),
system,
profile,
no_extensions: false,
})
.await
.context("engine boot failed")?;
let mut runtime = boot.runtime;
let session = boot.session;
let initial_messages = boot.api_messages;
let initial_in = boot.total_input_tokens;
let initial_out = boot.total_output_tokens;
let initial_cost = boot.session_cost;
let ext_manager = boot.ext_manager;
let background = boot.background;
if let Some(ref m) = model {
runtime.set_model(m.clone());
}
let (loader_tx, mut loader_rx) = mpsc::unbounded_channel();
synaps_cli::extensions::loader::spawn_discover_and_load(
Arc::clone(&ext_manager),
loader_tx,
);
let _ = tokio::time::timeout(
std::time::Duration::from_secs(2),
async {
while let Some(ev) = loader_rx.recv().await {
if matches!(ev, synaps_cli::extensions::loader::ExtensionLoaderEvent::Finished { .. }) {
break;
}
}
},
)
.await;
let ready_session_id = session.id.clone();
let ready_model = runtime.model().to_string();
let state = Arc::new(Mutex::new(RpcState {
runtime,
session,
api_messages: initial_messages,
total_input_tokens: initial_in,
total_output_tokens: initial_out,
session_cost: initial_cost,
in_flight: None,
}));
let (writer_tx, writer_rx) = mpsc::channel::<RpcEvent>(WRITER_CHAN_CAP);
let writer_handle = spawn_writer(writer_rx);
writer_tx
.send(RpcEvent::Ready {
session_id: ready_session_id,
model: ready_model,
protocol_version: RPC_PROTOCOL_VERSION,
})
.await
.context("writer channel closed before Ready frame could be sent")?;
tracing::info!("synaps rpc ready");
let stdin = tokio::io::stdin();
let mut lines = BufReader::new(stdin).lines();
loop {
match lines.next_line().await {
Err(e) => {
tracing::error!(error = %e, "stdin read error; exiting");
break;
}
Ok(None) => {
tracing::info!("stdin EOF; saving session and exiting");
state.lock().await.save_session().await;
background.shutdown();
break;
}
Ok(Some(line)) => {
let line = line.trim_end_matches('\r'); if line.trim().is_empty() {
continue;
}
let cmd = match parse_frame(line, MAX_FRAME_BYTES) {
Ok(c) => c,
Err(err_ev) => {
tracing::warn!("frame parse error");
let _ = writer_tx.send(err_ev).await;
continue;
}
};
tracing::debug!(?cmd, "received RpcCommand");
match cmd {
RpcCommand::Prompt { id, message, attachments } => {
handle_prompt(id, message, attachments, state.clone(), writer_tx.clone())
.await;
}
RpcCommand::FollowUp { id, message } => {
handle_prompt(id, message, Vec::new(), state.clone(), writer_tx.clone())
.await;
}
RpcCommand::Compact { id } => {
handle_compact(id, state.clone(), writer_tx.clone()).await;
}
RpcCommand::NewSession { id } => {
handle_new_session(id, state.clone(), writer_tx.clone()).await;
}
RpcCommand::GetMessages { id } => {
handle_get_messages(id, state.clone(), writer_tx.clone()).await;
}
RpcCommand::SetModel { id, model: m } => {
handle_set_model(id, m, state.clone(), writer_tx.clone()).await;
}
RpcCommand::GetAvailableModels { id } => {
handle_get_available_models(id, writer_tx.clone()).await;
}
RpcCommand::Abort { id } => {
handle_abort(id, state.clone(), writer_tx.clone()).await;
}
RpcCommand::GetSessionStats { id } => {
handle_get_session_stats(id, state.clone(), writer_tx.clone()).await;
}
RpcCommand::GetState { id } => {
handle_get_state(id, state.clone(), writer_tx.clone()).await;
}
RpcCommand::ToolsList { id } => {
handle_tools_list(id, state.clone(), writer_tx.clone()).await;
}
RpcCommand::Shutdown => {
tracing::info!("Shutdown received; draining and exiting");
loop {
let done = state.lock().await.in_flight.is_none();
if done {
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
state.lock().await.save_session().await;
background.shutdown();
drop(writer_tx);
match tokio::time::timeout(
tokio::time::Duration::from_secs(1),
writer_handle,
)
.await
{
Ok(Ok(())) => {} Ok(Err(e)) => {
tracing::warn!(error = ?e, "writer task panicked during shutdown")
}
Err(_) => tracing::warn!(
"writer task did not drain within 1s; exiting anyway"
),
}
std::process::exit(0);
}
}
}
}
}
Ok(())
}