use anyhow::{Context, Result};
use futures_util::{SinkExt, StreamExt};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;
use tracing::debug;
use crate::{
protocol::{codec, to_client, to_server},
EncodingKind,
};
use super::{
DriverConnectArgs, DriverConnection, DriverHandle, DriverStopReason, MessageToClient,
MessageToServer,
};
pub(crate) async fn connect(args: DriverConnectArgs) -> Result<DriverConnection> {
let actor_id = args.remote_manager.resolve_actor_id(&args.query).await?;
debug!(
"Opening WebSocket connection to actor via gateway: {}",
actor_id
);
let ws = args
.remote_manager
.open_websocket(
&actor_id,
args.encoding_kind,
args.parameters,
args.conn_id,
args.conn_token,
)
.await
.context("Failed to connect to WebSocket via gateway")?;
let (in_tx, in_rx) = mpsc::unbounded_channel::<MessageToClient>();
let (out_tx, out_rx) = mpsc::unbounded_channel::<MessageToServer>();
let task = tokio::spawn(start(ws, args.encoding_kind, in_tx, out_rx));
let handle = DriverHandle::new(out_tx, task.abort_handle());
Ok((handle, in_rx, task))
}
async fn start(
ws: tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
encoding_kind: EncodingKind,
in_tx: mpsc::UnboundedSender<MessageToClient>,
mut out_rx: mpsc::UnboundedReceiver<MessageToServer>,
) -> DriverStopReason {
let (mut ws_sink, mut ws_stream) = ws.split();
let serialize = get_msg_serializer(encoding_kind);
let deserialize = get_msg_deserializer(encoding_kind);
loop {
tokio::select! {
msg = out_rx.recv() => {
let Some(msg) = msg else {
debug!("Sender dropped");
return DriverStopReason::UserAborted;
};
let msg = match serialize(&msg) {
Ok(msg) => msg,
Err(e) => {
debug!("Failed to serialize message: {:?}", e);
continue;
}
};
if let Err(e) = ws_sink.send(msg).await {
debug!("Failed to send message: {:?}", e);
continue;
}
},
msg = ws_stream.next() => {
let Some(msg) = msg else {
debug!("Receiver dropped");
return DriverStopReason::ServerDisconnect;
};
match msg {
Ok(msg) => match msg {
Message::Text(_) | Message::Binary(_) => {
let Ok(msg) = deserialize(&msg) else {
debug!("Failed to parse message: {:?}", msg);
continue;
};
if let Err(e) = in_tx.send(Arc::new(msg)) {
debug!("Failed to send text message: {}", e);
return DriverStopReason::UserAborted;
}
},
Message::Close(_) => {
debug!("Close message");
return DriverStopReason::ServerDisconnect;
},
_ => {
debug!("Invalid message type received");
}
}
Err(e) => {
debug!("WebSocket error: {}", e);
return DriverStopReason::ServerError;
}
}
}
}
}
}
fn get_msg_deserializer(
encoding_kind: EncodingKind,
) -> fn(&Message) -> Result<to_client::ToClient> {
match encoding_kind {
EncodingKind::Json => json_msg_deserialize,
EncodingKind::Cbor => cbor_msg_deserialize,
EncodingKind::Bare => bare_msg_deserialize,
}
}
fn get_msg_serializer(encoding_kind: EncodingKind) -> fn(&to_server::ToServer) -> Result<Message> {
match encoding_kind {
EncodingKind::Json => json_msg_serialize,
EncodingKind::Cbor => cbor_msg_serialize,
EncodingKind::Bare => bare_msg_serialize,
}
}
fn json_msg_deserialize(value: &Message) -> Result<to_client::ToClient> {
match value {
Message::Text(text) => codec::decode_to_client(EncodingKind::Json, text.as_bytes()),
Message::Binary(bin) => codec::decode_to_client(EncodingKind::Json, bin),
_ => Err(anyhow::anyhow!("Invalid message type")),
}
}
fn cbor_msg_deserialize(value: &Message) -> Result<to_client::ToClient> {
match value {
Message::Binary(bin) => codec::decode_to_client(EncodingKind::Cbor, bin),
Message::Text(text) => codec::decode_to_client(EncodingKind::Cbor, text.as_bytes()),
_ => Err(anyhow::anyhow!("Invalid message type")),
}
}
fn json_msg_serialize(value: &to_server::ToServer) -> Result<Message> {
let payload = codec::encode_to_server(EncodingKind::Json, value)?;
Ok(Message::Text(String::from_utf8(payload)?.into()))
}
fn cbor_msg_serialize(value: &to_server::ToServer) -> Result<Message> {
Ok(Message::Binary(
codec::encode_to_server(EncodingKind::Cbor, value)?.into(),
))
}
fn bare_msg_deserialize(value: &Message) -> Result<to_client::ToClient> {
match value {
Message::Binary(bin) => codec::decode_to_client(EncodingKind::Bare, bin),
Message::Text(text) => codec::decode_to_client(EncodingKind::Bare, text.as_bytes()),
_ => Err(anyhow::anyhow!("Invalid message type")),
}
}
fn bare_msg_serialize(value: &to_server::ToServer) -> Result<Message> {
Ok(Message::Binary(
codec::encode_to_server(EncodingKind::Bare, value)?.into(),
))
}