plugins_protocol/
client.rs1use std::process::Stdio;
2use std::time::Duration;
3
4use serde::de::DeserializeOwned;
5use serde::Serialize;
6use serde_json::Value;
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::process::{Child, ChildStdin, ChildStdout, Command};
9
10use crate::{
11 CompleteRequest, CompleteResponse, InitializeRequest, InitializeResponse, ListModelsResponse,
12 RpcErrorObject,
13};
14use crate::{Error, Result};
15
16const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
17
18pub struct PluginClient {
19 child: Child,
20 stdin: ChildStdin,
21 stdout: BufReader<ChildStdout>,
22 next_id: u64,
23}
24
25impl PluginClient {
26 pub fn spawn(command: &str, env_pass: &[String]) -> Result<Self> {
27 let mut cmd = Command::new(command);
28 cmd.env_clear()
29 .stdin(Stdio::piped())
30 .stdout(Stdio::piped())
31 .stderr(Stdio::inherit());
32 for name in env_pass {
33 if let Ok(value) = std::env::var(name) {
34 cmd.env(name, value);
35 }
36 }
37 preserve_platform_process_env(&mut cmd);
38 Self::spawn_command(cmd)
39 }
40
41 pub fn spawn_command(mut command: Command) -> Result<Self> {
42 command
43 .stdin(Stdio::piped())
44 .stdout(Stdio::piped())
45 .stderr(Stdio::inherit());
46 let mut child = spawn_with_etxtbsy_retry(&mut command)?;
47 let stdin = child
48 .stdin
49 .take()
50 .ok_or_else(|| Error::Protocol("plugin stdin was not piped".to_string()))?;
51 let stdout = child
52 .stdout
53 .take()
54 .ok_or_else(|| Error::Protocol("plugin stdout was not piped".to_string()))?;
55 Ok(Self {
56 child,
57 stdin,
58 stdout: BufReader::new(stdout),
59 next_id: 1,
60 })
61 }
62
63 pub async fn initialize(&mut self, req: InitializeRequest) -> Result<InitializeResponse> {
64 self.request("initialize", req).await
65 }
66
67 pub async fn list_models(&mut self) -> Result<ListModelsResponse> {
68 self.request("list_models", serde_json::json!({})).await
69 }
70
71 pub async fn complete(&mut self, req: CompleteRequest) -> Result<CompleteResponse> {
72 self.request("complete", req).await
73 }
74
75 pub async fn shutdown(&mut self) -> Result<()> {
76 let _: Value = self.request("shutdown", serde_json::json!({})).await?;
77 Ok(())
78 }
79
80 async fn request<P, R>(&mut self, method: &str, params: P) -> Result<R>
81 where
82 P: Serialize,
83 R: DeserializeOwned,
84 {
85 let id = self.next_id;
86 self.next_id += 1;
87 let frame = serde_json::json!({
88 "jsonrpc": "2.0",
89 "id": id,
90 "method": method,
91 "params": params,
92 });
93 let mut encoded = serde_json::to_vec(&frame)?;
94 encoded.push(b'\n');
95 self.stdin.write_all(&encoded).await?;
96 self.stdin.flush().await?;
97
98 loop {
99 let mut line = String::new();
100 let read = tokio::time::timeout(REQUEST_TIMEOUT, self.stdout.read_line(&mut line))
101 .await
102 .map_err(|_| Error::Timeout {
103 method: method.to_string(),
104 })??;
105 if read == 0 {
106 return Err(Error::Protocol(format!(
107 "plugin exited before responding to {method}"
108 )));
109 }
110 let response: RpcResponse = serde_json::from_str(&line)?;
111 if response.id != Some(id) {
112 continue;
113 }
114 if let Some(error) = response.error {
115 return Err(Error::Rpc {
116 code: error.code,
117 message: error.message,
118 });
119 }
120 let result = response.result.ok_or_else(|| {
121 Error::Protocol(format!("plugin response to {method} missing result"))
122 })?;
123 return Ok(serde_json::from_value(result)?);
124 }
125 }
126}
127
128impl Drop for PluginClient {
129 fn drop(&mut self) {
130 let _ = self.child.start_kill();
131 }
132}
133
134#[derive(Debug, serde::Deserialize)]
135struct RpcResponse {
136 id: Option<u64>,
137 #[serde(default)]
138 result: Option<Value>,
139 #[serde(default)]
140 error: Option<RpcErrorObject>,
141}
142
143#[cfg(windows)]
144fn preserve_platform_process_env(cmd: &mut Command) {
145 for name in ["ComSpec", "SystemRoot", "PATHEXT"] {
146 if let Ok(value) = std::env::var(name) {
147 cmd.env(name, value);
148 }
149 }
150}
151
152#[cfg(not(windows))]
153fn preserve_platform_process_env(_cmd: &mut Command) {}
154
155fn spawn_with_etxtbsy_retry(command: &mut Command) -> std::io::Result<Child> {
165 const MAX_ATTEMPTS: u32 = 10;
166 let mut delay = Duration::from_millis(5);
167 for _ in 1..MAX_ATTEMPTS {
168 match command.spawn() {
169 Err(e) if is_etxtbsy(&e) => {
170 std::thread::sleep(delay);
171 delay = (delay * 2).min(Duration::from_millis(100));
172 }
173 other => return other,
174 }
175 }
176 command.spawn()
177}
178
179#[cfg(unix)]
180fn is_etxtbsy(e: &std::io::Error) -> bool {
181 e.raw_os_error() == Some(26)
184}
185
186#[cfg(not(unix))]
187fn is_etxtbsy(_e: &std::io::Error) -> bool {
188 false
189}
190
191#[cfg(all(test, unix))]
192mod tests {
193 use super::*;
194 use std::os::unix::fs::PermissionsExt;
195
196 #[tokio::test]
209 async fn spawn_retries_past_transient_etxtbsy() {
210 let dir = tempfile::tempdir().expect("tempdir");
211 let script = dir.path().join("busy-plugin");
212 std::fs::write(&script, "#!/bin/sh\nexit 0\n").expect("write script");
213 std::fs::set_permissions(&script, std::fs::Permissions::from_mode(0o755))
214 .expect("chmod script");
215
216 let held = std::fs::OpenOptions::new()
218 .append(true)
219 .open(&script)
220 .expect("hold write fd");
221 let releaser = std::thread::spawn(move || {
222 std::thread::sleep(Duration::from_millis(50));
223 drop(held);
224 });
225
226 let mut cmd = Command::new(&script);
227 cmd.env_clear();
228 let client = PluginClient::spawn_command(cmd);
229 releaser.join().expect("releaser thread");
230
231 assert!(
232 client.is_ok(),
233 "spawn must retry past a transient ETXTBSY, got: {:?}",
234 client.err()
235 );
236 }
237}