command_autocomplete/
connection.rs

1use crate::types::{Message, Request, RequestId, Response};
2use serde::{Deserialize, Serialize};
3use serde_json::json;
4use std::collections::HashMap;
5use std::io::{BufRead, BufReader, Read, Write};
6use std::sync::mpsc::{Receiver, SyncSender};
7use std::sync::{Arc, Mutex};
8
9// TODO: make it a trait
10#[derive(Default)]
11pub struct IdGenerator {
12    next: Mutex<i32>,
13}
14
15impl IdGenerator {
16    pub fn next(&self) -> RequestId {
17        let mut x = self.next.lock().unwrap();
18        let id = RequestId(format!("{}", x));
19        *x += 1;
20        id
21    }
22}
23
24struct ResponseCallback {
25    callback: Box<dyn FnOnce(Response) + Send + 'static>,
26    shutdown: bool,
27}
28
29// Internal state of the connection
30#[derive(Default)]
31struct ConnectionState {
32    responses: Mutex<HashMap<RequestId, ResponseCallback>>,
33}
34
35#[derive(Clone)]
36pub struct ConnectionSender {
37    ids: Arc<IdGenerator>,
38    state: Arc<ConnectionState>,
39    sender: SyncSender<Message>,
40}
41
42// shutdown
43// - send shutdown request
44// - disallow sending new requests
45// - close connection
46
47// when A sends shutdown
48// - A can't send any new requests
49// - when A receives response to shutdown, no new messages should be received
50// - A can close receiver and sender
51
52// when B receives shutdown
53// - it knows it will not receive any new requests
54// - it should respond to any active requests
55// - when all active requests are responded to, and no new requests are coming
56//   it should respond to the 'shutdown'
57// - B can close receiver and sender
58
59// premature shutdown
60
61pub struct ResponseHandle<R> {
62    receiver: Receiver<Result<R, ResponseError>>,
63}
64
65#[derive(Debug)]
66pub enum ResponseError {
67    /// Error received by the other side.
68    Err(crate::types::Error),
69    /// The connection has been closed and response will not be received.
70    ChannelClosed,
71    /// The received response failed deserialization into provided type.
72    DeserializationError(serde_json::Error),
73}
74
75impl<R> ResponseHandle<R> {
76    pub fn wait(self) -> Result<R, ResponseError> {
77        match self.receiver.recv() {
78            Ok(Ok(result)) => Ok(result),
79            Ok(Err(e)) => Err(e),
80            Err(_) => Err(ResponseError::ChannelClosed),
81        }
82    }
83}
84
85impl std::fmt::Display for ResponseError {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        match self {
88            ResponseError::Err(err) => write!(f, "received error: {:?}", err),
89            ResponseError::ChannelClosed => {
90                write!(f, "response not received, channel has been closed")
91            }
92            ResponseError::DeserializationError(err) => {
93                write!(f, "response has unexpected result type: {err}")
94            }
95        }
96    }
97}
98
99impl std::error::Error for ResponseError {}
100
101impl ConnectionSender {
102    /// Sends the request to the other side of the connection.
103    ///
104    /// Returns a ResponseHandle, that will return a response when received.
105    /// Note that for the response to be received, the ConnectionReceiver has to
106    /// be continuously looped over for new requests.
107    ///
108    /// If the connection is already closed, error is returned.
109    pub fn send<R: for<'a> Deserialize<'a> + 'static + Send>(
110        &self,
111        method: impl Into<String>,
112        params: impl Serialize,
113    ) -> Result<ResponseHandle<R>, SendError> {
114        let id = self.ids.next();
115
116        let method: String = method.into();
117        let shutdown = method == "shutdown";
118
119        self.sender
120            .send(Request::new(id.clone(), method, params).into())
121            .map_err(|_| SendError {})?;
122
123        let (tx, rx) = std::sync::mpsc::sync_channel(0);
124
125        let callback = ResponseCallback {
126            callback: Box::new(move |response: Response| {
127                let r: Result<R, ResponseError> = match response {
128                    Response::Ok { id: _, result } => {
129                        serde_json::from_value(result).map_err(ResponseError::DeserializationError)
130                    }
131                    Response::Err { id: _, error } => Err(ResponseError::Err(error)),
132                };
133                if tx.send(r).is_err() {
134                    log::debug!("response ignored, response handle was dropped");
135                }
136            }),
137            shutdown,
138        };
139
140        self.state.responses.lock().unwrap().insert(id, callback);
141        Ok(ResponseHandle { receiver: rx })
142    }
143
144    /// Sends shutdown request to the other side.
145    ///
146    /// No new requests are allowed to be send after this call.
147    pub fn shutdown(self) -> Result<ResponseHandle<serde_json::Value>, SendError> {
148        self.send("shutdown", json!({}))
149    }
150}
151
152pub struct ConnectionReceiver {
153    state: Arc<ConnectionState>,
154    receiver: Receiver<Message>,
155    sender: SyncSender<Message>,
156    shutdown: Mutex<bool>,
157}
158
159pub struct ConnRequest {
160    inner: Request,
161    sender: SyncSender<Message>,
162}
163
164impl ConnRequest {
165    pub fn inner(&self) -> &Request {
166        &self.inner
167    }
168
169    pub fn reply<R: Serialize>(
170        self,
171        response: Result<R, crate::types::Error>,
172    ) -> Result<(), SendError> {
173        match response {
174            Ok(result) => self.reply_ok(result),
175            Err(err) => self.reply_err(err),
176        }
177    }
178
179    pub fn reply_ok<R: Serialize>(self, result: R) -> Result<(), SendError> {
180        let response = Response::new_ok(self.inner.id, result);
181        self.sender
182            .send(Message::Response(response))
183            .map_err(|_| SendError {})
184    }
185    pub fn reply_err(self, err: crate::types::Error) -> Result<(), SendError> {
186        let response = Response::new_err(self.inner.id, err);
187        self.sender
188            .send(Message::Response(response))
189            .map_err(|_| SendError {})
190    }
191}
192
193impl ConnectionReceiver {
194    // Note: This has to be called / polled continuously to ensure the
195    // responses are populated
196    // returns None when the connection is closed
197    pub fn next_request(&self) -> Option<ConnRequest> {
198        if *self.shutdown.lock().unwrap() {
199            return None;
200        }
201        while let Ok(msg) = self.receiver.recv() {
202            match msg {
203                Message::Request(req) => {
204                    return Some(ConnRequest {
205                        inner: req,
206                        sender: self.sender.clone(),
207                    })
208                }
209                Message::Response(res) => {
210                    let mut r = self.state.responses.lock().unwrap();
211                    let Some(callback) = r.remove(res.id()) else {
212                        log::warn!(
213                            "Received response for id {:?}, but such request was never sent",
214                            res.id()
215                        );
216                        return None;
217                    };
218                    (callback.callback)(res);
219                    if callback.shutdown {
220                        let mut x = self.shutdown.lock().unwrap();
221                        *x = true;
222                        return None;
223                    }
224                }
225            }
226        }
227        None
228    }
229}
230
231pub fn new_connection(transport: Transport) -> (ConnectionSender, ConnectionReceiver) {
232    let state = Arc::new(ConnectionState::default());
233    (
234        ConnectionSender {
235            ids: Default::default(),
236            state: state.clone(),
237            sender: transport.sender.clone(),
238        },
239        ConnectionReceiver {
240            state,
241            receiver: transport.receiver,
242            sender: transport.sender,
243            shutdown: Default::default(),
244        },
245    )
246}
247
248pub struct Transport {
249    receiver: Receiver<Message>,
250    sender: SyncSender<Message>,
251}
252
253#[derive(Debug)]
254pub struct SendError {}
255
256impl std::fmt::Display for SendError {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        write!(f, "failed to send a message, the channel is already closed")
259    }
260}
261
262impl std::error::Error for SendError {}
263
264pub struct JoinHandle {
265    read_join: std::thread::JoinHandle<()>,
266    write_join: std::thread::JoinHandle<()>,
267}
268
269impl JoinHandle {
270    pub fn join(self) -> anyhow::Result<()> {
271        self.read_join.join().unwrap();
272        self.write_join.join().unwrap();
273        Ok(())
274    }
275}
276
277impl Transport {
278    pub fn stdio() -> (Transport, JoinHandle) {
279        Self::raw(std::io::stdin(), std::io::stdout())
280    }
281
282    pub fn raw<R: Read + Send + 'static, W: Write + Send + 'static>(
283        read: R,
284        write: W,
285    ) -> (Transport, JoinHandle) {
286        let (read_tx, read_rx) = std::sync::mpsc::sync_channel(0);
287        let read_join = std::thread::spawn(move || {
288            if let Err(err) = read_loop(read, read_tx) {
289                log::error!("read_loop err: {err}");
290            }
291        });
292        let (write_tx, write_rx) = std::sync::mpsc::sync_channel(0);
293        let write_join = std::thread::spawn(move || {
294            if let Err(err) = write_loop(write, write_rx) {
295                log::error!("write_loop err: {err}");
296            }
297        });
298        (
299            Transport {
300                receiver: read_rx,
301                sender: write_tx,
302            },
303            JoinHandle {
304                read_join,
305                write_join,
306            },
307        )
308    }
309
310    pub fn send(&self, message: Message) -> Result<(), SendError> {
311        self.sender.send(message).map_err(|_| SendError {})
312    }
313
314    // TODO: should Iterator be used here?
315    // TODO: should Result be returned, to differentiate error from
316    // cleanly closed channel?
317    pub fn next_message(&self) -> Option<Message> {
318        self.receiver.recv().ok()
319    }
320}
321
322fn read_loop<R: Read>(read: R, sender: SyncSender<Message>) -> anyhow::Result<()> {
323    let reader = BufReader::new(read);
324    for line in reader.lines() {
325        let msg: Message = serde_json::from_str(&line?)?;
326        log::trace!("received: {:?}", msg);
327        sender.send(msg)?;
328    }
329    log::debug!("read_loop: finished");
330    Ok(())
331}
332
333fn write_loop<W: Write>(mut write: W, receiver: Receiver<Message>) -> anyhow::Result<()> {
334    loop {
335        let Ok(msg) = receiver.recv() else {
336            break;
337        };
338        log::trace!("sending: {:?}", msg);
339        let mut b = serde_json::to_vec(&msg)?;
340        b.push(b'\n');
341        write.write_all(&b)?;
342        write.flush()?;
343    }
344    log::debug!("write_loop: finished");
345    Ok(())
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use googletest::prelude::*;
352    use serde_json::json;
353    use std::{collections::VecDeque, io::Cursor, sync::mpsc::Sender};
354    use test_log::test;
355
356    struct PipeRead {
357        state: VecDeque<u8>,
358        receiver: Receiver<Vec<u8>>,
359    }
360
361    impl Read for PipeRead {
362        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
363            if self.state.is_empty() {
364                if let Ok(v) = self.receiver.recv() {
365                    self.state.extend(&v);
366                }
367            }
368            self.state.read(buf)
369        }
370    }
371
372    struct PipeWrite {
373        state: Vec<u8>,
374        sender: Sender<Vec<u8>>,
375    }
376
377    impl Write for PipeWrite {
378        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
379            self.state.write(buf)
380        }
381
382        fn flush(&mut self) -> std::io::Result<()> {
383            let mut val = vec![];
384            std::mem::swap(&mut self.state, &mut val);
385            if !val.is_empty() {
386                self.sender
387                    .send(val)
388                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::UnexpectedEof, e))?;
389            }
390            Ok(())
391        }
392    }
393
394    impl Drop for PipeWrite {
395        fn drop(&mut self) {
396            let mut val = vec![];
397            std::mem::swap(&mut self.state, &mut val);
398            if !val.is_empty() {
399                // TODO: handle unwrap
400                self.sender.send(val).unwrap();
401            }
402        }
403    }
404
405    fn pipe() -> (PipeWrite, PipeRead) {
406        let (tx, rx) = std::sync::mpsc::channel();
407        (
408            PipeWrite {
409                state: Default::default(),
410                sender: tx,
411            },
412            PipeRead {
413                state: Default::default(),
414                receiver: rx,
415            },
416        )
417    }
418
419    #[test(gtest)]
420    fn reads_one_message() {
421        let input =
422            serde_json::to_vec(&json!({"id": "1", "method": "complete", "params":{}})).unwrap();
423        let c = Cursor::new(input);
424        let output: Vec<u8> = Vec::new();
425        let (t, join_handles) = Transport::raw(c, output);
426        expect_that!(t.next_message(), some(anything()));
427        expect_that!(t.next_message(), none());
428        join_handles.join().unwrap();
429    }
430
431    #[test(gtest)]
432    fn writes_one_message() {
433        let (pipe_w, mut pipe_r) = pipe();
434        let c = Cursor::new(vec![]);
435        let (t, join_handles) = Transport::raw(c, pipe_w);
436        let response = Message::Response(Response::new_err(
437            RequestId("1".into()),
438            crate::types::Error::internal("test"),
439        ));
440        t.send(response.clone()).unwrap();
441        // Drop, to ensure that the pipe is closed (otherwise below read_to_end will never finish).
442        drop(t);
443        let mut output = vec![];
444        pipe_r.read_to_end(&mut output).unwrap();
445        let mut expected = serde_json::to_vec(&response).unwrap();
446        expected.push(b'\n');
447        expect_that!(output, eq(&expected));
448        join_handles.join().unwrap();
449    }
450}