use candid::{decode_one, encode_one, CandidType, Principal};
#[cfg(not(test))]
use ic_cdk::api::time;
use ic_cdk::api::{caller, data_certificate, set_certified_data};
use ic_cdk::trap;
use ic_cdk_timers::{clear_timer, set_timer, set_timer_interval, TimerId};
use ic_certified_map::{labeled, labeled_hash, AsHashTree, Hash as ICHash, RbTree};
use serde::{Deserialize, Serialize};
use serde_cbor::Serializer;
use sha2::{Digest, Sha256};
use std::fmt;
use std::panic;
use std::rc::Rc;
use std::time::Duration;
use std::{
cell::RefCell,
collections::VecDeque,
collections::{HashMap, HashSet},
convert::AsRef,
};
mod logger;
mod tests;
const LABEL_WEBSOCKET: &[u8] = b"websocket";
const DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES: usize = 10;
const DEFAULT_SEND_ACK_INTERVAL_MS: u64 = 60_000; const DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS: u64 = 10_000; const INITIAL_OUTGOING_MESSAGE_NONCE: u64 = 0;
const INITIAL_CLIENT_SEQUENCE_NUM: u64 = 1;
const INITIAL_CANISTER_SEQUENCE_NUM: u64 = 0;
pub type ClientPrincipal = Principal;
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug, Hash)]
pub(crate) struct ClientKey {
pub client_principal: ClientPrincipal,
pub client_nonce: u64,
}
impl ClientKey {
pub(crate) fn new(client_principal: ClientPrincipal, client_nonce: u64) -> Self {
Self {
client_principal,
client_nonce,
}
}
}
impl fmt::Display for ClientKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}_{}", self.client_principal, self.client_nonce)
}
}
pub type CanisterWsOpenResult = Result<(), String>;
pub type CanisterWsCloseResult = Result<(), String>;
pub type CanisterWsMessageResult = Result<(), String>;
pub type CanisterWsGetMessagesResult = Result<CanisterOutputCertifiedMessages, String>;
pub type CanisterWsSendResult = Result<(), String>;
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
pub struct CanisterWsOpenArguments {
pub(crate) client_nonce: u64,
}
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
pub struct CanisterWsCloseArguments {
pub(crate) client_key: ClientKey,
}
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
pub struct CanisterWsMessageArguments {
pub(crate) msg: WebsocketMessage,
}
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
pub struct CanisterWsGetMessagesArguments {
pub(crate) nonce: u64,
}
#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub(crate) struct WebsocketMessage {
pub client_key: ClientKey, pub sequence_num: u64, pub timestamp: u64, pub is_service_message: bool, #[serde(with = "serde_bytes")]
pub content: Vec<u8>, }
impl WebsocketMessage {
fn cbor_serialize(&self) -> Result<Vec<u8>, String> {
let mut data = vec![];
let mut serializer = Serializer::new(&mut data);
serializer.self_describe().map_err(|e| e.to_string())?;
self.serialize(&mut serializer).map_err(|e| e.to_string())?;
Ok(data)
}
}
#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct CanisterOutputMessage {
pub(crate) client_key: ClientKey, pub(crate) key: String, #[serde(with = "serde_bytes")]
pub(crate) content: Vec<u8>, }
#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct CanisterOutputCertifiedMessages {
pub(crate) messages: Vec<CanisterOutputMessage>, #[serde(with = "serde_bytes")]
pub(crate) cert: Vec<u8>, #[serde(with = "serde_bytes")]
pub(crate) tree: Vec<u8>, }
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
struct RegisteredGateway {
gateway_principal: Principal,
}
impl RegisteredGateway {
fn new(gateway_principal: Principal) -> Self {
Self { gateway_principal }
}
}
fn get_current_time() -> u64 {
#[cfg(test)]
{
0u64
}
#[cfg(not(test))]
{
time()
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct RegisteredClient {
last_keep_alive_timestamp: u64,
}
impl RegisteredClient {
fn new() -> Self {
Self {
last_keep_alive_timestamp: get_current_time(),
}
}
fn get_last_keep_alive_timestamp(&self) -> u64 {
self.last_keep_alive_timestamp
}
fn update_last_keep_alive_timestamp(&mut self) {
self.last_keep_alive_timestamp = get_current_time();
}
}
thread_local! {
static REGISTERED_CLIENTS: Rc<RefCell<HashMap<ClientKey, RegisteredClient>>> = Rc::new(RefCell::new(HashMap::new()));
static CURRENT_CLIENT_KEY_MAP: RefCell<HashMap<ClientPrincipal, ClientKey>> = RefCell::new(HashMap::new());
static CLIENTS_WAITING_FOR_KEEP_ALIVE: Rc<RefCell<HashSet<ClientKey>>> = Rc::new(RefCell::new(HashSet::new()));
static OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP: RefCell<HashMap<ClientKey, u64>> = RefCell::new(HashMap::new());
static INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP: RefCell<HashMap<ClientKey, u64>> = RefCell::new(HashMap::new());
static CERT_TREE: RefCell<RbTree<String, ICHash>> = RefCell::new(RbTree::new());
static REGISTERED_GATEWAY: RefCell<Option<RegisteredGateway>> = RefCell::new(None);
static MESSAGES_FOR_GATEWAY: RefCell<VecDeque<CanisterOutputMessage>> = RefCell::new(VecDeque::new());
static OUTGOING_MESSAGE_NONCE: RefCell<u64> = RefCell::new(INITIAL_OUTGOING_MESSAGE_NONCE);
static PARAMS: RefCell<WsInitParams> = RefCell::new(WsInitParams::default());
static ACK_TIMER: Rc<RefCell<Option<TimerId>>> = Rc::new(RefCell::new(None));
static KEEP_ALIVE_TIMER: Rc<RefCell<Option<TimerId>>> = Rc::new(RefCell::new(None));
}
fn reset_internal_state() {
let client_keys_to_remove: Vec<ClientKey> = REGISTERED_CLIENTS.with(|state| {
let map = state.borrow();
map.keys().cloned().collect()
});
for client_key in client_keys_to_remove {
remove_client(&client_key);
}
CURRENT_CLIENT_KEY_MAP.with(|map| {
map.borrow_mut().clear();
});
CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|set| {
set.borrow_mut().clear();
});
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().clear();
});
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().clear();
});
CERT_TREE.with(|t| {
t.replace(RbTree::new());
});
MESSAGES_FOR_GATEWAY.with(|m| *m.borrow_mut() = VecDeque::new());
OUTGOING_MESSAGE_NONCE.with(|next_id| next_id.replace(INITIAL_OUTGOING_MESSAGE_NONCE));
}
pub fn wipe() {
reset_internal_state();
custom_print!("Internal state has been wiped!");
}
fn get_outgoing_message_nonce() -> u64 {
OUTGOING_MESSAGE_NONCE.with(|n| n.borrow().clone())
}
fn increment_outgoing_message_nonce() {
OUTGOING_MESSAGE_NONCE.with(|n| n.replace_with(|&mut old| old + 1));
}
fn insert_client(client_key: ClientKey, new_client: RegisteredClient) {
CURRENT_CLIENT_KEY_MAP.with(|map| {
map.borrow_mut()
.insert(client_key.client_principal.clone(), client_key.clone());
});
REGISTERED_CLIENTS.with(|map| {
map.borrow_mut().insert(client_key, new_client);
});
}
fn is_client_registered(client_key: &ClientKey) -> bool {
REGISTERED_CLIENTS.with(|map| map.borrow().contains_key(client_key))
}
fn get_client_key_from_principal(client_principal: &ClientPrincipal) -> Result<ClientKey, String> {
CURRENT_CLIENT_KEY_MAP.with(|map| {
map.borrow()
.get(client_principal)
.cloned()
.ok_or(String::from(format!(
"client with principal {} doesn't have an open connection",
client_principal
)))
})
}
fn check_registered_client(client_key: &ClientKey) -> Result<(), String> {
if !is_client_registered(client_key) {
return Err(String::from(format!(
"client with key {} doesn't have an open connection",
client_key
)));
}
Ok(())
}
fn add_client_to_wait_for_keep_alive(client_key: &ClientKey) {
CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|clients| {
clients.borrow_mut().insert(client_key.clone());
});
}
fn initialize_registered_gateway(gateway_principal: &str) {
REGISTERED_GATEWAY.with(|p| {
let gateway_principal =
Principal::from_text(gateway_principal).expect("invalid gateway principal");
*p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal));
});
}
fn get_registered_gateway_principal() -> Principal {
REGISTERED_GATEWAY.with(|g| {
g.borrow()
.expect("gateway should be initialized")
.gateway_principal
})
}
fn init_outgoing_message_to_client_num(client_key: ClientKey) {
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut()
.insert(client_key, INITIAL_CANISTER_SEQUENCE_NUM);
});
}
fn get_outgoing_message_to_client_num(client_key: &ClientKey) -> Result<u64, String> {
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
let map = map.borrow();
let num = *map.get(client_key).ok_or(String::from(
"outgoing message to client num not initialized for client",
))?;
Ok(num)
})
}
fn increment_outgoing_message_to_client_num(client_key: &ClientKey) -> Result<(), String> {
let num = get_outgoing_message_to_client_num(client_key)?;
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
let mut map = map.borrow_mut();
map.insert(client_key.clone(), num + 1);
Ok(())
})
}
fn init_expected_incoming_message_from_client_num(client_key: ClientKey) {
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut()
.insert(client_key, INITIAL_CLIENT_SEQUENCE_NUM);
});
}
fn get_expected_incoming_message_from_client_num(client_key: &ClientKey) -> Result<u64, String> {
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
let num = *map.borrow().get(client_key).ok_or(String::from(
"expected incoming message num not initialized for client",
))?;
Ok(num)
})
}
fn increment_expected_incoming_message_from_client_num(
client_key: &ClientKey,
) -> Result<(), String> {
let num = get_expected_incoming_message_from_client_num(client_key)?;
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
let mut map = map.borrow_mut();
map.insert(client_key.clone(), num + 1);
Ok(())
})
}
fn add_client(client_key: ClientKey, new_client: RegisteredClient) {
insert_client(client_key.clone(), new_client);
init_expected_incoming_message_from_client_num(client_key.clone());
init_outgoing_message_to_client_num(client_key);
}
fn remove_client(client_key: &ClientKey) {
CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|set| {
set.borrow_mut().remove(client_key);
});
CURRENT_CLIENT_KEY_MAP.with(|map| {
map.borrow_mut().remove(&client_key.client_principal);
});
REGISTERED_CLIENTS.with(|map| {
map.borrow_mut().remove(client_key);
});
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().remove(client_key);
});
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().remove(client_key);
});
let handlers = get_handlers_from_params();
handlers.call_on_close(OnCloseCallbackArgs {
client_principal: client_key.client_principal,
});
}
fn get_message_for_gateway_key(gateway_principal: Principal, nonce: u64) -> String {
gateway_principal.to_string() + "_" + &format!("{:0>20}", nonce.to_string())
}
fn get_messages_for_gateway_range(gateway_principal: Principal, nonce: u64) -> (usize, usize) {
let max_number_of_returned_messages = get_params().max_number_of_returned_messages;
MESSAGES_FOR_GATEWAY.with(|m| {
let queue_len = m.borrow().len();
if nonce == 0 && queue_len > 0 {
let start_index = if queue_len > max_number_of_returned_messages {
queue_len - max_number_of_returned_messages
} else {
0
};
return (start_index, queue_len);
}
let smallest_key = get_message_for_gateway_key(gateway_principal, nonce);
let start_index = m.borrow().partition_point(|x| x.key < smallest_key);
let mut end_index = queue_len;
if end_index - start_index > max_number_of_returned_messages {
end_index = start_index + max_number_of_returned_messages;
}
(start_index, end_index)
})
}
fn get_messages_for_gateway(start_index: usize, end_index: usize) -> Vec<CanisterOutputMessage> {
MESSAGES_FOR_GATEWAY.with(|m| {
let mut messages: Vec<CanisterOutputMessage> = Vec::with_capacity(end_index - start_index);
for index in start_index..end_index {
messages.push(m.borrow().get(index).unwrap().clone());
}
messages
})
}
fn get_cert_messages(gateway_principal: Principal, nonce: u64) -> CanisterWsGetMessagesResult {
let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, nonce);
let messages = get_messages_for_gateway(start_index, end_index);
if messages.is_empty() {
return Ok(CanisterOutputCertifiedMessages {
messages,
cert: Vec::new(),
tree: Vec::new(),
});
}
let first_key = messages.first().unwrap().key.clone();
let last_key = messages.last().unwrap().key.clone();
let (cert, tree) = get_cert_for_range(&first_key, &last_key);
Ok(CanisterOutputCertifiedMessages {
messages,
cert,
tree,
})
}
fn is_registered_gateway(principal: Principal) -> bool {
let registered_gateway_principal = get_registered_gateway_principal();
return registered_gateway_principal == principal;
}
fn check_is_registered_gateway(input_principal: Principal) -> Result<(), String> {
let gateway_principal = get_registered_gateway_principal();
if gateway_principal != input_principal {
return Err(String::from(
"caller is not the gateway that has been registered during CDK initialization",
));
}
Ok(())
}
fn put_cert_for_message(key: String, value: &Vec<u8>) {
let root_hash = CERT_TREE.with(|tree| {
let mut tree = tree.borrow_mut();
tree.insert(key.clone(), Sha256::digest(value).into());
labeled_hash(LABEL_WEBSOCKET, &tree.root_hash())
});
set_certified_data(&root_hash);
}
fn get_cert_for_range(first: &String, last: &String) -> (Vec<u8>, Vec<u8>) {
CERT_TREE.with(|tree| {
let tree = tree.borrow();
let witness = tree.value_range(first.as_ref(), last.as_ref());
let tree = labeled(LABEL_WEBSOCKET, witness);
let mut data = vec![];
let mut serializer = Serializer::new(&mut data);
serializer.self_describe().unwrap();
tree.serialize(&mut serializer).unwrap();
(data_certificate().unwrap(), data)
})
}
fn put_ack_timer_id(timer_id: TimerId) {
ACK_TIMER.with(|timer| timer.borrow_mut().replace(timer_id));
}
fn reset_ack_timer() {
if let Some(t_id) = ACK_TIMER.with(Rc::clone).borrow_mut().take() {
clear_timer(t_id);
}
}
fn put_keep_alive_timer_id(timer_id: TimerId) {
KEEP_ALIVE_TIMER.with(|timer| timer.borrow_mut().replace(timer_id));
}
fn reset_keep_alive_timer() {
if let Some(t_id) = KEEP_ALIVE_TIMER.with(Rc::clone).borrow_mut().take() {
clear_timer(t_id);
}
}
fn reset_timers() {
reset_ack_timer();
reset_keep_alive_timer();
}
fn set_params(params: WsInitParams) {
PARAMS.with(|state| *state.borrow_mut() = params);
}
fn get_params() -> WsInitParams {
PARAMS.with(|state| state.borrow().clone())
}
fn get_handlers_from_params() -> WsHandlers {
get_params().get_handlers()
}
#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
pub(crate) struct CanisterOpenMessageContent {
pub client_key: ClientKey,
}
#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
pub(crate) struct CanisterAckMessageContent {
pub last_incoming_sequence_num: u64,
}
#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
pub(crate) struct ClientKeepAliveMessageContent {
pub last_incoming_sequence_num: u64,
}
#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
pub(crate) enum WebsocketServiceMessageContent {
OpenMessage(CanisterOpenMessageContent),
AckMessage(CanisterAckMessageContent),
KeepAliveMessage(ClientKeepAliveMessageContent),
}
impl WebsocketServiceMessageContent {
fn from_candid_bytes(bytes: &[u8]) -> Result<Self, String> {
decode_one(&bytes).map_err(|e| {
let mut err = String::from("Error decoding service message content: ");
err.push_str(&e.to_string());
err
})
}
}
fn send_service_message_to_client(
client_key: &ClientKey,
message: WebsocketServiceMessageContent,
) -> Result<(), String> {
let message_bytes = encode_one(&message).unwrap();
_ws_send(client_key, message_bytes, true)
}
fn schedule_send_ack_to_clients() {
let ack_interval_ms = get_params().send_ack_interval_ms;
let timer_id = set_timer_interval(Duration::from_millis(ack_interval_ms), move || {
send_ack_to_clients_timer_callback();
schedule_check_keep_alive();
});
put_ack_timer_id(timer_id);
}
fn schedule_check_keep_alive() {
let keep_alive_timeout_ms = get_params().keep_alive_timeout_ms;
let timer_id = set_timer(Duration::from_millis(keep_alive_timeout_ms), move || {
check_keep_alive_timer_callback(keep_alive_timeout_ms);
});
put_keep_alive_timer_id(timer_id);
}
fn send_ack_to_clients_timer_callback() {
for client_key in REGISTERED_CLIENTS.with(Rc::clone).borrow().keys() {
match get_expected_incoming_message_from_client_num(client_key) {
Ok(expected_incoming_sequence_num) => {
let ack_message = CanisterAckMessageContent {
last_incoming_sequence_num: expected_incoming_sequence_num - 1,
};
let message = WebsocketServiceMessageContent::AckMessage(ack_message);
if let Err(e) = send_service_message_to_client(client_key, message) {
custom_print!(
"[ack-to-clients-timer-cb]: Error sending ack message to client {}: {:?}",
client_key,
e
);
} else {
add_client_to_wait_for_keep_alive(client_key);
}
},
Err(e) => {
custom_print!(
"[ack-to-clients-timer-cb]: Error getting expected incoming sequence number for client {}: {:?}",
client_key,
e,
);
},
}
}
custom_print!("[ack-to-clients-timer-cb]: Sent ack messages to all clients");
}
fn check_keep_alive_timer_callback(keep_alive_timeout_ms: u64) {
let client_keys_to_remove: Vec<ClientKey> = CLIENTS_WAITING_FOR_KEEP_ALIVE
.with(Rc::clone)
.borrow()
.iter()
.filter_map(|client_key| {
if let Some(client_metadata) =
REGISTERED_CLIENTS.with(Rc::clone).borrow().get(client_key)
{
let last_keep_alive = client_metadata.get_last_keep_alive_timestamp();
if get_current_time() - last_keep_alive > (keep_alive_timeout_ms * 1_000_000) {
Some(client_key.to_owned())
} else {
None
}
} else {
None
}
})
.collect();
for client_key in client_keys_to_remove {
remove_client(&client_key);
custom_print!(
"[check-keep-alive-timer-cb]: Client {} has not sent a keep alive message in the last {} ms and has been removed",
client_key,
keep_alive_timeout_ms
);
}
custom_print!("[check-keep-alive-timer-cb]: Checked keep alive messages for all clients");
}
fn handle_keep_alive_client_message(
client_key: &ClientKey,
_keep_alive_message: ClientKeepAliveMessageContent,
) -> Result<(), String> {
if let Some(client_metadata) = REGISTERED_CLIENTS
.with(Rc::clone)
.borrow_mut()
.get_mut(client_key)
{
client_metadata.update_last_keep_alive_timestamp();
}
Ok(())
}
fn _ws_send(
client_key: &ClientKey,
msg_bytes: Vec<u8>,
is_service_message: bool,
) -> CanisterWsSendResult {
check_registered_client(client_key)?;
let gateway_principal = get_registered_gateway_principal();
let outgoing_message_nonce = get_outgoing_message_nonce();
let key = get_message_for_gateway_key(gateway_principal, outgoing_message_nonce);
increment_outgoing_message_nonce();
increment_outgoing_message_to_client_num(client_key)?;
let websocket_message = WebsocketMessage {
client_key: client_key.clone(),
sequence_num: get_outgoing_message_to_client_num(client_key)?,
timestamp: get_current_time(),
is_service_message,
content: msg_bytes,
};
let content = websocket_message.cbor_serialize()?;
put_cert_for_message(key.clone(), &content);
MESSAGES_FOR_GATEWAY.with(|m| {
m.borrow_mut().push_back(CanisterOutputMessage {
client_key: client_key.clone(),
content,
key,
});
});
Ok(())
}
fn handle_received_service_message(
client_key: &ClientKey,
content: &[u8],
) -> CanisterWsMessageResult {
let decoded = WebsocketServiceMessageContent::from_candid_bytes(content)?;
match decoded {
WebsocketServiceMessageContent::OpenMessage(_)
| WebsocketServiceMessageContent::AckMessage(_) => {
Err(String::from("Invalid received service message"))
},
WebsocketServiceMessageContent::KeepAliveMessage(keep_alive_message) => {
handle_keep_alive_client_message(client_key, keep_alive_message)
},
}
}
pub struct OnOpenCallbackArgs {
pub client_principal: ClientPrincipal,
}
type OnOpenCallback = fn(OnOpenCallbackArgs);
pub struct OnMessageCallbackArgs {
pub client_principal: ClientPrincipal,
pub message: Vec<u8>,
}
type OnMessageCallback = fn(OnMessageCallbackArgs);
pub struct OnCloseCallbackArgs {
pub client_principal: ClientPrincipal,
}
type OnCloseCallback = fn(OnCloseCallbackArgs);
#[derive(Clone, Default)]
pub struct WsHandlers {
pub on_open: Option<OnOpenCallback>,
pub on_message: Option<OnMessageCallback>,
pub on_close: Option<OnCloseCallback>,
}
impl WsHandlers {
fn call_on_open(&self, args: OnOpenCallbackArgs) {
if let Some(on_open) = self.on_open {
let res = panic::catch_unwind(|| {
on_open(args);
});
if let Err(e) = res {
custom_print!("Error calling on_open handler: {:?}", e);
}
}
}
fn call_on_message(&self, args: OnMessageCallbackArgs) {
if let Some(on_message) = self.on_message {
let res = panic::catch_unwind(|| {
on_message(args);
});
if let Err(e) = res {
custom_print!("Error calling on_message handler: {:?}", e);
}
}
}
fn call_on_close(&self, args: OnCloseCallbackArgs) {
if let Some(on_close) = self.on_close {
let res = panic::catch_unwind(|| {
on_close(args);
});
if let Err(e) = res {
custom_print!("Error calling on_close handler: {:?}", e);
}
}
}
}
#[derive(Clone)]
pub struct WsInitParams {
pub handlers: WsHandlers,
pub gateway_principal: String,
pub max_number_of_returned_messages: usize,
pub send_ack_interval_ms: u64,
pub keep_alive_timeout_ms: u64,
}
impl WsInitParams {
pub fn new(handlers: WsHandlers, gateway_principal: String) -> Self {
Self {
handlers,
gateway_principal,
..Default::default()
}
}
fn get_handlers(&self) -> WsHandlers {
self.handlers.clone()
}
fn check_validity(&self) {
if self.keep_alive_timeout_ms > self.send_ack_interval_ms {
trap("send_ack_interval_ms must be greater than keep_alive_timeout_ms");
}
}
}
impl Default for WsInitParams {
fn default() -> Self {
Self {
handlers: WsHandlers::default(),
gateway_principal: String::new(),
max_number_of_returned_messages: DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES,
send_ack_interval_ms: DEFAULT_SEND_ACK_INTERVAL_MS,
keep_alive_timeout_ms: DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS,
}
}
}
pub fn init(params: WsInitParams) {
params.check_validity();
set_params(params.clone());
initialize_registered_gateway(¶ms.gateway_principal);
reset_timers();
schedule_send_ack_to_clients();
}
pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult {
let client_principal = caller();
if client_principal == ClientPrincipal::anonymous() {
return Err(String::from("anonymous principal cannot open a connection"));
}
if is_registered_gateway(client_principal) {
return Err(String::from(
"caller is the registered gateway which can't open a connection for itself",
));
}
let client_key = ClientKey::new(client_principal, args.client_nonce);
if is_client_registered(&client_key) {
return Err(format!(
"client with key {} already has an open connection",
client_key,
));
}
let new_client = RegisteredClient::new();
add_client(client_key.clone(), new_client);
let open_message = CanisterOpenMessageContent {
client_key: client_key.clone(),
};
let message = WebsocketServiceMessageContent::OpenMessage(open_message);
send_service_message_to_client(&client_key, message)?;
get_handlers_from_params().call_on_open(OnOpenCallbackArgs { client_principal });
Ok(())
}
pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult {
check_is_registered_gateway(caller())?;
check_registered_client(&args.client_key)?;
remove_client(&args.client_key);
Ok(())
}
pub fn ws_message<T: CandidType + for<'a> Deserialize<'a>>(
args: CanisterWsMessageArguments,
_message_type: Option<T>,
) -> CanisterWsMessageResult {
let client_principal = caller();
let registered_client_key = get_client_key_from_principal(&client_principal)?;
let WebsocketMessage {
client_key,
sequence_num,
timestamp: _,
is_service_message,
content,
} = args.msg;
if registered_client_key != client_key {
return Err(String::from(format!(
"client with principal {} has a different key than the one used in the message",
client_principal
)));
}
let expected_sequence_num = get_expected_incoming_message_from_client_num(&client_key)?;
if sequence_num != expected_sequence_num {
remove_client(&client_key);
return Err(String::from(
format!(
"incoming client's message does not have the expected sequence number. Expected: {expected_sequence_num}, actual: {sequence_num}. Client removed.",
),
));
}
increment_expected_incoming_message_from_client_num(&client_key)?;
if is_service_message {
return handle_received_service_message(&client_key, &content);
}
get_handlers_from_params().call_on_message(OnMessageCallbackArgs {
client_principal,
message: content,
});
Ok(())
}
pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMessagesResult {
let gateway_principal = caller();
check_is_registered_gateway(gateway_principal)?;
get_cert_messages(gateway_principal, args.nonce)
}
pub fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec<u8>) -> CanisterWsSendResult {
let client_key = get_client_key_from_principal(&client_principal)?;
_ws_send(&client_key, msg_bytes, false)
}