use candid::{decode_one, encode_one, CandidType, Principal};
#[cfg(not(test))]
use ic_cdk::api::time;
use ic_cdk::api::{caller, data_certificate, set_certified_data};
use ic_cdk::trap;
use ic_cdk_timers::{clear_timer, set_timer, set_timer_interval, TimerId};
use ic_certified_map::{labeled, labeled_hash, AsHashTree, Hash as ICHash, RbTree};
use serde::{Deserialize, Serialize};
use serde_cbor::Serializer;
use sha2::{Digest, Sha256};
use std::fmt;
use std::panic;
use std::rc::Rc;
use std::time::Duration;
use std::{
cell::RefCell,
collections::VecDeque,
collections::{HashMap, HashSet},
convert::AsRef,
};
mod logger;
const LABEL_WEBSOCKET: &[u8] = b"websocket";
const DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES: usize = 10;
const DEFAULT_SEND_ACK_INTERVAL_MS: u64 = 60_000; const DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS: u64 = 10_000; const INITIAL_OUTGOING_MESSAGE_NONCE: u64 = 0;
const INITIAL_CLIENT_SEQUENCE_NUM: u64 = 1;
const INITIAL_CANISTER_SEQUENCE_NUM: u64 = 0;
pub type ClientPrincipal = Principal;
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug, Hash)]
struct ClientKey {
client_principal: ClientPrincipal,
client_nonce: u64,
}
impl ClientKey {
fn new(client_principal: ClientPrincipal, client_nonce: u64) -> Self {
Self {
client_principal,
client_nonce,
}
}
}
impl fmt::Display for ClientKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}_{}", self.client_principal, self.client_nonce)
}
}
pub type CanisterWsOpenResult = Result<(), String>;
pub type CanisterWsCloseResult = Result<(), String>;
pub type CanisterWsMessageResult = Result<(), String>;
pub type CanisterWsGetMessagesResult = Result<CanisterOutputCertifiedMessages, String>;
pub type CanisterWsSendResult = Result<(), String>;
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
pub struct CanisterWsOpenArguments {
client_nonce: u64,
}
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
pub struct CanisterWsCloseArguments {
client_key: ClientKey,
}
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
pub struct CanisterWsMessageArguments {
msg: WebsocketMessage,
}
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
pub struct CanisterWsGetMessagesArguments {
nonce: u64,
}
#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
struct WebsocketMessage {
client_key: ClientKey, sequence_num: u64, timestamp: u64, is_service_message: bool, #[serde(with = "serde_bytes")]
content: Vec<u8>, }
impl WebsocketMessage {
fn cbor_serialize(&self) -> Result<Vec<u8>, String> {
let mut data = vec![];
let mut serializer = Serializer::new(&mut data);
serializer.self_describe().map_err(|e| e.to_string())?;
self.serialize(&mut serializer).map_err(|e| e.to_string())?;
Ok(data)
}
}
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)]
pub struct CanisterOutputMessage {
client_key: ClientKey, key: String, #[serde(with = "serde_bytes")]
content: Vec<u8>, }
#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq)]
pub struct CanisterOutputCertifiedMessages {
messages: Vec<CanisterOutputMessage>, #[serde(with = "serde_bytes")]
cert: Vec<u8>, #[serde(with = "serde_bytes")]
tree: Vec<u8>, }
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
struct RegisteredGateway {
gateway_principal: Principal,
}
impl RegisteredGateway {
fn new(gateway_principal: Principal) -> Self {
Self { gateway_principal }
}
}
fn get_current_time() -> u64 {
#[cfg(test)]
{
0u64
}
#[cfg(not(test))]
{
time()
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct RegisteredClient {
last_keep_alive_timestamp: u64,
}
impl RegisteredClient {
fn new() -> Self {
Self {
last_keep_alive_timestamp: get_current_time(),
}
}
fn get_last_keep_alive_timestamp(&self) -> u64 {
self.last_keep_alive_timestamp
}
fn update_last_keep_alive_timestamp(&mut self) {
self.last_keep_alive_timestamp = get_current_time();
}
}
thread_local! {
static REGISTERED_CLIENTS: Rc<RefCell<HashMap<ClientKey, RegisteredClient>>> = Rc::new(RefCell::new(HashMap::new()));
static CURRENT_CLIENT_KEY_MAP: RefCell<HashMap<ClientPrincipal, ClientKey>> = RefCell::new(HashMap::new());
static CLIENTS_WAITING_FOR_KEEP_ALIVE: Rc<RefCell<HashSet<ClientKey>>> = Rc::new(RefCell::new(HashSet::new()));
static OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP: RefCell<HashMap<ClientKey, u64>> = RefCell::new(HashMap::new());
static INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP: RefCell<HashMap<ClientKey, u64>> = RefCell::new(HashMap::new());
static CERT_TREE: RefCell<RbTree<String, ICHash>> = RefCell::new(RbTree::new());
static REGISTERED_GATEWAY: RefCell<Option<RegisteredGateway>> = RefCell::new(None);
static MESSAGES_FOR_GATEWAY: RefCell<VecDeque<CanisterOutputMessage>> = RefCell::new(VecDeque::new());
static OUTGOING_MESSAGE_NONCE: RefCell<u64> = RefCell::new(INITIAL_OUTGOING_MESSAGE_NONCE);
static PARAMS: RefCell<WsInitParams> = RefCell::new(WsInitParams::default());
static ACK_TIMER: Rc<RefCell<Option<TimerId>>> = Rc::new(RefCell::new(None));
static KEEP_ALIVE_TIMER: Rc<RefCell<Option<TimerId>>> = Rc::new(RefCell::new(None));
}
fn reset_internal_state() {
let client_keys_to_remove: Vec<ClientKey> = REGISTERED_CLIENTS.with(|state| {
let map = state.borrow();
map.keys().cloned().collect()
});
for client_key in client_keys_to_remove {
remove_client(&client_key);
}
CURRENT_CLIENT_KEY_MAP.with(|map| {
map.borrow_mut().clear();
});
CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|set| {
set.borrow_mut().clear();
});
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().clear();
});
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().clear();
});
CERT_TREE.with(|t| {
t.replace(RbTree::new());
});
MESSAGES_FOR_GATEWAY.with(|m| *m.borrow_mut() = VecDeque::new());
OUTGOING_MESSAGE_NONCE.with(|next_id| next_id.replace(INITIAL_OUTGOING_MESSAGE_NONCE));
}
pub fn wipe() {
reset_internal_state();
custom_print!("Internal state has been wiped!");
}
fn get_outgoing_message_nonce() -> u64 {
OUTGOING_MESSAGE_NONCE.with(|n| n.borrow().clone())
}
fn increment_outgoing_message_nonce() {
OUTGOING_MESSAGE_NONCE.with(|n| n.replace_with(|&mut old| old + 1));
}
fn insert_client(client_key: ClientKey, new_client: RegisteredClient) {
CURRENT_CLIENT_KEY_MAP.with(|map| {
map.borrow_mut()
.insert(client_key.client_principal.clone(), client_key.clone());
});
REGISTERED_CLIENTS.with(|map| {
map.borrow_mut().insert(client_key, new_client);
});
}
fn is_client_registered(client_key: &ClientKey) -> bool {
REGISTERED_CLIENTS.with(|map| map.borrow().contains_key(client_key))
}
fn get_client_key_from_principal(client_principal: &ClientPrincipal) -> Result<ClientKey, String> {
CURRENT_CLIENT_KEY_MAP.with(|map| {
map.borrow()
.get(client_principal)
.cloned()
.ok_or(String::from(format!(
"client with principal {} doesn't have an open connection",
client_principal
)))
})
}
fn check_registered_client(client_key: &ClientKey) -> Result<(), String> {
if !is_client_registered(client_key) {
return Err(String::from(format!(
"client with key {} doesn't have an open connection",
client_key
)));
}
Ok(())
}
fn add_client_to_wait_for_keep_alive(client_key: &ClientKey) {
CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|clients| {
clients.borrow_mut().insert(client_key.clone());
});
}
fn initialize_registered_gateway(gateway_principal: &str) {
REGISTERED_GATEWAY.with(|p| {
let gateway_principal =
Principal::from_text(gateway_principal).expect("invalid gateway principal");
*p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal));
});
}
fn get_registered_gateway_principal() -> Principal {
REGISTERED_GATEWAY.with(|g| {
g.borrow()
.expect("gateway should be initialized")
.gateway_principal
})
}
fn init_outgoing_message_to_client_num(client_key: ClientKey) {
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut()
.insert(client_key, INITIAL_CANISTER_SEQUENCE_NUM);
});
}
fn get_outgoing_message_to_client_num(client_key: &ClientKey) -> Result<u64, String> {
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
let map = map.borrow();
let num = *map.get(client_key).ok_or(String::from(
"outgoing message to client num not initialized for client",
))?;
Ok(num)
})
}
fn increment_outgoing_message_to_client_num(client_key: &ClientKey) -> Result<(), String> {
let num = get_outgoing_message_to_client_num(client_key)?;
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
let mut map = map.borrow_mut();
map.insert(client_key.clone(), num + 1);
Ok(())
})
}
fn init_expected_incoming_message_from_client_num(client_key: ClientKey) {
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut()
.insert(client_key, INITIAL_CLIENT_SEQUENCE_NUM);
});
}
fn get_expected_incoming_message_from_client_num(client_key: &ClientKey) -> Result<u64, String> {
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
let num = *map.borrow().get(client_key).ok_or(String::from(
"expected incoming message num not initialized for client",
))?;
Ok(num)
})
}
fn increment_expected_incoming_message_from_client_num(
client_key: &ClientKey,
) -> Result<(), String> {
let num = get_expected_incoming_message_from_client_num(client_key)?;
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
let mut map = map.borrow_mut();
map.insert(client_key.clone(), num + 1);
Ok(())
})
}
fn add_client(client_key: ClientKey, new_client: RegisteredClient) {
insert_client(client_key.clone(), new_client);
init_expected_incoming_message_from_client_num(client_key.clone());
init_outgoing_message_to_client_num(client_key);
}
fn remove_client(client_key: &ClientKey) {
CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|set| {
set.borrow_mut().remove(client_key);
});
CURRENT_CLIENT_KEY_MAP.with(|map| {
map.borrow_mut().remove(&client_key.client_principal);
});
REGISTERED_CLIENTS.with(|map| {
map.borrow_mut().remove(client_key);
});
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().remove(client_key);
});
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().remove(client_key);
});
let handlers = get_handlers_from_params();
handlers.call_on_close(OnCloseCallbackArgs {
client_principal: client_key.client_principal,
});
}
fn get_message_for_gateway_key(gateway_principal: Principal, nonce: u64) -> String {
gateway_principal.to_string() + "_" + &format!("{:0>20}", nonce.to_string())
}
fn get_messages_for_gateway_range(gateway_principal: Principal, nonce: u64) -> (usize, usize) {
let max_number_of_returned_messages = get_params().max_number_of_returned_messages;
MESSAGES_FOR_GATEWAY.with(|m| {
let queue_len = m.borrow().len();
if nonce == 0 && queue_len > 0 {
let start_index = if queue_len > max_number_of_returned_messages {
queue_len - max_number_of_returned_messages
} else {
0
};
return (start_index, queue_len);
}
let smallest_key = get_message_for_gateway_key(gateway_principal, nonce);
let start_index = m.borrow().partition_point(|x| x.key < smallest_key);
let mut end_index = queue_len;
if end_index - start_index > max_number_of_returned_messages {
end_index = start_index + max_number_of_returned_messages;
}
(start_index, end_index)
})
}
fn get_messages_for_gateway(start_index: usize, end_index: usize) -> Vec<CanisterOutputMessage> {
MESSAGES_FOR_GATEWAY.with(|m| {
let mut messages: Vec<CanisterOutputMessage> = Vec::with_capacity(end_index - start_index);
for index in start_index..end_index {
messages.push(m.borrow().get(index).unwrap().clone());
}
messages
})
}
fn get_cert_messages(gateway_principal: Principal, nonce: u64) -> CanisterWsGetMessagesResult {
let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, nonce);
let messages = get_messages_for_gateway(start_index, end_index);
if messages.is_empty() {
return Ok(CanisterOutputCertifiedMessages {
messages,
cert: Vec::new(),
tree: Vec::new(),
});
}
let first_key = messages.first().unwrap().key.clone();
let last_key = messages.last().unwrap().key.clone();
let (cert, tree) = get_cert_for_range(&first_key, &last_key);
Ok(CanisterOutputCertifiedMessages {
messages,
cert,
tree,
})
}
fn is_registered_gateway(principal: Principal) -> bool {
let registered_gateway_principal = get_registered_gateway_principal();
return registered_gateway_principal == principal;
}
fn check_is_registered_gateway(input_principal: Principal) -> Result<(), String> {
let gateway_principal = get_registered_gateway_principal();
if gateway_principal != input_principal {
return Err(String::from(
"caller is not the gateway that has been registered during CDK initialization",
));
}
Ok(())
}
fn put_cert_for_message(key: String, value: &Vec<u8>) {
let root_hash = CERT_TREE.with(|tree| {
let mut tree = tree.borrow_mut();
tree.insert(key.clone(), Sha256::digest(value).into());
labeled_hash(LABEL_WEBSOCKET, &tree.root_hash())
});
set_certified_data(&root_hash);
}
fn get_cert_for_range(first: &String, last: &String) -> (Vec<u8>, Vec<u8>) {
CERT_TREE.with(|tree| {
let tree = tree.borrow();
let witness = tree.value_range(first.as_ref(), last.as_ref());
let tree = labeled(LABEL_WEBSOCKET, witness);
let mut data = vec![];
let mut serializer = Serializer::new(&mut data);
serializer.self_describe().unwrap();
tree.serialize(&mut serializer).unwrap();
(data_certificate().unwrap(), data)
})
}
fn put_ack_timer_id(timer_id: TimerId) {
ACK_TIMER.with(|timer| timer.borrow_mut().replace(timer_id));
}
fn reset_ack_timer() {
if let Some(t_id) = ACK_TIMER.with(Rc::clone).borrow_mut().take() {
clear_timer(t_id);
}
}
fn put_keep_alive_timer_id(timer_id: TimerId) {
KEEP_ALIVE_TIMER.with(|timer| timer.borrow_mut().replace(timer_id));
}
fn reset_keep_alive_timer() {
if let Some(t_id) = KEEP_ALIVE_TIMER.with(Rc::clone).borrow_mut().take() {
clear_timer(t_id);
}
}
fn reset_timers() {
reset_ack_timer();
reset_keep_alive_timer();
}
fn set_params(params: WsInitParams) {
PARAMS.with(|state| *state.borrow_mut() = params);
}
fn get_params() -> WsInitParams {
PARAMS.with(|state| state.borrow().clone())
}
fn get_handlers_from_params() -> WsHandlers {
get_params().get_handlers()
}
#[derive(CandidType, Debug, Deserialize)]
struct CanisterOpenMessageContent {
client_key: ClientKey,
}
#[derive(CandidType, Debug, Deserialize)]
struct CanisterAckMessageContent {
last_incoming_sequence_num: u64,
}
#[derive(CandidType, Debug, Deserialize)]
struct ClientKeepAliveMessageContent {
last_incoming_sequence_num: u64,
}
#[derive(CandidType, Debug, Deserialize)]
enum WebsocketServiceMessageContent {
OpenMessage(CanisterOpenMessageContent),
AckMessage(CanisterAckMessageContent),
KeepAliveMessage(ClientKeepAliveMessageContent),
}
impl WebsocketServiceMessageContent {
fn from_candid_bytes(bytes: &[u8]) -> Result<Self, String> {
decode_one(&bytes).map_err(|e| {
let mut err = String::from("Error decoding service message content: ");
err.push_str(&e.to_string());
err
})
}
}
fn send_service_message_to_client(
client_key: &ClientKey,
message: WebsocketServiceMessageContent,
) -> Result<(), String> {
let message_bytes = encode_one(&message).unwrap();
_ws_send(client_key, message_bytes, true)
}
fn schedule_send_ack_to_clients() {
let ack_interval_ms = get_params().send_ack_interval_ms;
let timer_id = set_timer_interval(Duration::from_millis(ack_interval_ms), move || {
send_ack_to_clients_timer_callback();
schedule_check_keep_alive();
});
put_ack_timer_id(timer_id);
}
fn schedule_check_keep_alive() {
let keep_alive_timeout_ms = get_params().keep_alive_timeout_ms;
let timer_id = set_timer(Duration::from_millis(keep_alive_timeout_ms), move || {
check_keep_alive_timer_callback(keep_alive_timeout_ms);
});
put_keep_alive_timer_id(timer_id);
}
fn send_ack_to_clients_timer_callback() {
for client_key in REGISTERED_CLIENTS.with(Rc::clone).borrow().keys() {
match get_expected_incoming_message_from_client_num(client_key) {
Ok(expected_incoming_sequence_num) => {
let ack_message = CanisterAckMessageContent {
last_incoming_sequence_num: expected_incoming_sequence_num - 1,
};
let message = WebsocketServiceMessageContent::AckMessage(ack_message);
if let Err(e) = send_service_message_to_client(client_key, message) {
custom_print!(
"[ack-to-clients-timer-cb]: Error sending ack message to client {}: {:?}",
client_key,
e
);
} else {
add_client_to_wait_for_keep_alive(client_key);
}
},
Err(e) => {
custom_print!(
"[ack-to-clients-timer-cb]: Error getting expected incoming sequence number for client {}: {:?}",
client_key,
e,
);
},
}
}
custom_print!("[ack-to-clients-timer-cb]: Sent ack messages to all clients");
}
fn check_keep_alive_timer_callback(keep_alive_timeout_ms: u64) {
let client_keys_to_remove: Vec<ClientKey> = CLIENTS_WAITING_FOR_KEEP_ALIVE
.with(Rc::clone)
.borrow()
.iter()
.filter_map(|client_key| {
if let Some(client_metadata) =
REGISTERED_CLIENTS.with(Rc::clone).borrow().get(client_key)
{
let last_keep_alive = client_metadata.get_last_keep_alive_timestamp();
if get_current_time() - last_keep_alive > (keep_alive_timeout_ms * 1_000_000) {
Some(client_key.to_owned())
} else {
None
}
} else {
None
}
})
.collect();
for client_key in client_keys_to_remove {
remove_client(&client_key);
custom_print!(
"[check-keep-alive-timer-cb]: Client {} has not sent a keep alive message in the last {} ms and has been removed",
client_key,
keep_alive_timeout_ms
);
}
custom_print!("[check-keep-alive-timer-cb]: Checked keep alive messages for all clients");
}
fn handle_keep_alive_client_message(
client_key: &ClientKey,
_keep_alive_message: ClientKeepAliveMessageContent,
) -> Result<(), String> {
if let Some(client_metadata) = REGISTERED_CLIENTS
.with(Rc::clone)
.borrow_mut()
.get_mut(client_key)
{
client_metadata.update_last_keep_alive_timestamp();
}
Ok(())
}
fn _ws_send(
client_key: &ClientKey,
msg_bytes: Vec<u8>,
is_service_message: bool,
) -> CanisterWsSendResult {
check_registered_client(client_key)?;
let gateway_principal = get_registered_gateway_principal();
let outgoing_message_nonce = get_outgoing_message_nonce();
let key = get_message_for_gateway_key(gateway_principal, outgoing_message_nonce);
increment_outgoing_message_nonce();
increment_outgoing_message_to_client_num(client_key)?;
let websocket_message = WebsocketMessage {
client_key: client_key.clone(),
sequence_num: get_outgoing_message_to_client_num(client_key)?,
timestamp: get_current_time(),
is_service_message,
content: msg_bytes,
};
let content = websocket_message.cbor_serialize()?;
put_cert_for_message(key.clone(), &content);
MESSAGES_FOR_GATEWAY.with(|m| {
m.borrow_mut().push_back(CanisterOutputMessage {
client_key: client_key.clone(),
content,
key,
});
});
Ok(())
}
fn handle_received_service_message(
client_key: &ClientKey,
content: &[u8],
) -> CanisterWsMessageResult {
let decoded = WebsocketServiceMessageContent::from_candid_bytes(content)?;
match decoded {
WebsocketServiceMessageContent::OpenMessage(_)
| WebsocketServiceMessageContent::AckMessage(_) => {
Err(String::from("Invalid received service message"))
},
WebsocketServiceMessageContent::KeepAliveMessage(keep_alive_message) => {
handle_keep_alive_client_message(client_key, keep_alive_message)
},
}
}
pub struct OnOpenCallbackArgs {
pub client_principal: ClientPrincipal,
}
type OnOpenCallback = fn(OnOpenCallbackArgs);
pub struct OnMessageCallbackArgs {
pub client_principal: ClientPrincipal,
pub message: Vec<u8>,
}
type OnMessageCallback = fn(OnMessageCallbackArgs);
pub struct OnCloseCallbackArgs {
pub client_principal: ClientPrincipal,
}
type OnCloseCallback = fn(OnCloseCallbackArgs);
#[derive(Clone, Default)]
pub struct WsHandlers {
pub on_open: Option<OnOpenCallback>,
pub on_message: Option<OnMessageCallback>,
pub on_close: Option<OnCloseCallback>,
}
impl WsHandlers {
fn call_on_open(&self, args: OnOpenCallbackArgs) {
if let Some(on_open) = self.on_open {
let res = panic::catch_unwind(|| {
on_open(args);
});
if let Err(e) = res {
custom_print!("Error calling on_open handler: {:?}", e);
}
}
}
fn call_on_message(&self, args: OnMessageCallbackArgs) {
if let Some(on_message) = self.on_message {
let res = panic::catch_unwind(|| {
on_message(args);
});
if let Err(e) = res {
custom_print!("Error calling on_message handler: {:?}", e);
}
}
}
fn call_on_close(&self, args: OnCloseCallbackArgs) {
if let Some(on_close) = self.on_close {
let res = panic::catch_unwind(|| {
on_close(args);
});
if let Err(e) = res {
custom_print!("Error calling on_close handler: {:?}", e);
}
}
}
}
#[derive(Clone)]
pub struct WsInitParams {
pub handlers: WsHandlers,
pub gateway_principal: String,
pub max_number_of_returned_messages: usize,
pub send_ack_interval_ms: u64,
pub keep_alive_timeout_ms: u64,
}
impl WsInitParams {
pub fn new(handlers: WsHandlers, gateway_principal: String) -> Self {
Self {
handlers,
gateway_principal,
..Default::default()
}
}
fn get_handlers(&self) -> WsHandlers {
self.handlers.clone()
}
fn check_validity(&self) {
if self.keep_alive_timeout_ms > self.send_ack_interval_ms {
trap("send_ack_interval_ms must be greater than keep_alive_timeout_ms");
}
}
}
impl Default for WsInitParams {
fn default() -> Self {
Self {
handlers: WsHandlers::default(),
gateway_principal: String::new(),
max_number_of_returned_messages: DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES,
send_ack_interval_ms: DEFAULT_SEND_ACK_INTERVAL_MS,
keep_alive_timeout_ms: DEFAULT_CLIENT_KEEP_ALIVE_TIMEOUT_MS,
}
}
}
pub fn init(params: WsInitParams) {
params.check_validity();
set_params(params.clone());
initialize_registered_gateway(¶ms.gateway_principal);
reset_timers();
schedule_send_ack_to_clients();
}
pub fn ws_open(args: CanisterWsOpenArguments) -> CanisterWsOpenResult {
let client_principal = caller();
if client_principal == ClientPrincipal::anonymous() {
return Err(String::from("anonymous principal cannot open a connection"));
}
if is_registered_gateway(client_principal) {
return Err(String::from(
"caller is the registered gateway which can't open a connection for itself",
));
}
let client_key = ClientKey::new(client_principal, args.client_nonce);
if is_client_registered(&client_key) {
return Err(format!(
"client with key {} already has an open connection",
client_key,
));
}
let new_client = RegisteredClient::new();
add_client(client_key.clone(), new_client);
let open_message = CanisterOpenMessageContent {
client_key: client_key.clone(),
};
let message = WebsocketServiceMessageContent::OpenMessage(open_message);
send_service_message_to_client(&client_key, message)?;
get_handlers_from_params().call_on_open(OnOpenCallbackArgs { client_principal });
Ok(())
}
pub fn ws_close(args: CanisterWsCloseArguments) -> CanisterWsCloseResult {
check_is_registered_gateway(caller())?;
check_registered_client(&args.client_key)?;
remove_client(&args.client_key);
Ok(())
}
pub fn ws_message<T: CandidType + for<'a> Deserialize<'a>>(
args: CanisterWsMessageArguments,
_message_type: Option<T>,
) -> CanisterWsMessageResult {
let client_principal = caller();
let registered_client_key = get_client_key_from_principal(&client_principal)?;
let WebsocketMessage {
client_key,
sequence_num,
timestamp: _,
is_service_message,
content,
} = args.msg;
if registered_client_key != client_key {
return Err(String::from(format!(
"client with principal {} has a different key than the one used in the message",
client_principal
)));
}
let expected_sequence_num = get_expected_incoming_message_from_client_num(&client_key)?;
if sequence_num != expected_sequence_num {
remove_client(&client_key);
return Err(String::from(
format!(
"incoming client's message does not have the expected sequence number. Expected: {expected_sequence_num}, actual: {sequence_num}. Client removed.",
),
));
}
increment_expected_incoming_message_from_client_num(&client_key)?;
if is_service_message {
return handle_received_service_message(&client_key, &content);
}
get_handlers_from_params().call_on_message(OnMessageCallbackArgs {
client_principal,
message: content,
});
Ok(())
}
pub fn ws_get_messages(args: CanisterWsGetMessagesArguments) -> CanisterWsGetMessagesResult {
let gateway_principal = caller();
check_is_registered_gateway(gateway_principal)?;
get_cert_messages(gateway_principal, args.nonce)
}
pub fn ws_send(client_principal: ClientPrincipal, msg_bytes: Vec<u8>) -> CanisterWsSendResult {
let client_key = get_client_key_from_principal(&client_principal)?;
_ws_send(&client_key, msg_bytes, false)
}
#[cfg(test)]
mod test {
use super::*;
use proptest::prelude::*;
mod test_utils {
use candid::Principal;
use ic_agent::{identity::BasicIdentity, Identity};
use ring::signature::Ed25519KeyPair;
use super::{
get_message_for_gateway_key, CanisterOutputMessage, ClientKey, RegisteredClient,
MESSAGES_FOR_GATEWAY,
};
fn generate_random_key_pair() -> Ed25519KeyPair {
let rng = ring::rand::SystemRandom::new();
let key_pair =
Ed25519KeyPair::generate_pkcs8(&rng).expect("Could not generate a key pair.");
Ed25519KeyPair::from_pkcs8(key_pair.as_ref()).expect("Could not read the key pair.")
}
pub fn generate_random_principal() -> candid::Principal {
let key_pair = generate_random_key_pair();
let identity = BasicIdentity::from_key_pair(key_pair);
candid::Principal::from_text(identity.sender().unwrap().to_text()).unwrap()
}
pub(super) fn generate_random_registered_client() -> RegisteredClient {
RegisteredClient::new()
}
pub fn get_static_principal() -> Principal {
Principal::from_text("wnkwv-wdqb5-7wlzr-azfpw-5e5n5-dyxrf-uug7x-qxb55-mkmpa-5jqik-tqe")
.unwrap() }
pub(super) fn get_random_client_key() -> ClientKey {
ClientKey::new(
generate_random_principal(),
rand::random(),
)
}
pub(super) fn add_messages_for_gateway(
client_key: ClientKey,
gateway_principal: Principal,
count: u64,
) {
MESSAGES_FOR_GATEWAY.with(|m| {
for i in 0..count {
m.borrow_mut().push_back(CanisterOutputMessage {
client_key: client_key.clone(),
key: get_message_for_gateway_key(gateway_principal.clone(), i),
content: vec![],
});
}
});
}
pub fn clean_messages_for_gateway() {
MESSAGES_FOR_GATEWAY.with(|m| m.borrow_mut().clear());
}
}
#[test]
#[should_panic = "gateway should be initialized"]
fn test_get_gateway_principal_not_set() {
get_registered_gateway_principal();
}
#[test]
fn test_ws_handlers_are_called() {
struct CustomState {
is_on_open_called: bool,
is_on_message_called: bool,
is_on_close_called: bool,
}
impl CustomState {
fn new() -> Self {
Self {
is_on_open_called: false,
is_on_message_called: false,
is_on_close_called: false,
}
}
}
thread_local! {
static CUSTOM_STATE : RefCell<CustomState> = RefCell::new(CustomState::new());
}
let mut h = WsHandlers {
on_open: None,
on_message: None,
on_close: None,
};
set_params(WsInitParams {
handlers: h.clone(),
..Default::default()
});
let handlers = get_handlers_from_params();
assert!(handlers.on_open.is_none());
assert!(handlers.on_message.is_none());
assert!(handlers.on_close.is_none());
handlers.call_on_open(OnOpenCallbackArgs {
client_principal: test_utils::generate_random_principal(),
});
handlers.call_on_message(OnMessageCallbackArgs {
client_principal: test_utils::generate_random_principal(),
message: vec![],
});
handlers.call_on_close(OnCloseCallbackArgs {
client_principal: test_utils::generate_random_principal(),
});
assert!(!CUSTOM_STATE.with(|h| h.borrow().is_on_open_called));
assert!(!CUSTOM_STATE.with(|h| h.borrow().is_on_message_called));
assert!(!CUSTOM_STATE.with(|h| h.borrow().is_on_close_called));
let on_open = |_| {
CUSTOM_STATE.with(|h| {
let mut h = h.borrow_mut();
h.is_on_open_called = true;
});
};
let on_message = |_| {
CUSTOM_STATE.with(|h| {
let mut h = h.borrow_mut();
h.is_on_message_called = true;
});
};
let on_close = |_| {
CUSTOM_STATE.with(|h| {
let mut h = h.borrow_mut();
h.is_on_close_called = true;
});
};
h = WsHandlers {
on_open: Some(on_open),
on_message: Some(on_message),
on_close: Some(on_close),
};
set_params(WsInitParams {
handlers: h.clone(),
..Default::default()
});
let handlers = get_handlers_from_params();
assert!(handlers.on_open.is_some());
assert!(handlers.on_message.is_some());
assert!(handlers.on_close.is_some());
handlers.call_on_open(OnOpenCallbackArgs {
client_principal: test_utils::generate_random_principal(),
});
handlers.call_on_message(OnMessageCallbackArgs {
client_principal: test_utils::generate_random_principal(),
message: vec![],
});
handlers.call_on_close(OnCloseCallbackArgs {
client_principal: test_utils::generate_random_principal(),
});
assert!(CUSTOM_STATE.with(|h| h.borrow().is_on_open_called));
assert!(CUSTOM_STATE.with(|h| h.borrow().is_on_message_called));
assert!(CUSTOM_STATE.with(|h| h.borrow().is_on_close_called));
}
#[test]
fn test_ws_handlers_panic_is_handled() {
let h = WsHandlers {
on_open: Some(|_| {
panic!("on_open_panic");
}),
on_message: Some(|_| {
panic!("on_close_panic");
}),
on_close: Some(|_| {
panic!("on_close_panic");
}),
};
set_params(WsInitParams {
handlers: h.clone(),
..Default::default()
});
let handlers = get_handlers_from_params();
let res = panic::catch_unwind(|| {
handlers.call_on_open(OnOpenCallbackArgs {
client_principal: test_utils::generate_random_principal(),
});
});
assert!(res.is_ok());
let res = panic::catch_unwind(|| {
handlers.call_on_message(OnMessageCallbackArgs {
client_principal: test_utils::generate_random_principal(),
message: vec![],
});
});
assert!(res.is_ok());
let res = panic::catch_unwind(|| {
handlers.call_on_close(OnCloseCallbackArgs {
client_principal: test_utils::generate_random_principal(),
});
});
assert!(res.is_ok());
}
#[test]
fn test_current_time() {
assert_eq!(get_current_time(), 0u64);
}
proptest! {
#[test]
fn test_initialize_registered_gateway(test_gateway_principal in any::<u8>().prop_map(|_| test_utils::generate_random_principal())) {
initialize_registered_gateway(&test_gateway_principal.to_string());
REGISTERED_GATEWAY.with(|p| {
let p = p.borrow();
assert!(p.is_some());
assert_eq!(
p.unwrap(),
RegisteredGateway::new(test_gateway_principal)
);
});
}
#[test]
fn test_get_outgoing_message_nonce(test_nonce in any::<u64>()) {
OUTGOING_MESSAGE_NONCE.with(|n| *n.borrow_mut() = test_nonce);
let actual_nonce = get_outgoing_message_nonce();
prop_assert_eq!(actual_nonce, test_nonce);
}
#[test]
fn test_increment_outgoing_message_nonce(test_nonce in any::<u64>()) {
OUTGOING_MESSAGE_NONCE.with(|n| *n.borrow_mut() = test_nonce);
increment_outgoing_message_nonce();
prop_assert_eq!(get_outgoing_message_nonce(), test_nonce + 1);
}
#[test]
fn test_insert_client(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key())) {
let registered_client = test_utils::generate_random_registered_client();
insert_client(test_client_key.clone(), registered_client.clone());
let actual_client_key = CURRENT_CLIENT_KEY_MAP.with(|map| map.borrow().get(&test_client_key.client_principal).unwrap().clone());
prop_assert_eq!(actual_client_key, test_client_key.clone());
let actual_client = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_key).unwrap().clone());
prop_assert_eq!(actual_client, registered_client);
}
#[test]
fn test_get_gateway_principal(test_gateway_principal in any::<u8>().prop_map(|_| test_utils::generate_random_principal())) {
REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(test_gateway_principal.clone())));
let actual_gateway_principal = get_registered_gateway_principal();
prop_assert_eq!(actual_gateway_principal, test_gateway_principal);
}
#[test]
fn test_is_client_registered_empty(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key())) {
let actual_result = is_client_registered(&test_client_key);
prop_assert_eq!(actual_result, false);
}
#[test]
fn test_is_client_registered(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key())) {
REGISTERED_CLIENTS.with(|map| {
map.borrow_mut().insert(test_client_key.clone(), test_utils::generate_random_registered_client());
});
let actual_result = is_client_registered(&test_client_key);
prop_assert_eq!(actual_result, true);
}
#[test]
fn test_get_client_key_from_principal_empty(test_client_principal in any::<u8>().prop_map(|_| test_utils::generate_random_principal())) {
let actual_result = get_client_key_from_principal(&test_client_principal);
prop_assert_eq!(actual_result.err(), Some(String::from(format!(
"client with principal {} doesn't have an open connection",
test_client_principal
))));
}
#[test]
fn test_get_client_key_from_principal(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key())) {
CURRENT_CLIENT_KEY_MAP.with(|map| {
map.borrow_mut().insert(test_client_key.client_principal, test_client_key.clone());
});
let actual_result = get_client_key_from_principal(&test_client_key.client_principal);
prop_assert_eq!(actual_result.unwrap(), test_client_key);
}
#[test]
fn test_check_registered_client_empty(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key())) {
let actual_result = check_registered_client(&test_client_key);
prop_assert_eq!(actual_result.err(), Some(format!("client with key {} doesn't have an open connection", test_client_key)));
}
#[test]
fn test_check_registered_client(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key())) {
REGISTERED_CLIENTS.with(|map| {
map.borrow_mut().insert(test_client_key.clone(), test_utils::generate_random_registered_client());
});
let actual_result = check_registered_client(&test_client_key);
prop_assert!(actual_result.is_ok());
let non_existing_client_key = test_utils::get_random_client_key();
let actual_result = check_registered_client(&non_existing_client_key);
prop_assert_eq!(actual_result.err(), Some(format!("client with key {} doesn't have an open connection", non_existing_client_key)));
}
#[test]
fn test_init_outgoing_message_to_client_num(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key())) {
init_outgoing_message_to_client_num(test_client_key.clone());
let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone());
prop_assert_eq!(actual_result, INITIAL_CANISTER_SEQUENCE_NUM);
}
#[test]
fn test_increment_outgoing_message_to_client_num(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key()), test_num in any::<u64>()) {
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().insert(test_client_key.clone(), test_num);
});
let increment_result = increment_outgoing_message_to_client_num(&test_client_key);
prop_assert!(increment_result.is_ok());
let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone());
prop_assert_eq!(actual_result, test_num + 1);
}
#[test]
fn test_get_outgoing_message_to_client_num(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key()), test_num in any::<u64>()) {
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().insert(test_client_key.clone(), test_num);
});
let actual_result = get_outgoing_message_to_client_num(&test_client_key);
prop_assert!(actual_result.is_ok());
prop_assert_eq!(actual_result.unwrap(), test_num);
}
#[test]
fn test_init_expected_incoming_message_from_client_num(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key())) {
init_expected_incoming_message_from_client_num(test_client_key.clone());
let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone());
prop_assert_eq!(actual_result, INITIAL_CLIENT_SEQUENCE_NUM);
}
#[test]
fn test_get_expected_incoming_message_from_client_num(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key()), test_num in any::<u64>()) {
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().insert(test_client_key.clone(), test_num);
});
let actual_result = get_expected_incoming_message_from_client_num(&test_client_key);
prop_assert!(actual_result.is_ok());
prop_assert_eq!(actual_result.unwrap(), test_num);
}
#[test]
fn test_increment_expected_incoming_message_from_client_num(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key()), test_num in any::<u64>()) {
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().insert(test_client_key.clone(), test_num);
});
let increment_result = increment_expected_incoming_message_from_client_num(&test_client_key);
prop_assert!(increment_result.is_ok());
let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone());
prop_assert_eq!(actual_result, test_num + 1);
}
#[test]
fn test_add_client_to_wait_for_keep_alive(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key())) {
add_client_to_wait_for_keep_alive(&test_client_key);
let actual_result = CLIENTS_WAITING_FOR_KEEP_ALIVE.with(|map| map.borrow().get(&test_client_key).is_some());
prop_assert_eq!(actual_result, true);
}
#[test]
fn test_add_client(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key())) {
let registered_client = test_utils::generate_random_registered_client();
add_client(test_client_key.clone(), registered_client.clone());
let actual_result = CURRENT_CLIENT_KEY_MAP.with(|map| map.borrow().get(&test_client_key.client_principal).unwrap().clone());
prop_assert_eq!(actual_result, test_client_key.clone());
let actual_result = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_key).unwrap().clone());
prop_assert_eq!(actual_result, registered_client);
let actual_result = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone());
prop_assert_eq!(actual_result, INITIAL_CLIENT_SEQUENCE_NUM);
let actual_result = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).unwrap().clone());
prop_assert_eq!(actual_result, INITIAL_CANISTER_SEQUENCE_NUM);
}
#[test]
fn test_remove_client(test_client_key in any::<u8>().prop_map(|_| test_utils::get_random_client_key())) {
CURRENT_CLIENT_KEY_MAP.with(|map| {
map.borrow_mut().insert(test_client_key.client_principal.clone(), test_client_key.clone());
});
REGISTERED_CLIENTS.with(|map| {
map.borrow_mut().insert(test_client_key.clone(), test_utils::generate_random_registered_client());
});
INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().insert(test_client_key.clone(), INITIAL_CLIENT_SEQUENCE_NUM);
});
OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| {
map.borrow_mut().insert(test_client_key.clone(), INITIAL_CANISTER_SEQUENCE_NUM);
});
remove_client(&test_client_key);
let is_none = CURRENT_CLIENT_KEY_MAP.with(|map| map.borrow().get(&test_client_key.client_principal).is_none());
prop_assert!(is_none);
let is_none = REGISTERED_CLIENTS.with(|map| map.borrow().get(&test_client_key).is_none());
prop_assert!(is_none);
let is_none = INCOMING_MESSAGE_FROM_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).is_none());
prop_assert!(is_none);
let is_none = OUTGOING_MESSAGE_TO_CLIENT_NUM_MAP.with(|map| map.borrow().get(&test_client_key).is_none());
prop_assert!(is_none);
}
#[test]
fn test_get_message_for_gateway_key(test_gateway_principal in any::<u8>().prop_map(|_| test_utils::generate_random_principal()), test_nonce in any::<u64>()) {
let actual_result = get_message_for_gateway_key(test_gateway_principal.clone(), test_nonce);
prop_assert_eq!(actual_result, test_gateway_principal.to_string() + "_" + &format!("{:0>20}", test_nonce.to_string()));
}
#[test]
fn test_get_messages_for_gateway_range_empty(messages_count in any::<u64>().prop_map(|c| c % 1000)) {
let gateway_principal = test_utils::generate_random_principal();
REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone())));
for i in 0..messages_count {
let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, i);
prop_assert_eq!(start_index, 0);
prop_assert_eq!(end_index, 0);
}
}
#[test]
fn test_get_messages_for_gateway_range_smaller_than_max(gateway_principal in any::<u8>().prop_map(|_| test_utils::get_static_principal())) {
REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone())));
let messages_count = 4;
let test_client_key = test_utils::get_random_client_key();
test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count);
for i in 0..messages_count + 1 {
let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, i);
prop_assert_eq!(start_index, i as usize);
prop_assert_eq!(end_index, messages_count as usize);
}
test_utils::clean_messages_for_gateway();
}
#[test]
fn test_get_messages_for_gateway_range_larger_than_max(gateway_principal in any::<u8>().prop_map(|_| test_utils::get_static_principal()), max_number_of_returned_messages in any::<usize>().prop_map(|c| c % 1000)) {
PARAMS.with(|p| {
*p.borrow_mut() = WsInitParams {
max_number_of_returned_messages,
..Default::default()
}
});
REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone())));
let messages_count: u64 = (2 * max_number_of_returned_messages).try_into().unwrap();
let test_client_key = test_utils::get_random_client_key();
test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count);
for i in 1..messages_count + 1 {
let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, i);
let expected_end_index = if (i as usize) + max_number_of_returned_messages > messages_count as usize {
messages_count as usize
} else {
(i as usize) + max_number_of_returned_messages
};
prop_assert_eq!(start_index, i as usize);
prop_assert_eq!(end_index, expected_end_index);
}
test_utils::clean_messages_for_gateway();
}
#[test]
fn test_get_messages_for_gateway_initial_nonce(gateway_principal in any::<u8>().prop_map(|_| test_utils::get_static_principal()), messages_count in any::<u64>().prop_map(|c| c % 100), max_number_of_returned_messages in any::<usize>().prop_map(|c| c % 1000)) {
PARAMS.with(|p| {
*p.borrow_mut() = WsInitParams {
max_number_of_returned_messages,
..Default::default()
}
});
REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone())));
let test_client_key = test_utils::get_random_client_key();
test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count);
let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, 0);
let expected_start_index = if (messages_count as usize) > max_number_of_returned_messages {
(messages_count as usize) - max_number_of_returned_messages
} else {
0
};
prop_assert_eq!(start_index, expected_start_index);
prop_assert_eq!(end_index, messages_count as usize);
test_utils::clean_messages_for_gateway();
}
#[test]
fn test_get_messages_for_gateway(gateway_principal in any::<u8>().prop_map(|_| test_utils::get_static_principal()), messages_count in any::<u64>().prop_map(|c| c % 100)) {
REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(gateway_principal.clone())));
let test_client_key = test_utils::get_random_client_key();
test_utils::add_messages_for_gateway(test_client_key, gateway_principal, messages_count);
for i in 0..messages_count + 1 {
let (start_index, end_index) = get_messages_for_gateway_range(gateway_principal, i);
let messages = get_messages_for_gateway(start_index, end_index);
for (j, message) in messages.iter().enumerate() {
let expected_key = get_message_for_gateway_key(gateway_principal.clone(), (start_index + j) as u64);
prop_assert_eq!(&message.key, &expected_key);
}
}
test_utils::clean_messages_for_gateway();
}
#[test]
fn test_check_is_registered_gateway(test_gateway_principal in any::<u8>().prop_map(|_| test_utils::generate_random_principal())) {
REGISTERED_GATEWAY.with(|p| *p.borrow_mut() = Some(RegisteredGateway::new(test_gateway_principal.clone())));
let actual_result = check_is_registered_gateway(test_gateway_principal);
prop_assert!(actual_result.is_ok());
let other_principal = test_utils::generate_random_principal();
let actual_result = check_is_registered_gateway(other_principal);
prop_assert_eq!(actual_result.err(), Some(String::from("caller is not the gateway that has been registered during CDK initialization")));
}
#[test]
fn test_serialize_websocket_message(test_msg_bytes in any::<Vec<u8>>(), test_sequence_num in any::<u64>(), test_timestamp in any::<u64>()) {
let websocket_message = WebsocketMessage {
client_key: test_utils::get_random_client_key(),
sequence_num: test_sequence_num,
timestamp: test_timestamp,
is_service_message: false,
content: test_msg_bytes,
};
let serialized_message = websocket_message.cbor_serialize();
assert!(serialized_message.is_ok()); }
}
}