actor_core_client/drivers/
ws.rs1use anyhow::{Context, Result};
2use futures_util::{SinkExt, StreamExt};
3use serde_json::Value;
4use std::sync::Arc;
5use tokio::net::TcpStream;
6use tokio::sync::mpsc;
7use tokio::task::JoinHandle;
8use tokio_tungstenite::tungstenite::Message;
9use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
10use tracing::debug;
11
12use crate::encoding::EncodingKind;
13use crate::protocol::{ToClient, ToServer};
14
15use super::{
16 build_conn_url, DriverHandle, DriverStopReason, MessageToClient, MessageToServer, TransportKind,
17};
18
19pub(crate) async fn connect(
20 endpoint: String,
21 encoding_kind: EncodingKind,
22 parameters: &Option<Value>,
23) -> Result<(
24 DriverHandle,
25 mpsc::Receiver<MessageToClient>,
26 JoinHandle<DriverStopReason>,
27)> {
28 let url = build_conn_url(
29 &endpoint,
30 &TransportKind::WebSocket,
31 encoding_kind,
32 parameters,
33 )?;
34
35 let (ws, _res) = tokio_tungstenite::connect_async(url)
36 .await
37 .context("Failed to connect to WebSocket")?;
38
39 let (in_tx, in_rx) = mpsc::channel::<MessageToClient>(32);
40 let (out_tx, out_rx) = mpsc::channel::<MessageToServer>(32);
41 let task = tokio::spawn(start(ws, encoding_kind, in_tx, out_rx));
42
43 let handle = DriverHandle::new(out_tx, task.abort_handle());
44
45 Ok((handle, in_rx, task))
46}
47
48async fn start(
49 ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
50 encoding_kind: EncodingKind,
51 in_tx: mpsc::Sender<MessageToClient>,
52 mut out_rx: mpsc::Receiver<MessageToServer>,
53) -> DriverStopReason {
54 let (mut ws_sink, mut ws_stream) = ws.split();
55
56 let serialize = get_msg_serializer(encoding_kind);
57 let deserialize = get_msg_deserializer(encoding_kind);
58
59 loop {
60 tokio::select! {
61 msg = out_rx.recv() => {
63 let Some(msg) = msg else {
65 debug!("Sender dropped");
66 return DriverStopReason::UserAborted;
67 };
68
69 let msg = match serialize(&msg) {
70 Ok(msg) => msg,
71 Err(e) => {
72 debug!("Failed to serialize message: {:?}", e);
73 continue;
74 }
75 };
76
77 if let Err(e) = ws_sink.send(msg).await {
78 debug!("Failed to send message: {:?}", e);
79 continue;
80 }
81 },
82 msg = ws_stream.next() => {
84 let Some(msg) = msg else {
85 println!("Receiver dropped");
86 return DriverStopReason::ServerDisconnect;
87 };
88
89 match msg {
90 Ok(msg) => match msg {
91 Message::Text(_) | Message::Binary(_) => {
92 let Ok(msg) = deserialize(&msg) else {
93 debug!("Failed to parse message: {:?}", msg);
94 continue;
95 };
96
97 if let Err(e) = in_tx.send(Arc::new(msg)).await {
98 debug!("Failed to send text message: {}", e);
99 return DriverStopReason::UserAborted;
101 }
102 },
103 Message::Close(_) => {
104 debug!("Close message");
105 return DriverStopReason::ServerDisconnect;
106 },
107 _ => {
108 debug!("Invalid message type received");
109 }
110 }
111 Err(e) => {
112 debug!("WebSocket error: {}", e);
113 return DriverStopReason::ServerError;
114 }
115 }
116 }
117 }
118 }
119}
120
121fn get_msg_deserializer(encoding_kind: EncodingKind) -> fn(&Message) -> Result<ToClient> {
122 match encoding_kind {
123 EncodingKind::Json => json_msg_deserialize,
124 EncodingKind::Cbor => cbor_msg_deserialize,
125 }
126}
127
128fn get_msg_serializer(encoding_kind: EncodingKind) -> fn(&ToServer) -> Result<Message> {
129 match encoding_kind {
130 EncodingKind::Json => json_msg_serialize,
131 EncodingKind::Cbor => cbor_msg_serialize,
132 }
133}
134
135fn json_msg_deserialize(value: &Message) -> Result<ToClient> {
136 match value {
137 Message::Text(text) => Ok(serde_json::from_str(text)?),
138 Message::Binary(bin) => Ok(serde_json::from_slice(bin)?),
139 _ => Err(anyhow::anyhow!("Invalid message type")),
140 }
141}
142
143fn cbor_msg_deserialize(value: &Message) -> Result<ToClient> {
144 match value {
145 Message::Binary(bin) => Ok(serde_cbor::from_slice(bin)?),
146 Message::Text(text) => Ok(serde_cbor::from_slice(text.as_bytes())?),
147 _ => Err(anyhow::anyhow!("Invalid message type")),
148 }
149}
150
151fn json_msg_serialize(value: &ToServer) -> Result<Message> {
152 Ok(Message::Text(serde_json::to_string(value)?.into()))
153}
154
155fn cbor_msg_serialize(value: &ToServer) -> Result<Message> {
156 Ok(Message::Binary(serde_cbor::to_vec(value)?.into()))
157}