plane_common/typed_socket/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use crate::version::PlaneVersionInfo;
use crate::PlaneClientError;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::sync::Arc;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::{Receiver, Sender};

pub mod client;
pub mod server;

pub enum SocketAction<T> {
    Send(T),
    Close,
}

pub trait ChannelMessage: Send + Sync + 'static + DeserializeOwned + Serialize + Debug {
    type Reply: ChannelMessage<Reply = Self>;
}

pub struct TypedSocket<T: ChannelMessage> {
    send: Sender<SocketAction<T>>,
    recv: Receiver<T::Reply>,
    pub remote_handshake: Handshake,
}

#[derive(Clone)]
pub struct TypedSocketSender<A> {
    inner_send:
        Arc<dyn Fn(SocketAction<A>) -> Result<(), TypedSocketError> + 'static + Send + Sync>,
}

impl<T> Debug for TypedSocketSender<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str("typed socket sender")
    }
}

#[derive(Debug, thiserror::Error)]
pub enum TypedSocketError {
    #[error("Receiver closed")]
    Closed,
    #[error("Receiver queue full")]
    Clogged,
}

impl<A> From<TrySendError<A>> for TypedSocketError {
    fn from(e: TrySendError<A>) -> Self {
        match e {
            TrySendError::Full(_) => Self::Clogged,
            TrySendError::Closed(_) => Self::Closed,
        }
    }
}

impl<A: Debug> TypedSocketSender<A> {
    pub fn send(&self, message: A) -> Result<(), TypedSocketError> {
        (self.inner_send)(SocketAction::Send(message))?;
        Ok(())
    }

    pub fn close(&mut self) -> Result<(), TypedSocketError> {
        (self.inner_send)(SocketAction::Close)?;
        Ok(())
    }
}

impl<T: ChannelMessage> TypedSocket<T> {
    pub fn send(&mut self, message: T) -> Result<(), PlaneClientError> {
        self.send
            .try_send(SocketAction::Send(message))
            .map_err(|_| PlaneClientError::SendFailed)?;
        Ok(())
    }

    pub async fn recv(&mut self) -> Option<T::Reply> {
        self.recv.recv().await
    }

    pub fn sender<A, F>(&self, transform: F) -> TypedSocketSender<A>
    where
        F: (Fn(A) -> T) + 'static + Send + Sync,
    {
        let sender = self.send.clone();
        let inner_send = move |message: SocketAction<A>| {
            let message = match message {
                SocketAction::Close => SocketAction::Close,
                SocketAction::Send(message) => SocketAction::Send(transform(message)),
            };
            sender.try_send(message).map_err(|e| e.into())
        };

        TypedSocketSender {
            inner_send: Arc::new(inner_send),
        }
    }

    pub async fn close(&mut self) {
        let _ = self.send.send(SocketAction::Close).await;
    }
}

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Handshake {
    pub version: PlaneVersionInfo,
    pub name: String,
}

impl Handshake {
    /// Compare a local and remote handshake, and log a warning if they are not compatible.
    pub fn check_compat(&self, other: &Handshake) {
        if self.version.version != other.version.version {
            tracing::warn!(
                local_version = self.version.version,
                remote_version = other.version.version,
                "Client and server have different Plane versions."
            );
        } else if self.version.git_hash != other.version.git_hash {
            tracing::warn!(
                local_version = self.version.git_hash,
                remote_version = other.version.git_hash,
                "Client and server have different Plane git hashes.",
            );
        }
    }
}