1use std::{collections::VecDeque, fmt, time::Duration};
2
3use candid::{decode_one, CandidType, Principal};
4use serde::{Deserialize, Serialize};
5use serde_cbor::Serializer;
6
7use crate::{
8 custom_trap, errors::WsError, utils::get_current_time, CLIENT_KEEP_ALIVE_TIMEOUT_MS,
9 DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES, DEFAULT_SEND_ACK_INTERVAL_MS,
10 INITIAL_OUTGOING_MESSAGE_NONCE,
11};
12
13pub type ClientPrincipal = Principal;
14#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug, Hash)]
15pub struct ClientKey {
16 pub client_principal: ClientPrincipal,
17 pub client_nonce: u64,
18}
19
20impl ClientKey {
21 pub fn new(client_principal: ClientPrincipal, client_nonce: u64) -> Self {
23 Self {
24 client_principal,
25 client_nonce,
26 }
27 }
28}
29
30impl fmt::Display for ClientKey {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 write!(f, "{}_{}", self.client_principal, self.client_nonce)
33 }
34}
35
36pub type CanisterWsOpenResult = Result<(), String>;
38pub type CanisterWsCloseResult = Result<(), String>;
40pub type CanisterWsMessageResult = Result<(), String>;
42pub type CanisterWsGetMessagesResult = Result<CanisterOutputCertifiedMessages, String>;
44pub type CanisterSendResult = Result<(), String>;
46#[deprecated(since = "0.3.2", note = "use `CanisterSendResult` instead")]
47pub type CanisterWsSendResult = Result<(), String>;
48pub type CanisterCloseResult = Result<(), String>;
50
51#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
53pub struct CanisterWsOpenArguments {
54 pub client_nonce: u64,
55 pub gateway_principal: GatewayPrincipal,
56}
57
58impl CanisterWsOpenArguments {
59 pub fn new(client_nonce: u64, gateway_principal: GatewayPrincipal) -> Self {
60 Self {
61 client_nonce,
62 gateway_principal,
63 }
64 }
65}
66#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
68pub struct CanisterWsCloseArguments {
69 pub client_key: ClientKey,
70}
71
72impl CanisterWsCloseArguments {
73 pub fn new(client_key: ClientKey) -> Self {
74 Self { client_key }
75 }
76}
77
78#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
80pub struct CanisterWsMessageArguments {
81 pub msg: WebsocketMessage,
82}
83
84impl CanisterWsMessageArguments {
85 pub fn new(msg: WebsocketMessage) -> Self {
86 Self { msg }
87 }
88}
89
90#[derive(CandidType, Clone, Deserialize, Serialize, Eq, PartialEq, Debug)]
92pub struct CanisterWsGetMessagesArguments {
93 pub nonce: u64,
94}
95
96impl CanisterWsGetMessagesArguments {
97 pub fn new(nonce: u64) -> Self {
98 Self { nonce }
99 }
100}
101
102#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
106pub struct WebsocketMessage {
107 pub client_key: ClientKey, pub sequence_num: u64, pub timestamp: TimestampNs, pub is_service_message: bool, #[serde(with = "serde_bytes")]
112 pub content: Vec<u8>, }
114
115impl WebsocketMessage {
116 pub fn new(
117 client_key: ClientKey,
118 sequence_num: u64,
119 timestamp: TimestampNs,
120 is_service_message: bool,
121 content: Vec<u8>,
122 ) -> Self {
123 Self {
124 client_key,
125 sequence_num,
126 timestamp,
127 is_service_message,
128 content,
129 }
130 }
131
132 pub fn cbor_serialize(&self) -> Result<Vec<u8>, String> {
134 let mut data = vec![];
135 let mut serializer = Serializer::new(&mut data);
136 serializer.self_describe().map_err(|e| e.to_string())?;
137 self.serialize(&mut serializer).map_err(|e| e.to_string())?;
138 Ok(data)
139 }
140}
141
142#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
146pub struct CanisterOutputMessage {
147 pub client_key: ClientKey, pub key: String, #[serde(with = "serde_bytes")]
150 pub content: Vec<u8>, }
152
153#[derive(CandidType, Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
157pub struct CanisterOutputCertifiedMessages {
158 pub messages: Vec<CanisterOutputMessage>, #[serde(with = "serde_bytes")]
160 pub cert: Vec<u8>, #[serde(with = "serde_bytes")]
162 pub tree: Vec<u8>, pub is_end_of_queue: bool, }
165
166impl CanisterOutputCertifiedMessages {
167 pub fn empty() -> Self {
168 Self {
169 messages: vec![],
170 cert: vec![],
171 tree: vec![],
172 is_end_of_queue: true,
173 }
174 }
175}
176
177pub(crate) struct MessagesForGatewayRange {
178 pub start_index: usize,
179 pub end_index: usize,
180 pub is_end_of_queue: bool,
181}
182
183pub(crate) type TimestampNs = u64;
184
185#[derive(Clone, Debug, Default, Eq, PartialEq)]
186pub(crate) struct MessageToDelete {
187 timestamp: TimestampNs,
188}
189
190pub(crate) type GatewayPrincipal = Principal;
191
192#[derive(Clone, Debug, Default, Eq, PartialEq)]
194pub(crate) struct RegisteredGateway {
195 pub(crate) messages_queue: VecDeque<CanisterOutputMessage>,
197 pub(crate) messages_to_delete: VecDeque<MessageToDelete>,
199 pub(crate) outgoing_message_nonce: u64,
203 pub(crate) connected_clients_count: u64,
205}
206
207impl RegisteredGateway {
208 pub(crate) fn new() -> Self {
210 Self {
211 messages_queue: VecDeque::new(),
212 messages_to_delete: VecDeque::new(),
213 outgoing_message_nonce: INITIAL_OUTGOING_MESSAGE_NONCE,
214 connected_clients_count: 0,
215 }
216 }
217
218 pub(crate) fn increment_nonce(&mut self) {
220 self.outgoing_message_nonce += 1;
221 }
222
223 pub(crate) fn increment_clients_count(&mut self) {
225 self.connected_clients_count += 1;
226 }
227
228 pub(crate) fn decrement_clients_count(&mut self) -> u64 {
230 self.connected_clients_count = self.connected_clients_count.saturating_sub(1);
231 self.connected_clients_count
232 }
233
234 pub(crate) fn add_message_to_queue(
236 &mut self,
237 message: CanisterOutputMessage,
238 message_timestamp: TimestampNs,
239 ) {
240 self.messages_queue.push_back(message.clone());
241 self.messages_to_delete.push_back(MessageToDelete {
242 timestamp: message_timestamp,
243 });
244 }
245
246 pub(crate) fn delete_old_messages(&mut self, n: usize, message_max_age_ms: u64) -> Vec<String> {
250 let time = get_current_time();
251 let mut deleted_keys = vec![];
252
253 for _ in 0..n {
254 if let Some(message_to_delete) = self.messages_to_delete.front() {
255 if Duration::from_nanos(time - message_to_delete.timestamp)
256 > Duration::from_millis(message_max_age_ms)
257 {
258 let deleted_message = self.messages_queue.pop_front().unwrap();
261 deleted_keys.push(deleted_message.key.clone());
262 self.messages_to_delete.pop_front();
263 } else {
264 break;
267 }
268 } else {
269 break;
271 }
272 }
273
274 deleted_keys
275 }
276}
277
278#[derive(Clone, Debug, Eq, PartialEq)]
280pub(crate) struct RegisteredClient {
281 pub(crate) last_keep_alive_timestamp: TimestampNs,
282 pub(crate) gateway_principal: GatewayPrincipal,
283}
284
285impl RegisteredClient {
286 pub(crate) fn new(gateway_principal: GatewayPrincipal) -> Self {
288 Self {
289 last_keep_alive_timestamp: get_current_time(),
290 gateway_principal,
291 }
292 }
293
294 pub(crate) fn get_last_keep_alive_timestamp(&self) -> TimestampNs {
296 self.last_keep_alive_timestamp
297 }
298
299 pub(crate) fn update_last_keep_alive_timestamp(&mut self) {
301 self.last_keep_alive_timestamp = get_current_time();
302 }
303}
304
305#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
307pub struct CanisterOpenMessageContent {
308 pub client_key: ClientKey,
309}
310
311#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
313pub struct CanisterAckMessageContent {
314 pub last_incoming_sequence_num: u64,
315}
316
317#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
319pub struct ClientKeepAliveMessageContent {
320 pub last_incoming_sequence_num: u64,
321}
322
323#[derive(CandidType, Clone, Debug, Deserialize, PartialEq, Eq)]
325pub enum CloseMessageReason {
326 WrongSequenceNumber,
328 InvalidServiceMessage,
330 KeepAliveTimeout,
332 ClosedByApplication,
334}
335
336#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
338pub struct CanisterCloseMessageContent {
339 pub reason: CloseMessageReason,
340}
341
342#[derive(CandidType, Debug, Deserialize, PartialEq, Eq)]
346pub enum WebsocketServiceMessageContent {
347 OpenMessage(CanisterOpenMessageContent),
349 AckMessage(CanisterAckMessageContent),
351 KeepAliveMessage(ClientKeepAliveMessageContent),
353 CloseMessage(CanisterCloseMessageContent),
355}
356
357impl WebsocketServiceMessageContent {
358 pub fn from_candid_bytes(bytes: &[u8]) -> Result<Self, String> {
359 decode_one(&bytes).map_err(|err| WsError::DecodeServiceMessageContent { err }.to_string())
360 }
361}
362
363pub struct OnOpenCallbackArgs {
365 pub client_principal: ClientPrincipal,
366}
367type OnOpenCallback = fn(OnOpenCallbackArgs);
370
371pub struct OnMessageCallbackArgs {
394 pub client_principal: ClientPrincipal,
396 pub message: Vec<u8>,
398}
399type OnMessageCallback = fn(OnMessageCallbackArgs);
402
403pub struct OnCloseCallbackArgs {
405 pub client_principal: ClientPrincipal,
406}
407type OnCloseCallback = fn(OnCloseCallbackArgs);
413
414#[derive(Clone, Debug, Default, PartialEq)]
421pub struct WsHandlers {
422 pub on_open: Option<OnOpenCallback>,
423 pub on_message: Option<OnMessageCallback>,
424 pub on_close: Option<OnCloseCallback>,
425}
426
427impl WsHandlers {
428 pub(crate) fn call_on_open(&self, args: OnOpenCallbackArgs) {
429 if let Some(on_open) = self.on_open {
430 on_open(args);
433 }
434 }
435
436 pub(crate) fn call_on_message(&self, args: OnMessageCallbackArgs) {
437 if let Some(on_message) = self.on_message {
438 on_message(args);
440 }
441 }
442
443 pub(crate) fn call_on_close(&self, args: OnCloseCallbackArgs) {
444 if let Some(on_close) = self.on_close {
445 on_close(args);
447 }
448 }
449}
450
451#[derive(Clone)]
453pub struct WsInitParams {
454 pub handlers: WsHandlers,
456 pub max_number_of_returned_messages: usize,
460 pub send_ack_interval_ms: u64,
467}
468
469impl WsInitParams {
470 pub fn new(handlers: WsHandlers) -> Self {
472 Self {
473 handlers,
474 ..Default::default()
475 }
476 }
477
478 pub(crate) fn get_handlers(&self) -> WsHandlers {
479 self.handlers.clone()
480 }
481
482 pub(crate) fn check_validity(&self) {
488 if self.send_ack_interval_ms <= CLIENT_KEEP_ALIVE_TIMEOUT_MS {
489 custom_trap!("send_ack_interval_ms must be greater than CLIENT_KEEP_ALIVE_TIMEOUT_MS");
490 }
491 }
492
493 pub fn with_max_number_of_returned_messages(
494 mut self,
495 max_number_of_returned_messages: usize,
496 ) -> Self {
497 self.max_number_of_returned_messages = max_number_of_returned_messages;
498 self
499 }
500
501 pub fn with_send_ack_interval_ms(mut self, send_ack_interval_ms: u64) -> Self {
509 self.send_ack_interval_ms = send_ack_interval_ms;
510 self.check_validity();
511 self
512 }
513}
514
515impl Default for WsInitParams {
516 fn default() -> Self {
517 Self {
518 handlers: WsHandlers::default(),
519 max_number_of_returned_messages: DEFAULT_MAX_NUMBER_OF_RETURNED_MESSAGES,
520 send_ack_interval_ms: DEFAULT_SEND_ACK_INTERVAL_MS,
521 }
522 }
523}