1use std::time::Duration;
2use std::{collections::VecDeque, sync::Arc};
3
4use axum::extract::ws::{Message, WebSocket};
5use common::v1::types::emoji::EmojiOwner;
6use common::v1::types::user_status::Status;
7use common::v1::types::util::Time;
8use common::v1::types::voice::{SfuCommand, SfuPermissions, SignallingMessage, VoiceState};
9use common::v1::types::{self, SERVER_ROOM_ID};
10use common::v1::types::{
11 ChannelId, InviteTarget, InviteTargetId, MemberListGroup, MemberListGroupId, MemberListOp,
12 MessageClient, MessageEnvelope, MessageSync, Permission, RoomId, Session, UserId,
13};
14use tokio::time::Instant;
15use tracing::{debug, error, trace};
16
17use crate::error::{Error, Result};
18use crate::ServerState;
19
20type WsMessage = axum::extract::ws::Message;
21
22pub const HEARTBEAT_TIME: Duration = Duration::from_secs(30);
23pub const CLOSE_TIME: Duration = Duration::from_secs(10);
24const MAX_QUEUE_LEN: usize = 256;
25
26pub enum Timeout {
27 Ping(Instant),
28 Close(Instant),
29}
30
31pub struct Connection {
32 state: ConnectionState,
33 s: Arc<ServerState>,
34 queue: VecDeque<(Option<u64>, MessageEnvelope)>,
35 seq_server: u64,
36 seq_client: u64,
37 id: String,
38 member_list_sub: Option<MemberListSub>,
39 member_list_cache: Vec<(
40 Option<types::RoomMember>,
41 Option<types::ThreadMember>,
42 types::User,
43 )>,
44}
45
46#[derive(Debug, Clone)]
47struct MemberListSub {
48 target: MemberListTarget,
49 ranges: Vec<(u64, u64)>,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53enum MemberListTarget {
54 Room(RoomId),
55 Channel(ChannelId),
56}
57
58#[derive(Debug, Clone)]
59enum ConnectionState {
60 Unauthed,
61 Authenticated { session: Session },
62 Disconnected { session: Session },
63}
64
65#[derive(Debug)]
66enum AuthCheck {
67 Custom(bool),
68 Room(RoomId),
69 RoomPerm(RoomId, Permission),
70 RoomOrUser(RoomId, UserId),
71 ChannelOrUser(ChannelId, UserId),
72 User(UserId),
73 UserMutual(UserId),
74 Channel(ChannelId),
75 EitherChannel(ChannelId, ChannelId),
76}
77
78impl Connection {
79 pub fn new(s: Arc<ServerState>) -> Self {
80 Self {
81 state: ConnectionState::Unauthed,
82 queue: VecDeque::new(),
83 seq_server: 0,
84 seq_client: 0,
85 id: format!("{}", uuid::Uuid::new_v4().hyphenated()),
86 member_list_sub: None,
87 member_list_cache: Vec::new(),
88 s,
89 }
90 }
91
92 pub fn disconnect(&mut self) {
93 self.state = match &self.state {
95 ConnectionState::Authenticated { session } => ConnectionState::Disconnected {
96 session: session.clone(),
97 },
98 s => s.to_owned(),
99 };
100 }
101
102 pub fn rewind(&mut self, seq: u64) -> Result<()> {
103 let is_still_valid = self
104 .queue
105 .iter()
106 .any(|(seq, _)| seq.is_some_and(|s| s <= self.seq_client));
107 if is_still_valid {
108 self.seq_client = seq;
109 Ok(())
110 } else {
111 Err(Error::BadStatic("too old"))
112 }
113 }
114
115 pub async fn handle_message(
116 &mut self,
117 ws_msg: Message,
118 ws: &mut WebSocket,
119 timeout: &mut Timeout,
120 ) -> Result<()> {
121 match ws_msg {
122 Message::Text(utf8_bytes) => {
123 let msg: MessageClient = serde_json::from_str(&utf8_bytes)?;
124 self.handle_message_client(msg, ws, timeout).await
125 }
126 Message::Binary(_) => Err(Error::BadStatic("doesn't support binary sorry")),
127 _ => Ok(()),
128 }
129 }
130
131 #[tracing::instrument(level = "debug", skip(self, ws, timeout), fields(id = self.get_id()))]
132 pub async fn handle_message_client(
133 &mut self,
134 msg: MessageClient,
135 ws: &mut WebSocket,
136 timeout: &mut Timeout,
137 ) -> Result<()> {
138 trace!("{:#?}", msg);
139 match msg {
140 MessageClient::Hello {
141 token,
142 resume: reconnect,
143 status,
144 } => {
145 let srv = self.s.services();
146 let session = srv
147 .sessions
148 .get_by_token(token)
149 .await
150 .map_err(|err| match err {
151 Error::NotFound => Error::MissingAuth,
152 other => other,
153 })?;
154
155 if let Some(r) = reconnect {
157 debug!("attempting to resume");
158 if let Some((_, mut conn)) = self.s.syncers.remove(&r.conn) {
159 debug!("resume conn exists");
160 if let Some(recon_session) = conn.state.session() {
161 debug!("resume session exists");
162 if session.id == recon_session.id {
163 debug!("session id matches, resuming");
164 conn.rewind(r.seq)?;
165 conn.push(
166 MessageEnvelope {
167 payload: types::MessagePayload::Resumed,
168 },
169 None,
170 );
171 std::mem::swap(self, &mut conn);
172 return Ok(());
173 }
174 }
175 }
176 return Err(Error::BadStatic("bad or expired reconnection info"));
177 }
178
179 let user = if let Some(user_id) = session.user_id() {
180 let srv = self.s.services();
181 let mut user = srv.users.get(user_id, Some(user_id)).await?;
182 if user.is_suspended() {
183 Some(user)
184 } else {
185 let user_with_new_status = srv
186 .users
187 .status_set(
188 user_id,
189 status
190 .map(|s| s.apply(Status::offline()))
191 .unwrap_or(Status::online()),
192 )
193 .await?;
194 user.status = user_with_new_status.status;
195 Some(user)
196 }
197 } else {
198 None
199 };
200
201 let msg = MessageEnvelope {
202 payload: types::MessagePayload::Ready {
203 user: Box::new(user),
204 session: session.clone(),
205 conn: self.get_id().to_owned(),
206 seq: 0,
207 },
208 };
209
210 ws.send(WsMessage::text(serde_json::to_string(&msg)?))
211 .await?;
212
213 self.seq_server += 1;
214
215 if let Some(user_id) = session.user_id() {
216 let typing_states = srv.channels.typing_list();
218 for (channel_id, typing_user_id, until) in typing_states {
219 if let Ok(perms) = srv.perms.for_channel(user_id, channel_id).await {
220 if perms.has(Permission::ViewChannel) {
221 self.push_sync(MessageSync::ChannelTyping {
222 channel_id,
223 user_id: typing_user_id,
224 until: until.into(),
225 });
226 }
227 }
228 }
229
230 let voice_states = srv.users.voice_states_list();
232 for voice_state in voice_states {
233 if let Ok(perms) =
234 srv.perms.for_channel(user_id, voice_state.thread_id).await
235 {
236 let is_ours = self.state.session().and_then(|s| s.user_id())
237 == Some(voice_state.user_id);
238 if perms.has(Permission::ViewChannel) || is_ours {
239 let mut voice_state = voice_state.clone();
240 if !is_ours {
241 voice_state.session_id = None;
242 }
243 self.push_sync(MessageSync::VoiceState {
244 user_id: voice_state.user_id,
245 state: Some(voice_state),
246 old_state: None,
247 });
248 }
249 }
250 }
251 }
252
253 self.state = ConnectionState::Authenticated { session };
254 }
255 MessageClient::Status { status } => {
256 let session = match &self.state {
257 ConnectionState::Unauthed => return Err(Error::MissingAuth),
258 ConnectionState::Authenticated { session } => session,
259 ConnectionState::Disconnected { .. } => {
260 panic!("somehow recv msg while disconnected?")
261 }
262 };
263 let srv = self.s.services();
264 let user_id = session.user_id().ok_or(Error::UnauthSession)?;
265 let user = srv.users.get(user_id, None).await?;
266 user.ensure_unsuspended()?;
267 srv.users
268 .status_set(user_id, status.apply(Status::offline()))
269 .await?;
270 }
271 MessageClient::Pong => {
272 let session = match &self.state {
273 ConnectionState::Unauthed => return Err(Error::MissingAuth),
274 ConnectionState::Authenticated { session } => session,
275 ConnectionState::Disconnected { .. } => {
276 panic!("somehow recv msg while disconnected?")
277 }
278 };
279 let srv = self.s.services();
280 let user_id = session.user_id().ok_or(Error::UnauthSession)?;
281 srv.users.status_ping(user_id).await?;
282 *timeout = Timeout::Ping(Instant::now() + HEARTBEAT_TIME);
283 }
284 MessageClient::MemberListSubscribe {
285 room_id,
286 thread_id,
287 ranges,
288 } => {
289 let session = self.state.session().ok_or(Error::MissingAuth)?;
290 let user_id = session.user_id().ok_or(Error::UnauthSession)?;
291 let srv = self.s.services();
292
293 let target = if let Some(room_id) = room_id {
294 let _perms = srv.perms.for_room(user_id, room_id).await?;
295 MemberListTarget::Room(room_id)
296 } else if let Some(thread_id) = thread_id {
297 let perms = srv.perms.for_channel(user_id, thread_id).await?;
298 perms.ensure(Permission::ViewChannel)?;
299 MemberListTarget::Channel(thread_id)
300 } else {
301 return Err(Error::BadStatic("room_id or thread_id must be provided"));
302 };
303
304 if self.member_list_sub.as_ref().map(|s| &s.target) != Some(&target) {
305 self.member_list_cache.clear();
306 }
307
308 self.member_list_sub = Some(MemberListSub {
309 target: target.clone(),
310 ranges: ranges.clone(),
311 });
312
313 self.resync_member_list().await?;
314 }
315 MessageClient::VoiceDispatch {
316 user_id: _,
317 payload,
318 } => {
319 let Some(session) = self.state.session() else {
320 return Err(Error::BadStatic("no session"));
321 };
322 let Some(user_id) = session.user_id() else {
323 return Err(Error::BadStatic("no user"));
324 };
325
326 let srv = self.s.services();
327 let user = srv.users.get(user_id, Some(user_id)).await?;
328 user.ensure_unsuspended()?;
329
330 match &payload {
331 SignallingMessage::VoiceState { state: Some(state) } => {
332 let perms = srv.perms.for_channel(user_id, state.thread_id).await?;
333 perms.ensure(Permission::ViewChannel)?;
334 perms.ensure(Permission::VoiceConnect)?;
335 let thread = srv.channels.get(state.thread_id, Some(user_id)).await?;
336 if thread.archived_at.is_some() {
337 return Err(Error::BadStatic("thread is archived"));
338 }
339 if thread.deleted_at.is_some() {
340 return Err(Error::BadStatic("thread is removed"));
341 }
342 if thread.locked {
343 perms.ensure(Permission::ThreadLock)?;
344 }
345 let mut state = VoiceState {
346 user_id,
347 thread_id: state.thread_id,
348 session_id: Some(session.id),
349 joined_at: Time::now_utc(),
350 mute: false,
351 deaf: false,
352 self_deaf: state.self_deaf,
353 self_mute: state.self_mute,
354 self_video: state.self_video,
355 self_screen: state.self_screen,
356 };
357 if let Some(room_id) = thread.room_id {
358 let rm = self.s.data().room_member_get(room_id, user_id).await?;
359 state.mute = rm.mute;
360 state.deaf = rm.deaf;
361 }
362 self.s.alloc_sfu(state.thread_id).await?;
363 if let Err(err) = self.s.sushi_sfu.send(SfuCommand::VoiceState {
364 user_id,
365 state: Some(state),
366 permissions: SfuPermissions {
367 speak: perms.has(Permission::VoiceSpeak),
368 video: perms.has(Permission::VoiceVideo),
369 priority: perms.has(Permission::VoicePriority),
370 },
371 }) {
372 error!("failed to send to sushi_sfu: {err}");
373 }
374 return Ok(());
375 }
376 SignallingMessage::VoiceState { state: None } => {
377 if let Err(err) = self.s.sushi_sfu.send(SfuCommand::VoiceState {
378 user_id,
379 state: None,
380 permissions: SfuPermissions {
381 speak: false,
382 video: false,
383 priority: false,
384 },
385 }) {
386 error!("failed to send to sushi_sfu: {err}");
387 }
388 return Ok(());
389 }
390 SignallingMessage::Offer { .. } => {
391 }
400 _ => {}
401 }
402
403 if let Err(err) = self.s.sushi_sfu.send(SfuCommand::Signalling {
404 user_id,
405 inner: payload,
406 }) {
407 error!("failed to send to sushi_sfu: {err}");
408 }
409 }
410 }
411 Ok(())
412 }
413
414 #[tracing::instrument(level = "debug", skip(self), fields(id = self.get_id()))]
415 pub async fn queue_message(&mut self, msg: Box<MessageSync>) -> Result<()> {
416 let mut session = match &self.state {
417 ConnectionState::Authenticated { session }
418 | ConnectionState::Disconnected { session } => session.clone(),
419 _ => return Ok(()),
420 };
421
422 match &self.state {
423 ConnectionState::Disconnected { .. }
424 if self.seq_server > self.seq_client + MAX_QUEUE_LEN as u64 =>
425 {
426 self.s.syncers.remove(&self.id);
427 return Err(Error::BadStatic("expired session"));
428 }
429 _ => {}
430 }
431
432 let auth_check = match &*msg {
433 MessageSync::RoomCreate { room } => AuthCheck::Room(room.id),
434 MessageSync::RoomUpdate { room } => AuthCheck::Room(room.id),
435 MessageSync::RoomDelete { room_id } => AuthCheck::Room(*room_id),
436 MessageSync::ChannelCreate { channel } => AuthCheck::Channel(channel.id),
437 MessageSync::ChannelUpdate { channel } => AuthCheck::Channel(channel.id),
438 MessageSync::MessageCreate { message } => AuthCheck::Channel(message.channel_id),
439 MessageSync::MessageUpdate { message } => AuthCheck::Channel(message.channel_id),
440 MessageSync::UserCreate { user } => AuthCheck::UserMutual(user.id),
441 MessageSync::UserUpdate { user } => {
442 if self.member_list_sub.is_some() {
443 if let Some((_, _, old_user)) = self
444 .member_list_cache
445 .iter()
446 .find(|(_, _, u)| u.id == user.id)
447 {
448 let old_online = old_user.status.status.is_online();
449 let new_online = user.status.status.is_online();
450
451 if old_online != new_online {
452 self.diff_sync_member_list().await?;
453 }
454 }
455 }
456 AuthCheck::UserMutual(user.id)
457 }
458 MessageSync::UserConfigGlobal { user_id, .. } => AuthCheck::User(*user_id),
459 MessageSync::UserConfigRoom { user_id, .. } => AuthCheck::User(*user_id),
460 MessageSync::UserConfigChannel { user_id, .. } => AuthCheck::User(*user_id),
461 MessageSync::UserConfigUser { user_id, .. } => AuthCheck::User(*user_id),
462 MessageSync::RoomMemberUpsert { member } => {
463 if self
464 .member_list_sub
465 .as_ref()
466 .is_some_and(|s| s.target == MemberListTarget::Room(member.room_id))
467 {
468 self.diff_sync_member_list().await?;
469 }
470 AuthCheck::RoomOrUser(member.room_id, member.user_id)
471 }
472 MessageSync::ThreadMemberUpsert { member } => {
473 if self
474 .member_list_sub
475 .as_ref()
476 .is_some_and(|s| s.target == MemberListTarget::Channel(member.thread_id))
477 {
478 self.diff_sync_member_list().await?;
479 }
480 AuthCheck::ChannelOrUser(member.thread_id, member.user_id)
481 }
482 MessageSync::SessionCreate {
483 session: upserted_session,
484 } => {
485 if session.id == upserted_session.id {
486 session = upserted_session.to_owned();
487 self.state = ConnectionState::Authenticated {
488 session: upserted_session.to_owned(),
489 };
490 }
491 AuthCheck::Custom(session.can_see(upserted_session))
492 }
493 MessageSync::SessionUpdate {
494 session: upserted_session,
495 } => {
496 if session.id == upserted_session.id {
497 session = upserted_session.to_owned();
498 self.state = ConnectionState::Authenticated {
499 session: upserted_session.to_owned(),
500 };
501 }
502 AuthCheck::Custom(session.can_see(upserted_session))
503 }
504 MessageSync::RoleCreate { role } => AuthCheck::Room(role.room_id),
505 MessageSync::RoleUpdate { role } => AuthCheck::Room(role.room_id),
506 MessageSync::InviteCreate { invite } => match &invite.invite.target {
508 InviteTarget::Room { room, channel: _ } => AuthCheck::Room(room.id),
509 InviteTarget::Gdm { channel, .. } => AuthCheck::Channel(channel.id),
510 InviteTarget::Server => {
511 AuthCheck::RoomPerm(SERVER_ROOM_ID, Permission::ServerOversee)
512 }
513 InviteTarget::User { user, .. } => AuthCheck::User(user.id),
514 },
515 MessageSync::InviteUpdate { invite } => match &invite.invite.target {
516 InviteTarget::Room { room, .. } => AuthCheck::Room(room.id),
517 InviteTarget::Gdm { channel, .. } => AuthCheck::Channel(channel.id),
518 InviteTarget::Server => {
519 AuthCheck::RoomPerm(SERVER_ROOM_ID, Permission::ServerOversee)
520 }
521 InviteTarget::User { user, .. } => AuthCheck::User(user.id),
522 },
523 MessageSync::MessageDelete { channel_id, .. } => AuthCheck::Channel(*channel_id),
524 MessageSync::MessageVersionDelete { channel_id, .. } => AuthCheck::Channel(*channel_id),
525 MessageSync::UserDelete { id } => AuthCheck::UserMutual(*id),
526 MessageSync::SessionDelete { id, user_id } => {
527 if *id == session.id {
529 self.state = ConnectionState::Unauthed;
530 AuthCheck::Custom(true)
531 } else if let Some(user_id) = user_id {
532 AuthCheck::User(*user_id)
533 } else {
534 AuthCheck::Custom(false)
535 }
536 }
537 MessageSync::RoleDelete { room_id, .. } => AuthCheck::Room(*room_id),
538 MessageSync::RoleReorder { room_id, .. } => AuthCheck::Room(*room_id),
539 MessageSync::InviteDelete { target, .. } => match target {
540 InviteTargetId::Room { room_id, .. } => AuthCheck::Room(*room_id),
541 InviteTargetId::Gdm { channel_id, .. } => AuthCheck::Channel(*channel_id),
542 InviteTargetId::Server => {
543 AuthCheck::RoomPerm(SERVER_ROOM_ID, Permission::ServerOversee)
544 }
545 InviteTargetId::User { user_id, .. } => AuthCheck::User(*user_id),
546 },
547 MessageSync::ChannelTyping { channel_id, .. } => AuthCheck::Channel(*channel_id),
548 MessageSync::ChannelAck { user_id, .. } => AuthCheck::User(*user_id),
549 MessageSync::RelationshipUpsert { user_id, .. } => AuthCheck::User(*user_id),
550 MessageSync::RelationshipDelete { user_id, .. } => AuthCheck::User(*user_id),
551 MessageSync::ReactionCreate { channel_id, .. } => AuthCheck::Channel(*channel_id),
552 MessageSync::ReactionDelete { channel_id, .. } => AuthCheck::Channel(*channel_id),
553 MessageSync::ReactionPurge { channel_id, .. } => AuthCheck::Channel(*channel_id),
554 MessageSync::MessageDeleteBulk { channel_id, .. } => AuthCheck::Channel(*channel_id),
555 MessageSync::MessageRemove { channel_id, .. } => AuthCheck::Channel(*channel_id),
556 MessageSync::MessageRestore { channel_id, .. } => AuthCheck::Channel(*channel_id),
557 MessageSync::VoiceDispatch { user_id, .. } => AuthCheck::User(*user_id),
558 MessageSync::VoiceState {
559 state,
560 user_id,
561 old_state,
562 } => match (state, old_state) {
563 (None, None) => AuthCheck::User(*user_id),
564 (None, Some(o)) => AuthCheck::Channel(o.thread_id),
565 (Some(s), None) => AuthCheck::Channel(s.thread_id),
566 (Some(s), Some(o)) => AuthCheck::EitherChannel(s.thread_id, o.thread_id),
567 },
568 MessageSync::EmojiCreate { emoji } => match emoji.owner {
569 EmojiOwner::Room { room_id } => AuthCheck::Room(room_id),
570 EmojiOwner::User => AuthCheck::User(emoji.creator_id),
571 },
572 MessageSync::EmojiUpdate { emoji } => match emoji.owner {
573 EmojiOwner::Room { room_id } => AuthCheck::Room(room_id),
574 EmojiOwner::User => AuthCheck::User(emoji.creator_id),
575 },
576 MessageSync::EmojiDelete {
577 room_id,
578 emoji_id: _,
579 } => AuthCheck::Room(*room_id),
580 MessageSync::ConnectionCreate { user_id, .. } => AuthCheck::User(*user_id),
581 MessageSync::ConnectionDelete { user_id, .. } => AuthCheck::User(*user_id),
582 MessageSync::AuditLogEntryCreate { entry } => {
583 AuthCheck::RoomPerm(entry.room_id, Permission::ViewAuditLog)
584 }
585 MessageSync::BanCreate { room_id, .. } => {
586 AuthCheck::RoomPerm(*room_id, Permission::MemberBan)
587 }
588 MessageSync::BanDelete { room_id, .. } => {
589 AuthCheck::RoomPerm(*room_id, Permission::MemberBan)
590 }
591 MessageSync::MemberListSync { user_id, .. } => AuthCheck::User(*user_id),
592 MessageSync::InboxNotificationCreate { user_id, .. } => AuthCheck::User(*user_id),
593 MessageSync::InboxMarkRead { user_id, .. } => AuthCheck::User(*user_id),
594 MessageSync::InboxMarkUnread { user_id, .. } => AuthCheck::User(*user_id),
595 MessageSync::InboxFlush { user_id, .. } => AuthCheck::User(*user_id),
596 MessageSync::CalendarEventCreate { event } => AuthCheck::Channel(event.channel_id),
597 MessageSync::CalendarEventUpdate { event } => AuthCheck::Channel(event.channel_id),
598 MessageSync::CalendarEventDelete { channel_id, .. } => AuthCheck::Channel(*channel_id),
599 };
600 let should_send = match (session.user_id(), auth_check) {
601 (Some(user_id), AuthCheck::Room(room_id)) => {
602 let _perms = self.s.services().perms.for_room(user_id, room_id).await?;
603 true
604 }
605 (Some(user_id), AuthCheck::RoomPerm(room_id, perm)) => {
606 let perms = self.s.services().perms.for_room(user_id, room_id).await?;
607 perms.has(perm)
608 }
609 (Some(auth_user_id), AuthCheck::RoomOrUser(room_id, target_user_id)) => {
610 if auth_user_id == target_user_id {
611 true
612 } else {
613 let _perms = self
614 .s
615 .services()
616 .perms
617 .for_room(auth_user_id, room_id)
618 .await?;
619 true
620 }
621 }
622 (Some(user_id), AuthCheck::Channel(thread_id)) => {
623 let perms = self
624 .s
625 .services()
626 .perms
627 .for_channel(user_id, thread_id)
628 .await?;
629 perms.has(Permission::ViewChannel)
630 }
631 (Some(user_id), AuthCheck::EitherChannel(thread_id_0, thread_id_1)) => {
632 let perms0 = self
633 .s
634 .services()
635 .perms
636 .for_channel(user_id, thread_id_0)
637 .await?;
638 let perms1 = self
639 .s
640 .services()
641 .perms
642 .for_channel(user_id, thread_id_1)
643 .await?;
644 perms0.has(Permission::ViewChannel) || perms1.has(Permission::ViewChannel)
645 }
646 (Some(auth_user_id), AuthCheck::ChannelOrUser(thread_id, target_user_id)) => {
647 if auth_user_id == target_user_id {
648 true
649 } else {
650 let perms = self
651 .s
652 .services()
653 .perms
654 .for_channel(auth_user_id, thread_id)
655 .await?;
656 perms.has(Permission::ViewChannel)
657 }
658 }
659 (Some(auth_user_id), AuthCheck::User(target_user_id)) => auth_user_id == target_user_id,
660 (Some(auth_user_id), AuthCheck::UserMutual(target_user_id)) => {
661 if auth_user_id == target_user_id {
662 true
663 } else {
664 self.s
665 .services()
666 .perms
667 .is_mutual(auth_user_id, target_user_id)
668 .await?
669 }
670 }
671 (_, AuthCheck::Custom(b)) => b,
672 (None, _) => false,
673 };
674 if should_send {
675 let d = self.s.data();
676 let srv = self.s.services();
677 let msg = match *msg {
678 MessageSync::ChannelCreate { channel } => MessageSync::ChannelCreate {
679 channel: Box::new(srv.channels.get(channel.id, session.user_id()).await?),
680 },
681 MessageSync::ChannelUpdate { channel } => MessageSync::ChannelUpdate {
682 channel: Box::new(srv.channels.get(channel.id, session.user_id()).await?),
683 },
684 MessageSync::MessageCreate { message } => MessageSync::MessageCreate {
685 message: {
686 let mut m = d
687 .message_get(message.channel_id, message.id, session.user_id().unwrap())
688 .await?;
689 self.s.presign_message(&mut m).await?;
690 m.nonce = message.nonce;
691 m
692 },
693 },
694 MessageSync::MessageUpdate { message } => MessageSync::MessageUpdate {
695 message: {
696 let mut m = d
697 .message_get(message.channel_id, message.id, session.user_id().unwrap())
698 .await?;
699 self.s.presign_message(&mut m).await?;
700 m.nonce = message.nonce;
701 m
702 },
703 },
704 MessageSync::VoiceState {
705 user_id,
706 mut state,
707 mut old_state,
708 } => {
709 let is_ours = self.state.session().and_then(|s| s.user_id()) == Some(user_id);
711 if !is_ours {
712 if let Some(s) = &mut state {
713 s.session_id = None;
714 }
715
716 if let Some(s) = &mut old_state {
717 s.session_id = None;
718 }
719 }
720
721 if let Some(s) = &state {
723 let perms = self
724 .s
725 .services()
726 .perms
727 .for_channel(user_id, s.thread_id)
728 .await?;
729 if !perms.has(Permission::ViewChannel) {
730 state = None;
731 }
732 }
733
734 MessageSync::VoiceState {
735 user_id,
736 state,
737 old_state,
738 }
739 }
740 m => m,
741 };
742 self.push_sync(msg);
743 }
744 Ok(())
745 }
746
747 fn push_sync(&mut self, sync: MessageSync) {
748 let seq = self.seq_server;
749 let msg = MessageEnvelope {
750 payload: types::MessagePayload::Sync {
751 data: Box::new(sync),
752 seq,
753 },
754 };
755 self.push(msg, Some(seq));
756 self.seq_server += 1;
757 }
758
759 fn push(&mut self, msg: MessageEnvelope, seq: Option<u64>) {
760 self.queue.push_front((seq, msg));
761 self.queue.truncate(MAX_QUEUE_LEN);
762 }
763
764 #[tracing::instrument(level = "debug", skip(self, ws), fields(id = self.get_id()))]
765 pub async fn drain(&mut self, ws: &mut WebSocket) -> Result<()> {
766 let last_seen = self.seq_client;
767 let mut high_water_mark = last_seen;
768 for (seq, msg) in self.queue.iter().rev() {
769 if seq.is_none_or(|s| s > last_seen) {
770 let json = serde_json::to_string(&msg)?;
771 ws.send(WsMessage::text(json)).await?;
772 if let Some(seq) = *seq {
773 high_water_mark = high_water_mark.max(seq);
774 }
775 }
776 }
777 self.seq_client = high_water_mark;
778 self.queue.retain(|(seq, _)| seq.is_some());
779 Ok(())
780 }
781
782 pub fn get_id(&self) -> &str {
783 &self.id
784 }
785
786 async fn get_member_list(
787 &self,
788 ) -> Result<
789 Vec<(
790 Option<types::RoomMember>,
791 Option<types::ThreadMember>,
792 types::User,
793 )>,
794 > {
795 let sub = match &self.member_list_sub {
796 Some(sub) => sub.clone(),
797 None => return Ok(Vec::new()),
798 };
799
800 let session = self.state.session().ok_or(Error::MissingAuth)?;
801 let user_id = session.user_id().ok_or(Error::UnauthSession)?;
802 let srv = self.s.services();
803 let data = self.s.data();
804
805 let (room_members, thread_members, users) = match &sub.target {
806 MemberListTarget::Room(room_id) => {
807 let members = data.room_member_list_all(*room_id).await?;
808 let user_ids: Vec<_> = members.iter().map(|m| m.user_id).collect();
809 let users = futures::future::try_join_all(
810 user_ids
811 .into_iter()
812 .map(|id| srv.users.get(id, Some(user_id))),
813 )
814 .await?;
815 (Some(members), None, users)
816 }
817 MemberListTarget::Channel(thread_id) => {
818 let thread = srv.channels.get(*thread_id, Some(user_id)).await?;
819 let thread_members = data.thread_member_list_all(*thread_id).await?;
820 let room_members = if let Some(room_id) = thread.room_id {
821 Some(data.room_member_list_all(room_id).await?)
822 } else {
823 None
824 };
825 let user_ids: Vec<_> = thread_members.iter().map(|m| m.user_id).collect();
826 let users = futures::future::try_join_all(
827 user_ids
828 .into_iter()
829 .map(|id| srv.users.get(id, Some(user_id))),
830 )
831 .await?;
832 (room_members, Some(thread_members), users)
833 }
834 };
835
836 let mut members: Vec<(Option<_>, Option<_>, _)> = if let Some(t) = thread_members {
838 let mut users_map: std::collections::HashMap<_, _> =
839 users.into_iter().map(|u| (u.id, u)).collect();
840 t.into_iter()
841 .enumerate()
842 .map(|(idx, m)| {
843 (
844 room_members.as_ref().and_then(|m| m.get(idx).cloned()),
845 Some(m.clone()),
846 users_map.remove(&m.user_id).unwrap(),
847 )
848 })
849 .collect()
850 } else if let Some(r) = room_members {
851 let mut users_map: std::collections::HashMap<_, _> =
852 users.into_iter().map(|u| (u.id, u)).collect();
853 r.into_iter()
854 .map(|m| (Some(m.clone()), None, users_map.remove(&m.user_id).unwrap()))
855 .collect()
856 } else {
857 unreachable!()
858 };
859
860 members.sort_by(|(_, _, a), (_, _, b)| {
861 let a_online = srv.users.is_online(a.id);
862 let b_online = srv.users.is_online(b.id);
863 a_online
864 .cmp(&b_online)
865 .reverse()
866 .then_with(|| a.name.cmp(&b.name))
867 });
868
869 Ok(members)
870 }
871
872 pub async fn diff_sync_member_list(&mut self) -> Result<()> {
873 let sub = match &self.member_list_sub {
874 Some(sub) => sub.clone(),
875 None => return Ok(()),
876 };
877 let session = self.state.session().ok_or(Error::MissingAuth)?;
878 let user_id = session.user_id().ok_or(Error::UnauthSession)?;
879 let srv = self.s.services();
880
881 let new_members = self.get_member_list().await?;
882
883 let old_ids: Vec<_> = self
884 .member_list_cache
885 .iter()
886 .map(|(_, _, u)| u.id)
887 .collect();
888 let new_ids: Vec<_> = new_members.iter().map(|(_, _, u)| u.id).collect();
889
890 let mut ops = Vec::new();
891
892 if self.member_list_cache.is_empty() {
893 } else {
895 let mut new_idx = 0;
896 let mut consecutive_deletes = 0;
897
898 let diff_result = diff::slice(&old_ids, &new_ids);
899
900 for result in diff_result {
901 match result {
902 diff::Result::Left(_) => {
903 consecutive_deletes += 1;
904 }
905 diff::Result::Right(user_id) => {
906 if consecutive_deletes > 0 {
907 ops.push(MemberListOp::Delete {
908 position: new_idx,
909 count: consecutive_deletes,
910 });
911 consecutive_deletes = 0;
912 }
913 let (room_member, thread_member, user) = new_members
914 .iter()
915 .find(|(_, _, u)| u.id == *user_id)
916 .unwrap()
917 .clone();
918 ops.push(MemberListOp::Insert {
919 position: new_idx,
920 room_member,
921 thread_member,
922 user: Box::new(user),
923 });
924 new_idx += 1;
925 }
926 diff::Result::Both(_, _) => {
927 if consecutive_deletes > 0 {
928 ops.push(MemberListOp::Delete {
929 position: new_idx,
930 count: consecutive_deletes,
931 });
932 consecutive_deletes = 0;
933 }
934 new_idx += 1;
935 }
936 }
937 }
938
939 if consecutive_deletes > 0 {
940 ops.push(MemberListOp::Delete {
941 position: new_idx,
942 count: consecutive_deletes,
943 });
944 }
945 }
946
947 let online_count = new_members
948 .iter()
949 .filter(|(_, _, u)| srv.users.is_online(u.id))
950 .count() as u64;
951 let offline_count = new_members.len() as u64 - online_count;
952
953 let groups = vec![
954 MemberListGroup {
955 id: MemberListGroupId::Online,
956 count: online_count,
957 },
958 MemberListGroup {
959 id: MemberListGroupId::Offline,
960 count: offline_count,
961 },
962 ];
963
964 self.push_sync(MessageSync::MemberListSync {
965 user_id,
966 room_id: if let MemberListTarget::Room(id) = sub.target {
967 Some(id)
968 } else {
969 None
970 },
971 channel_id: if let MemberListTarget::Channel(id) = sub.target {
972 Some(id)
973 } else {
974 None
975 },
976 ops,
977 groups,
978 });
979
980 self.member_list_cache = new_members;
981
982 Ok(())
983 }
984
985 pub async fn resync_member_list(&mut self) -> Result<()> {
986 let sub = match &self.member_list_sub {
987 Some(sub) => sub.clone(),
988 None => return Ok(()),
989 };
990
991 let session = self.state.session().ok_or(Error::MissingAuth)?;
992 let user_id = session.user_id().ok_or(Error::UnauthSession)?;
993 let srv = self.s.services();
994
995 let members = self.get_member_list().await?;
996
997 let online_count = members
998 .iter()
999 .filter(|(_, _, u)| srv.users.is_online(u.id))
1000 .count() as u64;
1001 let offline_count = members.len() as u64 - online_count;
1002
1003 let groups = vec![
1004 MemberListGroup {
1005 id: MemberListGroupId::Online,
1006 count: online_count,
1007 },
1008 MemberListGroup {
1009 id: MemberListGroupId::Offline,
1010 count: offline_count,
1011 },
1012 ];
1013
1014 let mut ops = vec![];
1015
1016 for (start, end) in sub.ranges {
1017 let end = end.min(members.len() as u64);
1018 if start >= end {
1019 continue;
1020 }
1021 let slice = &members[start as usize..end as usize];
1022 let mut room_members = Vec::with_capacity(slice.len());
1023 let mut thread_members = Vec::with_capacity(slice.len());
1024 let mut users = Vec::with_capacity(slice.len());
1025 for (rm, tm, u) in slice.iter().cloned() {
1026 room_members.push(rm);
1027 thread_members.push(tm);
1028 users.push(u);
1029 }
1030
1031 ops.push(MemberListOp::Sync {
1032 position: start,
1033 room_members: if room_members.iter().all(|m| m.is_some()) {
1034 Some(room_members.into_iter().map(|m| m.unwrap()).collect())
1035 } else {
1036 None
1037 },
1038 thread_members: if thread_members.iter().all(|m| m.is_some()) {
1039 Some(thread_members.into_iter().map(|m| m.unwrap()).collect())
1040 } else {
1041 None
1042 },
1043 users,
1044 });
1045 }
1046
1047 self.push_sync(MessageSync::MemberListSync {
1048 user_id,
1049 room_id: if let MemberListTarget::Room(id) = sub.target {
1050 Some(id)
1051 } else {
1052 None
1053 },
1054 channel_id: if let MemberListTarget::Channel(id) = sub.target {
1055 Some(id)
1056 } else {
1057 None
1058 },
1059 ops,
1060 groups,
1061 });
1062
1063 self.member_list_cache = members;
1064
1065 Ok(())
1066 }
1067}
1068
1069impl ConnectionState {
1070 pub fn session(&self) -> Option<&Session> {
1071 match self {
1072 ConnectionState::Unauthed => None,
1073 ConnectionState::Authenticated { session } => Some(session),
1074 ConnectionState::Disconnected { session } => Some(session),
1075 }
1076 }
1077}
1078
1079impl Timeout {
1080 pub fn for_ping() -> Self {
1081 Timeout::Ping(Instant::now() + HEARTBEAT_TIME)
1082 }
1083
1084 pub fn for_close() -> Self {
1085 Timeout::Close(Instant::now() + CLOSE_TIME)
1086 }
1087
1088 pub fn get_instant(&self) -> Instant {
1089 match self {
1090 Timeout::Ping(instant) => *instant,
1091 Timeout::Close(instant) => *instant,
1092 }
1093 }
1094}