1use devcade_onboard_types::{Request, RequestBody, Response, ResponseBody};
2use std::collections::HashMap;
3use std::fmt;
4use std::sync::Arc;
5use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::net::UnixStream;
7use tokio::sync::{mpsc, oneshot, Mutex, OnceCell};
8
9pub struct BackendClient {
10 connection: OnceCell<SynchronizedConnection>,
11}
12
13type RequestSender = oneshot::Sender<Result<ResponseBody, RequestError>>;
14struct SynchronizedConnection {
15 requests_tx: mpsc::Sender<(RequestBody, RequestSender)>,
16}
17
18#[derive(Debug)]
19pub enum RequestError {
20 IoError(io::Error),
21 ResponseError(String),
22 UnexpectedResponse(ResponseBody),
23 ChannelClosed,
24}
25
26impl fmt::Display for RequestError {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 match self {
29 Self::IoError(err) => write!(f, "IoError({err})"),
30 Self::ResponseError(err) => write!(f, "ResponseError({err})"),
31 Self::UnexpectedResponse(response) => write!(f, "UnexpectedResponse({response})"),
32 Self::ChannelClosed => write!(f, "ChannelClosed"),
33 }
34 }
35}
36
37impl From<io::Error> for RequestError {
38 fn from(error: io::Error) -> Self {
39 Self::IoError(error)
40 }
41}
42
43impl Default for BackendClient {
44 fn default() -> Self {
45 Self {
46 connection: OnceCell::new(),
47 }
48 }
49}
50
51impl BackendClient {
64 async fn create_connection() -> Result<SynchronizedConnection, io::Error> {
65 let (connection_reader, mut connection_writer) = UnixStream::connect(
66 std::env::var("DEVCADE_ONBOARD_PATH").unwrap_or("/tmp/devcade/game.sock".to_owned()),
67 )
68 .await?
69 .into_split();
70 let (requests_tx, mut requests_rx) = mpsc::channel::<(RequestBody, RequestSender)>(100);
71 let listeners = Arc::new(Mutex::new(HashMap::<u32, RequestSender>::new()));
72 {
73 let listeners = listeners.clone();
74 tokio::spawn(async move {
75 let mut request_id_counter = 0;
76 while let Some((body, callback_tx)) = requests_rx.recv().await {
77 let mut listeners = listeners.lock().await;
78 while listeners.contains_key(&request_id_counter) {
79 request_id_counter = request_id_counter.wrapping_add(1);
80 }
81 let request_id = request_id_counter;
82 let request = Request { request_id, body };
83
84 let mut frame = serde_json::to_vec(&request).expect("Couldn't serialize RequestBody?");
85 frame.push(b'\n');
86 if let Err(err) = connection_writer.write_all(&frame).await {
87 if let Err(Err(err)) = callback_tx.send(Err(err.into())) {
88 log::error!("Couldn't send message to callback! Message we were asked to send was: {request:?}. Failed because {err}");
89 }
90 return;
91 }
92 listeners.insert(request_id, callback_tx);
93 }
94 });
95 }
96 tokio::spawn(async move {
97 let connection_reader = BufReader::new(connection_reader);
98 let mut lines = connection_reader.lines();
99 while let Ok(Some(line)) = lines.next_line().await {
100 let response: Response = match serde_json::from_str(&line) {
101 Ok(response) => response,
102 Err(err) => {
103 log::error!("Couldn't decode response ({line}) {err}");
104 continue;
105 }
106 };
107
108 let request_id = &response.request_id;
109 let mut listeners = listeners.lock().await;
110 let handler = match listeners.remove(request_id) {
111 Some(handler) => handler,
112 None => {
113 log::error!(
114 "Got response for request ID {request_id} that we weren't expecting! {response}"
115 );
116 continue;
117 }
118 };
119 std::mem::drop(listeners);
120
121 if handler
122 .send(match response.body {
123 ResponseBody::Err(err) => Err(RequestError::ResponseError(err)),
124 body => Ok(body),
125 })
126 .is_err()
127 {
128 log::error!("Failed to send response for {request_id} because the other side of the callback closed");
129 }
130 }
131 });
132 Ok(SynchronizedConnection { requests_tx })
133 }
134
135 async fn get_connection(&self) -> Result<&SynchronizedConnection, io::Error> {
136 self
137 .connection
138 .get_or_try_init(Self::create_connection)
139 .await
140 }
141
142 pub async fn send(&self, body: RequestBody) -> Result<ResponseBody, RequestError> {
147 let connection = self.get_connection().await?;
148 let (tx, rx) = oneshot::channel();
149 connection
150 .requests_tx
151 .send((body, tx))
152 .await
153 .map_err(|_| RequestError::ChannelClosed)?;
154 match rx.await.map_err(|_| RequestError::ChannelClosed) {
155 Ok(Ok(response)) => Ok(response),
156 Ok(Err(err)) | Err(err) => Err(err),
157 }
158 }
159}