actor_core_client/drivers/
ws.rs

1use anyhow::{Context, Result};
2use futures_util::{SinkExt, StreamExt};
3use serde_json::Value;
4use std::sync::Arc;
5use tokio::net::TcpStream;
6use tokio::sync::mpsc;
7use tokio::task::JoinHandle;
8use tokio_tungstenite::tungstenite::Message;
9use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
10use tracing::debug;
11
12use crate::encoding::EncodingKind;
13use crate::protocol::{ToClient, ToServer};
14
15use super::{
16    build_conn_url, DriverHandle, DriverStopReason, MessageToClient, MessageToServer, TransportKind,
17};
18
19pub(crate) async fn connect(
20    endpoint: String,
21    encoding_kind: EncodingKind,
22    parameters: &Option<Value>,
23) -> Result<(
24    DriverHandle,
25    mpsc::Receiver<MessageToClient>,
26    JoinHandle<DriverStopReason>,
27)> {
28    let url = build_conn_url(
29        &endpoint,
30        &TransportKind::WebSocket,
31        encoding_kind,
32        parameters,
33    )?;
34
35    let (ws, _res) = tokio_tungstenite::connect_async(url)
36        .await
37        .context("Failed to connect to WebSocket")?;
38
39    let (in_tx, in_rx) = mpsc::channel::<MessageToClient>(32);
40    let (out_tx, out_rx) = mpsc::channel::<MessageToServer>(32);
41    let task = tokio::spawn(start(ws, encoding_kind, in_tx, out_rx));
42
43    let handle = DriverHandle::new(out_tx, task.abort_handle());
44
45    Ok((handle, in_rx, task))
46}
47
48async fn start(
49    ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
50    encoding_kind: EncodingKind,
51    in_tx: mpsc::Sender<MessageToClient>,
52    mut out_rx: mpsc::Receiver<MessageToServer>,
53) -> DriverStopReason {
54    let (mut ws_sink, mut ws_stream) = ws.split();
55
56    let serialize = get_msg_serializer(encoding_kind);
57    let deserialize = get_msg_deserializer(encoding_kind);
58
59    loop {
60        tokio::select! {
61            // Dispatch ws outgoing queue
62            msg = out_rx.recv() => {
63                // If the sender is dropped, break the loop
64                let Some(msg) = msg else {
65                    debug!("Sender dropped");
66                    return DriverStopReason::UserAborted;
67                };
68
69                let msg = match serialize(&msg) {
70                    Ok(msg) => msg,
71                    Err(e) => {
72                        debug!("Failed to serialize message: {:?}", e);
73                        continue;
74                    }
75                };
76
77                if let Err(e) = ws_sink.send(msg).await {
78                    debug!("Failed to send message: {:?}", e);
79                    continue;
80                }
81            },
82            // Handle ws incoming
83            msg = ws_stream.next() => {
84                let Some(msg) = msg else {
85                    println!("Receiver dropped");
86                    return DriverStopReason::ServerDisconnect;
87                };
88
89                match msg {
90                    Ok(msg) => match msg {
91                        Message::Text(_) | Message::Binary(_) => {
92                            let Ok(msg) = deserialize(&msg) else {
93                                debug!("Failed to parse message: {:?}", msg);
94                                continue;
95                            };
96
97                            if let Err(e) = in_tx.send(Arc::new(msg)).await {
98                                debug!("Failed to send text message: {}", e);
99                                // failure to send means user dropped incoming receiver
100                                return DriverStopReason::UserAborted;
101                            }
102                        },
103                        Message::Close(_) => {
104                            debug!("Close message");
105                            return DriverStopReason::ServerDisconnect;
106                        },
107                        _ => {
108                            debug!("Invalid message type received");
109                        }
110                    }
111                    Err(e) => {
112                        debug!("WebSocket error: {}", e);
113                        return DriverStopReason::ServerError;
114                    }
115                }
116            }
117        }
118    }
119}
120
121fn get_msg_deserializer(encoding_kind: EncodingKind) -> fn(&Message) -> Result<ToClient> {
122    match encoding_kind {
123        EncodingKind::Json => json_msg_deserialize,
124        EncodingKind::Cbor => cbor_msg_deserialize,
125    }
126}
127
128fn get_msg_serializer(encoding_kind: EncodingKind) -> fn(&ToServer) -> Result<Message> {
129    match encoding_kind {
130        EncodingKind::Json => json_msg_serialize,
131        EncodingKind::Cbor => cbor_msg_serialize,
132    }
133}
134
135fn json_msg_deserialize(value: &Message) -> Result<ToClient> {
136    match value {
137        Message::Text(text) => Ok(serde_json::from_str(text)?),
138        Message::Binary(bin) => Ok(serde_json::from_slice(bin)?),
139        _ => Err(anyhow::anyhow!("Invalid message type")),
140    }
141}
142
143fn cbor_msg_deserialize(value: &Message) -> Result<ToClient> {
144    match value {
145        Message::Binary(bin) => Ok(serde_cbor::from_slice(bin)?),
146        Message::Text(text) => Ok(serde_cbor::from_slice(text.as_bytes())?),
147        _ => Err(anyhow::anyhow!("Invalid message type")),
148    }
149}
150
151fn json_msg_serialize(value: &ToServer) -> Result<Message> {
152    Ok(Message::Text(serde_json::to_string(value)?.into()))
153}
154
155fn cbor_msg_serialize(value: &ToServer) -> Result<Message> {
156    Ok(Message::Binary(serde_cbor::to_vec(value)?.into()))
157}