plane_common/typed_socket/
server.rs

1use super::{ChannelMessage, Handshake, SocketAction, TypedSocket};
2use crate::version::plane_version_info;
3use axum::extract::ws::{CloseFrame, Message, WebSocket};
4use tokio::sync::mpsc::{Receiver, Sender};
5
6#[derive(Debug, thiserror::Error)]
7pub enum Error {
8    #[error("Handshake message was not text.")]
9    HandshakeNotText,
10
11    #[error("Socket closed before handshake received.")]
12    SocketClosedBeforeHandshake,
13
14    #[error("Failed to parse message.")]
15    ParseMessage(#[from] serde_json::Error),
16
17    #[error("Failed to send message on websocket.")]
18    SendMessage(#[from] axum::Error),
19}
20
21pub async fn handle_messages<T: ChannelMessage>(
22    mut messages_to_send: Receiver<SocketAction<T>>,
23    messages_received: Sender<T::Reply>,
24    mut socket: WebSocket,
25) {
26    loop {
27        tokio::select! {
28            Some(msg) = messages_to_send.recv() => {
29                match msg {
30                    SocketAction::Send(msg) => {
31                        let msg = Message::Text(serde_json::to_string(&msg).expect("Always serializable."));
32                        if let Err(err) = socket.send(msg.clone()).await {
33                            tracing::error!(?err, message=?msg, "Failed to send message on websocket.");
34                        }
35                    }
36                    SocketAction::Close => {
37                        if let Err(err) = socket.close().await {
38                            tracing::error!(?err, "Failed to close websocket.");
39                        }
40                        break;
41                    }
42                }
43            }
44            Some(msg) = socket.recv() => {
45                let msg = match msg {
46                    Ok(Message::Text(msg)) => msg,
47                    Err(err) => {
48                        tracing::error!(?err, "Failed to receive message from websocket.");
49                        break;
50                    }
51                    Ok(Message::Close(Some(CloseFrame { code: 1001, .. }))) => {
52                        tracing::warn!("Websocket connection closed.");
53                        break;
54                    }
55                    msg => {
56                        tracing::warn!("Received ignored message: {:?}", msg);
57                        continue;
58                    }
59                };
60                let msg: T::Reply = match serde_json::from_str(&msg) {
61                    Ok(msg) => msg,
62                    Err(err) => {
63                        tracing::warn!(?err, "Failed to parse message.");
64                        continue;
65                    }
66                };
67                if let Err(err) = messages_received.send(msg).await {
68                    tracing::error!(?err, "Failed to receive message.");
69                    break;
70                }
71            }
72            else => {
73                break;
74            }
75        }
76    }
77}
78
79pub async fn new_server<T: ChannelMessage>(
80    mut ws: WebSocket,
81    name: String,
82) -> Result<TypedSocket<T>, Error> {
83    let msg = ws
84        .recv()
85        .await
86        .ok_or(Error::SocketClosedBeforeHandshake)?
87        .map_err(Error::from)?;
88    let msg = match msg {
89        Message::Text(msg) => msg,
90        msg => {
91            tracing::warn!("Received ignored message: {:?}", msg);
92            return Err(Error::HandshakeNotText);
93        }
94    };
95    let remote_handshake: Handshake = serde_json::from_str(&msg).map_err(Error::ParseMessage)?;
96    tracing::info!(
97        client_version = %remote_handshake.version.version,
98        client_hash = %remote_handshake.version.git_hash,
99        client_name = %remote_handshake.name,
100        "Client connected"
101    );
102
103    let local_handshake = Handshake {
104        version: plane_version_info(),
105        name,
106    };
107    ws.send(Message::Text(serde_json::to_string(&local_handshake)?))
108        .await?;
109
110    local_handshake.check_compat(&remote_handshake);
111
112    let (outgoing_message_sender, outgoing_message_receiver) = tokio::sync::mpsc::channel(100);
113    let (incoming_message_sender, incoming_message_receiver) = tokio::sync::mpsc::channel(100);
114    tokio::spawn(async move {
115        handle_messages(outgoing_message_receiver, incoming_message_sender, ws).await;
116    });
117
118    Ok(TypedSocket {
119        send: outgoing_message_sender,
120        recv: incoming_message_receiver,
121        remote_handshake,
122    })
123}