1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
use super::{ChannelMessage, Handshake, SocketAction, TypedSocket};
use crate::plane_version_info;
use anyhow::{anyhow, Context, Result};
use axum::extract::ws::{Message, WebSocket};
use tokio::sync::mpsc::{Receiver, Sender};

pub async fn handle_messages<T: ChannelMessage>(
    mut messages_to_send: Receiver<SocketAction<T>>,
    messages_received: Sender<T::Reply>,
    mut socket: WebSocket,
) {
    loop {
        tokio::select! {
            Some(msg) = messages_to_send.recv() => {
                match msg {
                    SocketAction::Send(msg) => {
                        let msg = Message::Text(serde_json::to_string(&msg).expect("Always serializable."));
                        if let Err(err) = socket.send(msg).await {
                            tracing::error!(?err, "Failed to send message on websocket.");
                        }
                    }
                    SocketAction::Close => {
                        if let Err(err) = socket.close().await {
                            tracing::error!(?err, "Failed to close websocket.");
                        }
                        break;
                    }
                }
            }
            Some(msg) = socket.recv() => {
                let msg = match msg {
                    Ok(Message::Text(msg)) => msg,
                    msg => {
                        tracing::warn!("Received ignored message: {:?}", msg);
                        continue;
                    }
                };
                let msg: T::Reply = match serde_json::from_str(&msg) {
                    Ok(msg) => msg,
                    Err(err) => {
                        tracing::warn!(?err, "Failed to parse message.");
                        continue;
                    }
                };
                if let Err(err) = messages_received.send(msg).await {
                    tracing::error!(?err, "Failed to receive message.");
                }
            }
            else => {
                break;
            }
        }
    }
}

pub async fn new_server<T: ChannelMessage>(
    mut ws: WebSocket,
    name: String,
) -> Result<TypedSocket<T>> {
    let msg = ws
        .recv()
        .await
        .ok_or_else(|| anyhow!("Socket closed before handshake received."))??;
    let msg = match msg {
        Message::Text(msg) => msg,
        msg => {
            tracing::warn!("Received ignored message: {:?}", msg);
            return Err(anyhow!("Handshake message was not text."));
        }
    };
    let remote_handshake: Handshake =
        serde_json::from_str(&msg).context("Parsing handshake from client.")?;
    tracing::info!(
        client_version = %remote_handshake.version.version,
        client_hash = %remote_handshake.version.git_hash,
        client_name = %remote_handshake.name,
        "Client connected"
    );

    let local_handshake = Handshake {
        version: plane_version_info(),
        name,
    };
    ws.send(Message::Text(serde_json::to_string(&local_handshake)?))
        .await?;

    local_handshake.check_compat(&remote_handshake);

    let (outgoing_message_sender, outgoing_message_receiver) = tokio::sync::mpsc::channel(100);
    let (incoming_message_sender, incoming_message_receiver) = tokio::sync::mpsc::channel(100);
    tokio::spawn(async move {
        handle_messages(outgoing_message_receiver, incoming_message_sender, ws).await;
    });

    Ok(TypedSocket {
        send: outgoing_message_sender,
        recv: incoming_message_receiver,
        remote_handshake,
    })
}