plane_common/typed_socket/
server.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use super::{ChannelMessage, Handshake, SocketAction, TypedSocket};
use crate::version::plane_version_info;
use axum::extract::ws::{CloseFrame, Message, WebSocket};
use tokio::sync::mpsc::{Receiver, Sender};

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error("Handshake message was not text.")]
    HandshakeNotText,

    #[error("Socket closed before handshake received.")]
    SocketClosedBeforeHandshake,

    #[error("Failed to parse message.")]
    ParseMessage(#[from] serde_json::Error),

    #[error("Failed to send message on websocket.")]
    SendMessage(#[from] axum::Error),
}

pub async fn handle_messages<T: ChannelMessage>(
    mut messages_to_send: Receiver<SocketAction<T>>,
    messages_received: Sender<T::Reply>,
    mut socket: WebSocket,
) {
    loop {
        tokio::select! {
            Some(msg) = messages_to_send.recv() => {
                match msg {
                    SocketAction::Send(msg) => {
                        let msg = Message::Text(serde_json::to_string(&msg).expect("Always serializable."));
                        if let Err(err) = socket.send(msg.clone()).await {
                            tracing::error!(?err, message=?msg, "Failed to send message on websocket.");
                        }
                    }
                    SocketAction::Close => {
                        if let Err(err) = socket.close().await {
                            tracing::error!(?err, "Failed to close websocket.");
                        }
                        break;
                    }
                }
            }
            Some(msg) = socket.recv() => {
                let msg = match msg {
                    Ok(Message::Text(msg)) => msg,
                    Err(err) => {
                        tracing::error!(?err, "Failed to receive message from websocket.");
                        break;
                    }
                    Ok(Message::Close(Some(CloseFrame { code: 1001, .. }))) => {
                        tracing::warn!("Websocket connection closed.");
                        break;
                    }
                    msg => {
                        tracing::warn!("Received ignored message: {:?}", msg);
                        continue;
                    }
                };
                let msg: T::Reply = match serde_json::from_str(&msg) {
                    Ok(msg) => msg,
                    Err(err) => {
                        tracing::warn!(?err, "Failed to parse message.");
                        continue;
                    }
                };
                if let Err(err) = messages_received.send(msg).await {
                    tracing::error!(?err, "Failed to receive message.");
                    break;
                }
            }
            else => {
                break;
            }
        }
    }
}

pub async fn new_server<T: ChannelMessage>(
    mut ws: WebSocket,
    name: String,
) -> Result<TypedSocket<T>, Error> {
    let msg = ws
        .recv()
        .await
        .ok_or(Error::SocketClosedBeforeHandshake)?
        .map_err(Error::from)?;
    let msg = match msg {
        Message::Text(msg) => msg,
        msg => {
            tracing::warn!("Received ignored message: {:?}", msg);
            return Err(Error::HandshakeNotText);
        }
    };
    let remote_handshake: Handshake = serde_json::from_str(&msg).map_err(Error::ParseMessage)?;
    tracing::info!(
        client_version = %remote_handshake.version.version,
        client_hash = %remote_handshake.version.git_hash,
        client_name = %remote_handshake.name,
        "Client connected"
    );

    let local_handshake = Handshake {
        version: plane_version_info(),
        name,
    };
    ws.send(Message::Text(serde_json::to_string(&local_handshake)?))
        .await?;

    local_handshake.check_compat(&remote_handshake);

    let (outgoing_message_sender, outgoing_message_receiver) = tokio::sync::mpsc::channel(100);
    let (incoming_message_sender, incoming_message_receiver) = tokio::sync::mpsc::channel(100);
    tokio::spawn(async move {
        handle_messages(outgoing_message_receiver, incoming_message_sender, ws).await;
    });

    Ok(TypedSocket {
        send: outgoing_message_sender,
        recv: incoming_message_receiver,
        remote_handshake,
    })
}