1use async_trait::async_trait;
2use serde_json::Value;
3use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
4
5use crate::Result;
6use crate::{
7 CompleteRequest, CompleteResponse, InitializeRequest, InitializeResponse, ListModelsResponse,
8 RpcErrorObject,
9};
10
11#[async_trait]
12pub trait PluginHandler: Send + Sync {
13 async fn initialize(&self, req: InitializeRequest) -> Result<InitializeResponse>;
14 async fn list_models(&self) -> Result<ListModelsResponse>;
15 async fn complete(&self, req: CompleteRequest) -> Result<CompleteResponse>;
16
17 async fn shutdown(&self) -> Result<()> {
18 Ok(())
19 }
20}
21
22pub struct PluginServer<H> {
23 handler: H,
24}
25
26impl<H> PluginServer<H>
27where
28 H: PluginHandler,
29{
30 pub fn new(handler: H) -> Self {
31 Self { handler }
32 }
33
34 pub async fn run_stdio(self) -> Result<()> {
35 let stdin = tokio::io::stdin();
36 let mut stdout = tokio::io::stdout();
37 self.run(stdin, &mut stdout).await
38 }
39
40 pub async fn run<R, W>(&self, reader: R, writer: &mut W) -> Result<()>
41 where
42 R: AsyncRead + Unpin,
43 W: AsyncWrite + Unpin,
44 {
45 let mut lines = BufReader::new(reader).lines();
46 while let Some(line) = lines.next_line().await? {
47 if line.trim().is_empty() {
48 continue;
49 }
50
51 let request: std::result::Result<RpcRequest, serde_json::Error> =
52 serde_json::from_str(&line);
53 let request = match request {
54 Ok(request) => request,
55 Err(e) => {
56 write_response(
57 writer,
58 &RpcResponse::error(None, -32700, format!("parse error: {e}")),
59 )
60 .await?;
61 continue;
62 }
63 };
64
65 let id = request.id;
66 let mut should_shutdown = false;
67 let response = match request.method.as_str() {
68 "initialize" => match serde_json::from_value::<InitializeRequest>(request.params) {
69 Ok(req) => match self.handler.initialize(req).await {
70 Ok(result) => RpcResponse::ok(id, serde_json::to_value(result)?),
71 Err(e) => RpcResponse::error(id, -32000, e.to_string()),
72 },
73 Err(e) => RpcResponse::error(id, -32602, e.to_string()),
74 },
75 "list_models" => match self.handler.list_models().await {
76 Ok(result) => RpcResponse::ok(id, serde_json::to_value(result)?),
77 Err(e) => RpcResponse::error(id, -32000, e.to_string()),
78 },
79 "complete" => match serde_json::from_value::<CompleteRequest>(request.params) {
80 Ok(req) => match self.handler.complete(req).await {
81 Ok(result) => RpcResponse::ok(id, serde_json::to_value(result)?),
82 Err(e) => RpcResponse::error(id, -32000, e.to_string()),
83 },
84 Err(e) => RpcResponse::error(id, -32602, e.to_string()),
85 },
86 "stream" => RpcResponse::error(id, -32601, "stream is not supported".to_string()),
87 "shutdown" => {
88 should_shutdown = true;
89 match self.handler.shutdown().await {
90 Ok(()) => RpcResponse::ok(id, serde_json::json!({})),
91 Err(e) => RpcResponse::error(id, -32000, e.to_string()),
92 }
93 }
94 other => RpcResponse::error(id, -32601, format!("unknown method: {other}")),
95 };
96 write_response(writer, &response).await?;
97 if should_shutdown {
98 break;
99 }
100 }
101 Ok(())
102 }
103}
104
105async fn write_response<W>(writer: &mut W, response: &RpcResponse) -> Result<()>
106where
107 W: AsyncWrite + Unpin,
108{
109 let mut bytes = serde_json::to_vec(response)?;
110 bytes.push(b'\n');
111 writer.write_all(&bytes).await?;
112 writer.flush().await?;
113 Ok(())
114}
115
116#[derive(Debug, serde::Deserialize)]
117struct RpcRequest {
118 id: Option<u64>,
119 method: String,
120 #[serde(default)]
121 params: Value,
122}
123
124#[derive(Debug, serde::Serialize)]
125struct RpcResponse {
126 jsonrpc: &'static str,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 id: Option<u64>,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 result: Option<Value>,
131 #[serde(skip_serializing_if = "Option::is_none")]
132 error: Option<RpcErrorObject>,
133}
134
135impl RpcResponse {
136 fn ok(id: Option<u64>, result: Value) -> Self {
137 Self {
138 jsonrpc: "2.0",
139 id,
140 result: Some(result),
141 error: None,
142 }
143 }
144
145 fn error(id: Option<u64>, code: i64, message: String) -> Self {
146 Self {
147 jsonrpc: "2.0",
148 id,
149 result: None,
150 error: Some(RpcErrorObject { code, message }),
151 }
152 }
153}