ic_websocket_cdk/
types.rs

1use std::{collections::VecDeque, fmt, time::Duration};
2
3use candid::{decode_one, CandidType, Principal};
4use serde::{Deserialize, Serialize};
5use serde_cbor::Serializer;
6
7use crate::{
8    custom_trap, errors::WsError, utils::get_current_time, CLIENT_KEEP_ALIVE_TIMEOUT_MS,
9    DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES, DEFAULT_SEND_ACK_INTERVAL_MS,
10    INITIAL_OUTGOING_MESSAGE_NONCE,
11};
12
13pub type ClientPrincipal = Principal;
14#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug, Hash)]
15pub struct ClientKey {
16    pub client_principal: ClientPrincipal,
17    pub client_nonce: u64,
18}
19
20impl ClientKey {
21    /// Creates a new instance of ClientKey.
22    pub fn new(client_principal: ClientPrincipal, client_nonce: u64) -> Self {
23        Self {
24            client_principal,
25            client_nonce,
26        }
27    }
28}
29
30impl fmt::Display for ClientKey {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        write!(f, "{}_{}", self.client_principal, self.client_nonce)
33    }
34}
35
36/// The result of [ws_open](crate::ws_open).
37pub type CanisterWsOpenResult = Result<(), String>;
38/// The result of [ws_close](crate::ws_close).
39pub type CanisterWsCloseResult = Result<(), String>;
40/// The result of [ws_message](crate::ws_message).
41pub type CanisterWsMessageResult = Result<(), String>;
42/// The result of [ws_get_messages](crate::ws_get_messages).
43pub type CanisterWsGetMessagesResult = Result<CanisterOutputCertifiedMessages, String>;
44/// The result of [send](crate::send).
45pub type CanisterSendResult = Result<(), String>;
46#[deprecated(since = "0.3.2", note = "use `CanisterSendResult` instead")]
47pub type CanisterWsSendResult = Result<(), String>;
48/// The result of [close](crate::close).
49pub type CanisterCloseResult = Result<(), String>;
50
51/// The arguments for [ws_open](crate::ws_open).
52#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
53pub struct CanisterWsOpenArguments {
54    pub client_nonce: u64,
55    pub gateway_principal: GatewayPrincipal,
56}
57
58impl CanisterWsOpenArguments {
59    pub fn new(client_nonce: u64, gateway_principal: GatewayPrincipal) -> Self {
60        Self {
61            client_nonce,
62            gateway_principal,
63        }
64    }
65}
66/// The arguments for [ws_close](crate::ws_close).
67#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
68pub struct CanisterWsCloseArguments {
69    pub client_key: ClientKey,
70}
71
72impl CanisterWsCloseArguments {
73    pub fn new(client_key: ClientKey) -> Self {
74        Self { client_key }
75    }
76}
77
78/// The arguments for [ws_message](crate::ws_message).
79#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
80pub struct CanisterWsMessageArguments {
81    pub msg: WebsocketMessage,
82}
83
84impl CanisterWsMessageArguments {
85    pub fn new(msg: WebsocketMessage) -> Self {
86        Self { msg }
87    }
88}
89
90/// The arguments for [ws_get_messages](crate::ws_get_messages).
91#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
92pub struct CanisterWsGetMessagesArguments {
93    pub nonce: u64,
94}
95
96impl CanisterWsGetMessagesArguments {
97    pub fn new(nonce: u64) -> Self {
98        Self { nonce }
99    }
100}
101
102/// Messages exchanged through the WebSocket.
103///
104/// **Note:** You should only use this struct in tests.
105#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
106pub struct WebsocketMessage {
107    pub client_key: ClientKey, // The client that the gateway will forward the message to or that sent the message.
108    pub sequence_num: u64,     // Both ways, messages should arrive with sequence numbers 0, 1, 2...
109    pub timestamp: TimestampNs, // Timestamp of when the message was made for the recipient to inspect.
110    pub is_service_message: bool, // Whether the message is a service message sent by the CDK to the client or vice versa.
111    #[serde(with = "serde_bytes")]
112    pub content: Vec<u8>, // Application message encoded in binary.
113}
114
115impl WebsocketMessage {
116    pub fn new(
117        client_key: ClientKey,
118        sequence_num: u64,
119        timestamp: TimestampNs,
120        is_service_message: bool,
121        content: Vec<u8>,
122    ) -> Self {
123        Self {
124            client_key,
125            sequence_num,
126            timestamp,
127            is_service_message,
128            content,
129        }
130    }
131
132    /// Serializes the message into a Vec<u8>, using CBOR.
133    pub fn cbor_serialize(&self) -> Result<Vec<u8>, String> {
134        let mut data = vec![];
135        let mut serializer = Serializer::new(&mut data);
136        serializer.self_describe().map_err(|e| e.to_string())?;
137        self.serialize(&mut serializer).map_err(|e| e.to_string())?;
138        Ok(data)
139    }
140}
141
142/// Element of the list of messages returned to the WS Gateway after polling.
143///
144/// **Note:** You should only use this struct in tests.
145#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
146pub struct CanisterOutputMessage {
147    pub client_key: ClientKey, // The client that the gateway will forward the message to or that sent the message.
148    pub key: String,           // Key for certificate verification.
149    #[serde(with = "serde_bytes")]
150    pub content: Vec<u8>, // The message to be relayed, that contains the application message.
151}
152
153/// List of messages returned to the WS Gateway after polling.
154///
155/// **Note:** You should only use this struct in tests.
156#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
157pub struct CanisterOutputCertifiedMessages {
158    pub messages: Vec<CanisterOutputMessage>, // List of messages.
159    #[serde(with = "serde_bytes")]
160    pub cert: Vec<u8>, // cert+tree constitute the certificate for all returned messages.
161    #[serde(with = "serde_bytes")]
162    pub tree: Vec<u8>, // cert+tree constitute the certificate for all returned messages.
163    pub is_end_of_queue: bool, // Whether the end of the messages queue has been reached.
164}
165
166impl CanisterOutputCertifiedMessages {
167    pub fn empty() -> Self {
168        Self {
169            messages: vec![],
170            cert: vec![],
171            tree: vec![],
172            is_end_of_queue: true,
173        }
174    }
175}
176
177pub(crate) struct MessagesForGatewayRange {
178    pub start_index: usize,
179    pub end_index: usize,
180    pub is_end_of_queue: bool,
181}
182
183pub(crate) type TimestampNs = u64;
184
185#[derive(Clone, Debug, Default, Eq, PartialEq)]
186pub(crate) struct MessageToDelete {
187    timestamp: TimestampNs,
188}
189
190pub(crate) type GatewayPrincipal = Principal;
191
192/// Contains data about the registered WS Gateway.
193#[derive(Clone, Debug, Default, Eq, PartialEq)]
194pub(crate) struct RegisteredGateway {
195    /// The queue of the messages that the gateway can poll.
196    pub(crate) messages_queue: VecDeque<CanisterOutputMessage>,
197    /// The queue of messages' keys to delete.
198    pub(crate) messages_to_delete: VecDeque<MessageToDelete>,
199    /// Keeps track of the nonce which:
200    /// - the WS Gateway uses to specify the first index of the certified messages to be returned when polling
201    /// - the client uses as part of the path in the Merkle tree in order to verify the certificate of the messages relayed by the WS Gateway
202    pub(crate) outgoing_message_nonce: u64,
203    /// The number of clients connected to this gateway.
204    pub(crate) connected_clients_count: u64,
205}
206
207impl RegisteredGateway {
208    /// Creates a new instance of RegisteredGateway.
209    pub(crate) fn new() -> Self {
210        Self {
211            messages_queue: VecDeque::new(),
212            messages_to_delete: VecDeque::new(),
213            outgoing_message_nonce: INITIAL_OUTGOING_MESSAGE_NONCE,
214            connected_clients_count: 0,
215        }
216    }
217
218    /// Increments the outgoing message nonce by 1.
219    pub(crate) fn increment_nonce(&mut self) {
220        self.outgoing_message_nonce += 1;
221    }
222
223    /// Increments the connected clients count by 1.
224    pub(crate) fn increment_clients_count(&mut self) {
225        self.connected_clients_count += 1;
226    }
227
228    /// Decrements the connected clients count by 1, returning the new value.
229    pub(crate) fn decrement_clients_count(&mut self) -> u64 {
230        self.connected_clients_count = self.connected_clients_count.saturating_sub(1);
231        self.connected_clients_count
232    }
233
234    /// Adds the message to the queue and its metadata to the `messages_to_delete` queue.
235    pub(crate) fn add_message_to_queue(
236        &mut self,
237        message: CanisterOutputMessage,
238        message_timestamp: TimestampNs,
239    ) {
240        self.messages_queue.push_back(message.clone());
241        self.messages_to_delete.push_back(MessageToDelete {
242            timestamp: message_timestamp,
243        });
244    }
245
246    /// Deletes the oldest `n` messages that are older than `message_max_age_ms` from the queue.
247    ///
248    /// Returns the deleted messages keys.
249    pub(crate) fn delete_old_messages(&mut self, n: usize, message_max_age_ms: u64) -> Vec<String> {
250        let time = get_current_time();
251        let mut deleted_keys = vec![];
252
253        for _ in 0..n {
254            if let Some(message_to_delete) = self.messages_to_delete.front() {
255                if Duration::from_nanos(time - message_to_delete.timestamp)
256                    > Duration::from_millis(message_max_age_ms)
257                {
258                    // unwrap is safe because there is no case in which the messages_to_delete queue is populated
259                    // while the messages_queue is empty
260                    let deleted_message = self.messages_queue.pop_front().unwrap();
261                    deleted_keys.push(deleted_message.key.clone());
262                    self.messages_to_delete.pop_front();
263                } else {
264                    // In this case, no messages can be deleted because
265                    // they're all not older than `message_max_age_ms`.
266                    break;
267                }
268            } else {
269                // There are no messages in the queue. Shouldn't happen.
270                break;
271            }
272        }
273
274        deleted_keys
275    }
276}
277
278/// The metadata about a registered client.
279#[derive(Clone, Debug, Eq, PartialEq)]
280pub(crate) struct RegisteredClient {
281    pub(crate) last_keep_alive_timestamp: TimestampNs,
282    pub(crate) gateway_principal: GatewayPrincipal,
283}
284
285impl RegisteredClient {
286    /// Creates a new instance of RegisteredClient.
287    pub(crate) fn new(gateway_principal: GatewayPrincipal) -> Self {
288        Self {
289            last_keep_alive_timestamp: get_current_time(),
290            gateway_principal,
291        }
292    }
293
294    /// Gets the last keep alive timestamp.
295    pub(crate) fn get_last_keep_alive_timestamp(&self) -> TimestampNs {
296        self.last_keep_alive_timestamp
297    }
298
299    /// Set the last keep alive timestamp to the current time.
300    pub(crate) fn update_last_keep_alive_timestamp(&mut self) {
301        self.last_keep_alive_timestamp = get_current_time();
302    }
303}
304
305/// **Note:** You should only use this struct in tests.
306#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
307pub struct CanisterOpenMessageContent {
308    pub client_key: ClientKey,
309}
310
311/// **Note:** You should only use this struct in tests.
312#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
313pub struct CanisterAckMessageContent {
314    pub last_incoming_sequence_num: u64,
315}
316
317/// **Note:** You should only use this struct in tests.
318#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
319pub struct ClientKeepAliveMessageContent {
320    pub last_incoming_sequence_num: u64,
321}
322
323/// **Note:** You should only use this struct in tests.
324#[derive(CandidType, Clone, Debug, Deserialize, PartialEq, Eq)]
325pub enum CloseMessageReason {
326    /// When the canister receives a wrong sequence number from the client.
327    WrongSequenceNumber,
328    /// When the canister receives an invalid service message from the client.
329    InvalidServiceMessage,
330    /// When the canister doesn't receive the keep alive message from the client in time.
331    KeepAliveTimeout,
332    /// When the developer calls the `close` function.
333    ClosedByApplication,
334}
335
336/// **Note:** You should only use this struct in tests.
337#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
338pub struct CanisterCloseMessageContent {
339    pub reason: CloseMessageReason,
340}
341
342/// A service message sent by the CDK to the client or vice versa.
343///
344/// **Note:** You should only use this struct in tests.
345#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
346pub enum WebsocketServiceMessageContent {
347    /// Message sent by the **canister** when a client opens a connection.
348    OpenMessage(CanisterOpenMessageContent),
349    /// Message sent _periodically_ by the **canister** to the client to acknowledge the messages received.
350    AckMessage(CanisterAckMessageContent),
351    /// Message sent by the **client** in response to an acknowledgement message from the canister.
352    KeepAliveMessage(ClientKeepAliveMessageContent),
353    /// Message sent by the **canister** when it wants to close the connection.
354    CloseMessage(CanisterCloseMessageContent),
355}
356
357impl WebsocketServiceMessageContent {
358    pub fn from_candid_bytes(bytes: &[u8]) -> Result<Self, String> {
359        decode_one(&bytes).map_err(|err| WsError::DecodeServiceMessageContent { err }.to_string())
360    }
361}
362
363/// Arguments passed to the `on_open` handler.
364pub struct OnOpenCallbackArgs {
365    pub client_principal: ClientPrincipal,
366}
367/// Handler initialized by the canister
368/// and triggered by the CDK once the IC WebSocket connection is established.
369type OnOpenCallback = fn(OnOpenCallbackArgs);
370
371/// Arguments passed to the `on_message` handler.
372/// The `message` argument is the message received from the client, serialized in Candid.
373/// To deserialize the message, use [candid::decode_one].
374///
375/// # Example
376/// This example is the deserialize equivalent of the [send's example](fn.send.html#example) serialize one.
377/// ```rust
378/// use candid::{decode_one, CandidType};
379/// use ic_websocket_cdk::OnMessageCallbackArgs;
380/// use serde::Deserialize;
381///
382/// #[derive(CandidType, Deserialize)]
383/// struct MyMessage {
384///     some_field: String,
385/// }
386///
387/// fn on_message(args: OnMessageCallbackArgs) {
388///     let received_message: MyMessage = decode_one(&args.message).unwrap();
389///
390///     println!("Received message: some_field: {:?}", received_message.some_field);
391/// }
392/// ```
393pub struct OnMessageCallbackArgs {
394    /// The principal of the client sending the message to the canister.
395    pub client_principal: ClientPrincipal,
396    /// The message received from the client, serialized in Candid. See [OnMessageCallbackArgs] for an example on how to deserialize the message.
397    pub message: Vec<u8>,
398}
399/// Handler initialized by the canister
400/// and triggered by the CDK once an IC WebSocket message is received.
401type OnMessageCallback = fn(OnMessageCallbackArgs);
402
403/// Arguments passed to the `on_close` handler.
404pub struct OnCloseCallbackArgs {
405    pub client_principal: ClientPrincipal,
406}
407/// Handler initialized by the canister
408/// and triggered by the CDK once the WS Gateway closes the IC WebSocket connection
409/// for that client.
410///
411/// Make sure you **don't** call the [close](crate::close) function in this callback.
412type OnCloseCallback = fn(OnCloseCallbackArgs);
413
414/// Handlers initialized by the canister and triggered by the CDK.
415///
416/// **Note**: if the callbacks that you define here trap for some reason,
417/// the CDK will disconnect the client with principal `args.client_principal`.
418/// However, the client **won't** be notified
419/// until at least the next time it will try to send a message to the canister.
420#[derive(Clone, Debug, Default, PartialEq)]
421pub struct WsHandlers {
422    pub on_open: Option<OnOpenCallback>,
423    pub on_message: Option<OnMessageCallback>,
424    pub on_close: Option<OnCloseCallback>,
425}
426
427impl WsHandlers {
428    pub(crate) fn call_on_open(&self, args: OnOpenCallbackArgs) {
429        if let Some(on_open) = self.on_open {
430            // we don't have to recover from errors here,
431            // we just let the canister trap
432            on_open(args);
433        }
434    }
435
436    pub(crate) fn call_on_message(&self, args: OnMessageCallbackArgs) {
437        if let Some(on_message) = self.on_message {
438            // see call_on_open
439            on_message(args);
440        }
441    }
442
443    pub(crate) fn call_on_close(&self, args: OnCloseCallbackArgs) {
444        if let Some(on_close) = self.on_close {
445            // see call_on_open
446            on_close(args);
447        }
448    }
449}
450
451/// Parameters for the IC WebSocket CDK initialization. For default parameters and simpler initialization, use [`WsInitParams::new`].
452#[derive(Clone)]
453pub struct WsInitParams {
454    /// The callback handlers for the WebSocket.
455    pub handlers: WsHandlers,
456    /// The maximum number of messages to be returned in a polling iteration.
457    ///
458    /// Defaults to `50`.
459    pub max_number_of_returned_messages: usize,
460    /// The interval at which to send an acknowledgement message to the client,
461    /// so that the client knows that all the messages it sent have been received by the canister (in milliseconds).
462    ///
463    /// Must be greater than [`CLIENT_KEEP_ALIVE_TIMEOUT_MS`] (1 minute).
464    ///
465    /// Defaults to `300_000` (5 minutes).
466    pub send_ack_interval_ms: u64,
467}
468
469impl WsInitParams {
470    /// Creates a new instance of WsInitParams, with default interval values.
471    pub fn new(handlers: WsHandlers) -> Self {
472        Self {
473            handlers,
474            ..Default::default()
475        }
476    }
477
478    pub(crate) fn get_handlers(&self) -> WsHandlers {
479        self.handlers.clone()
480    }
481
482    /// Checks the validity of the timer parameters.
483    /// `send_ack_interval_ms` must be greater than [`CLIENT_KEEP_ALIVE_TIMEOUT_MS`].
484    ///
485    /// # Traps
486    /// If `send_ack_interval_ms` <= [`CLIENT_KEEP_ALIVE_TIMEOUT_MS`].
487    pub(crate) fn check_validity(&self) {
488        if self.send_ack_interval_ms <= CLIENT_KEEP_ALIVE_TIMEOUT_MS {
489            custom_trap!("send_ack_interval_ms must be greater than CLIENT_KEEP_ALIVE_TIMEOUT_MS");
490        }
491    }
492
493    pub fn with_max_number_of_returned_messages(
494        mut self,
495        max_number_of_returned_messages: usize,
496    ) -> Self {
497        self.max_number_of_returned_messages = max_number_of_returned_messages;
498        self
499    }
500
501    /// Sets the interval (in milliseconds) at which to send an acknowledgement message
502    /// to the connected clients.
503    ///
504    /// Must be greater than [`CLIENT_KEEP_ALIVE_TIMEOUT_MS`] (1 minute).
505    ///
506    /// # Traps
507    /// If `send_ack_interval_ms` <= [`CLIENT_KEEP_ALIVE_TIMEOUT_MS`]. See [WsInitParams::check_validity].
508    pub fn with_send_ack_interval_ms(mut self, send_ack_interval_ms: u64) -> Self {
509        self.send_ack_interval_ms = send_ack_interval_ms;
510        self.check_validity();
511        self
512    }
513}
514
515impl Default for WsInitParams {
516    fn default() -> Self {
517        Self {
518            handlers: WsHandlers::default(),
519            max_number_of_returned_messages: DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES,
520            send_ack_interval_ms: DEFAULT_SEND_ACK_INTERVAL_MS,
521        }
522    }
523}