kaspa_wrpc_server/
connection.rs

1use kaspa_grpc_client::{GrpcClient, GrpcClientNotify};
2use kaspa_notify::{
3    connection::Connection as ConnectionT,
4    error::{Error as NotifyError, Result as NotifyResult},
5    listener::ListenerId,
6    notification::Notification as NotificationT,
7    notifier::Notify,
8};
9use kaspa_rpc_core::{api::ops::RpcApiOps, notify::mode::NotificationMode, Notification};
10use std::{
11    fmt::{Debug, Display},
12    sync::{Arc, Mutex},
13};
14use workflow_log::log_trace;
15use workflow_rpc::{
16    server::{prelude::*, result::Result as WrpcResult},
17    types::{MsgT, OpsT},
18};
19use workflow_serializer::prelude::*;
20
21//
22// FIXME: Use workflow_rpc::encoding::Encoding directly in the ConnectionT implementation by deriving Hash, Eq and PartialEq in situ
23//
24#[derive(Clone, Debug, Hash, Eq, PartialEq)]
25pub enum NotifyEncoding {
26    Borsh,
27    SerdeJson,
28}
29impl From<Encoding> for NotifyEncoding {
30    fn from(value: Encoding) -> Self {
31        match value {
32            Encoding::Borsh => NotifyEncoding::Borsh,
33            Encoding::SerdeJson => NotifyEncoding::SerdeJson,
34        }
35    }
36}
37impl From<NotifyEncoding> for Encoding {
38    fn from(value: NotifyEncoding) -> Self {
39        match value {
40            NotifyEncoding::Borsh => Encoding::Borsh,
41            NotifyEncoding::SerdeJson => Encoding::SerdeJson,
42        }
43    }
44}
45
46#[derive(Debug)]
47struct ConnectionInner {
48    pub id: u64,
49    pub peer: SocketAddr,
50    pub messenger: Arc<Messenger>,
51    pub grpc_client: Option<Arc<GrpcClient>>,
52    // not using an atomic in case an Id will change type in the future...
53    pub listener_id: Mutex<Option<ListenerId>>,
54}
55
56impl ConnectionInner {
57    fn send(&self, message: Message) -> crate::result::Result<()> {
58        Ok(self.messenger.send_raw_message(message)?)
59    }
60}
61
62impl Notify<Notification> for ConnectionInner {
63    fn notify(&self, notification: Notification) -> NotifyResult<()> {
64        self.send(Connection::into_message(&notification, &self.messenger.encoding().into()))
65            .map_err(|err| NotifyError::General(err.to_string()))
66    }
67}
68
69impl Display for ConnectionInner {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(f, "{}@{}", self.id, self.peer)
72    }
73}
74
75/// [`Connection`] represents a currently connected WebSocket RPC channel.
76/// This struct owns a [`Messenger`] that has [`Messenger::notify`]
77/// function that can be used to post notifications to the connection.
78/// [`Messenger::close`] function can be used to terminate the connection
79/// asynchronously.
80#[derive(Debug, Clone)]
81pub struct Connection {
82    inner: Arc<ConnectionInner>,
83}
84
85impl Connection {
86    pub fn new(id: u64, peer: &SocketAddr, messenger: Arc<Messenger>, grpc_client: Option<Arc<GrpcClient>>) -> Connection {
87        // If a GrpcClient is provided, it has to come configured in direct mode
88        assert!(grpc_client.is_none() || grpc_client.as_ref().unwrap().notification_mode() == NotificationMode::Direct);
89        // Should a gRPC client be provided, no listener_id is required for subscriptions so the listener id is set to default
90        let listener_id = Mutex::new(grpc_client.clone().map(|_| ListenerId::default()));
91        Connection { inner: Arc::new(ConnectionInner { id, peer: *peer, messenger, grpc_client, listener_id }) }
92    }
93
94    /// Obtain the connection id
95    pub fn id(&self) -> u64 {
96        self.inner.id
97    }
98
99    /// Get a reference to the connection [`Messenger`]
100    pub fn messenger(&self) -> &Arc<Messenger> {
101        &self.inner.messenger
102    }
103
104    pub fn grpc_client(&self) -> Arc<GrpcClient> {
105        self.inner
106            .grpc_client
107            .as_ref()
108            .cloned()
109            .unwrap_or_else(|| panic!("Incorrect use: `server::Connection` does not carry RpcApi references"))
110    }
111
112    pub fn grpc_client_notify_target(&self) -> GrpcClientNotify {
113        self.inner.clone()
114    }
115
116    pub fn listener_id(&self) -> Option<ListenerId> {
117        *self.inner.listener_id.lock().unwrap()
118    }
119
120    pub fn register_notification_listener(&self, listener_id: ListenerId) {
121        self.inner.listener_id.lock().unwrap().replace(listener_id);
122    }
123
124    pub fn peer(&self) -> &SocketAddr {
125        &self.inner.peer
126    }
127
128    /// Creates a WebSocket [`Message`] that can be posted to the connection ([`Messenger`]) sink
129    /// directly.
130    pub fn create_serialized_notification_message<Ops, Msg>(encoding: Encoding, op: Ops, msg: Msg) -> WrpcResult<Message>
131    where
132        Ops: OpsT,
133        Msg: MsgT,
134    {
135        match encoding {
136            Encoding::Borsh => workflow_rpc::server::protocol::borsh::create_serialized_notification_message(op, msg),
137            Encoding::SerdeJson => workflow_rpc::server::protocol::serde_json::create_serialized_notification_message(op, msg),
138        }
139    }
140}
141
142impl Display for Connection {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        write!(f, "{}", self.inner)
145    }
146}
147
148#[async_trait::async_trait]
149impl ConnectionT for Connection {
150    type Notification = Notification;
151    type Message = Message;
152    type Encoding = NotifyEncoding;
153    type Error = kaspa_notify::error::Error;
154
155    fn encoding(&self) -> Self::Encoding {
156        self.messenger().encoding().into()
157    }
158
159    fn into_message(notification: &Self::Notification, encoding: &Self::Encoding) -> Self::Message {
160        let op: RpcApiOps = notification.event_type().into();
161        Self::create_serialized_notification_message(encoding.clone().into(), op, Serializable(notification.clone())).unwrap()
162    }
163
164    async fn send(&self, message: Self::Message) -> core::result::Result<(), Self::Error> {
165        self.inner.send(message).map_err(|err| NotifyError::General(err.to_string()))
166    }
167
168    fn close(&self) -> bool {
169        if !self.is_closed() {
170            if let Err(err) = self.messenger().close() {
171                log_trace!("Error closing connection {}: {}", self.peer(), err);
172            } else {
173                return true;
174            }
175        }
176        false
177    }
178
179    fn is_closed(&self) -> bool {
180        self.messenger().sink().is_closed()
181    }
182}
183
184pub type ConnectionReference = Arc<Connection>;