plugins-protocol 0.6.8

Newt-Agent provider-plugin JSON-RPC schema + reference client SDK
Documentation
use async_trait::async_trait;
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};

use crate::Result;
use crate::{
    CompleteRequest, CompleteResponse, InitializeRequest, InitializeResponse, ListModelsResponse,
    RpcErrorObject,
};

#[async_trait]
pub trait PluginHandler: Send + Sync {
    async fn initialize(&self, req: InitializeRequest) -> Result<InitializeResponse>;
    async fn list_models(&self) -> Result<ListModelsResponse>;
    async fn complete(&self, req: CompleteRequest) -> Result<CompleteResponse>;

    async fn shutdown(&self) -> Result<()> {
        Ok(())
    }
}

pub struct PluginServer<H> {
    handler: H,
}

impl<H> PluginServer<H>
where
    H: PluginHandler,
{
    pub fn new(handler: H) -> Self {
        Self { handler }
    }

    pub async fn run_stdio(self) -> Result<()> {
        let stdin = tokio::io::stdin();
        let mut stdout = tokio::io::stdout();
        self.run(stdin, &mut stdout).await
    }

    pub async fn run<R, W>(&self, reader: R, writer: &mut W) -> Result<()>
    where
        R: AsyncRead + Unpin,
        W: AsyncWrite + Unpin,
    {
        let mut lines = BufReader::new(reader).lines();
        while let Some(line) = lines.next_line().await? {
            if line.trim().is_empty() {
                continue;
            }

            let request: std::result::Result<RpcRequest, serde_json::Error> =
                serde_json::from_str(&line);
            let request = match request {
                Ok(request) => request,
                Err(e) => {
                    write_response(
                        writer,
                        &RpcResponse::error(None, -32700, format!("parse error: {e}")),
                    )
                    .await?;
                    continue;
                }
            };

            let id = request.id;
            let mut should_shutdown = false;
            let response = match request.method.as_str() {
                "initialize" => match serde_json::from_value::<InitializeRequest>(request.params) {
                    Ok(req) => match self.handler.initialize(req).await {
                        Ok(result) => RpcResponse::ok(id, serde_json::to_value(result)?),
                        Err(e) => RpcResponse::error(id, -32000, e.to_string()),
                    },
                    Err(e) => RpcResponse::error(id, -32602, e.to_string()),
                },
                "list_models" => match self.handler.list_models().await {
                    Ok(result) => RpcResponse::ok(id, serde_json::to_value(result)?),
                    Err(e) => RpcResponse::error(id, -32000, e.to_string()),
                },
                "complete" => match serde_json::from_value::<CompleteRequest>(request.params) {
                    Ok(req) => match self.handler.complete(req).await {
                        Ok(result) => RpcResponse::ok(id, serde_json::to_value(result)?),
                        Err(e) => RpcResponse::error(id, -32000, e.to_string()),
                    },
                    Err(e) => RpcResponse::error(id, -32602, e.to_string()),
                },
                "stream" => RpcResponse::error(id, -32601, "stream is not supported".to_string()),
                "shutdown" => {
                    should_shutdown = true;
                    match self.handler.shutdown().await {
                        Ok(()) => RpcResponse::ok(id, serde_json::json!({})),
                        Err(e) => RpcResponse::error(id, -32000, e.to_string()),
                    }
                }
                other => RpcResponse::error(id, -32601, format!("unknown method: {other}")),
            };
            write_response(writer, &response).await?;
            if should_shutdown {
                break;
            }
        }
        Ok(())
    }
}

async fn write_response<W>(writer: &mut W, response: &RpcResponse) -> Result<()>
where
    W: AsyncWrite + Unpin,
{
    let mut bytes = serde_json::to_vec(response)?;
    bytes.push(b'\n');
    writer.write_all(&bytes).await?;
    writer.flush().await?;
    Ok(())
}

#[derive(Debug, serde::Deserialize)]
struct RpcRequest {
    id: Option<u64>,
    method: String,
    #[serde(default)]
    params: Value,
}

#[derive(Debug, serde::Serialize)]
struct RpcResponse {
    jsonrpc: &'static str,
    #[serde(skip_serializing_if = "Option::is_none")]
    id: Option<u64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    result: Option<Value>,
    #[serde(skip_serializing_if = "Option::is_none")]
    error: Option<RpcErrorObject>,
}

impl RpcResponse {
    fn ok(id: Option<u64>, result: Value) -> Self {
        Self {
            jsonrpc: "2.0",
            id,
            result: Some(result),
            error: None,
        }
    }

    fn error(id: Option<u64>, code: i64, message: String) -> Self {
        Self {
            jsonrpc: "2.0",
            id,
            result: None,
            error: Some(RpcErrorObject { code, message }),
        }
    }
}