plane_common/typed_socket/
mod.rs

1use crate::version::PlaneVersionInfo;
2use crate::PlaneClientError;
3use serde::de::DeserializeOwned;
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::sync::Arc;
7use tokio::sync::mpsc::error::TrySendError;
8use tokio::sync::mpsc::{Receiver, Sender};
9
10pub mod client;
11pub mod server;
12
13pub enum SocketAction<T> {
14    Send(T),
15    Close,
16}
17
18pub trait ChannelMessage: Send + Sync + 'static + DeserializeOwned + Serialize + Debug {
19    type Reply: ChannelMessage<Reply = Self>;
20}
21
22pub struct TypedSocket<T: ChannelMessage> {
23    send: Sender<SocketAction<T>>,
24    recv: Receiver<T::Reply>,
25    pub remote_handshake: Handshake,
26}
27
28#[derive(Clone)]
29pub struct TypedSocketSender<A> {
30    inner_send:
31        Arc<dyn Fn(SocketAction<A>) -> Result<(), TypedSocketError> + 'static + Send + Sync>,
32}
33
34impl<T> Debug for TypedSocketSender<T> {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.write_str("typed socket sender")
37    }
38}
39
40#[derive(Debug, thiserror::Error)]
41pub enum TypedSocketError {
42    #[error("Receiver closed")]
43    Closed,
44    #[error("Receiver queue full")]
45    Clogged,
46}
47
48impl<A> From<TrySendError<A>> for TypedSocketError {
49    fn from(e: TrySendError<A>) -> Self {
50        match e {
51            TrySendError::Full(_) => Self::Clogged,
52            TrySendError::Closed(_) => Self::Closed,
53        }
54    }
55}
56
57impl<A: Debug> TypedSocketSender<A> {
58    pub fn send(&self, message: A) -> Result<(), TypedSocketError> {
59        (self.inner_send)(SocketAction::Send(message))?;
60        Ok(())
61    }
62
63    pub fn close(&mut self) -> Result<(), TypedSocketError> {
64        (self.inner_send)(SocketAction::Close)?;
65        Ok(())
66    }
67}
68
69impl<T: ChannelMessage> TypedSocket<T> {
70    pub fn send(&mut self, message: T) -> Result<(), PlaneClientError> {
71        self.send
72            .try_send(SocketAction::Send(message))
73            .map_err(|_| PlaneClientError::SendFailed)?;
74        Ok(())
75    }
76
77    pub async fn recv(&mut self) -> Option<T::Reply> {
78        self.recv.recv().await
79    }
80
81    pub fn sender<A, F>(&self, transform: F) -> TypedSocketSender<A>
82    where
83        F: (Fn(A) -> T) + 'static + Send + Sync,
84    {
85        let sender = self.send.clone();
86        let inner_send = move |message: SocketAction<A>| {
87            let message = match message {
88                SocketAction::Close => SocketAction::Close,
89                SocketAction::Send(message) => SocketAction::Send(transform(message)),
90            };
91            sender.try_send(message).map_err(|e| e.into())
92        };
93
94        TypedSocketSender {
95            inner_send: Arc::new(inner_send),
96        }
97    }
98
99    pub async fn close(&mut self) {
100        let _ = self.send.send(SocketAction::Close).await;
101    }
102}
103
104#[derive(Serialize, Deserialize, Debug, Clone)]
105pub struct Handshake {
106    pub version: PlaneVersionInfo,
107    pub name: String,
108}
109
110impl Handshake {
111    /// Compare a local and remote handshake, and log a warning if they are not compatible.
112    pub fn check_compat(&self, other: &Handshake) {
113        if self.version.version != other.version.version {
114            tracing::warn!(
115                local_version = self.version.version,
116                remote_version = other.version.version,
117                "Client and server have different Plane versions."
118            );
119        } else if self.version.git_hash != other.version.git_hash {
120            tracing::warn!(
121                local_version = self.version.git_hash,
122                remote_version = other.version.git_hash,
123                "Client and server have different Plane git hashes.",
124            );
125        }
126    }
127}