1use std::{
5 cell::RefCell, collections::HashMap, future::Future, marker::PhantomData, sync::Arc,
6 time::Duration,
7};
8
9use anyhow::Context as _;
10use futures::stream::FuturesUnordered;
11use opentalk_roomserver_types::{
12 client_parameters::{ClientKind, Role},
13 connection_id::ConnectionId,
14 error::{self, SignalingError},
15 room_kind::RoomKind,
16 shared_raw_json::SharedRawJson,
17 signaling::module_error::FatalError,
18};
19use opentalk_types_common::{rooms::RoomId, time::Timestamp, users::UserId};
20use opentalk_types_signaling::ParticipantId;
21use serde_json::value::RawValue;
22use tokio::{
23 select,
24 sync::oneshot::{self, Receiver, Sender},
25};
26use tracing::{Instrument as _, debug_span};
27
28use crate::{
29 banned_participant::BannedParticipant,
30 event_origin::EventOrigin,
31 instruction::Instruction,
32 internal_module_message::InterModuleMessage,
33 loopback::{LoopbackFuture, LoopbackMessage},
34 participant_state::{ParticipantState, Participants},
35 room_info::RoomTaskInfo,
36 signaling_event::SignalingEvent,
37 signaling_module::SignalingModule,
38 storage::{
39 StorageContext,
40 assets::{
41 AssetMetaData, ModuleAssetStorage, UploadResult,
42 provider::{AssetStorageProvider, AssetStream},
43 },
44 module_resources::{ModuleResourceStorage, provider::ModuleResourceProvider},
45 },
46 waiting_participant::WaitingParticipant,
47};
48
49#[derive(Debug)]
50pub enum ModuleMessage {
51 Websocket {
52 connection_id: ConnectionId,
53 message: SharedRawJson,
54 },
55 WaitingRoomWebsocket {
56 connection_id: ConnectionId,
57 message: SharedRawJson,
58 },
59 InternalCommand(InterModuleMessage),
60 Instruction(Instruction),
61}
62
63pub struct ModuleContext<'ctx, M>
65where
66 M: SignalingModule,
67{
68 pub room_id: RoomId,
69 pub room: RoomKind,
70 pub event_origin: EventOrigin,
71 pub room_task_info: &'ctx mut RoomTaskInfo,
72 messages: &'ctx mut RefCell<Vec<ModuleMessage>>,
74 pub participants: &'ctx mut Participants,
76 pub waiting_participants: &'ctx mut HashMap<ParticipantId, WaitingParticipant>,
77 pub banned_participants: &'ctx mut HashMap<ParticipantId, BannedParticipant>,
78 pub timestamp: Timestamp,
79 loopback_futures: &'ctx mut FuturesUnordered<LoopbackFuture>,
80 assets: Arc<dyn AssetStorageProvider>,
81 module_resources: Arc<dyn ModuleResourceProvider>,
82
83 m: PhantomData<fn() -> M>,
84}
85
86impl<'ctx, M> ModuleContext<'ctx, M>
87where
88 M: SignalingModule,
89{
90 #[allow(clippy::too_many_arguments)]
91 pub fn new(
92 room_id: RoomId,
93 room: RoomKind,
94 event_origin: EventOrigin,
95 room_task_info: &'ctx mut RoomTaskInfo,
96 messages: &'ctx mut RefCell<Vec<ModuleMessage>>,
97 participants: &'ctx mut Participants,
98 waiting_participants: &'ctx mut HashMap<ParticipantId, WaitingParticipant>,
99 banned_participants: &'ctx mut HashMap<ParticipantId, BannedParticipant>,
100 timestamp: Timestamp,
101 loopback_futures: &'ctx mut FuturesUnordered<LoopbackFuture>,
102 assets: Arc<dyn AssetStorageProvider>,
103 module_resources: Arc<dyn ModuleResourceProvider>,
104 ) -> ModuleContext<'ctx, M> {
105 Self {
106 room_id,
107 room,
108 event_origin,
109 room_task_info,
110 messages,
111 participants,
112 waiting_participants,
113 banned_participants,
114 timestamp,
115 loopback_futures,
116 assets,
117 module_resources,
118 m: PhantomData,
119 }
120 }
121
122 pub fn reborrow<M2: SignalingModule>(&mut self) -> ModuleContext<'_, M2> {
123 ModuleContext {
124 room_id: self.room_id,
125 room: self.room,
126 event_origin: self.event_origin,
127 room_task_info: self.room_task_info,
128 messages: self.messages,
129 participants: self.participants,
130 waiting_participants: self.waiting_participants,
131 banned_participants: self.banned_participants,
132 timestamp: self.timestamp,
133 loopback_futures: self.loopback_futures,
134 assets: Arc::clone(&self.assets),
135 module_resources: Arc::clone(&self.module_resources),
136 m: PhantomData,
137 }
138 }
139
140 pub fn send_ws_message(
149 &self,
150 participant_ids: impl IntoIterator<Item = ParticipantId>,
151 msg: M::Outgoing,
152 ) -> Result<(), FatalError> {
153 let event = SignalingEvent {
154 namespace: M::NAMESPACE,
155 transaction_id: self.event_origin.transaction_id(),
156 timestamp: Timestamp::now(),
157 payload: msg,
158 };
159 let shared_json: SharedRawJson = serde_json::value::to_raw_value(&event)
160 .context("Failed to serialize internal websocket payload type")
161 .map_err(FatalError)?
162 .into();
163
164 for participant_id in participant_ids {
165 let Some(state) = self.participants.connected().get(&participant_id) else {
166 tracing::error!(
167 "Module '{}' attempted to send a websocket message to unknown participant {participant_id}",
168 M::NAMESPACE
169 );
170 return Ok(());
171 };
172 let mut messages = self.messages.borrow_mut();
173
174 for (connection_id, ..) in &state.connections {
175 messages.push(ModuleMessage::Websocket {
176 connection_id: *connection_id,
177 message: shared_json.clone(),
178 });
179 }
180 }
181
182 Ok(())
183 }
184
185 pub fn send_ws_message_to_connections(
193 &self,
194 connection_ids: impl IntoIterator<Item = ConnectionId>,
195 msg: M::Outgoing,
196 ) -> Result<(), FatalError> {
197 let event = SignalingEvent {
198 namespace: M::NAMESPACE,
199 transaction_id: self.event_origin.transaction_id(),
200 timestamp: Timestamp::now(),
201 payload: msg,
202 };
203 let shared_json: SharedRawJson = serde_json::value::to_raw_value(&event)
204 .context("Failed to serialize internal websocket payload type")
205 .map_err(FatalError)?
206 .into();
207
208 let mut messages = self.messages.borrow_mut();
209 for connection_id in connection_ids {
210 messages.push(ModuleMessage::Websocket {
211 connection_id,
212 message: shared_json.clone(),
213 });
214 }
215
216 Ok(())
217 }
218
219 pub fn send_ws_message_to_waiting_room(
220 &self,
221 participant_ids: impl IntoIterator<Item = ParticipantId>,
222 msg: M::Outgoing,
223 ) -> Result<(), FatalError> {
224 let event = SignalingEvent {
225 namespace: M::NAMESPACE,
226 transaction_id: self.event_origin.transaction_id(),
227 timestamp: Timestamp::now(),
228 payload: msg,
229 };
230 let shared_json: SharedRawJson = serde_json::value::to_raw_value(&event)
231 .context("Failed to serialize internal websocket payload type")
232 .map_err(FatalError)?
233 .into();
234
235 for participant_id in participant_ids {
236 let Some(waiting_participant) = self.waiting_participants.get(&participant_id) else {
237 tracing::error!(
238 "Module '{}' attempted to send a websocket message to unknown participant {participant_id}",
239 M::NAMESPACE
240 );
241 return Ok(());
242 };
243 let mut messages = self.messages.borrow_mut();
244
245 for (connection_id, ..) in &waiting_participant.connections {
246 messages.push(ModuleMessage::WaitingRoomWebsocket {
247 connection_id: *connection_id,
248 message: shared_json.clone(),
249 });
250 }
251 }
252
253 Ok(())
254 }
255
256 pub fn send_replica(
265 &self,
266 sender: ParticipantId,
267 source_connection: ConnectionId,
268 replication_event: M::Outgoing,
269 ) -> Result<(), FatalError> {
270 let event = SignalingEvent {
271 namespace: M::NAMESPACE,
272 transaction_id: self.event_origin.transaction_id(),
273 timestamp: Timestamp::now(),
274 payload: replication_event,
275 };
276
277 let shared_json: SharedRawJson = serde_json::value::to_raw_value(&event)
278 .context("Failed to serialize internal websocket payload type")
279 .map_err(FatalError)?
280 .into();
281
282 let Some(state) = self.participants.connected().get(&sender) else {
283 tracing::error!(
284 "Module '{}' attempted to replicate a command to unknown participant {sender}",
285 M::NAMESPACE
286 );
287 return Ok(());
288 };
289 let mut messages = self.messages.borrow_mut();
290
291 for connection_id in state.connections.keys().copied() {
292 if connection_id != source_connection {
293 messages.push(ModuleMessage::Websocket {
294 connection_id,
295 message: shared_json.clone(),
296 });
297 }
298 }
299
300 Ok(())
301 }
302
303 pub fn send_internal_command<R>(&mut self, command: R::Internal)
308 where
309 R: SignalingModule,
310 {
311 let command = InterModuleMessage {
312 sender: M::NAMESPACE,
313 receiver: R::NAMESPACE,
314 command: Box::new(command),
315 };
316 self.messages
317 .get_mut()
318 .push(ModuleMessage::InternalCommand(command));
319 }
320
321 pub fn kick_participants(&mut self, participants: Vec<ParticipantId>) {
323 let command = ModuleMessage::Instruction(Instruction::Kick { participants });
324 self.messages.get_mut().push(command);
325 }
326
327 pub fn ban_participant(&mut self, participant: ParticipantId) {
328 let command = ModuleMessage::Instruction(Instruction::Ban { participant });
329 self.messages.get_mut().push(command);
330 }
331
332 pub fn ban_waiting_participant(&mut self, participant: ParticipantId) {
333 let command: ModuleMessage =
334 ModuleMessage::Instruction(Instruction::BanWaiting { participant });
335 self.messages.get_mut().push(command);
336 }
337
338 pub fn move_to_waiting_room(&mut self, participant: ParticipantId) {
340 let command = ModuleMessage::Instruction(Instruction::MoveToWaitingRoom { participant });
341 self.messages.get_mut().push(command);
342 }
343
344 pub fn handle_error(&self, error: SignalingError) {
350 let participant_id = match self.event_origin {
351 EventOrigin::Participant(participant_origin) => participant_origin.id,
352 EventOrigin::Internal => {
353 tracing::error!(
354 "Signaling module '{}' returned an error on an event with internal origin: {error:?} ",
355 M::NAMESPACE
356 );
357 return;
358 }
359 };
360
361 let event = SignalingEvent {
362 namespace: error::ERROR_MODULE_ID,
363 transaction_id: self.event_origin.transaction_id(),
364 timestamp: Timestamp::now(),
365 payload: error,
366 };
367
368 let shared_json: SharedRawJson = match serde_json::value::to_raw_value(&event) {
369 Ok(value) => value.into(),
370 Err(err) => {
371 tracing::error!("Failed to serialize SignalingError type: {err}");
372 RawValue::from_string(r#"{"error": "internal"}"#.into())
373 .unwrap()
374 .into()
375 }
376 };
377
378 let mut messages = self.messages.borrow_mut();
379 if let Some(state) = self.participants.connected().get(&participant_id) {
380 for connection_id in state.connections() {
381 messages.push(ModuleMessage::Websocket {
382 connection_id,
383 message: shared_json.clone(),
384 });
385 }
386 } else if let Some(waiting_participant) = self.waiting_participants.get(&participant_id) {
387 let connections = waiting_participant.connections.keys();
388 for &connection_id in connections {
389 messages.push(ModuleMessage::WaitingRoomWebsocket {
390 connection_id,
391 message: shared_json.clone(),
392 });
393 }
394 } else {
395 tracing::error!(
396 "Module '{}' attempted to send a websocket error message to unknown participant {}",
397 M::NAMESPACE,
398 participant_id,
399 );
400 }
401 }
402
403 pub fn spawn<F>(&self, future: F)
409 where
410 F: Future<Output = M::Loopback> + Send + 'static,
411 {
412 let origin = self.event_origin;
413 let room = self.room;
414 let timestamp = self.timestamp;
415 let span = debug_span!("spawn");
416
417 let future = future.instrument(span.clone());
418 let future = Box::pin(async move {
419 Some(LoopbackMessage {
420 namespace: M::NAMESPACE,
421 origin,
422 room,
423 timestamp,
424 span,
425 value: Box::new(future.await),
426 })
427 });
428
429 self.loopback_futures.push(future);
430 }
431
432 pub fn spawn_optional<F>(&self, future: F)
438 where
439 F: Future<Output = Option<M::Loopback>> + Send + 'static,
440 {
441 let origin = self.event_origin;
442 let room = self.room;
443 let timestamp = self.timestamp;
444 let span = debug_span!("spawn_optional");
445
446 let future = future.instrument(span.clone());
447 let future = Box::pin(async move {
448 future.await.map(|value| LoopbackMessage {
449 namespace: M::NAMESPACE,
450 origin,
451 room,
452 timestamp,
453 span,
454 value: Box::new(value),
455 })
456 });
457
458 self.loopback_futures.push(future);
459 }
460
461 pub fn spawn_blocking<F>(&self, blocking_function: F)
468 where
469 F: FnOnce() -> M::Loopback + Send + 'static,
470 {
471 let span = debug_span!("spawn_blocking");
472 let origin = self.event_origin;
473 let room = self.room;
474 let join_handle = {
475 let span = span.clone();
476 tokio::task::spawn_blocking(move || span.in_scope(blocking_function))
477 };
478 let timestamp = self.timestamp;
479
480 let future = Box::pin(async move {
481 let Ok(value) = join_handle.await else {
482 tracing::error!("module {} panicked in loopback task", M::NAMESPACE);
483 return None;
484 };
485
486 Some(LoopbackMessage {
487 namespace: M::NAMESPACE,
488 origin,
489 room,
490 timestamp,
491 span,
492 value: Box::new(value),
493 })
494 });
495
496 self.loopback_futures.push(future);
497 }
498
499 pub fn loopback_after<F>(&self, duration: Duration, create_result: F) -> Sender<M::Loopback>
505 where
506 M::Loopback: From<ChannelDroppedError> + Send + Sync + 'static,
507 F: FnOnce() -> M::Loopback + Send + Sync + 'static,
508 {
509 let (tx_cancel, rx_cancel) = oneshot::channel();
510 self.spawn(handle_loopback_after(duration, rx_cancel, create_result));
511 tx_cancel
512 }
513
514 pub fn recv_loopback<F, R>(&self, receiver: oneshot::Receiver<R>, create_result: F)
515 where
516 M::Loopback: From<ChannelDroppedError> + Send + Sync + 'static,
517 F: FnOnce(R) -> M::Loopback + Send + Sync + 'static,
518 R: Send + Sync + 'static,
519 {
520 self.spawn(async move {
521 match receiver.await {
522 Ok(result) => create_result(result),
523 Err(_) => ChannelDroppedError.into(),
524 }
525 });
526 }
527
528 pub fn participant_state(&self, participant_id: ParticipantId) -> Option<&ParticipantState> {
529 self.participants.all_unfiltered.get(&participant_id)
530 }
531
532 pub fn participant_role(&self, participant_id: ParticipantId) -> Option<Role> {
533 self.participant_state(participant_id).map(|p| p.role)
534 }
535
536 pub fn user_id(&self, participant_id: ParticipantId) -> Option<UserId> {
537 self.participant_state(participant_id)
538 .and_then(|state| state.kind.user_id())
539 }
540
541 pub fn is_moderator(&self, participant_id: ParticipantId) -> bool {
542 self.participant_role(participant_id)
543 .is_some_and(|r| r == Role::Moderator)
544 }
545
546 pub fn get_client_kind(&self, participant_id: ParticipantId) -> Option<&ClientKind> {
547 self.participant_state(participant_id)
548 .map(|state| &state.kind)
549 }
550
551 pub fn is_room_owner(&self, participant_id: ParticipantId) -> bool {
552 let user_id = self
553 .participants
554 .all_unfiltered
555 .get(&participant_id)
556 .and_then(|state| state.kind.user_id());
557 let Some(user_id) = user_id else {
558 return false;
559 };
560 user_id == self.room_task_info.room.created_by.id
561 }
562
563 pub fn assets(&self) -> ModuleAssetStorage {
564 ModuleAssetStorage::new(Arc::clone(&self.assets), self.storage_context())
565 }
566
567 pub fn module_resources(&self) -> ModuleResourceStorage {
568 ModuleResourceStorage::new(Arc::clone(&self.module_resources), self.storage_context())
569 }
570
571 fn storage_context(&self) -> StorageContext {
572 StorageContext {
573 room_id: self.room_id,
574 namespace: M::NAMESPACE,
575 event: self.room_task_info.room.event.clone(),
576 }
577 }
578}
579
580pub struct ChannelDroppedError;
581
582impl<M> ModuleContext<'_, M>
583where
584 M: SignalingModule,
585 M::Loopback: From<UploadResult>,
586{
587 pub fn upload_file(&self, asset: AssetStream, metadata: AssetMetaData) {
588 let storage_context = self.storage_context();
589 let assets = Arc::clone(&self.assets);
590
591 self.spawn(async move {
592 assets
593 .upload_asset(asset, metadata, &storage_context)
594 .await
595 .into()
596 });
597 }
598}
599
600async fn handle_loopback_after<F, L>(
601 duration: Duration,
602 rx_cancel: Receiver<L>,
603 create_result: F,
604) -> L
605where
606 F: FnOnce() -> L + 'static,
607 L: From<ChannelDroppedError> + Send + Sync + 'static,
608{
609 select! {
610 result = rx_cancel => {
611 match result {
612 Ok(value) => value,
613 Err(_) => ChannelDroppedError.into(),
614 }
615 },
616 () = tokio::time::sleep(duration) => {
617 create_result()
618 }
619 }
620}