ic_websocket_cdk/lib.rs
1use candid::{CandidType, Principal};
2use errors::WsError;
3
4use ic_cdk::api::msg_caller;
5use serde::Deserialize;
6
7mod errors;
8mod state;
9mod tests;
10mod timers;
11pub mod types;
12mod utils;
13
14use state::*;
15use timers::*;
16#[allow(deprecated)]
17pub use types::CanisterWsSendResult;
18use types::*;
19pub use types::{
20 CanisterCloseResult, CanisterSendResult, CanisterWsCloseArguments, CanisterWsCloseResult,
21 CanisterWsGetMessagesArguments, CanisterWsGetMessagesResult, CanisterWsMessageArguments,
22 CanisterWsMessageResult, CanisterWsOpenArguments, CanisterWsOpenResult, ClientPrincipal,
23 OnCloseCallbackArgs, OnMessageCallbackArgs, OnOpenCallbackArgs, WsHandlers, WsInitParams,
24};
25
26/// The label used when constructing the certification tree.
27const LABEL_WEBSOCKET: &[u8] = b"websocket";
28
29/// The default maximum number of messages returned by [ws_get_messages] at each poll.
30const DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES: usize = 50;
31/// The default interval at which to send acknowledgements to the client.
32const DEFAULT_SEND_ACK_INTERVAL_MS: u64 = 300_000; // 5 minutes
33/// The maximum communication latency allowed between the client and the canister.
34const COMMUNICATION_LATENCY_BOUND_MS: u64 = 30_000; // 30 seconds
35/// The default timeout to wait for the client to send a keep alive after receiving an acknowledgement.
36const CLIENT_KEEP_ALIVE_TIMEOUT_MS: u64 = 2 * COMMUNICATION_LATENCY_BOUND_MS;
37/// Same as [CLIENT_KEEP_ALIVE_TIMEOUT_MS], but in nanoseconds.
38const CLIENT_KEEP_ALIVE_TIMEOUT_NS: u64 = CLIENT_KEEP_ALIVE_TIMEOUT_MS * 1_000_000;
39
40/// The initial nonce for outgoing messages.
41const INITIAL_OUTGOING_MESSAGE_NONCE: u64 = 0;
42/// The initial sequence number to expect from messages coming from clients.
43/// The first message coming from the client will have sequence number `1` because on the client the sequence number is incremented before sending the message.
44const INITIAL_CLIENT_SEQUENCE_NUM: u64 = 1;
45/// The initial sequence number for outgoing messages.
46const INITIAL_CANISTER_SEQUENCE_NUM: u64 = 0;
47
48/// The number of messages to delete from the outgoing messages queue every time a new message is added.
49const MESSAGES_TO_DELETE_COUNT: usize = 5;
50
51/// Initialize the CDK.
52///
53/// **Note**: Restarts the acknowledgement timers under the hood.
54///
55/// # Traps
56/// If the parameters are invalid.
57pub fn init(params: WsInitParams) {
58 // check if the parameters are valid
59 params.check_validity();
60
61 // set the handlers specified by the canister that the CDK uses to manage the IC WebSocket connection
62 set_params(params.clone());
63
64 // cancel possibly running timers
65 cancel_timers();
66
67 // schedule a timer that will send an acknowledgement message to clients
68 schedule_send_ack_to_clients();
69}
70
71/// Handles the WS connection open event sent by the client and relayed by the Gateway.
72pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult {
73 let caller = msg_caller();
74 // anonymous clients cannot open a connection
75 caller
76 .ne(&Principal::anonymous())
77 .then_some(())
78 .ok_or_else(|| WsError::AnonymousPrincipalNotAllowed.to_string())?;
79
80 let client_key = ClientKey::new(caller, args.client_nonce);
81 // check if client is not registered yet
82 // by swapping the result of the check_registered_client_exists function
83 check_registered_client_exists(&client_key).map_or(Ok(()), |_| {
84 WsError::ClientKeyAlreadyConnected {
85 client_key: &client_key,
86 }
87 .to_string_result()
88 })?;
89
90 // check if there's a client already registered with the same principal
91 // and remove it if there is
92 match get_client_key_from_principal(&client_key.client_principal) {
93 Err(_) => {
94 // Do nothing
95 },
96 Ok(old_client_key) => {
97 remove_client(&old_client_key, None);
98 },
99 };
100
101 // initialize client maps
102 let new_client = RegisteredClient::new(args.gateway_principal);
103 add_client(client_key.clone(), new_client);
104
105 // send the open message
106 let open_message = CanisterOpenMessageContent {
107 client_key: client_key.clone(),
108 };
109 let message = WebsocketServiceMessageContent::OpenMessage(open_message);
110 send_service_message_to_client(&client_key, &message)?;
111
112 // call the on_open handler initialized in init()
113 get_handlers_from_params().call_on_open(OnOpenCallbackArgs {
114 client_principal: client_key.client_principal,
115 });
116
117 Ok(())
118}
119
120/// Handles the WS connection close event received from the WS Gateway.
121///
122/// If you want to close the connection with the client in your logic,
123/// use the [close] function instead.
124pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult {
125 let gateway_principal = msg_caller();
126
127 // check if the gateway is registered
128 check_is_gateway_registered(&gateway_principal)?;
129
130 // check if client registered itself by calling ws_open
131 check_registered_client_exists(&args.client_key)?;
132
133 // check if the client is registered to the gateway that is closing the connection
134 check_client_registered_to_gateway(&args.client_key, &gateway_principal)?;
135
136 remove_client(&args.client_key, None);
137
138 Ok(())
139}
140
141/// Handles the WS messages received either directly from the client or relayed by the WS Gateway.
142///
143/// The second argument is only needed to expose the type of the message on the canister Candid interface and get automatic types generation on the client side.
144/// This way, on the client you have the same types and you don't have to care about serializing and deserializing the messages sent through IC WebSocket.
145///
146/// # Example
147/// ```rust
148/// use ic_cdk::{update};
149/// use candid::{CandidType};
150/// use ic_websocket_cdk::{CanisterWsMessageArguments, CanisterWsMessageResult};
151/// use serde::Deserialize;
152///
153/// #[derive(CandidType, Deserialize)]
154/// struct MyMessage {
155/// some_field: String,
156/// }
157///
158/// // method called by the WS Gateway to send a message of type GatewayMessage to the canister
159/// #[update]
160/// fn ws_message(
161/// args: CanisterWsMessageArguments,
162/// msg_type: Option<MyMessage>,
163/// ) -> CanisterWsMessageResult {
164/// ic_websocket_cdk::ws_message(args, msg_type)
165/// }
166/// ```
167pub fn ws_message<T: CandidType + for<'a> Deserialize<'a>>(
168 args: CanisterWsMessageArguments,
169 _message_type: Option<T>,
170) -> CanisterWsMessageResult {
171 let client_principal = msg_caller();
172 let registered_client_key = get_client_key_from_principal(&client_principal)?;
173
174 let WebsocketMessage {
175 client_key,
176 sequence_num,
177 timestamp: _,
178 is_service_message,
179 content,
180 } = args.msg;
181
182 // check if the client key is correct
183 client_key
184 .eq(®istered_client_key)
185 .then_some(())
186 .ok_or_else(|| {
187 WsError::ClientKeyMessageMismatch {
188 client_key: &client_key,
189 }
190 .to_string()
191 })?;
192
193 let expected_sequence_num = get_expected_incoming_message_from_client_num(&client_key)?;
194
195 // check if the incoming message has the expected sequence number
196 sequence_num
197 .eq(&expected_sequence_num)
198 .then_some(())
199 .ok_or_else(|| {
200 remove_client(&client_key, Some(CloseMessageReason::WrongSequenceNumber));
201
202 WsError::IncomingSequenceNumberWrong {
203 expected_sequence_num,
204 actual_sequence_num: sequence_num,
205 }
206 .to_string()
207 })?;
208 // increase the expected sequence number by 1
209 increment_expected_incoming_message_from_client_num(&client_key)?;
210
211 if is_service_message {
212 return handle_received_service_message(&client_key, &content);
213 }
214
215 // call the on_message handler initialized in init()
216 get_handlers_from_params().call_on_message(OnMessageCallbackArgs {
217 client_principal,
218 message: content,
219 });
220 Ok(())
221}
222
223/// Returns messages to the WS Gateway in response of a polling iteration.
224pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMessagesResult {
225 let gateway_principal = msg_caller();
226 if !is_registered_gateway(&gateway_principal) {
227 return get_cert_messages_empty();
228 }
229
230 get_cert_messages(&gateway_principal, args.nonce)
231}
232
233/// Sends a message to the client. The message must already be serialized **using Candid**.
234/// Use [candid::encode_one] to serialize the message.
235///
236/// Under the hood, the message is certified and added to the queue of messages
237/// that the WS Gateway will poll in the next iteration.
238///
239/// # Example
240/// This example is the serialize equivalent of the [OnMessageCallbackArgs's example](struct.OnMessageCallbackArgs.html#example) deserialize one.
241/// ```rust
242/// use candid::{encode_one, CandidType, Principal};
243/// use ic_websocket_cdk::send;
244/// use serde::Deserialize;
245///
246/// #[derive(CandidType, Deserialize)]
247/// struct MyMessage {
248/// some_field: String,
249/// }
250///
251/// // obtained when the on_open callback was fired
252/// let my_client_principal = Principal::from_text("wnkwv-wdqb5-7wlzr-azfpw-5e5n5-dyxrf-uug7x-qxb55-mkmpa-5jqik-tqe").unwrap();
253///
254/// let my_message = MyMessage {
255/// some_field: "Hello, World!".to_string(),
256/// };
257///
258/// let msg_bytes = encode_one(&my_message).unwrap();
259/// send(my_client_principal, msg_bytes);
260/// ```
261pub fn send(client_principal: ClientPrincipal, msg_bytes: Vec<u8>) -> CanisterSendResult {
262 let client_key = get_client_key_from_principal(&client_principal)?;
263 _ws_send(&client_key, msg_bytes, false)
264}
265
266/// Closes the connection with the client.
267///
268/// This function **must not** be called in the `on_close` callback.
269pub fn close(client_principal: ClientPrincipal) -> CanisterCloseResult {
270 let client_key = get_client_key_from_principal(&client_principal)?;
271
272 remove_client(&client_key, Some(CloseMessageReason::ClosedByApplication));
273
274 Ok(())
275}
276
277/// Resets the internal state of the IC WebSocket CDK.
278///
279/// **Note:** You should only call this function in tests.
280pub fn wipe() {
281 reset_internal_state();
282
283 custom_print!("Internal state has been wiped!");
284}