Skip to main content

opentalk_roomserver_signaling/
module_context.rs

1// SPDX-License-Identifier: EUPL-1.2
2// SPDX-FileCopyrightText: OpenTalk Team <mail@opentalk.eu>
3
4use std::{
5    cell::RefCell, collections::HashMap, future::Future, marker::PhantomData, sync::Arc,
6    time::Duration,
7};
8
9use anyhow::Context as _;
10use futures::stream::FuturesUnordered;
11use opentalk_roomserver_types::{
12    client_parameters::{ClientKind, Role},
13    connection_id::ConnectionId,
14    error::{self, SignalingError},
15    room_kind::RoomKind,
16    shared_raw_json::SharedRawJson,
17    signaling::module_error::FatalError,
18};
19use opentalk_types_common::{rooms::RoomId, time::Timestamp, users::UserId};
20use opentalk_types_signaling::ParticipantId;
21use serde_json::value::RawValue;
22use tokio::{
23    select,
24    sync::oneshot::{self, Receiver, Sender},
25};
26use tracing::{Instrument as _, debug_span};
27
28use crate::{
29    banned_participant::BannedParticipant,
30    event_origin::EventOrigin,
31    instruction::Instruction,
32    internal_module_message::InterModuleMessage,
33    loopback::{LoopbackFuture, LoopbackMessage},
34    participant_state::{ParticipantState, Participants},
35    room_info::RoomTaskInfo,
36    signaling_event::SignalingEvent,
37    signaling_module::SignalingModule,
38    storage::{
39        StorageContext,
40        assets::{
41            AssetMetaData, ModuleAssetStorage, UploadResult,
42            provider::{AssetStorageProvider, AssetStream},
43        },
44        module_resources::{ModuleResourceStorage, provider::ModuleResourceProvider},
45    },
46    waiting_participant::WaitingParticipant,
47};
48
49#[derive(Debug)]
50pub enum ModuleMessage {
51    Websocket {
52        connection_id: ConnectionId,
53        message: SharedRawJson,
54    },
55    WaitingRoomWebsocket {
56        connection_id: ConnectionId,
57        message: SharedRawJson,
58    },
59    InternalCommand(InterModuleMessage),
60    Instruction(Instruction),
61}
62
63/// Contains the room state and provides an interface to send websocket messages.
64pub struct ModuleContext<'ctx, M>
65where
66    M: SignalingModule,
67{
68    pub room_id: RoomId,
69    pub room: RoomKind,
70    pub event_origin: EventOrigin,
71    pub room_task_info: &'ctx mut RoomTaskInfo,
72    /// The websocket messages that are sent out after the module finished its event handling
73    messages: &'ctx mut RefCell<Vec<ModuleMessage>>,
74    /// Contains all participants including disconnected ones
75    pub participants: &'ctx mut Participants,
76    pub waiting_participants: &'ctx mut HashMap<ParticipantId, WaitingParticipant>,
77    pub banned_participants: &'ctx mut HashMap<ParticipantId, BannedParticipant>,
78    pub timestamp: Timestamp,
79    loopback_futures: &'ctx mut FuturesUnordered<LoopbackFuture>,
80    assets: Arc<dyn AssetStorageProvider>,
81    module_resources: Arc<dyn ModuleResourceProvider>,
82
83    m: PhantomData<fn() -> M>,
84}
85
86impl<'ctx, M> ModuleContext<'ctx, M>
87where
88    M: SignalingModule,
89{
90    #[allow(clippy::too_many_arguments)]
91    pub fn new(
92        room_id: RoomId,
93        room: RoomKind,
94        event_origin: EventOrigin,
95        room_task_info: &'ctx mut RoomTaskInfo,
96        messages: &'ctx mut RefCell<Vec<ModuleMessage>>,
97        participants: &'ctx mut Participants,
98        waiting_participants: &'ctx mut HashMap<ParticipantId, WaitingParticipant>,
99        banned_participants: &'ctx mut HashMap<ParticipantId, BannedParticipant>,
100        timestamp: Timestamp,
101        loopback_futures: &'ctx mut FuturesUnordered<LoopbackFuture>,
102        assets: Arc<dyn AssetStorageProvider>,
103        module_resources: Arc<dyn ModuleResourceProvider>,
104    ) -> ModuleContext<'ctx, M> {
105        Self {
106            room_id,
107            room,
108            event_origin,
109            room_task_info,
110            messages,
111            participants,
112            waiting_participants,
113            banned_participants,
114            timestamp,
115            loopback_futures,
116            assets,
117            module_resources,
118            m: PhantomData,
119        }
120    }
121
122    pub fn reborrow<M2: SignalingModule>(&mut self) -> ModuleContext<'_, M2> {
123        ModuleContext {
124            room_id: self.room_id,
125            room: self.room,
126            event_origin: self.event_origin,
127            room_task_info: self.room_task_info,
128            messages: self.messages,
129            participants: self.participants,
130            waiting_participants: self.waiting_participants,
131            banned_participants: self.banned_participants,
132            timestamp: self.timestamp,
133            loopback_futures: self.loopback_futures,
134            assets: Arc::clone(&self.assets),
135            module_resources: Arc::clone(&self.module_resources),
136            m: PhantomData,
137        }
138    }
139
140    /// Send a websocket message of type [`SignalingModule::Outgoing`] to the given
141    /// `participant_ids`
142    ///
143    /// The message is always scoped to the [`SignalingModule::NAMESPACE`]
144    ///
145    /// # Errors
146    ///
147    /// Returns `Err` when the [`SignalingModule::Outgoing`] type failed to be serialized.
148    pub fn send_ws_message(
149        &self,
150        participant_ids: impl IntoIterator<Item = ParticipantId>,
151        msg: M::Outgoing,
152    ) -> Result<(), FatalError> {
153        let event = SignalingEvent {
154            namespace: M::NAMESPACE,
155            transaction_id: self.event_origin.transaction_id(),
156            timestamp: Timestamp::now(),
157            payload: msg,
158        };
159        let shared_json: SharedRawJson = serde_json::value::to_raw_value(&event)
160            .context("Failed to serialize internal websocket payload type")
161            .map_err(FatalError)?
162            .into();
163
164        for participant_id in participant_ids {
165            let Some(state) = self.participants.connected().get(&participant_id) else {
166                tracing::error!(
167                    "Module '{}' attempted to send a websocket message to unknown participant {participant_id}",
168                    M::NAMESPACE
169                );
170                return Ok(());
171            };
172            let mut messages = self.messages.borrow_mut();
173
174            for (connection_id, ..) in &state.connections {
175                messages.push(ModuleMessage::Websocket {
176                    connection_id: *connection_id,
177                    message: shared_json.clone(),
178                });
179            }
180        }
181
182        Ok(())
183    }
184
185    /// Send a websocket message of type [`SignalingModule::Outgoing`] to the given `connection_ids`
186    ///
187    /// The message is always scoped to the [`SignalingModule::NAMESPACE`]
188    ///
189    /// # Errors
190    ///
191    /// Returns `Err` when the [`SignalingModule::Outgoing`] type failed to be serialized.
192    pub fn send_ws_message_to_connections(
193        &self,
194        connection_ids: impl IntoIterator<Item = ConnectionId>,
195        msg: M::Outgoing,
196    ) -> Result<(), FatalError> {
197        let event = SignalingEvent {
198            namespace: M::NAMESPACE,
199            transaction_id: self.event_origin.transaction_id(),
200            timestamp: Timestamp::now(),
201            payload: msg,
202        };
203        let shared_json: SharedRawJson = serde_json::value::to_raw_value(&event)
204            .context("Failed to serialize internal websocket payload type")
205            .map_err(FatalError)?
206            .into();
207
208        let mut messages = self.messages.borrow_mut();
209        for connection_id in connection_ids {
210            messages.push(ModuleMessage::Websocket {
211                connection_id,
212                message: shared_json.clone(),
213            });
214        }
215
216        Ok(())
217    }
218
219    pub fn send_ws_message_to_waiting_room(
220        &self,
221        participant_ids: impl IntoIterator<Item = ParticipantId>,
222        msg: M::Outgoing,
223    ) -> Result<(), FatalError> {
224        let event = SignalingEvent {
225            namespace: M::NAMESPACE,
226            transaction_id: self.event_origin.transaction_id(),
227            timestamp: Timestamp::now(),
228            payload: msg,
229        };
230        let shared_json: SharedRawJson = serde_json::value::to_raw_value(&event)
231            .context("Failed to serialize internal websocket payload type")
232            .map_err(FatalError)?
233            .into();
234
235        for participant_id in participant_ids {
236            let Some(waiting_participant) = self.waiting_participants.get(&participant_id) else {
237                tracing::error!(
238                    "Module '{}' attempted to send a websocket message to unknown participant {participant_id}",
239                    M::NAMESPACE
240                );
241                return Ok(());
242            };
243            let mut messages = self.messages.borrow_mut();
244
245            for (connection_id, ..) in &waiting_participant.connections {
246                messages.push(ModuleMessage::WaitingRoomWebsocket {
247                    connection_id: *connection_id,
248                    message: shared_json.clone(),
249                });
250            }
251        }
252
253        Ok(())
254    }
255
256    /// Send a websocket command received from one `source_connection` to all
257    /// other connections of the same participant.
258    ///
259    /// The message is always scoped to the [`SignalingModule::NAMESPACE`]
260    ///
261    /// # Errors
262    ///
263    /// Returns [`FatalError`] when the [`SignalingEvent`] type failed to be serialized.
264    pub fn send_replica(
265        &self,
266        sender: ParticipantId,
267        source_connection: ConnectionId,
268        replication_event: M::Outgoing,
269    ) -> Result<(), FatalError> {
270        let event = SignalingEvent {
271            namespace: M::NAMESPACE,
272            transaction_id: self.event_origin.transaction_id(),
273            timestamp: Timestamp::now(),
274            payload: replication_event,
275        };
276
277        let shared_json: SharedRawJson = serde_json::value::to_raw_value(&event)
278            .context("Failed to serialize internal websocket payload type")
279            .map_err(FatalError)?
280            .into();
281
282        let Some(state) = self.participants.connected().get(&sender) else {
283            tracing::error!(
284                "Module '{}' attempted to replicate a command to unknown participant {sender}",
285                M::NAMESPACE
286            );
287            return Ok(());
288        };
289        let mut messages = self.messages.borrow_mut();
290
291        for connection_id in state.connections.keys().copied() {
292            if connection_id != source_connection {
293                messages.push(ModuleMessage::Websocket {
294                    connection_id,
295                    message: shared_json.clone(),
296                });
297            }
298        }
299
300        Ok(())
301    }
302
303    /// Send a command to another [`SignalingModule`]
304    ///
305    /// * `command` - The command to be sent. The type is defined by the receiving module.
306    /// * `handle_result` - Closure that receives the result of the command.
307    pub fn send_internal_command<R>(&mut self, command: R::Internal)
308    where
309        R: SignalingModule,
310    {
311        let command = InterModuleMessage {
312            sender: M::NAMESPACE,
313            receiver: R::NAMESPACE,
314            command: Box::new(command),
315        };
316        self.messages
317            .get_mut()
318            .push(ModuleMessage::InternalCommand(command));
319    }
320
321    /// Kick the specified participants
322    pub fn kick_participants(&mut self, participants: Vec<ParticipantId>) {
323        let command = ModuleMessage::Instruction(Instruction::Kick { participants });
324        self.messages.get_mut().push(command);
325    }
326
327    pub fn ban_participant(&mut self, participant: ParticipantId) {
328        let command = ModuleMessage::Instruction(Instruction::Ban { participant });
329        self.messages.get_mut().push(command);
330    }
331
332    pub fn ban_waiting_participant(&mut self, participant: ParticipantId) {
333        let command: ModuleMessage =
334            ModuleMessage::Instruction(Instruction::BanWaiting { participant });
335        self.messages.get_mut().push(command);
336    }
337
338    /// Move the specified participant to the waiting room
339    pub fn move_to_waiting_room(&mut self, participant: ParticipantId) {
340        let command = ModuleMessage::Instruction(Instruction::MoveToWaitingRoom { participant });
341        self.messages.get_mut().push(command);
342    }
343
344    /// Invoke an error message of type [`SignalingError`]
345    ///
346    /// If the event origin is a signaling connection, the error will be sent to the participant.
347    ///
348    /// The message is always scoped to the [`error::ERROR_MODULE_ID`]
349    pub fn handle_error(&self, error: SignalingError) {
350        let participant_id = match self.event_origin {
351            EventOrigin::Participant(participant_origin) => participant_origin.id,
352            EventOrigin::Internal => {
353                tracing::error!(
354                    "Signaling module '{}' returned an error on an event with internal origin: {error:?} ",
355                    M::NAMESPACE
356                );
357                return;
358            }
359        };
360
361        let event = SignalingEvent {
362            namespace: error::ERROR_MODULE_ID,
363            transaction_id: self.event_origin.transaction_id(),
364            timestamp: Timestamp::now(),
365            payload: error,
366        };
367
368        let shared_json: SharedRawJson = match serde_json::value::to_raw_value(&event) {
369            Ok(value) => value.into(),
370            Err(err) => {
371                tracing::error!("Failed to serialize SignalingError type: {err}");
372                RawValue::from_string(r#"{"error": "internal"}"#.into())
373                    .unwrap()
374                    .into()
375            }
376        };
377
378        let mut messages = self.messages.borrow_mut();
379        if let Some(state) = self.participants.connected().get(&participant_id) {
380            for connection_id in state.connections() {
381                messages.push(ModuleMessage::Websocket {
382                    connection_id,
383                    message: shared_json.clone(),
384                });
385            }
386        } else if let Some(waiting_participant) = self.waiting_participants.get(&participant_id) {
387            let connections = waiting_participant.connections.keys();
388            for &connection_id in connections {
389                messages.push(ModuleMessage::WaitingRoomWebsocket {
390                    connection_id,
391                    message: shared_json.clone(),
392                });
393            }
394        } else {
395            tracing::error!(
396                "Module '{}' attempted to send a websocket error message to unknown participant {}",
397                M::NAMESPACE,
398                participant_id,
399            );
400        }
401    }
402
403    /// Spawns a new task that completes the given `future` and sends the result
404    /// back to the calling module as [`SignalingModule::Loopback`] in the
405    /// [`SignalingModule::on_loopback_event`] method.
406    ///
407    /// The room task will panic if the provided future panics.
408    pub fn spawn<F>(&self, future: F)
409    where
410        F: Future<Output = M::Loopback> + Send + 'static,
411    {
412        let origin = self.event_origin;
413        let room = self.room;
414        let timestamp = self.timestamp;
415        let span = debug_span!("spawn");
416
417        let future = future.instrument(span.clone());
418        let future = Box::pin(async move {
419            Some(LoopbackMessage {
420                namespace: M::NAMESPACE,
421                origin,
422                room,
423                timestamp,
424                span,
425                value: Box::new(future.await),
426            })
427        });
428
429        self.loopback_futures.push(future);
430    }
431
432    /// Spawns a new task that completes the given `future` and sends the result
433    /// back to the calling module as [`SignalingModule::Loopback`] in the
434    /// [`SignalingModule::on_loopback_event`] method when the result is [`Some`].
435    ///
436    /// The room task will panic if the provided future panics.
437    pub fn spawn_optional<F>(&self, future: F)
438    where
439        F: Future<Output = Option<M::Loopback>> + Send + 'static,
440    {
441        let origin = self.event_origin;
442        let room = self.room;
443        let timestamp = self.timestamp;
444        let span = debug_span!("spawn_optional");
445
446        let future = future.instrument(span.clone());
447        let future = Box::pin(async move {
448            future.await.map(|value| LoopbackMessage {
449                namespace: M::NAMESPACE,
450                origin,
451                room,
452                timestamp,
453                span,
454                value: Box::new(value),
455            })
456        });
457
458        self.loopback_futures.push(future);
459    }
460
461    /// Spawns a blocking function as a asynchronous task and sends the result
462    /// back to the calling module as [`SignalingModule::Loopback`] in the
463    /// [`SignalingModule::on_loopback_event`] method.
464    ///
465    /// If the provided function panics, any results will be discarded and the module won't be
466    /// notified.
467    pub fn spawn_blocking<F>(&self, blocking_function: F)
468    where
469        F: FnOnce() -> M::Loopback + Send + 'static,
470    {
471        let span = debug_span!("spawn_blocking");
472        let origin = self.event_origin;
473        let room = self.room;
474        let join_handle = {
475            let span = span.clone();
476            tokio::task::spawn_blocking(move || span.in_scope(blocking_function))
477        };
478        let timestamp = self.timestamp;
479
480        let future = Box::pin(async move {
481            let Ok(value) = join_handle.await else {
482                tracing::error!("module {} panicked in loopback task", M::NAMESPACE);
483                return None;
484            };
485
486            Some(LoopbackMessage {
487                namespace: M::NAMESPACE,
488                origin,
489                room,
490                timestamp,
491                span,
492                value: Box::new(value),
493            })
494        });
495
496        self.loopback_futures.push(future);
497    }
498
499    /// Creates a loopback future that resolves after the `duration`.
500    ///
501    /// When `duration` has passed, `create_result` is invoked and the return
502    /// value is sent as a loopback event.
503    /// Can be cancelled by sending a result into the `rx_cancel` [`Receiver`].
504    pub fn loopback_after<F>(&self, duration: Duration, create_result: F) -> Sender<M::Loopback>
505    where
506        M::Loopback: From<ChannelDroppedError> + Send + Sync + 'static,
507        F: FnOnce() -> M::Loopback + Send + Sync + 'static,
508    {
509        let (tx_cancel, rx_cancel) = oneshot::channel();
510        self.spawn(handle_loopback_after(duration, rx_cancel, create_result));
511        tx_cancel
512    }
513
514    pub fn recv_loopback<F, R>(&self, receiver: oneshot::Receiver<R>, create_result: F)
515    where
516        M::Loopback: From<ChannelDroppedError> + Send + Sync + 'static,
517        F: FnOnce(R) -> M::Loopback + Send + Sync + 'static,
518        R: Send + Sync + 'static,
519    {
520        self.spawn(async move {
521            match receiver.await {
522                Ok(result) => create_result(result),
523                Err(_) => ChannelDroppedError.into(),
524            }
525        });
526    }
527
528    pub fn participant_state(&self, participant_id: ParticipantId) -> Option<&ParticipantState> {
529        self.participants.all_unfiltered.get(&participant_id)
530    }
531
532    pub fn participant_role(&self, participant_id: ParticipantId) -> Option<Role> {
533        self.participant_state(participant_id).map(|p| p.role)
534    }
535
536    pub fn user_id(&self, participant_id: ParticipantId) -> Option<UserId> {
537        self.participant_state(participant_id)
538            .and_then(|state| state.kind.user_id())
539    }
540
541    pub fn is_moderator(&self, participant_id: ParticipantId) -> bool {
542        self.participant_role(participant_id)
543            .is_some_and(|r| r == Role::Moderator)
544    }
545
546    pub fn get_client_kind(&self, participant_id: ParticipantId) -> Option<&ClientKind> {
547        self.participant_state(participant_id)
548            .map(|state| &state.kind)
549    }
550
551    pub fn is_room_owner(&self, participant_id: ParticipantId) -> bool {
552        let user_id = self
553            .participants
554            .all_unfiltered
555            .get(&participant_id)
556            .and_then(|state| state.kind.user_id());
557        let Some(user_id) = user_id else {
558            return false;
559        };
560        user_id == self.room_task_info.room.created_by.id
561    }
562
563    pub fn assets(&self) -> ModuleAssetStorage {
564        ModuleAssetStorage::new(Arc::clone(&self.assets), self.storage_context())
565    }
566
567    pub fn module_resources(&self) -> ModuleResourceStorage {
568        ModuleResourceStorage::new(Arc::clone(&self.module_resources), self.storage_context())
569    }
570
571    fn storage_context(&self) -> StorageContext {
572        StorageContext {
573            room_id: self.room_id,
574            namespace: M::NAMESPACE,
575            event: self.room_task_info.room.event.clone(),
576        }
577    }
578}
579
580pub struct ChannelDroppedError;
581
582impl<M> ModuleContext<'_, M>
583where
584    M: SignalingModule,
585    M::Loopback: From<UploadResult>,
586{
587    pub fn upload_file(&self, asset: AssetStream, metadata: AssetMetaData) {
588        let storage_context = self.storage_context();
589        let assets = Arc::clone(&self.assets);
590
591        self.spawn(async move {
592            assets
593                .upload_asset(asset, metadata, &storage_context)
594                .await
595                .into()
596        });
597    }
598}
599
600async fn handle_loopback_after<F, L>(
601    duration: Duration,
602    rx_cancel: Receiver<L>,
603    create_result: F,
604) -> L
605where
606    F: FnOnce() -> L + 'static,
607    L: From<ChannelDroppedError> + Send + Sync + 'static,
608{
609    select! {
610        result = rx_cancel => {
611            match result {
612                Ok(value) => value,
613                Err(_) => ChannelDroppedError.into(),
614            }
615        },
616        () = tokio::time::sleep(duration) => {
617            create_result()
618        }
619    }
620}