ic_websocket_cdk/
lib.rs

1use candid::{CandidType, Principal};
2use errors::WsError;
3
4use ic_cdk::api::msg_caller;
5use serde::Deserialize;
6
7mod errors;
8mod state;
9mod tests;
10mod timers;
11pub mod types;
12mod utils;
13
14use state::*;
15use timers::*;
16#[allow(deprecated)]
17pub use types::CanisterWsSendResult;
18use types::*;
19pub use types::{
20    CanisterCloseResult, CanisterSendResult, CanisterWsCloseArguments, CanisterWsCloseResult,
21    CanisterWsGetMessagesArguments, CanisterWsGetMessagesResult, CanisterWsMessageArguments,
22    CanisterWsMessageResult, CanisterWsOpenArguments, CanisterWsOpenResult, ClientPrincipal,
23    OnCloseCallbackArgs, OnMessageCallbackArgs, OnOpenCallbackArgs, WsHandlers, WsInitParams,
24};
25
26/// The label used when constructing the certification tree.
27const LABEL_WEBSOCKET: &[u8] = b"websocket";
28
29/// The default maximum number of messages returned by [ws_get_messages] at each poll.
30const DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES: usize = 50;
31/// The default interval at which to send acknowledgements to the client.
32const DEFAULT_SEND_ACK_INTERVAL_MS: u64 = 300_000; // 5 minutes
33/// The maximum communication latency allowed between the client and the canister.
34const COMMUNICATION_LATENCY_BOUND_MS: u64 = 30_000; // 30 seconds
35/// The default timeout to wait for the client to send a keep alive after receiving an acknowledgement.
36const CLIENT_KEEP_ALIVE_TIMEOUT_MS: u64 = 2 * COMMUNICATION_LATENCY_BOUND_MS;
37/// Same as [CLIENT_KEEP_ALIVE_TIMEOUT_MS], but in nanoseconds.
38const CLIENT_KEEP_ALIVE_TIMEOUT_NS: u64 = CLIENT_KEEP_ALIVE_TIMEOUT_MS * 1_000_000;
39
40/// The initial nonce for outgoing messages.
41const INITIAL_OUTGOING_MESSAGE_NONCE: u64 = 0;
42/// The initial sequence number to expect from messages coming from clients.
43/// The first message coming from the client will have sequence number `1` because on the client the sequence number is incremented before sending the message.
44const INITIAL_CLIENT_SEQUENCE_NUM: u64 = 1;
45/// The initial sequence number for outgoing messages.
46const INITIAL_CANISTER_SEQUENCE_NUM: u64 = 0;
47
48/// The number of messages to delete from the outgoing messages queue every time a new message is added.
49const MESSAGES_TO_DELETE_COUNT: usize = 5;
50
51/// Initialize the CDK.
52///
53/// **Note**: Restarts the acknowledgement timers under the hood.
54///
55/// # Traps
56/// If the parameters are invalid.
57pub fn init(params: WsInitParams) {
58    // check if the parameters are valid
59    params.check_validity();
60
61    // set the handlers specified by the canister that the CDK uses to manage the IC WebSocket connection
62    set_params(params.clone());
63
64    // cancel possibly running timers
65    cancel_timers();
66
67    // schedule a timer that will send an acknowledgement message to clients
68    schedule_send_ack_to_clients();
69}
70
71/// Handles the WS connection open event sent by the client and relayed by the Gateway.
72pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult {
73    let caller = msg_caller();
74    // anonymous clients cannot open a connection
75    caller
76        .ne(&Principal::anonymous())
77        .then_some(())
78        .ok_or_else(|| WsError::AnonymousPrincipalNotAllowed.to_string())?;
79
80    let client_key = ClientKey::new(caller, args.client_nonce);
81    // check if client is not registered yet
82    // by swapping the result of the check_registered_client_exists function
83    check_registered_client_exists(&client_key).map_or(Ok(()), |_| {
84        WsError::ClientKeyAlreadyConnected {
85            client_key: &client_key,
86        }
87        .to_string_result()
88    })?;
89
90    // check if there's a client already registered with the same principal
91    // and remove it if there is
92    match get_client_key_from_principal(&client_key.client_principal) {
93        Err(_) => {
94            // Do nothing
95        },
96        Ok(old_client_key) => {
97            remove_client(&old_client_key, None);
98        },
99    };
100
101    // initialize client maps
102    let new_client = RegisteredClient::new(args.gateway_principal);
103    add_client(client_key.clone(), new_client);
104
105    // send the open message
106    let open_message = CanisterOpenMessageContent {
107        client_key: client_key.clone(),
108    };
109    let message = WebsocketServiceMessageContent::OpenMessage(open_message);
110    send_service_message_to_client(&client_key, &message)?;
111
112    // call the on_open handler initialized in init()
113    get_handlers_from_params().call_on_open(OnOpenCallbackArgs {
114        client_principal: client_key.client_principal,
115    });
116
117    Ok(())
118}
119
120/// Handles the WS connection close event received from the WS Gateway.
121///
122/// If you want to close the connection with the client in your logic,
123/// use the [close] function instead.
124pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult {
125    let gateway_principal = msg_caller();
126
127    // check if the gateway is registered
128    check_is_gateway_registered(&gateway_principal)?;
129
130    // check if client registered itself by calling ws_open
131    check_registered_client_exists(&args.client_key)?;
132
133    // check if the client is registered to the gateway that is closing the connection
134    check_client_registered_to_gateway(&args.client_key, &gateway_principal)?;
135
136    remove_client(&args.client_key, None);
137
138    Ok(())
139}
140
141/// Handles the WS messages received either directly from the client or relayed by the WS Gateway.
142///
143/// The second argument is only needed to expose the type of the message on the canister Candid interface and get automatic types generation on the client side.
144/// This way, on the client you have the same types and you don't have to care about serializing and deserializing the messages sent through IC WebSocket.
145///
146/// # Example
147/// ```rust
148/// use ic_cdk::{update};
149/// use candid::{CandidType};
150/// use ic_websocket_cdk::{CanisterWsMessageArguments, CanisterWsMessageResult};
151/// use serde::Deserialize;
152///
153/// #[derive(CandidType, Deserialize)]
154/// struct MyMessage {
155///     some_field: String,
156/// }
157///
158/// // method called by the WS Gateway to send a message of type GatewayMessage to the canister
159/// #[update]
160/// fn ws_message(
161///     args: CanisterWsMessageArguments,
162///     msg_type: Option<MyMessage>,
163/// ) -> CanisterWsMessageResult {
164///     ic_websocket_cdk::ws_message(args, msg_type)
165/// }
166/// ```
167pub fn ws_message<T: CandidType + for<'a> Deserialize<'a>>(
168    args: CanisterWsMessageArguments,
169    _message_type: Option<T>,
170) -> CanisterWsMessageResult {
171    let client_principal = msg_caller();
172    let registered_client_key = get_client_key_from_principal(&client_principal)?;
173
174    let WebsocketMessage {
175        client_key,
176        sequence_num,
177        timestamp: _,
178        is_service_message,
179        content,
180    } = args.msg;
181
182    // check if the client key is correct
183    client_key
184        .eq(&registered_client_key)
185        .then_some(())
186        .ok_or_else(|| {
187            WsError::ClientKeyMessageMismatch {
188                client_key: &client_key,
189            }
190            .to_string()
191        })?;
192
193    let expected_sequence_num = get_expected_incoming_message_from_client_num(&client_key)?;
194
195    // check if the incoming message has the expected sequence number
196    sequence_num
197        .eq(&expected_sequence_num)
198        .then_some(())
199        .ok_or_else(|| {
200            remove_client(&client_key, Some(CloseMessageReason::WrongSequenceNumber));
201
202            WsError::IncomingSequenceNumberWrong {
203                expected_sequence_num,
204                actual_sequence_num: sequence_num,
205            }
206            .to_string()
207        })?;
208    // increase the expected sequence number by 1
209    increment_expected_incoming_message_from_client_num(&client_key)?;
210
211    if is_service_message {
212        return handle_received_service_message(&client_key, &content);
213    }
214
215    // call the on_message handler initialized in init()
216    get_handlers_from_params().call_on_message(OnMessageCallbackArgs {
217        client_principal,
218        message: content,
219    });
220    Ok(())
221}
222
223/// Returns messages to the WS Gateway in response of a polling iteration.
224pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMessagesResult {
225    let gateway_principal = msg_caller();
226    if !is_registered_gateway(&gateway_principal) {
227        return get_cert_messages_empty();
228    }
229
230    get_cert_messages(&gateway_principal, args.nonce)
231}
232
233/// Sends a message to the client. The message must already be serialized **using Candid**.
234/// Use [candid::encode_one] to serialize the message.
235///
236/// Under the hood, the message is certified and added to the queue of messages
237/// that the WS Gateway will poll in the next iteration.
238///
239/// # Example
240/// This example is the serialize equivalent of the [OnMessageCallbackArgs's example](struct.OnMessageCallbackArgs.html#example) deserialize one.
241/// ```rust
242/// use candid::{encode_one, CandidType, Principal};
243/// use ic_websocket_cdk::send;
244/// use serde::Deserialize;
245///
246/// #[derive(CandidType, Deserialize)]
247/// struct MyMessage {
248///     some_field: String,
249/// }
250///
251/// // obtained when the on_open callback was fired
252/// let my_client_principal = Principal::from_text("wnkwv-wdqb5-7wlzr-azfpw-5e5n5-dyxrf-uug7x-qxb55-mkmpa-5jqik-tqe").unwrap();
253///
254/// let my_message = MyMessage {
255///     some_field: "Hello, World!".to_string(),
256/// };
257///
258/// let msg_bytes = encode_one(&my_message).unwrap();
259/// send(my_client_principal, msg_bytes);
260/// ```
261pub fn send(client_principal: ClientPrincipal, msg_bytes: Vec<u8>) -> CanisterSendResult {
262    let client_key = get_client_key_from_principal(&client_principal)?;
263    _ws_send(&client_key, msg_bytes, false)
264}
265
266/// Closes the connection with the client.
267///
268/// This function **must not** be called in the `on_close` callback.
269pub fn close(client_principal: ClientPrincipal) -> CanisterCloseResult {
270    let client_key = get_client_key_from_principal(&client_principal)?;
271
272    remove_client(&client_key, Some(CloseMessageReason::ClosedByApplication));
273
274    Ok(())
275}
276
277/// Resets the internal state of the IC WebSocket CDK.
278///
279/// **Note:** You should only call this function in tests.
280pub fn wipe() {
281    reset_internal_state();
282
283    custom_print!("Internal state has been wiped!");
284}