1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
use candid::{CandidType, Principal};
use errors::WsError;
use ic_cdk::api::caller;
use serde::Deserialize;
mod errors;
mod state;
mod tests;
mod timers;
pub mod types;
mod utils;
use state::*;
use timers::*;
#[allow(deprecated)]
pub use types::CanisterWsSendResult;
use types::*;
pub use types::{
CanisterCloseResult, CanisterSendResult, CanisterWsCloseArguments, CanisterWsCloseResult,
CanisterWsGetMessagesArguments, CanisterWsGetMessagesResult, CanisterWsMessageArguments,
CanisterWsMessageResult, CanisterWsOpenArguments, CanisterWsOpenResult, ClientPrincipal,
OnCloseCallbackArgs, OnMessageCallbackArgs, OnOpenCallbackArgs, WsHandlers, WsInitParams,
};
/// The label used when constructing the certification tree.
const LABEL_WEBSOCKET: &[u8] = b"websocket";
/// The default maximum number of messages returned by [ws_get_messages] at each poll.
const DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES: usize = 50;
/// The default interval at which to send acknowledgements to the client.
const DEFAULT_SEND_ACK_INTERVAL_MS: u64 = 300_000; // 5 minutes
/// The maximum communication latency allowed between the client and the canister.
const COMMUNICATION_LATENCY_BOUND_MS: u64 = 30_000; // 30 seconds
/// The default timeout to wait for the client to send a keep alive after receiving an acknowledgement.
const CLIENT_KEEP_ALIVE_TIMEOUT_MS: u64 = 2 * COMMUNICATION_LATENCY_BOUND_MS;
/// Same as [CLIENT_KEEP_ALIVE_TIMEOUT_MS], but in nanoseconds.
const CLIENT_KEEP_ALIVE_TIMEOUT_NS: u64 = CLIENT_KEEP_ALIVE_TIMEOUT_MS * 1_000_000;
/// The initial nonce for outgoing messages.
const INITIAL_OUTGOING_MESSAGE_NONCE: u64 = 0;
/// The initial sequence number to expect from messages coming from clients.
/// The first message coming from the client will have sequence number `1` because on the client the sequence number is incremented before sending the message.
const INITIAL_CLIENT_SEQUENCE_NUM: u64 = 1;
/// The initial sequence number for outgoing messages.
const INITIAL_CANISTER_SEQUENCE_NUM: u64 = 0;
/// The number of messages to delete from the outgoing messages queue every time a new message is added.
const MESSAGES_TO_DELETE_COUNT: usize = 5;
/// Initialize the CDK.
///
/// **Note**: Restarts the acknowledgement timers under the hood.
///
/// # Traps
/// If the parameters are invalid.
pub fn init(params: WsInitParams) {
// check if the parameters are valid
params.check_validity();
// set the handlers specified by the canister that the CDK uses to manage the IC WebSocket connection
set_params(params.clone());
// cancel possibly running timers
cancel_timers();
// schedule a timer that will send an acknowledgement message to clients
schedule_send_ack_to_clients();
}
/// Handles the WS connection open event sent by the client and relayed by the Gateway.
pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult {
let caller = caller();
// anonymous clients cannot open a connection
caller
.ne(&Principal::anonymous())
.then_some(())
.ok_or_else(|| WsError::AnonymousPrincipalNotAllowed.to_string())?;
let client_key = ClientKey::new(caller, args.client_nonce);
// check if client is not registered yet
// by swapping the result of the check_registered_client_exists function
check_registered_client_exists(&client_key).map_or(Ok(()), |_| {
WsError::ClientKeyAlreadyConnected {
client_key: &client_key,
}
.to_string_result()
})?;
// check if there's a client already registered with the same principal
// and remove it if there is
match get_client_key_from_principal(&client_key.client_principal) {
Err(_) => {
// Do nothing
},
Ok(old_client_key) => {
remove_client(&old_client_key, None);
},
};
// initialize client maps
let new_client = RegisteredClient::new(args.gateway_principal);
add_client(client_key.clone(), new_client);
// send the open message
let open_message = CanisterOpenMessageContent {
client_key: client_key.clone(),
};
let message = WebsocketServiceMessageContent::OpenMessage(open_message);
send_service_message_to_client(&client_key, &message)?;
// call the on_open handler initialized in init()
get_handlers_from_params().call_on_open(OnOpenCallbackArgs {
client_principal: client_key.client_principal,
});
Ok(())
}
/// Handles the WS connection close event received from the WS Gateway.
///
/// If you want to close the connection with the client in your logic,
/// use the [close] function instead.
pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult {
let gateway_principal = caller();
// check if the gateway is registered
check_is_gateway_registered(&gateway_principal)?;
// check if client registered itself by calling ws_open
check_registered_client_exists(&args.client_key)?;
// check if the client is registered to the gateway that is closing the connection
check_client_registered_to_gateway(&args.client_key, &gateway_principal)?;
remove_client(&args.client_key, None);
Ok(())
}
/// Handles the WS messages received either directly from the client or relayed by the WS Gateway.
///
/// The second argument is only needed to expose the type of the message on the canister Candid interface and get automatic types generation on the client side.
/// This way, on the client you have the same types and you don't have to care about serializing and deserializing the messages sent through IC WebSocket.
///
/// # Example
/// ```rust
/// use ic_cdk_macros::*;
/// use candid::{CandidType};
/// use ic_websocket_cdk::{CanisterWsMessageArguments, CanisterWsMessageResult};
/// use serde::Deserialize;
///
/// #[derive(CandidType, Deserialize)]
/// struct MyMessage {
/// some_field: String,
/// }
///
/// // method called by the WS Gateway to send a message of type GatewayMessage to the canister
/// #[update]
/// fn ws_message(
/// args: CanisterWsMessageArguments,
/// msg_type: Option<MyMessage>,
/// ) -> CanisterWsMessageResult {
/// ic_websocket_cdk::ws_message(args, msg_type)
/// }
/// ```
pub fn ws_message<T: CandidType + for<'a> Deserialize<'a>>(
args: CanisterWsMessageArguments,
_message_type: Option<T>,
) -> CanisterWsMessageResult {
let client_principal = caller();
let registered_client_key = get_client_key_from_principal(&client_principal)?;
let WebsocketMessage {
client_key,
sequence_num,
timestamp: _,
is_service_message,
content,
} = args.msg;
// check if the client key is correct
client_key
.eq(®istered_client_key)
.then_some(())
.ok_or_else(|| {
WsError::ClientKeyMessageMismatch {
client_key: &client_key,
}
.to_string()
})?;
let expected_sequence_num = get_expected_incoming_message_from_client_num(&client_key)?;
// check if the incoming message has the expected sequence number
sequence_num
.eq(&expected_sequence_num)
.then_some(())
.ok_or_else(|| {
remove_client(&client_key, Some(CloseMessageReason::WrongSequenceNumber));
WsError::IncomingSequenceNumberWrong {
expected_sequence_num,
actual_sequence_num: sequence_num,
}
.to_string()
})?;
// increase the expected sequence number by 1
increment_expected_incoming_message_from_client_num(&client_key)?;
if is_service_message {
return handle_received_service_message(&client_key, &content);
}
// call the on_message handler initialized in init()
get_handlers_from_params().call_on_message(OnMessageCallbackArgs {
client_principal,
message: content,
});
Ok(())
}
/// Returns messages to the WS Gateway in response of a polling iteration.
pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMessagesResult {
let gateway_principal = caller();
if !is_registered_gateway(&gateway_principal) {
return get_cert_messages_empty();
}
get_cert_messages(&gateway_principal, args.nonce)
}
/// Sends a message to the client. The message must already be serialized **using Candid**.
/// Use [candid::encode_one] to serialize the message.
///
/// Under the hood, the message is certified and added to the queue of messages
/// that the WS Gateway will poll in the next iteration.
///
/// # Example
/// This example is the serialize equivalent of the [OnMessageCallbackArgs's example](struct.OnMessageCallbackArgs.html#example) deserialize one.
/// ```rust
/// use candid::{encode_one, CandidType, Principal};
/// use ic_websocket_cdk::send;
/// use serde::Deserialize;
///
/// #[derive(CandidType, Deserialize)]
/// struct MyMessage {
/// some_field: String,
/// }
///
/// // obtained when the on_open callback was fired
/// let my_client_principal = Principal::from_text("wnkwv-wdqb5-7wlzr-azfpw-5e5n5-dyxrf-uug7x-qxb55-mkmpa-5jqik-tqe").unwrap();
///
/// let my_message = MyMessage {
/// some_field: "Hello, World!".to_string(),
/// };
///
/// let msg_bytes = encode_one(&my_message).unwrap();
/// send(my_client_principal, msg_bytes);
/// ```
pub fn send(client_principal: ClientPrincipal, msg_bytes: Vec<u8>) -> CanisterSendResult {
let client_key = get_client_key_from_principal(&client_principal)?;
_ws_send(&client_key, msg_bytes, false)
}
#[deprecated(since = "0.3.2", note = "use `ic_websocket_cdk::send` instead")]
#[allow(deprecated)]
/// Deprecated: use [send] instead.
pub fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec<u8>) -> CanisterWsSendResult {
send(client_principal, msg_bytes)
}
/// Closes the connection with the client.
///
/// This function **must not** be called in the `on_close` callback.
pub fn close(client_principal: ClientPrincipal) -> CanisterCloseResult {
let client_key = get_client_key_from_principal(&client_principal)?;
remove_client(&client_key, Some(CloseMessageReason::ClosedByApplication));
Ok(())
}
/// Resets the internal state of the IC WebSocket CDK.
///
/// **Note:** You should only call this function in tests.
pub fn wipe() {
reset_internal_state();
custom_print!("Internal state has been wiped!");
}