use std::io;
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use crate::ENVELOPE_VERSION;
use super::framing::{FrameError, read_frame, write_frame_json};
use super::methods::{
ConnectionState, HandlerContext, dispatch as dispatch_request, internal_error_response,
};
use super::protocol::{
DaemonHello, DaemonHelloResponse, JsonRpcResponse, ShimProtocol, ShimRegister, ShimRegisterAck,
};
use super::shim_registry::ShimHandle;
use super::validation::{ValidationError, parse_error_response, validate_request_value};
#[derive(Debug, Error)]
pub enum ConnectionError {
#[error("connection io: {0}")]
Io(#[from] io::Error),
}
impl From<FrameError> for ConnectionError {
fn from(err: FrameError) -> Self {
match err {
FrameError::Io(e) => Self::Io(e),
FrameError::Json(e) => Self::Io(io::Error::new(
io::ErrorKind::InvalidData,
format!("frame json: {e}"),
)),
}
}
}
#[derive(Debug)]
enum FrameResponse {
None,
Single(JsonRpcResponse),
Batch(Vec<JsonRpcResponse>),
ParseError(JsonRpcResponse),
}
pub async fn run_connection<S>(stream: S, ctx: HandlerContext) -> Result<(), ConnectionError>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (mut reader, mut writer) = tokio::io::split(stream);
let first_bytes = match read_frame(&mut reader).await {
Ok(Some(b)) => b,
Ok(None) => return Ok(()),
Err(e) => return Err(e.into()),
};
let first_value: serde_json::Value = match serde_json::from_slice(&first_bytes) {
Ok(v) => v,
Err(e) => {
let resp = ValidationError::InvalidRequest {
reason: "first frame must be DaemonHello or ShimRegister",
context: Some(e.to_string()),
}
.into_jsonrpc_response();
let _ = write_frame_json(&mut writer, &resp).await;
let _ = writer.shutdown().await;
return Ok(());
}
};
if is_shim_register_shape(&first_value) {
let shim_req: ShimRegister = match serde_json::from_value(first_value) {
Ok(r) => r,
Err(e) => {
let ack = ShimRegisterAck {
accepted: false,
daemon_version: ctx.daemon_version.to_owned(),
reason: Some(format!("invalid ShimRegister: {e}")),
envelope_version: ENVELOPE_VERSION,
};
let _ = write_frame_json(&mut writer, &ack).await;
let _ = writer.shutdown().await;
return Ok(());
}
};
return run_shim_connection(reader, writer, shim_req, ctx).await;
}
run_hello_connection(reader, writer, first_value, ctx).await
}
#[must_use]
fn is_shim_register_shape(v: &serde_json::Value) -> bool {
v.as_object()
.is_some_and(|m| m.get("protocol").is_some() && m.get("pid").is_some())
}
async fn run_hello_connection<R, W>(
mut reader: R,
mut writer: W,
first_value: serde_json::Value,
ctx: HandlerContext,
) -> Result<(), ConnectionError>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let hello: DaemonHello = match serde_json::from_value(first_value) {
Ok(h) => h,
Err(e) => {
let resp = ValidationError::InvalidRequest {
reason: "first frame must be a DaemonHello handshake",
context: Some(e.to_string()),
}
.into_jsonrpc_response();
let _ = write_frame_json(&mut writer, &resp).await;
let _ = writer.shutdown().await;
return Ok(());
}
};
let compatible = hello.protocol_version == 1;
let hello_resp = DaemonHelloResponse {
compatible,
daemon_version: ctx.daemon_version.to_owned(),
envelope_version: ENVELOPE_VERSION,
};
write_frame_json(&mut writer, &hello_resp).await?;
if !compatible {
let _ = writer.shutdown().await;
return Ok(());
}
let conn_state = ConnectionState::from_hello(hello.logical_workspace);
loop {
tokio::select! {
biased;
() = ctx.shutdown.cancelled() => break,
frame = read_frame(&mut reader) => match frame {
Ok(None) => break,
Ok(Some(bytes)) => {
let outcome = handle_frame(&ctx, &conn_state, &bytes).await;
match outcome {
FrameResponse::None => {}
FrameResponse::Single(resp) => {
write_frame_json(&mut writer, &resp).await?;
}
FrameResponse::Batch(responses) => {
write_frame_json(&mut writer, &responses).await?;
}
FrameResponse::ParseError(resp) => {
let _ = write_frame_json(&mut writer, &resp).await;
break;
}
}
}
Err(e) => return Err(e.into()),
}
}
}
let _ = writer.shutdown().await;
Ok(())
}
async fn run_shim_connection<R, W>(
reader: R,
mut writer: W,
req: ShimRegister,
ctx: HandlerContext,
) -> Result<(), ConnectionError>
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
let handle_result = ctx.shim_registry.try_register_bounded(
req.protocol,
req.pid,
ctx.config.max_shim_connections,
);
let _handle: ShimHandle = match handle_result {
Ok(h) => {
let ack = ShimRegisterAck {
accepted: true,
daemon_version: ctx.daemon_version.to_owned(),
reason: None,
envelope_version: ENVELOPE_VERSION,
};
write_frame_json(&mut writer, &ack).await?;
h
}
Err(reject) => {
let ack = ShimRegisterAck {
accepted: false,
daemon_version: ctx.daemon_version.to_owned(),
reason: Some(reject.to_string()),
envelope_version: ENVELOPE_VERSION,
};
let _ = write_frame_json(&mut writer, &ack).await;
let _ = writer.shutdown().await;
return Ok(());
}
};
match req.protocol {
ShimProtocol::Lsp => {
let session =
sqry_lsp::session::SessionManager::new(sqry_lsp::LspOptions::default_daemon());
if let Err(e) = sqry_lsp::daemon_host::host_on_streams(
reader,
writer,
session,
ctx.shutdown.clone(),
)
.await
{
tracing::warn!(error = %e, "lsp shim host errored");
}
}
ShimProtocol::Mcp => {
let tool_timeout = std::time::Duration::from_secs(ctx.config.tool_timeout_secs);
if let Err(e) = crate::mcp_host::host_mcp_on_streams(
reader,
writer,
Arc::clone(&ctx.manager),
Arc::clone(&ctx.workspace_builder),
Arc::clone(&ctx.tool_executor),
tool_timeout,
ctx.daemon_version,
ctx.shutdown.clone(),
)
.await
{
tracing::warn!(error = %e, "mcp shim host errored");
}
}
}
Ok(())
}
async fn handle_frame(ctx: &HandlerContext, conn: &ConnectionState, bytes: &[u8]) -> FrameResponse {
let value: serde_json::Value = match serde_json::from_slice(bytes) {
Ok(v) => v,
Err(e) => return FrameResponse::ParseError(parse_error_response(e)),
};
match value {
serde_json::Value::Array(items) => dispatch_batch(ctx, conn, items).await,
serde_json::Value::Object(_) => dispatch_single(ctx, conn, value).await,
_ => FrameResponse::Single(
ValidationError::InvalidRequest {
reason: "request must be an object or array",
context: None,
}
.into_jsonrpc_response(),
),
}
}
async fn dispatch_batch(
ctx: &HandlerContext,
conn: &ConnectionState,
items: Vec<serde_json::Value>,
) -> FrameResponse {
if items.is_empty() {
return FrameResponse::Single(
ValidationError::InvalidRequest {
reason: "batch must contain at least one request",
context: None,
}
.into_jsonrpc_response(),
);
}
let mut responses = Vec::with_capacity(items.len());
for item in items {
if item.is_array() {
responses.push(
ValidationError::InvalidRequest {
reason: "batch element must be a request object",
context: None,
}
.into_jsonrpc_response(),
);
continue;
}
match dispatch_single(ctx, conn, item).await {
FrameResponse::Single(resp) => responses.push(resp),
FrameResponse::None => {}
other => responses.push(internal_error_response(
None,
&format!("batch dispatch returned unexpected shape: {other:?}"),
)),
}
}
if responses.is_empty() {
FrameResponse::None
} else {
FrameResponse::Batch(responses)
}
}
async fn dispatch_single(
ctx: &HandlerContext,
conn: &ConnectionState,
value: serde_json::Value,
) -> FrameResponse {
match validate_request_value(value) {
Ok(req) => match dispatch_request(ctx, conn, req).await {
Some(resp) => FrameResponse::Single(resp),
None => FrameResponse::None, },
Err(e) => FrameResponse::Single(e.into_jsonrpc_response()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn is_shim_register_shape_accepts_full_object() {
let v = serde_json::json!({"protocol": "lsp", "pid": 42});
assert!(is_shim_register_shape(&v));
}
#[test]
fn is_shim_register_shape_accepts_object_with_extra_keys() {
let v = serde_json::json!({"protocol": "lsp", "pid": 42, "other": 1});
assert!(is_shim_register_shape(&v));
}
#[test]
fn is_shim_register_shape_rejects_missing_protocol() {
let v = serde_json::json!({"pid": 42});
assert!(!is_shim_register_shape(&v));
}
#[test]
fn is_shim_register_shape_rejects_missing_pid() {
let v = serde_json::json!({"protocol": "lsp"});
assert!(!is_shim_register_shape(&v));
}
#[test]
fn is_shim_register_shape_rejects_daemon_hello_shape() {
let v = serde_json::json!({"client_version": "8.0.6", "protocol_version": 1});
assert!(!is_shim_register_shape(&v));
}
#[test]
fn is_shim_register_shape_rejects_non_object() {
assert!(!is_shim_register_shape(&serde_json::Value::Null));
assert!(!is_shim_register_shape(&serde_json::json!([1, 2, 3])));
assert!(!is_shim_register_shape(&serde_json::json!("hello")));
assert!(!is_shim_register_shape(&serde_json::json!(42)));
}
}