1use crate::types::{Message, Request, RequestId, Response};
2use serde::{Deserialize, Serialize};
3use serde_json::json;
4use std::collections::HashMap;
5use std::io::{BufRead, BufReader, Read, Write};
6use std::sync::mpsc::{Receiver, SyncSender};
7use std::sync::{Arc, Mutex};
8
9#[derive(Default)]
11pub struct IdGenerator {
12 next: Mutex<i32>,
13}
14
15impl IdGenerator {
16 pub fn next(&self) -> RequestId {
17 let mut x = self.next.lock().unwrap();
18 let id = RequestId(format!("{}", x));
19 *x += 1;
20 id
21 }
22}
23
24struct ResponseCallback {
25 callback: Box<dyn FnOnce(Response) + Send + 'static>,
26 shutdown: bool,
27}
28
29#[derive(Default)]
31struct ConnectionState {
32 responses: Mutex<HashMap<RequestId, ResponseCallback>>,
33}
34
35#[derive(Clone)]
36pub struct ConnectionSender {
37 ids: Arc<IdGenerator>,
38 state: Arc<ConnectionState>,
39 sender: SyncSender<Message>,
40}
41
42pub struct ResponseHandle<R> {
62 receiver: Receiver<Result<R, ResponseError>>,
63}
64
65#[derive(Debug)]
66pub enum ResponseError {
67 Err(crate::types::Error),
69 ChannelClosed,
71 DeserializationError(serde_json::Error),
73}
74
75impl<R> ResponseHandle<R> {
76 pub fn wait(self) -> Result<R, ResponseError> {
77 match self.receiver.recv() {
78 Ok(Ok(result)) => Ok(result),
79 Ok(Err(e)) => Err(e),
80 Err(_) => Err(ResponseError::ChannelClosed),
81 }
82 }
83}
84
85impl std::fmt::Display for ResponseError {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 match self {
88 ResponseError::Err(err) => write!(f, "received error: {:?}", err),
89 ResponseError::ChannelClosed => {
90 write!(f, "response not received, channel has been closed")
91 }
92 ResponseError::DeserializationError(err) => {
93 write!(f, "response has unexpected result type: {err}")
94 }
95 }
96 }
97}
98
99impl std::error::Error for ResponseError {}
100
101impl ConnectionSender {
102 pub fn send<R: for<'a> Deserialize<'a> + 'static + Send>(
110 &self,
111 method: impl Into<String>,
112 params: impl Serialize,
113 ) -> Result<ResponseHandle<R>, SendError> {
114 let id = self.ids.next();
115
116 let method: String = method.into();
117 let shutdown = method == "shutdown";
118
119 self.sender
120 .send(Request::new(id.clone(), method, params).into())
121 .map_err(|_| SendError {})?;
122
123 let (tx, rx) = std::sync::mpsc::sync_channel(0);
124
125 let callback = ResponseCallback {
126 callback: Box::new(move |response: Response| {
127 let r: Result<R, ResponseError> = match response {
128 Response::Ok { id: _, result } => {
129 serde_json::from_value(result).map_err(ResponseError::DeserializationError)
130 }
131 Response::Err { id: _, error } => Err(ResponseError::Err(error)),
132 };
133 if tx.send(r).is_err() {
134 log::debug!("response ignored, response handle was dropped");
135 }
136 }),
137 shutdown,
138 };
139
140 self.state.responses.lock().unwrap().insert(id, callback);
141 Ok(ResponseHandle { receiver: rx })
142 }
143
144 pub fn shutdown(self) -> Result<ResponseHandle<serde_json::Value>, SendError> {
148 self.send("shutdown", json!({}))
149 }
150}
151
152pub struct ConnectionReceiver {
153 state: Arc<ConnectionState>,
154 receiver: Receiver<Message>,
155 sender: SyncSender<Message>,
156 shutdown: Mutex<bool>,
157}
158
159pub struct ConnRequest {
160 inner: Request,
161 sender: SyncSender<Message>,
162}
163
164impl ConnRequest {
165 pub fn inner(&self) -> &Request {
166 &self.inner
167 }
168
169 pub fn reply<R: Serialize>(
170 self,
171 response: Result<R, crate::types::Error>,
172 ) -> Result<(), SendError> {
173 match response {
174 Ok(result) => self.reply_ok(result),
175 Err(err) => self.reply_err(err),
176 }
177 }
178
179 pub fn reply_ok<R: Serialize>(self, result: R) -> Result<(), SendError> {
180 let response = Response::new_ok(self.inner.id, result);
181 self.sender
182 .send(Message::Response(response))
183 .map_err(|_| SendError {})
184 }
185 pub fn reply_err(self, err: crate::types::Error) -> Result<(), SendError> {
186 let response = Response::new_err(self.inner.id, err);
187 self.sender
188 .send(Message::Response(response))
189 .map_err(|_| SendError {})
190 }
191}
192
193impl ConnectionReceiver {
194 pub fn next_request(&self) -> Option<ConnRequest> {
198 if *self.shutdown.lock().unwrap() {
199 return None;
200 }
201 while let Ok(msg) = self.receiver.recv() {
202 match msg {
203 Message::Request(req) => {
204 return Some(ConnRequest {
205 inner: req,
206 sender: self.sender.clone(),
207 })
208 }
209 Message::Response(res) => {
210 let mut r = self.state.responses.lock().unwrap();
211 let Some(callback) = r.remove(res.id()) else {
212 log::warn!(
213 "Received response for id {:?}, but such request was never sent",
214 res.id()
215 );
216 return None;
217 };
218 (callback.callback)(res);
219 if callback.shutdown {
220 let mut x = self.shutdown.lock().unwrap();
221 *x = true;
222 return None;
223 }
224 }
225 }
226 }
227 None
228 }
229}
230
231pub fn new_connection(transport: Transport) -> (ConnectionSender, ConnectionReceiver) {
232 let state = Arc::new(ConnectionState::default());
233 (
234 ConnectionSender {
235 ids: Default::default(),
236 state: state.clone(),
237 sender: transport.sender.clone(),
238 },
239 ConnectionReceiver {
240 state,
241 receiver: transport.receiver,
242 sender: transport.sender,
243 shutdown: Default::default(),
244 },
245 )
246}
247
248pub struct Transport {
249 receiver: Receiver<Message>,
250 sender: SyncSender<Message>,
251}
252
253#[derive(Debug)]
254pub struct SendError {}
255
256impl std::fmt::Display for SendError {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 write!(f, "failed to send a message, the channel is already closed")
259 }
260}
261
262impl std::error::Error for SendError {}
263
264pub struct JoinHandle {
265 read_join: std::thread::JoinHandle<()>,
266 write_join: std::thread::JoinHandle<()>,
267}
268
269impl JoinHandle {
270 pub fn join(self) -> anyhow::Result<()> {
271 self.read_join.join().unwrap();
272 self.write_join.join().unwrap();
273 Ok(())
274 }
275}
276
277impl Transport {
278 pub fn stdio() -> (Transport, JoinHandle) {
279 Self::raw(std::io::stdin(), std::io::stdout())
280 }
281
282 pub fn raw<R: Read + Send + 'static, W: Write + Send + 'static>(
283 read: R,
284 write: W,
285 ) -> (Transport, JoinHandle) {
286 let (read_tx, read_rx) = std::sync::mpsc::sync_channel(0);
287 let read_join = std::thread::spawn(move || {
288 if let Err(err) = read_loop(read, read_tx) {
289 log::error!("read_loop err: {err}");
290 }
291 });
292 let (write_tx, write_rx) = std::sync::mpsc::sync_channel(0);
293 let write_join = std::thread::spawn(move || {
294 if let Err(err) = write_loop(write, write_rx) {
295 log::error!("write_loop err: {err}");
296 }
297 });
298 (
299 Transport {
300 receiver: read_rx,
301 sender: write_tx,
302 },
303 JoinHandle {
304 read_join,
305 write_join,
306 },
307 )
308 }
309
310 pub fn send(&self, message: Message) -> Result<(), SendError> {
311 self.sender.send(message).map_err(|_| SendError {})
312 }
313
314 pub fn next_message(&self) -> Option<Message> {
318 self.receiver.recv().ok()
319 }
320}
321
322fn read_loop<R: Read>(read: R, sender: SyncSender<Message>) -> anyhow::Result<()> {
323 let reader = BufReader::new(read);
324 for line in reader.lines() {
325 let msg: Message = serde_json::from_str(&line?)?;
326 log::trace!("received: {:?}", msg);
327 sender.send(msg)?;
328 }
329 log::debug!("read_loop: finished");
330 Ok(())
331}
332
333fn write_loop<W: Write>(mut write: W, receiver: Receiver<Message>) -> anyhow::Result<()> {
334 loop {
335 let Ok(msg) = receiver.recv() else {
336 break;
337 };
338 log::trace!("sending: {:?}", msg);
339 let mut b = serde_json::to_vec(&msg)?;
340 b.push(b'\n');
341 write.write_all(&b)?;
342 write.flush()?;
343 }
344 log::debug!("write_loop: finished");
345 Ok(())
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use googletest::prelude::*;
352 use serde_json::json;
353 use std::{collections::VecDeque, io::Cursor, sync::mpsc::Sender};
354 use test_log::test;
355
356 struct PipeRead {
357 state: VecDeque<u8>,
358 receiver: Receiver<Vec<u8>>,
359 }
360
361 impl Read for PipeRead {
362 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
363 if self.state.is_empty() {
364 if let Ok(v) = self.receiver.recv() {
365 self.state.extend(&v);
366 }
367 }
368 self.state.read(buf)
369 }
370 }
371
372 struct PipeWrite {
373 state: Vec<u8>,
374 sender: Sender<Vec<u8>>,
375 }
376
377 impl Write for PipeWrite {
378 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
379 self.state.write(buf)
380 }
381
382 fn flush(&mut self) -> std::io::Result<()> {
383 let mut val = vec![];
384 std::mem::swap(&mut self.state, &mut val);
385 if !val.is_empty() {
386 self.sender
387 .send(val)
388 .map_err(|e| std::io::Error::new(std::io::ErrorKind::UnexpectedEof, e))?;
389 }
390 Ok(())
391 }
392 }
393
394 impl Drop for PipeWrite {
395 fn drop(&mut self) {
396 let mut val = vec![];
397 std::mem::swap(&mut self.state, &mut val);
398 if !val.is_empty() {
399 self.sender.send(val).unwrap();
401 }
402 }
403 }
404
405 fn pipe() -> (PipeWrite, PipeRead) {
406 let (tx, rx) = std::sync::mpsc::channel();
407 (
408 PipeWrite {
409 state: Default::default(),
410 sender: tx,
411 },
412 PipeRead {
413 state: Default::default(),
414 receiver: rx,
415 },
416 )
417 }
418
419 #[test(gtest)]
420 fn reads_one_message() {
421 let input =
422 serde_json::to_vec(&json!({"id": "1", "method": "complete", "params":{}})).unwrap();
423 let c = Cursor::new(input);
424 let output: Vec<u8> = Vec::new();
425 let (t, join_handles) = Transport::raw(c, output);
426 expect_that!(t.next_message(), some(anything()));
427 expect_that!(t.next_message(), none());
428 join_handles.join().unwrap();
429 }
430
431 #[test(gtest)]
432 fn writes_one_message() {
433 let (pipe_w, mut pipe_r) = pipe();
434 let c = Cursor::new(vec![]);
435 let (t, join_handles) = Transport::raw(c, pipe_w);
436 let response = Message::Response(Response::new_err(
437 RequestId("1".into()),
438 crate::types::Error::internal("test"),
439 ));
440 t.send(response.clone()).unwrap();
441 drop(t);
443 let mut output = vec![];
444 pipe_r.read_to_end(&mut output).unwrap();
445 let mut expected = serde_json::to_vec(&response).unwrap();
446 expected.push(b'\n');
447 expect_that!(output, eq(&expected));
448 join_handles.join().unwrap();
449 }
450}