1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
use devcade_onboard_types::{Request, RequestBody, Response, ResponseBody};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::UnixStream;
use tokio::sync::{mpsc, oneshot, Mutex, OnceCell};

pub struct BackendClient {
  connection: OnceCell<SynchronizedConnection>,
}

type RequestSender = oneshot::Sender<Result<ResponseBody, RequestError>>;
struct SynchronizedConnection {
  requests_tx: mpsc::Sender<(RequestBody, RequestSender)>,
}

#[derive(Debug)]
pub enum RequestError {
  IoError(io::Error),
  ResponseError(String),
  UnexpectedResponse(ResponseBody),
  ChannelClosed,
}

impl fmt::Display for RequestError {
  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
    match self {
      Self::IoError(err) => write!(f, "IoError({err})"),
      Self::ResponseError(err) => write!(f, "ResponseError({err})"),
      Self::UnexpectedResponse(response) => write!(f, "UnexpectedResponse({response})"),
      Self::ChannelClosed => write!(f, "ChannelClosed"),
    }
  }
}

impl From<io::Error> for RequestError {
  fn from(error: io::Error) -> Self {
    Self::IoError(error)
  }
}

impl Default for BackendClient {
  fn default() -> Self {
    Self {
      connection: OnceCell::new(),
    }
  }
}

/// Client for the devcade backend;
/// Allows you to send requests to the backend and get their responses.
///
/// This struct represents an underlying connection to the devcade backend, so
/// try not to make more than one.
///
/// # Example
/// ```
/// let backend_client: BackendClient = Default::default();
/// let pong = backend_client.send(RequestBody::Ping).await.unwrap();
/// println!("Pong! {pong}");
/// ```
impl BackendClient {
  async fn create_connection() -> Result<SynchronizedConnection, io::Error> {
    let (connection_reader, mut connection_writer) = UnixStream::connect(
      std::env::var("DEVCADE_ONBOARD_PATH").unwrap_or("/tmp/devcade/onboard.sock".to_owned()),
    )
    .await?
    .into_split();
    let (requests_tx, mut requests_rx) = mpsc::channel::<(RequestBody, RequestSender)>(100);
    let listeners = Arc::new(Mutex::new(HashMap::<u32, RequestSender>::new()));
    {
      let listeners = listeners.clone();
      tokio::spawn(async move {
        let mut request_id_counter = 0;
        while let Some((body, callback_tx)) = requests_rx.recv().await {
          let mut listeners = listeners.lock().await;
          while listeners.contains_key(&request_id_counter) {
            request_id_counter = request_id_counter.wrapping_add(1);
          }
          let request_id = request_id_counter;
          let request = Request { request_id, body };

          let mut frame = serde_json::to_vec(&request).expect("Couldn't serialize RequestBody?");
          frame.push(b'\n');
          if let Err(err) = connection_writer.write_all(&frame).await {
            if let Err(Err(err)) = callback_tx.send(Err(err.into())) {
              log::error!("Couldn't send message to callback! Message we were asked to send was: {request:?}. Failed because {err}");
            }
            return;
          }
          listeners.insert(request_id, callback_tx);
        }
      });
    }
    tokio::spawn(async move {
      let connection_reader = BufReader::new(connection_reader);
      let mut lines = connection_reader.lines();
      while let Ok(Some(line)) = lines.next_line().await {
        let response: Response = match serde_json::from_str(&line) {
          Ok(response) => response,
          Err(err) => {
            log::error!("Couldn't decode response ({line}) {err}");
            continue;
          }
        };

        let request_id = &response.request_id;
        let mut listeners = listeners.lock().await;
        let handler = match listeners.remove(request_id) {
          Some(handler) => handler,
          None => {
            log::error!(
              "Got response for request ID {request_id} that we weren't expecting! {response}"
            );
            continue;
          }
        };
        std::mem::drop(listeners);

        if handler
          .send(match response.body {
            ResponseBody::Err(err) => Err(RequestError::ResponseError(err)),
            body => Ok(body),
          })
          .is_err()
        {
          log::error!("Failed to send response for {request_id} because the other side of the callback closed");
        }
      }
    });
    Ok(SynchronizedConnection { requests_tx })
  }

  async fn get_connection(&self) -> Result<&SynchronizedConnection, io::Error> {
    self
      .connection
      .get_or_try_init(Self::create_connection)
      .await
  }

  /// Sends a request to the backend and returns the corresponding response.
  /// If the response is [`ResponseBody::Err`],
  /// a [`RequestError::ResponseError`] is returned instead with the error
  /// message.
  pub async fn send(&self, body: RequestBody) -> Result<ResponseBody, RequestError> {
    let connection = self.get_connection().await?;
    let (tx, rx) = oneshot::channel();
    connection
      .requests_tx
      .send((body, tx))
      .await
      .map_err(|_| RequestError::ChannelClosed)?;
    match rx.await.map_err(|_| RequestError::ChannelClosed) {
      Ok(Ok(response)) => Ok(response),
      Ok(Err(err)) | Err(err) => Err(err),
    }
  }
}