plane_common/typed_socket/
server.rs1use super::{ChannelMessage, Handshake, SocketAction, TypedSocket};
2use crate::version::plane_version_info;
3use axum::extract::ws::{CloseFrame, Message, WebSocket};
4use tokio::sync::mpsc::{Receiver, Sender};
5
6#[derive(Debug, thiserror::Error)]
7pub enum Error {
8 #[error("Handshake message was not text.")]
9 HandshakeNotText,
10
11 #[error("Socket closed before handshake received.")]
12 SocketClosedBeforeHandshake,
13
14 #[error("Failed to parse message.")]
15 ParseMessage(#[from] serde_json::Error),
16
17 #[error("Failed to send message on websocket.")]
18 SendMessage(#[from] axum::Error),
19}
20
21pub async fn handle_messages<T: ChannelMessage>(
22 mut messages_to_send: Receiver<SocketAction<T>>,
23 messages_received: Sender<T::Reply>,
24 mut socket: WebSocket,
25) {
26 loop {
27 tokio::select! {
28 Some(msg) = messages_to_send.recv() => {
29 match msg {
30 SocketAction::Send(msg) => {
31 let msg = Message::Text(serde_json::to_string(&msg).expect("Always serializable."));
32 if let Err(err) = socket.send(msg.clone()).await {
33 tracing::error!(?err, message=?msg, "Failed to send message on websocket.");
34 }
35 }
36 SocketAction::Close => {
37 if let Err(err) = socket.close().await {
38 tracing::error!(?err, "Failed to close websocket.");
39 }
40 break;
41 }
42 }
43 }
44 Some(msg) = socket.recv() => {
45 let msg = match msg {
46 Ok(Message::Text(msg)) => msg,
47 Err(err) => {
48 tracing::error!(?err, "Failed to receive message from websocket.");
49 break;
50 }
51 Ok(Message::Close(Some(CloseFrame { code: 1001, .. }))) => {
52 tracing::warn!("Websocket connection closed.");
53 break;
54 }
55 msg => {
56 tracing::warn!("Received ignored message: {:?}", msg);
57 continue;
58 }
59 };
60 let msg: T::Reply = match serde_json::from_str(&msg) {
61 Ok(msg) => msg,
62 Err(err) => {
63 tracing::warn!(?err, "Failed to parse message.");
64 continue;
65 }
66 };
67 if let Err(err) = messages_received.send(msg).await {
68 tracing::error!(?err, "Failed to receive message.");
69 break;
70 }
71 }
72 else => {
73 break;
74 }
75 }
76 }
77}
78
79pub async fn new_server<T: ChannelMessage>(
80 mut ws: WebSocket,
81 name: String,
82) -> Result<TypedSocket<T>, Error> {
83 let msg = ws
84 .recv()
85 .await
86 .ok_or(Error::SocketClosedBeforeHandshake)?
87 .map_err(Error::from)?;
88 let msg = match msg {
89 Message::Text(msg) => msg,
90 msg => {
91 tracing::warn!("Received ignored message: {:?}", msg);
92 return Err(Error::HandshakeNotText);
93 }
94 };
95 let remote_handshake: Handshake = serde_json::from_str(&msg).map_err(Error::ParseMessage)?;
96 tracing::info!(
97 client_version = %remote_handshake.version.version,
98 client_hash = %remote_handshake.version.git_hash,
99 client_name = %remote_handshake.name,
100 "Client connected"
101 );
102
103 let local_handshake = Handshake {
104 version: plane_version_info(),
105 name,
106 };
107 ws.send(Message::Text(serde_json::to_string(&local_handshake)?))
108 .await?;
109
110 local_handshake.check_compat(&remote_handshake);
111
112 let (outgoing_message_sender, outgoing_message_receiver) = tokio::sync::mpsc::channel(100);
113 let (incoming_message_sender, incoming_message_receiver) = tokio::sync::mpsc::channel(100);
114 tokio::spawn(async move {
115 handle_messages(outgoing_message_receiver, incoming_message_sender, ws).await;
116 });
117
118 Ok(TypedSocket {
119 send: outgoing_message_sender,
120 recv: incoming_message_receiver,
121 remote_handshake,
122 })
123}