Skip to main content

plugins_protocol/
server.rs

1use async_trait::async_trait;
2use serde_json::Value;
3use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
4
5use crate::Result;
6use crate::{
7    CompleteRequest, CompleteResponse, InitializeRequest, InitializeResponse, ListModelsResponse,
8    RpcErrorObject,
9};
10
11#[async_trait]
12pub trait PluginHandler: Send + Sync {
13    async fn initialize(&self, req: InitializeRequest) -> Result<InitializeResponse>;
14    async fn list_models(&self) -> Result<ListModelsResponse>;
15    async fn complete(&self, req: CompleteRequest) -> Result<CompleteResponse>;
16
17    async fn shutdown(&self) -> Result<()> {
18        Ok(())
19    }
20}
21
22pub struct PluginServer<H> {
23    handler: H,
24}
25
26impl<H> PluginServer<H>
27where
28    H: PluginHandler,
29{
30    pub fn new(handler: H) -> Self {
31        Self { handler }
32    }
33
34    pub async fn run_stdio(self) -> Result<()> {
35        let stdin = tokio::io::stdin();
36        let mut stdout = tokio::io::stdout();
37        self.run(stdin, &mut stdout).await
38    }
39
40    pub async fn run<R, W>(&self, reader: R, writer: &mut W) -> Result<()>
41    where
42        R: AsyncRead + Unpin,
43        W: AsyncWrite + Unpin,
44    {
45        let mut lines = BufReader::new(reader).lines();
46        while let Some(line) = lines.next_line().await? {
47            if line.trim().is_empty() {
48                continue;
49            }
50
51            let request: std::result::Result<RpcRequest, serde_json::Error> =
52                serde_json::from_str(&line);
53            let request = match request {
54                Ok(request) => request,
55                Err(e) => {
56                    write_response(
57                        writer,
58                        &RpcResponse::error(None, -32700, format!("parse error: {e}")),
59                    )
60                    .await?;
61                    continue;
62                }
63            };
64
65            let id = request.id;
66            let mut should_shutdown = false;
67            let response = match request.method.as_str() {
68                "initialize" => match serde_json::from_value::<InitializeRequest>(request.params) {
69                    Ok(req) => match self.handler.initialize(req).await {
70                        Ok(result) => RpcResponse::ok(id, serde_json::to_value(result)?),
71                        Err(e) => RpcResponse::error(id, -32000, e.to_string()),
72                    },
73                    Err(e) => RpcResponse::error(id, -32602, e.to_string()),
74                },
75                "list_models" => match self.handler.list_models().await {
76                    Ok(result) => RpcResponse::ok(id, serde_json::to_value(result)?),
77                    Err(e) => RpcResponse::error(id, -32000, e.to_string()),
78                },
79                "complete" => match serde_json::from_value::<CompleteRequest>(request.params) {
80                    Ok(req) => match self.handler.complete(req).await {
81                        Ok(result) => RpcResponse::ok(id, serde_json::to_value(result)?),
82                        Err(e) => RpcResponse::error(id, -32000, e.to_string()),
83                    },
84                    Err(e) => RpcResponse::error(id, -32602, e.to_string()),
85                },
86                "stream" => RpcResponse::error(id, -32601, "stream is not supported".to_string()),
87                "shutdown" => {
88                    should_shutdown = true;
89                    match self.handler.shutdown().await {
90                        Ok(()) => RpcResponse::ok(id, serde_json::json!({})),
91                        Err(e) => RpcResponse::error(id, -32000, e.to_string()),
92                    }
93                }
94                other => RpcResponse::error(id, -32601, format!("unknown method: {other}")),
95            };
96            write_response(writer, &response).await?;
97            if should_shutdown {
98                break;
99            }
100        }
101        Ok(())
102    }
103}
104
105async fn write_response<W>(writer: &mut W, response: &RpcResponse) -> Result<()>
106where
107    W: AsyncWrite + Unpin,
108{
109    let mut bytes = serde_json::to_vec(response)?;
110    bytes.push(b'\n');
111    writer.write_all(&bytes).await?;
112    writer.flush().await?;
113    Ok(())
114}
115
116#[derive(Debug, serde::Deserialize)]
117struct RpcRequest {
118    id: Option<u64>,
119    method: String,
120    #[serde(default)]
121    params: Value,
122}
123
124#[derive(Debug, serde::Serialize)]
125struct RpcResponse {
126    jsonrpc: &'static str,
127    #[serde(skip_serializing_if = "Option::is_none")]
128    id: Option<u64>,
129    #[serde(skip_serializing_if = "Option::is_none")]
130    result: Option<Value>,
131    #[serde(skip_serializing_if = "Option::is_none")]
132    error: Option<RpcErrorObject>,
133}
134
135impl RpcResponse {
136    fn ok(id: Option<u64>, result: Value) -> Self {
137        Self {
138            jsonrpc: "2.0",
139            id,
140            result: Some(result),
141            error: None,
142        }
143    }
144
145    fn error(id: Option<u64>, code: i64, message: String) -> Self {
146        Self {
147            jsonrpc: "2.0",
148            id,
149            result: None,
150            error: Some(RpcErrorObject { code, message }),
151        }
152    }
153}