devcaders/
client.rs

1use devcade_onboard_types::{Request, RequestBody, Response, ResponseBody};
2use std::collections::HashMap;
3use std::fmt;
4use std::sync::Arc;
5use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::net::UnixStream;
7use tokio::sync::{mpsc, oneshot, Mutex, OnceCell};
8
9pub struct BackendClient {
10  connection: OnceCell<SynchronizedConnection>,
11}
12
13type RequestSender = oneshot::Sender<Result<ResponseBody, RequestError>>;
14struct SynchronizedConnection {
15  requests_tx: mpsc::Sender<(RequestBody, RequestSender)>,
16}
17
18#[derive(Debug)]
19pub enum RequestError {
20  IoError(io::Error),
21  ResponseError(String),
22  UnexpectedResponse(ResponseBody),
23  ChannelClosed,
24}
25
26impl fmt::Display for RequestError {
27  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28    match self {
29      Self::IoError(err) => write!(f, "IoError({err})"),
30      Self::ResponseError(err) => write!(f, "ResponseError({err})"),
31      Self::UnexpectedResponse(response) => write!(f, "UnexpectedResponse({response})"),
32      Self::ChannelClosed => write!(f, "ChannelClosed"),
33    }
34  }
35}
36
37impl From<io::Error> for RequestError {
38  fn from(error: io::Error) -> Self {
39    Self::IoError(error)
40  }
41}
42
43impl Default for BackendClient {
44  fn default() -> Self {
45    Self {
46      connection: OnceCell::new(),
47    }
48  }
49}
50
51/// Client for the devcade backend;
52/// Allows you to send requests to the backend and get their responses.
53///
54/// This struct represents an underlying connection to the devcade backend, so
55/// try not to make more than one.
56///
57/// # Example
58/// ```
59/// let backend_client: BackendClient = Default::default();
60/// let pong = backend_client.send(RequestBody::Ping).await.unwrap();
61/// println!("Pong! {pong}");
62/// ```
63impl BackendClient {
64  async fn create_connection() -> Result<SynchronizedConnection, io::Error> {
65    let (connection_reader, mut connection_writer) = UnixStream::connect(
66      std::env::var("DEVCADE_ONBOARD_PATH").unwrap_or("/tmp/devcade/game.sock".to_owned()),
67    )
68    .await?
69    .into_split();
70    let (requests_tx, mut requests_rx) = mpsc::channel::<(RequestBody, RequestSender)>(100);
71    let listeners = Arc::new(Mutex::new(HashMap::<u32, RequestSender>::new()));
72    {
73      let listeners = listeners.clone();
74      tokio::spawn(async move {
75        let mut request_id_counter = 0;
76        while let Some((body, callback_tx)) = requests_rx.recv().await {
77          let mut listeners = listeners.lock().await;
78          while listeners.contains_key(&request_id_counter) {
79            request_id_counter = request_id_counter.wrapping_add(1);
80          }
81          let request_id = request_id_counter;
82          let request = Request { request_id, body };
83
84          let mut frame = serde_json::to_vec(&request).expect("Couldn't serialize RequestBody?");
85          frame.push(b'\n');
86          if let Err(err) = connection_writer.write_all(&frame).await {
87            if let Err(Err(err)) = callback_tx.send(Err(err.into())) {
88              log::error!("Couldn't send message to callback! Message we were asked to send was: {request:?}. Failed because {err}");
89            }
90            return;
91          }
92          listeners.insert(request_id, callback_tx);
93        }
94      });
95    }
96    tokio::spawn(async move {
97      let connection_reader = BufReader::new(connection_reader);
98      let mut lines = connection_reader.lines();
99      while let Ok(Some(line)) = lines.next_line().await {
100        let response: Response = match serde_json::from_str(&line) {
101          Ok(response) => response,
102          Err(err) => {
103            log::error!("Couldn't decode response ({line}) {err}");
104            continue;
105          }
106        };
107
108        let request_id = &response.request_id;
109        let mut listeners = listeners.lock().await;
110        let handler = match listeners.remove(request_id) {
111          Some(handler) => handler,
112          None => {
113            log::error!(
114              "Got response for request ID {request_id} that we weren't expecting! {response}"
115            );
116            continue;
117          }
118        };
119        std::mem::drop(listeners);
120
121        if handler
122          .send(match response.body {
123            ResponseBody::Err(err) => Err(RequestError::ResponseError(err)),
124            body => Ok(body),
125          })
126          .is_err()
127        {
128          log::error!("Failed to send response for {request_id} because the other side of the callback closed");
129        }
130      }
131    });
132    Ok(SynchronizedConnection { requests_tx })
133  }
134
135  async fn get_connection(&self) -> Result<&SynchronizedConnection, io::Error> {
136    self
137      .connection
138      .get_or_try_init(Self::create_connection)
139      .await
140  }
141
142  /// Sends a request to the backend and returns the corresponding response.
143  /// If the response is [`ResponseBody::Err`],
144  /// a [`RequestError::ResponseError`] is returned instead with the error
145  /// message.
146  pub async fn send(&self, body: RequestBody) -> Result<ResponseBody, RequestError> {
147    let connection = self.get_connection().await?;
148    let (tx, rx) = oneshot::channel();
149    connection
150      .requests_tx
151      .send((body, tx))
152      .await
153      .map_err(|_| RequestError::ChannelClosed)?;
154    match rx.await.map_err(|_| RequestError::ChannelClosed) {
155      Ok(Ok(response)) => Ok(response),
156      Ok(Err(err)) | Err(err) => Err(err),
157    }
158  }
159}