1use crate::plugin_protocol::{PluginMessage, RpcRequest, RpcResponse};
2use serde::Serialize;
3use serde::de::DeserializeOwned;
4use serde_json::Value;
5use std::collections::{HashMap, VecDeque};
6use std::io::{self, BufRead, Write};
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone)]
11pub struct PluginError {
12 pub code: Option<String>,
13 pub message: String,
14 pub data: Option<Value>,
15}
16
17impl PluginError {
18 pub fn message(message: impl Into<String>) -> Self {
19 Self {
20 code: None,
21 message: message.into(),
22 data: None,
23 }
24 }
25
26 pub fn code(code: impl Into<String>, message: impl Into<String>) -> Self {
27 Self {
28 code: Some(code.into()),
29 message: message.into(),
30 data: None,
31 }
32 }
33
34 pub fn with_data(mut self, data: Value) -> Self {
35 self.data = Some(data);
36 self
37 }
38}
39
40pub fn err_display(err: impl std::fmt::Display) -> PluginError {
41 PluginError::message(err.to_string())
42}
43
44pub fn receive_message<R: BufRead>(stdin: &mut R) -> Option<PluginMessage> {
45 let mut line = String::new();
46 loop {
47 line.clear();
48 let n = stdin.read_line(&mut line).ok()?;
49 if n == 0 {
50 return None;
51 }
52 let trimmed = line.trim();
53 if trimmed.is_empty() {
54 continue;
55 }
56 if let Ok(msg) = serde_json::from_str::<PluginMessage>(trimmed) {
57 return Some(msg);
58 }
59 }
60}
61
62pub fn write_message<W: Write>(out: &mut W, msg: &PluginMessage) -> io::Result<()> {
63 let line = serde_json::to_string(msg).unwrap_or_else(|_| "{}".into());
64 writeln!(out, "{line}")?;
65 out.flush()?;
66 Ok(())
67}
68
69pub fn send_message_shared<W: Write>(out: &Arc<Mutex<W>>, msg: &PluginMessage) {
70 if let Ok(mut w) = out.lock() {
71 let _ = write_message(&mut *w, msg);
72 }
73}
74
75pub fn send_request_shared<W: Write>(out: &Arc<Mutex<W>>, req: RpcRequest) {
76 send_message_shared(out, &PluginMessage::Request(req));
77}
78
79pub fn send_request<W: Write>(out: &mut W, req: RpcRequest) -> io::Result<()> {
80 write_message(out, &PluginMessage::Request(req))
81}
82
83pub fn receive_request<R: BufRead>(stdin: &mut R) -> Option<RpcRequest> {
84 loop {
85 match receive_message(stdin)? {
86 PluginMessage::Request(req) => return Some(req),
87 PluginMessage::Response(_) | PluginMessage::Event { .. } => {}
88 }
89 }
90}
91
92pub fn respond_shared<W: Write>(out: &Arc<Mutex<W>>, id: u64, res: Result<Value, PluginError>) {
93 let response = match res {
94 Ok(result) => RpcResponse {
95 id,
96 ok: true,
97 result,
98 error: None,
99 error_code: None,
100 error_data: None,
101 },
102 Err(err) => RpcResponse {
103 id,
104 ok: false,
105 result: Value::Null,
106 error: Some(err.message),
107 error_code: err.code,
108 error_data: err.data,
109 },
110 };
111
112 send_message_shared(out, &PluginMessage::Response(response));
113}
114
115pub fn ok<T: Serialize>(value: T) -> Result<Value, PluginError> {
116 serde_json::to_value(value).map_err(|e| PluginError::code("plugin.serialize", e.to_string()))
117}
118
119pub fn ok_null() -> Result<Value, PluginError> {
120 Ok(Value::Null)
121}
122
123pub fn parse_json_params<T: DeserializeOwned>(value: Value) -> Result<T, String> {
124 serde_json::from_value(value).map_err(|e| format!("invalid params: {e}"))
125}
126
127#[derive(Debug)]
128pub struct RequestIdState {
129 pub next_id: u64,
130}
131
132pub fn call_host<W: Write, R: BufRead>(
133 out: &Arc<Mutex<W>>,
134 stdin: &Arc<Mutex<R>>,
135 queue: &Arc<Mutex<VecDeque<crate::plugin_protocol::RpcRequest>>>,
136 ids: &Arc<Mutex<RequestIdState>>,
137 method: &str,
138 params: Value,
139 timeout: Duration,
140) -> Result<Value, PluginError> {
141 let id = {
142 let mut lock = ids
143 .lock()
144 .map_err(|_| PluginError::message("pending lock poisoned"))?;
145 let id = lock.next_id;
146 lock.next_id = lock.next_id.saturating_add(1);
147 id
148 };
149
150 send_request_shared(
151 out,
152 crate::plugin_protocol::RpcRequest {
153 id,
154 method: method.to_string(),
155 params,
156 },
157 );
158
159 let deadline = Instant::now() + timeout;
160 let mut stash: HashMap<u64, RpcResponse> = HashMap::new();
161
162 loop {
163 if Instant::now() > deadline {
164 return Err(PluginError::code("host.timeout", "host call timed out"));
165 }
166
167 if let Some(resp) = stash.remove(&id) {
168 return if resp.ok {
169 Ok(resp.result)
170 } else {
171 Err(PluginError {
172 code: resp.error_code.or(Some("host.error".into())),
173 message: resp.error.unwrap_or_else(|| "error".into()),
174 data: resp.error_data,
175 })
176 };
177 }
178
179 let msg = {
180 let mut lock = stdin
181 .lock()
182 .map_err(|_| PluginError::message("stdin lock poisoned"))?;
183 receive_message(&mut *lock).ok_or_else(|| PluginError::message("host closed stdin"))?
184 };
185
186 match msg {
187 PluginMessage::Response(resp) => {
188 if resp.id == id {
189 return if resp.ok {
190 Ok(resp.result)
191 } else {
192 Err(PluginError {
193 code: resp.error_code.or(Some("host.error".into())),
194 message: resp.error.unwrap_or_else(|| "error".into()),
195 data: resp.error_data,
196 })
197 };
198 }
199 stash.insert(resp.id, resp);
200 }
201 PluginMessage::Request(req) => {
202 if let Ok(mut q) = queue.lock() {
203 q.push_back(req);
204 }
205 }
206 PluginMessage::Event { .. } => {}
207 }
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::models::VcsEvent;
215 use std::io::Cursor;
216
217 #[test]
218 fn plugin_error_builders_set_fields() {
219 let err = PluginError::message("nope");
220 assert!(err.code.is_none());
221 assert_eq!(err.message, "nope");
222 assert!(err.data.is_none());
223
224 let err = PluginError::code("x.y", "bad").with_data(serde_json::json!({"k": 1}));
225 assert_eq!(err.code.as_deref(), Some("x.y"));
226 assert_eq!(err.message, "bad");
227 assert_eq!(err.data, Some(serde_json::json!({"k": 1})));
228 }
229
230 #[test]
231 fn receive_message_skips_blank_and_invalid_lines() {
232 let input = b"\n \nnot json\n{\"id\":1,\"method\":\"ping\"}\n";
233 let mut cursor = Cursor::new(&input[..]);
234
235 let msg = receive_message(&mut cursor).expect("message");
236 match msg {
237 PluginMessage::Request(req) => {
238 assert_eq!(req.id, 1);
239 assert_eq!(req.method, "ping");
240 assert_eq!(req.params, Value::Null);
241 }
242 other => panic!("unexpected message: {other:?}"),
243 }
244 }
245
246 #[test]
247 fn receive_request_ignores_non_request_messages() {
248 let input = b"{\"id\":7,\"ok\":true,\"result\":null}\n{\"event\":{\"type\":\"info\",\"msg\":\"hi\"}}\n{\"id\":1,\"method\":\"ping\"}\n";
249 let mut cursor = Cursor::new(&input[..]);
250
251 let req = receive_request(&mut cursor).expect("request");
252 assert_eq!(req.id, 1);
253 assert_eq!(req.method, "ping");
254 }
255
256 #[test]
257 fn write_message_writes_one_json_line() {
258 let msg = PluginMessage::Event {
259 event: VcsEvent::Info {
260 msg: "hello".into(),
261 },
262 };
263
264 let mut out = Vec::<u8>::new();
265 write_message(&mut out, &msg).expect("write ok");
266 assert!(out.ends_with(b"\n"));
267
268 let line = std::str::from_utf8(&out).expect("utf-8");
269 let parsed: PluginMessage = serde_json::from_str(line.trim()).expect("valid message");
270 match parsed {
271 PluginMessage::Event { event } => match event {
272 VcsEvent::Info { msg } => assert_eq!(msg, "hello"),
273 other => panic!("unexpected event: {other:?}"),
274 },
275 other => panic!("unexpected message: {other:?}"),
276 }
277 }
278
279 #[test]
280 fn parse_json_params_errors_are_prefixed() {
281 let err = parse_json_params::<serde_json::Map<String, Value>>(Value::String("x".into()))
282 .expect_err("should fail");
283 assert!(err.starts_with("invalid params:"));
284 }
285
286 #[test]
287 fn call_host_returns_ok_and_queues_incoming_requests() {
288 let out = Arc::new(Mutex::new(Vec::<u8>::new()));
289 let stdin = Arc::new(Mutex::new(Cursor::new(
290 b"{\"id\":999,\"ok\":true,\"result\":{\"ignored\":true}}\n\
291 {\"id\":77,\"method\":\"noop\",\"params\":null}\n\
292 {\"id\":5,\"ok\":true,\"result\":{\"answer\":42}}\n" as &[u8],
293 )));
294 let queue = Arc::new(Mutex::new(VecDeque::<RpcRequest>::new()));
295 let ids = Arc::new(Mutex::new(RequestIdState { next_id: 5 }));
296
297 let result = call_host(
298 &out,
299 &stdin,
300 &queue,
301 &ids,
302 "math.answer",
303 serde_json::json!({}),
304 Duration::from_secs(1),
305 )
306 .expect("host call ok");
307 assert_eq!(result, serde_json::json!({"answer": 42}));
308
309 let queue = queue.lock().expect("queue lock");
310 assert_eq!(queue.len(), 1);
311 assert_eq!(queue[0].id, 77);
312 assert_eq!(queue[0].method, "noop");
313
314 let out = out.lock().expect("out lock");
315 let line = std::str::from_utf8(&out).expect("utf-8");
316 let first = line.lines().next().expect("at least one line");
317 let sent: PluginMessage = serde_json::from_str(first).expect("valid sent message");
318 match sent {
319 PluginMessage::Request(req) => {
320 assert_eq!(req.id, 5);
321 assert_eq!(req.method, "math.answer");
322 }
323 other => panic!("unexpected sent message: {other:?}"),
324 }
325 }
326}