openvcs_core/
plugin_stdio.rs

1use crate::plugin_protocol::{PluginMessage, RpcRequest, RpcResponse};
2use serde::Serialize;
3use serde::de::DeserializeOwned;
4use serde_json::Value;
5use std::collections::{HashMap, VecDeque};
6use std::io::{self, BufRead, Write};
7use std::sync::{Arc, Mutex};
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone)]
11pub struct PluginError {
12    pub code: Option<String>,
13    pub message: String,
14    pub data: Option<Value>,
15}
16
17impl PluginError {
18    pub fn message(message: impl Into<String>) -> Self {
19        Self {
20            code: None,
21            message: message.into(),
22            data: None,
23        }
24    }
25
26    pub fn code(code: impl Into<String>, message: impl Into<String>) -> Self {
27        Self {
28            code: Some(code.into()),
29            message: message.into(),
30            data: None,
31        }
32    }
33
34    pub fn with_data(mut self, data: Value) -> Self {
35        self.data = Some(data);
36        self
37    }
38}
39
40pub fn err_display(err: impl std::fmt::Display) -> PluginError {
41    PluginError::message(err.to_string())
42}
43
44pub fn receive_message<R: BufRead>(stdin: &mut R) -> Option<PluginMessage> {
45    let mut line = String::new();
46    loop {
47        line.clear();
48        let n = stdin.read_line(&mut line).ok()?;
49        if n == 0 {
50            return None;
51        }
52        let trimmed = line.trim();
53        if trimmed.is_empty() {
54            continue;
55        }
56        if let Ok(msg) = serde_json::from_str::<PluginMessage>(trimmed) {
57            return Some(msg);
58        }
59    }
60}
61
62pub fn write_message<W: Write>(out: &mut W, msg: &PluginMessage) -> io::Result<()> {
63    let line = serde_json::to_string(msg).unwrap_or_else(|_| "{}".into());
64    writeln!(out, "{line}")?;
65    out.flush()?;
66    Ok(())
67}
68
69pub fn send_message_shared<W: Write>(out: &Arc<Mutex<W>>, msg: &PluginMessage) {
70    if let Ok(mut w) = out.lock() {
71        let _ = write_message(&mut *w, msg);
72    }
73}
74
75pub fn send_request_shared<W: Write>(out: &Arc<Mutex<W>>, req: RpcRequest) {
76    send_message_shared(out, &PluginMessage::Request(req));
77}
78
79pub fn send_request<W: Write>(out: &mut W, req: RpcRequest) -> io::Result<()> {
80    write_message(out, &PluginMessage::Request(req))
81}
82
83pub fn receive_request<R: BufRead>(stdin: &mut R) -> Option<RpcRequest> {
84    loop {
85        match receive_message(stdin)? {
86            PluginMessage::Request(req) => return Some(req),
87            PluginMessage::Response(_) | PluginMessage::Event { .. } => {}
88        }
89    }
90}
91
92pub fn respond_shared<W: Write>(out: &Arc<Mutex<W>>, id: u64, res: Result<Value, PluginError>) {
93    let response = match res {
94        Ok(result) => RpcResponse {
95            id,
96            ok: true,
97            result,
98            error: None,
99            error_code: None,
100            error_data: None,
101        },
102        Err(err) => RpcResponse {
103            id,
104            ok: false,
105            result: Value::Null,
106            error: Some(err.message),
107            error_code: err.code,
108            error_data: err.data,
109        },
110    };
111
112    send_message_shared(out, &PluginMessage::Response(response));
113}
114
115pub fn ok<T: Serialize>(value: T) -> Result<Value, PluginError> {
116    serde_json::to_value(value).map_err(|e| PluginError::code("plugin.serialize", e.to_string()))
117}
118
119pub fn ok_null() -> Result<Value, PluginError> {
120    Ok(Value::Null)
121}
122
123pub fn parse_json_params<T: DeserializeOwned>(value: Value) -> Result<T, String> {
124    serde_json::from_value(value).map_err(|e| format!("invalid params: {e}"))
125}
126
127#[derive(Debug)]
128pub struct RequestIdState {
129    pub next_id: u64,
130}
131
132pub fn call_host<W: Write, R: BufRead>(
133    out: &Arc<Mutex<W>>,
134    stdin: &Arc<Mutex<R>>,
135    queue: &Arc<Mutex<VecDeque<crate::plugin_protocol::RpcRequest>>>,
136    ids: &Arc<Mutex<RequestIdState>>,
137    method: &str,
138    params: Value,
139    timeout: Duration,
140) -> Result<Value, PluginError> {
141    let id = {
142        let mut lock = ids
143            .lock()
144            .map_err(|_| PluginError::message("pending lock poisoned"))?;
145        let id = lock.next_id;
146        lock.next_id = lock.next_id.saturating_add(1);
147        id
148    };
149
150    send_request_shared(
151        out,
152        crate::plugin_protocol::RpcRequest {
153            id,
154            method: method.to_string(),
155            params,
156        },
157    );
158
159    let deadline = Instant::now() + timeout;
160    let mut stash: HashMap<u64, RpcResponse> = HashMap::new();
161
162    loop {
163        if Instant::now() > deadline {
164            return Err(PluginError::code("host.timeout", "host call timed out"));
165        }
166
167        if let Some(resp) = stash.remove(&id) {
168            return if resp.ok {
169                Ok(resp.result)
170            } else {
171                Err(PluginError {
172                    code: resp.error_code.or(Some("host.error".into())),
173                    message: resp.error.unwrap_or_else(|| "error".into()),
174                    data: resp.error_data,
175                })
176            };
177        }
178
179        let msg = {
180            let mut lock = stdin
181                .lock()
182                .map_err(|_| PluginError::message("stdin lock poisoned"))?;
183            receive_message(&mut *lock).ok_or_else(|| PluginError::message("host closed stdin"))?
184        };
185
186        match msg {
187            PluginMessage::Response(resp) => {
188                if resp.id == id {
189                    return if resp.ok {
190                        Ok(resp.result)
191                    } else {
192                        Err(PluginError {
193                            code: resp.error_code.or(Some("host.error".into())),
194                            message: resp.error.unwrap_or_else(|| "error".into()),
195                            data: resp.error_data,
196                        })
197                    };
198                }
199                stash.insert(resp.id, resp);
200            }
201            PluginMessage::Request(req) => {
202                if let Ok(mut q) = queue.lock() {
203                    q.push_back(req);
204                }
205            }
206            PluginMessage::Event { .. } => {}
207        }
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::models::VcsEvent;
215    use std::io::Cursor;
216
217    #[test]
218    fn plugin_error_builders_set_fields() {
219        let err = PluginError::message("nope");
220        assert!(err.code.is_none());
221        assert_eq!(err.message, "nope");
222        assert!(err.data.is_none());
223
224        let err = PluginError::code("x.y", "bad").with_data(serde_json::json!({"k": 1}));
225        assert_eq!(err.code.as_deref(), Some("x.y"));
226        assert_eq!(err.message, "bad");
227        assert_eq!(err.data, Some(serde_json::json!({"k": 1})));
228    }
229
230    #[test]
231    fn receive_message_skips_blank_and_invalid_lines() {
232        let input = b"\n   \nnot json\n{\"id\":1,\"method\":\"ping\"}\n";
233        let mut cursor = Cursor::new(&input[..]);
234
235        let msg = receive_message(&mut cursor).expect("message");
236        match msg {
237            PluginMessage::Request(req) => {
238                assert_eq!(req.id, 1);
239                assert_eq!(req.method, "ping");
240                assert_eq!(req.params, Value::Null);
241            }
242            other => panic!("unexpected message: {other:?}"),
243        }
244    }
245
246    #[test]
247    fn receive_request_ignores_non_request_messages() {
248        let input = b"{\"id\":7,\"ok\":true,\"result\":null}\n{\"event\":{\"type\":\"info\",\"msg\":\"hi\"}}\n{\"id\":1,\"method\":\"ping\"}\n";
249        let mut cursor = Cursor::new(&input[..]);
250
251        let req = receive_request(&mut cursor).expect("request");
252        assert_eq!(req.id, 1);
253        assert_eq!(req.method, "ping");
254    }
255
256    #[test]
257    fn write_message_writes_one_json_line() {
258        let msg = PluginMessage::Event {
259            event: VcsEvent::Info {
260                msg: "hello".into(),
261            },
262        };
263
264        let mut out = Vec::<u8>::new();
265        write_message(&mut out, &msg).expect("write ok");
266        assert!(out.ends_with(b"\n"));
267
268        let line = std::str::from_utf8(&out).expect("utf-8");
269        let parsed: PluginMessage = serde_json::from_str(line.trim()).expect("valid message");
270        match parsed {
271            PluginMessage::Event { event } => match event {
272                VcsEvent::Info { msg } => assert_eq!(msg, "hello"),
273                other => panic!("unexpected event: {other:?}"),
274            },
275            other => panic!("unexpected message: {other:?}"),
276        }
277    }
278
279    #[test]
280    fn parse_json_params_errors_are_prefixed() {
281        let err = parse_json_params::<serde_json::Map<String, Value>>(Value::String("x".into()))
282            .expect_err("should fail");
283        assert!(err.starts_with("invalid params:"));
284    }
285
286    #[test]
287    fn call_host_returns_ok_and_queues_incoming_requests() {
288        let out = Arc::new(Mutex::new(Vec::<u8>::new()));
289        let stdin = Arc::new(Mutex::new(Cursor::new(
290            b"{\"id\":999,\"ok\":true,\"result\":{\"ignored\":true}}\n\
291              {\"id\":77,\"method\":\"noop\",\"params\":null}\n\
292              {\"id\":5,\"ok\":true,\"result\":{\"answer\":42}}\n" as &[u8],
293        )));
294        let queue = Arc::new(Mutex::new(VecDeque::<RpcRequest>::new()));
295        let ids = Arc::new(Mutex::new(RequestIdState { next_id: 5 }));
296
297        let result = call_host(
298            &out,
299            &stdin,
300            &queue,
301            &ids,
302            "math.answer",
303            serde_json::json!({}),
304            Duration::from_secs(1),
305        )
306        .expect("host call ok");
307        assert_eq!(result, serde_json::json!({"answer": 42}));
308
309        let queue = queue.lock().expect("queue lock");
310        assert_eq!(queue.len(), 1);
311        assert_eq!(queue[0].id, 77);
312        assert_eq!(queue[0].method, "noop");
313
314        let out = out.lock().expect("out lock");
315        let line = std::str::from_utf8(&out).expect("utf-8");
316        let first = line.lines().next().expect("at least one line");
317        let sent: PluginMessage = serde_json::from_str(first).expect("valid sent message");
318        match sent {
319            PluginMessage::Request(req) => {
320                assert_eq!(req.id, 5);
321                assert_eq!(req.method, "math.answer");
322            }
323            other => panic!("unexpected sent message: {other:?}"),
324        }
325    }
326}