lamprey_backend/
sync.rs

1use std::time::Duration;
2use std::{collections::VecDeque, sync::Arc};
3
4use axum::extract::ws::{Message, WebSocket};
5use common::v1::types::emoji::EmojiOwner;
6use common::v1::types::user_status::Status;
7use common::v1::types::util::Time;
8use common::v1::types::voice::{SfuCommand, SfuPermissions, SignallingMessage, VoiceState};
9use common::v1::types::{self, SERVER_ROOM_ID};
10use common::v1::types::{
11    ChannelId, InviteTarget, InviteTargetId, MemberListGroup, MemberListGroupId, MemberListOp,
12    MessageClient, MessageEnvelope, MessageSync, Permission, RoomId, Session, UserId,
13};
14use tokio::time::Instant;
15use tracing::{debug, error, trace};
16
17use crate::error::{Error, Result};
18use crate::ServerState;
19
20type WsMessage = axum::extract::ws::Message;
21
22pub const HEARTBEAT_TIME: Duration = Duration::from_secs(30);
23pub const CLOSE_TIME: Duration = Duration::from_secs(10);
24const MAX_QUEUE_LEN: usize = 256;
25
26pub enum Timeout {
27    Ping(Instant),
28    Close(Instant),
29}
30
31pub struct Connection {
32    state: ConnectionState,
33    s: Arc<ServerState>,
34    queue: VecDeque<(Option<u64>, MessageEnvelope)>,
35    seq_server: u64,
36    seq_client: u64,
37    id: String,
38    member_list_sub: Option<MemberListSub>,
39    member_list_cache: Vec<(
40        Option<types::RoomMember>,
41        Option<types::ThreadMember>,
42        types::User,
43    )>,
44}
45
46#[derive(Debug, Clone)]
47struct MemberListSub {
48    target: MemberListTarget,
49    ranges: Vec<(u64, u64)>,
50}
51
52#[derive(Debug, Clone, PartialEq, Eq)]
53enum MemberListTarget {
54    Room(RoomId),
55    Channel(ChannelId),
56}
57
58#[derive(Debug, Clone)]
59enum ConnectionState {
60    Unauthed,
61    Authenticated { session: Session },
62    Disconnected { session: Session },
63}
64
65#[derive(Debug)]
66enum AuthCheck {
67    Custom(bool),
68    Room(RoomId),
69    RoomPerm(RoomId, Permission),
70    RoomOrUser(RoomId, UserId),
71    ChannelOrUser(ChannelId, UserId),
72    User(UserId),
73    UserMutual(UserId),
74    Channel(ChannelId),
75    EitherChannel(ChannelId, ChannelId),
76}
77
78impl Connection {
79    pub fn new(s: Arc<ServerState>) -> Self {
80        Self {
81            state: ConnectionState::Unauthed,
82            queue: VecDeque::new(),
83            seq_server: 0,
84            seq_client: 0,
85            id: format!("{}", uuid::Uuid::new_v4().hyphenated()),
86            member_list_sub: None,
87            member_list_cache: Vec::new(),
88            s,
89        }
90    }
91
92    pub fn disconnect(&mut self) {
93        // surely there's a way to do this with zero copies
94        self.state = match &self.state {
95            ConnectionState::Authenticated { session } => ConnectionState::Disconnected {
96                session: session.clone(),
97            },
98            s => s.to_owned(),
99        };
100    }
101
102    pub fn rewind(&mut self, seq: u64) -> Result<()> {
103        let is_still_valid = self
104            .queue
105            .iter()
106            .any(|(seq, _)| seq.is_some_and(|s| s <= self.seq_client));
107        if is_still_valid {
108            self.seq_client = seq;
109            Ok(())
110        } else {
111            Err(Error::BadStatic("too old"))
112        }
113    }
114
115    pub async fn handle_message(
116        &mut self,
117        ws_msg: Message,
118        ws: &mut WebSocket,
119        timeout: &mut Timeout,
120    ) -> Result<()> {
121        match ws_msg {
122            Message::Text(utf8_bytes) => {
123                let msg: MessageClient = serde_json::from_str(&utf8_bytes)?;
124                self.handle_message_client(msg, ws, timeout).await
125            }
126            Message::Binary(_) => Err(Error::BadStatic("doesn't support binary sorry")),
127            _ => Ok(()),
128        }
129    }
130
131    #[tracing::instrument(level = "debug", skip(self, ws, timeout), fields(id = self.get_id()))]
132    pub async fn handle_message_client(
133        &mut self,
134        msg: MessageClient,
135        ws: &mut WebSocket,
136        timeout: &mut Timeout,
137    ) -> Result<()> {
138        trace!("{:#?}", msg);
139        match msg {
140            MessageClient::Hello {
141                token,
142                resume: reconnect,
143                status,
144            } => {
145                let srv = self.s.services();
146                let session = srv
147                    .sessions
148                    .get_by_token(token)
149                    .await
150                    .map_err(|err| match err {
151                        Error::NotFound => Error::MissingAuth,
152                        other => other,
153                    })?;
154
155                // TODO: more forgiving reconnections
156                if let Some(r) = reconnect {
157                    debug!("attempting to resume");
158                    if let Some((_, mut conn)) = self.s.syncers.remove(&r.conn) {
159                        debug!("resume conn exists");
160                        if let Some(recon_session) = conn.state.session() {
161                            debug!("resume session exists");
162                            if session.id == recon_session.id {
163                                debug!("session id matches, resuming");
164                                conn.rewind(r.seq)?;
165                                conn.push(
166                                    MessageEnvelope {
167                                        payload: types::MessagePayload::Resumed,
168                                    },
169                                    None,
170                                );
171                                std::mem::swap(self, &mut conn);
172                                return Ok(());
173                            }
174                        }
175                    }
176                    return Err(Error::BadStatic("bad or expired reconnection info"));
177                }
178
179                let user = if let Some(user_id) = session.user_id() {
180                    let srv = self.s.services();
181                    let mut user = srv.users.get(user_id, Some(user_id)).await?;
182                    if user.is_suspended() {
183                        Some(user)
184                    } else {
185                        let user_with_new_status = srv
186                            .users
187                            .status_set(
188                                user_id,
189                                status
190                                    .map(|s| s.apply(Status::offline()))
191                                    .unwrap_or(Status::online()),
192                            )
193                            .await?;
194                        user.status = user_with_new_status.status;
195                        Some(user)
196                    }
197                } else {
198                    None
199                };
200
201                let msg = MessageEnvelope {
202                    payload: types::MessagePayload::Ready {
203                        user: Box::new(user),
204                        session: session.clone(),
205                        conn: self.get_id().to_owned(),
206                        seq: 0,
207                    },
208                };
209
210                ws.send(WsMessage::text(serde_json::to_string(&msg)?))
211                    .await?;
212
213                self.seq_server += 1;
214
215                if let Some(user_id) = session.user_id() {
216                    // Send typing states
217                    let typing_states = srv.channels.typing_list();
218                    for (channel_id, typing_user_id, until) in typing_states {
219                        if let Ok(perms) = srv.perms.for_channel(user_id, channel_id).await {
220                            if perms.has(Permission::ViewChannel) {
221                                self.push_sync(MessageSync::ChannelTyping {
222                                    channel_id,
223                                    user_id: typing_user_id,
224                                    until: until.into(),
225                                });
226                            }
227                        }
228                    }
229
230                    // Send voice states
231                    let voice_states = srv.users.voice_states_list();
232                    for voice_state in voice_states {
233                        if let Ok(perms) =
234                            srv.perms.for_channel(user_id, voice_state.thread_id).await
235                        {
236                            let is_ours = self.state.session().and_then(|s| s.user_id())
237                                == Some(voice_state.user_id);
238                            if perms.has(Permission::ViewChannel) || is_ours {
239                                let mut voice_state = voice_state.clone();
240                                if !is_ours {
241                                    voice_state.session_id = None;
242                                }
243                                self.push_sync(MessageSync::VoiceState {
244                                    user_id: voice_state.user_id,
245                                    state: Some(voice_state),
246                                    old_state: None,
247                                });
248                            }
249                        }
250                    }
251                }
252
253                self.state = ConnectionState::Authenticated { session };
254            }
255            MessageClient::Status { status } => {
256                let session = match &self.state {
257                    ConnectionState::Unauthed => return Err(Error::MissingAuth),
258                    ConnectionState::Authenticated { session } => session,
259                    ConnectionState::Disconnected { .. } => {
260                        panic!("somehow recv msg while disconnected?")
261                    }
262                };
263                let srv = self.s.services();
264                let user_id = session.user_id().ok_or(Error::UnauthSession)?;
265                let user = srv.users.get(user_id, None).await?;
266                user.ensure_unsuspended()?;
267                srv.users
268                    .status_set(user_id, status.apply(Status::offline()))
269                    .await?;
270            }
271            MessageClient::Pong => {
272                let session = match &self.state {
273                    ConnectionState::Unauthed => return Err(Error::MissingAuth),
274                    ConnectionState::Authenticated { session } => session,
275                    ConnectionState::Disconnected { .. } => {
276                        panic!("somehow recv msg while disconnected?")
277                    }
278                };
279                let srv = self.s.services();
280                let user_id = session.user_id().ok_or(Error::UnauthSession)?;
281                srv.users.status_ping(user_id).await?;
282                *timeout = Timeout::Ping(Instant::now() + HEARTBEAT_TIME);
283            }
284            MessageClient::MemberListSubscribe {
285                room_id,
286                thread_id,
287                ranges,
288            } => {
289                let session = self.state.session().ok_or(Error::MissingAuth)?;
290                let user_id = session.user_id().ok_or(Error::UnauthSession)?;
291                let srv = self.s.services();
292
293                let target = if let Some(room_id) = room_id {
294                    let _perms = srv.perms.for_room(user_id, room_id).await?;
295                    MemberListTarget::Room(room_id)
296                } else if let Some(thread_id) = thread_id {
297                    let perms = srv.perms.for_channel(user_id, thread_id).await?;
298                    perms.ensure(Permission::ViewChannel)?;
299                    MemberListTarget::Channel(thread_id)
300                } else {
301                    return Err(Error::BadStatic("room_id or thread_id must be provided"));
302                };
303
304                if self.member_list_sub.as_ref().map(|s| &s.target) != Some(&target) {
305                    self.member_list_cache.clear();
306                }
307
308                self.member_list_sub = Some(MemberListSub {
309                    target: target.clone(),
310                    ranges: ranges.clone(),
311                });
312
313                self.resync_member_list().await?;
314            }
315            MessageClient::VoiceDispatch {
316                user_id: _,
317                payload,
318            } => {
319                let Some(session) = self.state.session() else {
320                    return Err(Error::BadStatic("no session"));
321                };
322                let Some(user_id) = session.user_id() else {
323                    return Err(Error::BadStatic("no user"));
324                };
325
326                let srv = self.s.services();
327                let user = srv.users.get(user_id, Some(user_id)).await?;
328                user.ensure_unsuspended()?;
329
330                match &payload {
331                    SignallingMessage::VoiceState { state: Some(state) } => {
332                        let perms = srv.perms.for_channel(user_id, state.thread_id).await?;
333                        perms.ensure(Permission::ViewChannel)?;
334                        perms.ensure(Permission::VoiceConnect)?;
335                        let thread = srv.channels.get(state.thread_id, Some(user_id)).await?;
336                        if thread.archived_at.is_some() {
337                            return Err(Error::BadStatic("thread is archived"));
338                        }
339                        if thread.deleted_at.is_some() {
340                            return Err(Error::BadStatic("thread is removed"));
341                        }
342                        if thread.locked {
343                            perms.ensure(Permission::ThreadLock)?;
344                        }
345                        let mut state = VoiceState {
346                            user_id,
347                            thread_id: state.thread_id,
348                            session_id: Some(session.id),
349                            joined_at: Time::now_utc(),
350                            mute: false,
351                            deaf: false,
352                            self_deaf: state.self_deaf,
353                            self_mute: state.self_mute,
354                            self_video: state.self_video,
355                            self_screen: state.self_screen,
356                        };
357                        if let Some(room_id) = thread.room_id {
358                            let rm = self.s.data().room_member_get(room_id, user_id).await?;
359                            state.mute = rm.mute;
360                            state.deaf = rm.deaf;
361                        }
362                        self.s.alloc_sfu(state.thread_id).await?;
363                        if let Err(err) = self.s.sushi_sfu.send(SfuCommand::VoiceState {
364                            user_id,
365                            state: Some(state),
366                            permissions: SfuPermissions {
367                                speak: perms.has(Permission::VoiceSpeak),
368                                video: perms.has(Permission::VoiceVideo),
369                                priority: perms.has(Permission::VoicePriority),
370                            },
371                        }) {
372                            error!("failed to send to sushi_sfu: {err}");
373                        }
374                        return Ok(());
375                    }
376                    SignallingMessage::VoiceState { state: None } => {
377                        if let Err(err) = self.s.sushi_sfu.send(SfuCommand::VoiceState {
378                            user_id,
379                            state: None,
380                            permissions: SfuPermissions {
381                                speak: false,
382                                video: false,
383                                priority: false,
384                            },
385                        }) {
386                            error!("failed to send to sushi_sfu: {err}");
387                        }
388                        return Ok(());
389                    }
390                    SignallingMessage::Offer { .. } => {
391                        // TODO: also verify sdp and/or send permissions to sfu instead of only parsing tracks
392                        // let perms = srv.perms.for_thread(user_id, voice_state.thread_id).await?;
393                        // if tracks.iter().any(|t| t.kind == MediaKindSerde::Audio) {
394                        //     perms.ensure(Permission::VoiceSpeak)?;
395                        // }
396                        // if tracks.iter().any(|t| t.kind == MediaKindSerde::Video) {
397                        //     perms.ensure(Permission::VoiceVideo)?;
398                        // }
399                    }
400                    _ => {}
401                }
402
403                if let Err(err) = self.s.sushi_sfu.send(SfuCommand::Signalling {
404                    user_id,
405                    inner: payload,
406                }) {
407                    error!("failed to send to sushi_sfu: {err}");
408                }
409            }
410        }
411        Ok(())
412    }
413
414    #[tracing::instrument(level = "debug", skip(self), fields(id = self.get_id()))]
415    pub async fn queue_message(&mut self, msg: Box<MessageSync>) -> Result<()> {
416        let mut session = match &self.state {
417            ConnectionState::Authenticated { session }
418            | ConnectionState::Disconnected { session } => session.clone(),
419            _ => return Ok(()),
420        };
421
422        match &self.state {
423            ConnectionState::Disconnected { .. }
424                if self.seq_server > self.seq_client + MAX_QUEUE_LEN as u64 =>
425            {
426                self.s.syncers.remove(&self.id);
427                return Err(Error::BadStatic("expired session"));
428            }
429            _ => {}
430        }
431
432        let auth_check = match &*msg {
433            MessageSync::RoomCreate { room } => AuthCheck::Room(room.id),
434            MessageSync::RoomUpdate { room } => AuthCheck::Room(room.id),
435            MessageSync::RoomDelete { room_id } => AuthCheck::Room(*room_id),
436            MessageSync::ChannelCreate { channel } => AuthCheck::Channel(channel.id),
437            MessageSync::ChannelUpdate { channel } => AuthCheck::Channel(channel.id),
438            MessageSync::MessageCreate { message } => AuthCheck::Channel(message.channel_id),
439            MessageSync::MessageUpdate { message } => AuthCheck::Channel(message.channel_id),
440            MessageSync::UserCreate { user } => AuthCheck::UserMutual(user.id),
441            MessageSync::UserUpdate { user } => {
442                if self.member_list_sub.is_some() {
443                    if let Some((_, _, old_user)) = self
444                        .member_list_cache
445                        .iter()
446                        .find(|(_, _, u)| u.id == user.id)
447                    {
448                        let old_online = old_user.status.status.is_online();
449                        let new_online = user.status.status.is_online();
450
451                        if old_online != new_online {
452                            self.diff_sync_member_list().await?;
453                        }
454                    }
455                }
456                AuthCheck::UserMutual(user.id)
457            }
458            MessageSync::UserConfigGlobal { user_id, .. } => AuthCheck::User(*user_id),
459            MessageSync::UserConfigRoom { user_id, .. } => AuthCheck::User(*user_id),
460            MessageSync::UserConfigChannel { user_id, .. } => AuthCheck::User(*user_id),
461            MessageSync::UserConfigUser { user_id, .. } => AuthCheck::User(*user_id),
462            MessageSync::RoomMemberUpsert { member } => {
463                if self
464                    .member_list_sub
465                    .as_ref()
466                    .is_some_and(|s| s.target == MemberListTarget::Room(member.room_id))
467                {
468                    self.diff_sync_member_list().await?;
469                }
470                AuthCheck::RoomOrUser(member.room_id, member.user_id)
471            }
472            MessageSync::ThreadMemberUpsert { member } => {
473                if self
474                    .member_list_sub
475                    .as_ref()
476                    .is_some_and(|s| s.target == MemberListTarget::Channel(member.thread_id))
477                {
478                    self.diff_sync_member_list().await?;
479                }
480                AuthCheck::ChannelOrUser(member.thread_id, member.user_id)
481            }
482            MessageSync::SessionCreate {
483                session: upserted_session,
484            } => {
485                if session.id == upserted_session.id {
486                    session = upserted_session.to_owned();
487                    self.state = ConnectionState::Authenticated {
488                        session: upserted_session.to_owned(),
489                    };
490                }
491                AuthCheck::Custom(session.can_see(upserted_session))
492            }
493            MessageSync::SessionUpdate {
494                session: upserted_session,
495            } => {
496                if session.id == upserted_session.id {
497                    session = upserted_session.to_owned();
498                    self.state = ConnectionState::Authenticated {
499                        session: upserted_session.to_owned(),
500                    };
501                }
502                AuthCheck::Custom(session.can_see(upserted_session))
503            }
504            MessageSync::RoleCreate { role } => AuthCheck::Room(role.room_id),
505            MessageSync::RoleUpdate { role } => AuthCheck::Room(role.room_id),
506            // FIXME(#612): only return invite events to creator and members with InviteManage
507            MessageSync::InviteCreate { invite } => match &invite.invite.target {
508                InviteTarget::Room { room, channel: _ } => AuthCheck::Room(room.id),
509                InviteTarget::Gdm { channel, .. } => AuthCheck::Channel(channel.id),
510                InviteTarget::Server => {
511                    AuthCheck::RoomPerm(SERVER_ROOM_ID, Permission::ServerOversee)
512                }
513                InviteTarget::User { user, .. } => AuthCheck::User(user.id),
514            },
515            MessageSync::InviteUpdate { invite } => match &invite.invite.target {
516                InviteTarget::Room { room, .. } => AuthCheck::Room(room.id),
517                InviteTarget::Gdm { channel, .. } => AuthCheck::Channel(channel.id),
518                InviteTarget::Server => {
519                    AuthCheck::RoomPerm(SERVER_ROOM_ID, Permission::ServerOversee)
520                }
521                InviteTarget::User { user, .. } => AuthCheck::User(user.id),
522            },
523            MessageSync::MessageDelete { channel_id, .. } => AuthCheck::Channel(*channel_id),
524            MessageSync::MessageVersionDelete { channel_id, .. } => AuthCheck::Channel(*channel_id),
525            MessageSync::UserDelete { id } => AuthCheck::UserMutual(*id),
526            MessageSync::SessionDelete { id, user_id } => {
527                // TODO: send message when other sessions from the same user are deleted
528                if *id == session.id {
529                    self.state = ConnectionState::Unauthed;
530                    AuthCheck::Custom(true)
531                } else if let Some(user_id) = user_id {
532                    AuthCheck::User(*user_id)
533                } else {
534                    AuthCheck::Custom(false)
535                }
536            }
537            MessageSync::RoleDelete { room_id, .. } => AuthCheck::Room(*room_id),
538            MessageSync::RoleReorder { room_id, .. } => AuthCheck::Room(*room_id),
539            MessageSync::InviteDelete { target, .. } => match target {
540                InviteTargetId::Room { room_id, .. } => AuthCheck::Room(*room_id),
541                InviteTargetId::Gdm { channel_id, .. } => AuthCheck::Channel(*channel_id),
542                InviteTargetId::Server => {
543                    AuthCheck::RoomPerm(SERVER_ROOM_ID, Permission::ServerOversee)
544                }
545                InviteTargetId::User { user_id, .. } => AuthCheck::User(*user_id),
546            },
547            MessageSync::ChannelTyping { channel_id, .. } => AuthCheck::Channel(*channel_id),
548            MessageSync::ChannelAck { user_id, .. } => AuthCheck::User(*user_id),
549            MessageSync::RelationshipUpsert { user_id, .. } => AuthCheck::User(*user_id),
550            MessageSync::RelationshipDelete { user_id, .. } => AuthCheck::User(*user_id),
551            MessageSync::ReactionCreate { channel_id, .. } => AuthCheck::Channel(*channel_id),
552            MessageSync::ReactionDelete { channel_id, .. } => AuthCheck::Channel(*channel_id),
553            MessageSync::ReactionPurge { channel_id, .. } => AuthCheck::Channel(*channel_id),
554            MessageSync::MessageDeleteBulk { channel_id, .. } => AuthCheck::Channel(*channel_id),
555            MessageSync::MessageRemove { channel_id, .. } => AuthCheck::Channel(*channel_id),
556            MessageSync::MessageRestore { channel_id, .. } => AuthCheck::Channel(*channel_id),
557            MessageSync::VoiceDispatch { user_id, .. } => AuthCheck::User(*user_id),
558            MessageSync::VoiceState {
559                state,
560                user_id,
561                old_state,
562            } => match (state, old_state) {
563                (None, None) => AuthCheck::User(*user_id),
564                (None, Some(o)) => AuthCheck::Channel(o.thread_id),
565                (Some(s), None) => AuthCheck::Channel(s.thread_id),
566                (Some(s), Some(o)) => AuthCheck::EitherChannel(s.thread_id, o.thread_id),
567            },
568            MessageSync::EmojiCreate { emoji } => match emoji.owner {
569                EmojiOwner::Room { room_id } => AuthCheck::Room(room_id),
570                EmojiOwner::User => AuthCheck::User(emoji.creator_id),
571            },
572            MessageSync::EmojiUpdate { emoji } => match emoji.owner {
573                EmojiOwner::Room { room_id } => AuthCheck::Room(room_id),
574                EmojiOwner::User => AuthCheck::User(emoji.creator_id),
575            },
576            MessageSync::EmojiDelete {
577                room_id,
578                emoji_id: _,
579            } => AuthCheck::Room(*room_id),
580            MessageSync::ConnectionCreate { user_id, .. } => AuthCheck::User(*user_id),
581            MessageSync::ConnectionDelete { user_id, .. } => AuthCheck::User(*user_id),
582            MessageSync::AuditLogEntryCreate { entry } => {
583                AuthCheck::RoomPerm(entry.room_id, Permission::ViewAuditLog)
584            }
585            MessageSync::BanCreate { room_id, .. } => {
586                AuthCheck::RoomPerm(*room_id, Permission::MemberBan)
587            }
588            MessageSync::BanDelete { room_id, .. } => {
589                AuthCheck::RoomPerm(*room_id, Permission::MemberBan)
590            }
591            MessageSync::MemberListSync { user_id, .. } => AuthCheck::User(*user_id),
592            MessageSync::InboxNotificationCreate { user_id, .. } => AuthCheck::User(*user_id),
593            MessageSync::InboxMarkRead { user_id, .. } => AuthCheck::User(*user_id),
594            MessageSync::InboxMarkUnread { user_id, .. } => AuthCheck::User(*user_id),
595            MessageSync::InboxFlush { user_id, .. } => AuthCheck::User(*user_id),
596            MessageSync::CalendarEventCreate { event } => AuthCheck::Channel(event.channel_id),
597            MessageSync::CalendarEventUpdate { event } => AuthCheck::Channel(event.channel_id),
598            MessageSync::CalendarEventDelete { channel_id, .. } => AuthCheck::Channel(*channel_id),
599        };
600        let should_send = match (session.user_id(), auth_check) {
601            (Some(user_id), AuthCheck::Room(room_id)) => {
602                let _perms = self.s.services().perms.for_room(user_id, room_id).await?;
603                true
604            }
605            (Some(user_id), AuthCheck::RoomPerm(room_id, perm)) => {
606                let perms = self.s.services().perms.for_room(user_id, room_id).await?;
607                perms.has(perm)
608            }
609            (Some(auth_user_id), AuthCheck::RoomOrUser(room_id, target_user_id)) => {
610                if auth_user_id == target_user_id {
611                    true
612                } else {
613                    let _perms = self
614                        .s
615                        .services()
616                        .perms
617                        .for_room(auth_user_id, room_id)
618                        .await?;
619                    true
620                }
621            }
622            (Some(user_id), AuthCheck::Channel(thread_id)) => {
623                let perms = self
624                    .s
625                    .services()
626                    .perms
627                    .for_channel(user_id, thread_id)
628                    .await?;
629                perms.has(Permission::ViewChannel)
630            }
631            (Some(user_id), AuthCheck::EitherChannel(thread_id_0, thread_id_1)) => {
632                let perms0 = self
633                    .s
634                    .services()
635                    .perms
636                    .for_channel(user_id, thread_id_0)
637                    .await?;
638                let perms1 = self
639                    .s
640                    .services()
641                    .perms
642                    .for_channel(user_id, thread_id_1)
643                    .await?;
644                perms0.has(Permission::ViewChannel) || perms1.has(Permission::ViewChannel)
645            }
646            (Some(auth_user_id), AuthCheck::ChannelOrUser(thread_id, target_user_id)) => {
647                if auth_user_id == target_user_id {
648                    true
649                } else {
650                    let perms = self
651                        .s
652                        .services()
653                        .perms
654                        .for_channel(auth_user_id, thread_id)
655                        .await?;
656                    perms.has(Permission::ViewChannel)
657                }
658            }
659            (Some(auth_user_id), AuthCheck::User(target_user_id)) => auth_user_id == target_user_id,
660            (Some(auth_user_id), AuthCheck::UserMutual(target_user_id)) => {
661                if auth_user_id == target_user_id {
662                    true
663                } else {
664                    self.s
665                        .services()
666                        .perms
667                        .is_mutual(auth_user_id, target_user_id)
668                        .await?
669                }
670            }
671            (_, AuthCheck::Custom(b)) => b,
672            (None, _) => false,
673        };
674        if should_send {
675            let d = self.s.data();
676            let srv = self.s.services();
677            let msg = match *msg {
678                MessageSync::ChannelCreate { channel } => MessageSync::ChannelCreate {
679                    channel: Box::new(srv.channels.get(channel.id, session.user_id()).await?),
680                },
681                MessageSync::ChannelUpdate { channel } => MessageSync::ChannelUpdate {
682                    channel: Box::new(srv.channels.get(channel.id, session.user_id()).await?),
683                },
684                MessageSync::MessageCreate { message } => MessageSync::MessageCreate {
685                    message: {
686                        let mut m = d
687                            .message_get(message.channel_id, message.id, session.user_id().unwrap())
688                            .await?;
689                        self.s.presign_message(&mut m).await?;
690                        m.nonce = message.nonce;
691                        m
692                    },
693                },
694                MessageSync::MessageUpdate { message } => MessageSync::MessageUpdate {
695                    message: {
696                        let mut m = d
697                            .message_get(message.channel_id, message.id, session.user_id().unwrap())
698                            .await?;
699                        self.s.presign_message(&mut m).await?;
700                        m.nonce = message.nonce;
701                        m
702                    },
703                },
704                MessageSync::VoiceState {
705                    user_id,
706                    mut state,
707                    mut old_state,
708                } => {
709                    // strip session_id for voice states that aren't ours
710                    let is_ours = self.state.session().and_then(|s| s.user_id()) == Some(user_id);
711                    if !is_ours {
712                        if let Some(s) = &mut state {
713                            s.session_id = None;
714                        }
715
716                        if let Some(s) = &mut old_state {
717                            s.session_id = None;
718                        }
719                    }
720
721                    // if we don't have view perms in the new thread, treat it like a disconnect
722                    if let Some(s) = &state {
723                        let perms = self
724                            .s
725                            .services()
726                            .perms
727                            .for_channel(user_id, s.thread_id)
728                            .await?;
729                        if !perms.has(Permission::ViewChannel) {
730                            state = None;
731                        }
732                    }
733
734                    MessageSync::VoiceState {
735                        user_id,
736                        state,
737                        old_state,
738                    }
739                }
740                m => m,
741            };
742            self.push_sync(msg);
743        }
744        Ok(())
745    }
746
747    fn push_sync(&mut self, sync: MessageSync) {
748        let seq = self.seq_server;
749        let msg = MessageEnvelope {
750            payload: types::MessagePayload::Sync {
751                data: Box::new(sync),
752                seq,
753            },
754        };
755        self.push(msg, Some(seq));
756        self.seq_server += 1;
757    }
758
759    fn push(&mut self, msg: MessageEnvelope, seq: Option<u64>) {
760        self.queue.push_front((seq, msg));
761        self.queue.truncate(MAX_QUEUE_LEN);
762    }
763
764    #[tracing::instrument(level = "debug", skip(self, ws), fields(id = self.get_id()))]
765    pub async fn drain(&mut self, ws: &mut WebSocket) -> Result<()> {
766        let last_seen = self.seq_client;
767        let mut high_water_mark = last_seen;
768        for (seq, msg) in self.queue.iter().rev() {
769            if seq.is_none_or(|s| s > last_seen) {
770                let json = serde_json::to_string(&msg)?;
771                ws.send(WsMessage::text(json)).await?;
772                if let Some(seq) = *seq {
773                    high_water_mark = high_water_mark.max(seq);
774                }
775            }
776        }
777        self.seq_client = high_water_mark;
778        self.queue.retain(|(seq, _)| seq.is_some());
779        Ok(())
780    }
781
782    pub fn get_id(&self) -> &str {
783        &self.id
784    }
785
786    async fn get_member_list(
787        &self,
788    ) -> Result<
789        Vec<(
790            Option<types::RoomMember>,
791            Option<types::ThreadMember>,
792            types::User,
793        )>,
794    > {
795        let sub = match &self.member_list_sub {
796            Some(sub) => sub.clone(),
797            None => return Ok(Vec::new()),
798        };
799
800        let session = self.state.session().ok_or(Error::MissingAuth)?;
801        let user_id = session.user_id().ok_or(Error::UnauthSession)?;
802        let srv = self.s.services();
803        let data = self.s.data();
804
805        let (room_members, thread_members, users) = match &sub.target {
806            MemberListTarget::Room(room_id) => {
807                let members = data.room_member_list_all(*room_id).await?;
808                let user_ids: Vec<_> = members.iter().map(|m| m.user_id).collect();
809                let users = futures::future::try_join_all(
810                    user_ids
811                        .into_iter()
812                        .map(|id| srv.users.get(id, Some(user_id))),
813                )
814                .await?;
815                (Some(members), None, users)
816            }
817            MemberListTarget::Channel(thread_id) => {
818                let thread = srv.channels.get(*thread_id, Some(user_id)).await?;
819                let thread_members = data.thread_member_list_all(*thread_id).await?;
820                let room_members = if let Some(room_id) = thread.room_id {
821                    Some(data.room_member_list_all(room_id).await?)
822                } else {
823                    None
824                };
825                let user_ids: Vec<_> = thread_members.iter().map(|m| m.user_id).collect();
826                let users = futures::future::try_join_all(
827                    user_ids
828                        .into_iter()
829                        .map(|id| srv.users.get(id, Some(user_id))),
830                )
831                .await?;
832                (room_members, Some(thread_members), users)
833            }
834        };
835
836        // this is a bit cursed
837        let mut members: Vec<(Option<_>, Option<_>, _)> = if let Some(t) = thread_members {
838            let mut users_map: std::collections::HashMap<_, _> =
839                users.into_iter().map(|u| (u.id, u)).collect();
840            t.into_iter()
841                .enumerate()
842                .map(|(idx, m)| {
843                    (
844                        room_members.as_ref().and_then(|m| m.get(idx).cloned()),
845                        Some(m.clone()),
846                        users_map.remove(&m.user_id).unwrap(),
847                    )
848                })
849                .collect()
850        } else if let Some(r) = room_members {
851            let mut users_map: std::collections::HashMap<_, _> =
852                users.into_iter().map(|u| (u.id, u)).collect();
853            r.into_iter()
854                .map(|m| (Some(m.clone()), None, users_map.remove(&m.user_id).unwrap()))
855                .collect()
856        } else {
857            unreachable!()
858        };
859
860        members.sort_by(|(_, _, a), (_, _, b)| {
861            let a_online = srv.users.is_online(a.id);
862            let b_online = srv.users.is_online(b.id);
863            a_online
864                .cmp(&b_online)
865                .reverse()
866                .then_with(|| a.name.cmp(&b.name))
867        });
868
869        Ok(members)
870    }
871
872    pub async fn diff_sync_member_list(&mut self) -> Result<()> {
873        let sub = match &self.member_list_sub {
874            Some(sub) => sub.clone(),
875            None => return Ok(()),
876        };
877        let session = self.state.session().ok_or(Error::MissingAuth)?;
878        let user_id = session.user_id().ok_or(Error::UnauthSession)?;
879        let srv = self.s.services();
880
881        let new_members = self.get_member_list().await?;
882
883        let old_ids: Vec<_> = self
884            .member_list_cache
885            .iter()
886            .map(|(_, _, u)| u.id)
887            .collect();
888        let new_ids: Vec<_> = new_members.iter().map(|(_, _, u)| u.id).collect();
889
890        let mut ops = Vec::new();
891
892        if self.member_list_cache.is_empty() {
893            // initial sync, just send sync ops
894        } else {
895            let mut new_idx = 0;
896            let mut consecutive_deletes = 0;
897
898            let diff_result = diff::slice(&old_ids, &new_ids);
899
900            for result in diff_result {
901                match result {
902                    diff::Result::Left(_) => {
903                        consecutive_deletes += 1;
904                    }
905                    diff::Result::Right(user_id) => {
906                        if consecutive_deletes > 0 {
907                            ops.push(MemberListOp::Delete {
908                                position: new_idx,
909                                count: consecutive_deletes,
910                            });
911                            consecutive_deletes = 0;
912                        }
913                        let (room_member, thread_member, user) = new_members
914                            .iter()
915                            .find(|(_, _, u)| u.id == *user_id)
916                            .unwrap()
917                            .clone();
918                        ops.push(MemberListOp::Insert {
919                            position: new_idx,
920                            room_member,
921                            thread_member,
922                            user: Box::new(user),
923                        });
924                        new_idx += 1;
925                    }
926                    diff::Result::Both(_, _) => {
927                        if consecutive_deletes > 0 {
928                            ops.push(MemberListOp::Delete {
929                                position: new_idx,
930                                count: consecutive_deletes,
931                            });
932                            consecutive_deletes = 0;
933                        }
934                        new_idx += 1;
935                    }
936                }
937            }
938
939            if consecutive_deletes > 0 {
940                ops.push(MemberListOp::Delete {
941                    position: new_idx,
942                    count: consecutive_deletes,
943                });
944            }
945        }
946
947        let online_count = new_members
948            .iter()
949            .filter(|(_, _, u)| srv.users.is_online(u.id))
950            .count() as u64;
951        let offline_count = new_members.len() as u64 - online_count;
952
953        let groups = vec![
954            MemberListGroup {
955                id: MemberListGroupId::Online,
956                count: online_count,
957            },
958            MemberListGroup {
959                id: MemberListGroupId::Offline,
960                count: offline_count,
961            },
962        ];
963
964        self.push_sync(MessageSync::MemberListSync {
965            user_id,
966            room_id: if let MemberListTarget::Room(id) = sub.target {
967                Some(id)
968            } else {
969                None
970            },
971            channel_id: if let MemberListTarget::Channel(id) = sub.target {
972                Some(id)
973            } else {
974                None
975            },
976            ops,
977            groups,
978        });
979
980        self.member_list_cache = new_members;
981
982        Ok(())
983    }
984
985    pub async fn resync_member_list(&mut self) -> Result<()> {
986        let sub = match &self.member_list_sub {
987            Some(sub) => sub.clone(),
988            None => return Ok(()),
989        };
990
991        let session = self.state.session().ok_or(Error::MissingAuth)?;
992        let user_id = session.user_id().ok_or(Error::UnauthSession)?;
993        let srv = self.s.services();
994
995        let members = self.get_member_list().await?;
996
997        let online_count = members
998            .iter()
999            .filter(|(_, _, u)| srv.users.is_online(u.id))
1000            .count() as u64;
1001        let offline_count = members.len() as u64 - online_count;
1002
1003        let groups = vec![
1004            MemberListGroup {
1005                id: MemberListGroupId::Online,
1006                count: online_count,
1007            },
1008            MemberListGroup {
1009                id: MemberListGroupId::Offline,
1010                count: offline_count,
1011            },
1012        ];
1013
1014        let mut ops = vec![];
1015
1016        for (start, end) in sub.ranges {
1017            let end = end.min(members.len() as u64);
1018            if start >= end {
1019                continue;
1020            }
1021            let slice = &members[start as usize..end as usize];
1022            let mut room_members = Vec::with_capacity(slice.len());
1023            let mut thread_members = Vec::with_capacity(slice.len());
1024            let mut users = Vec::with_capacity(slice.len());
1025            for (rm, tm, u) in slice.iter().cloned() {
1026                room_members.push(rm);
1027                thread_members.push(tm);
1028                users.push(u);
1029            }
1030
1031            ops.push(MemberListOp::Sync {
1032                position: start,
1033                room_members: if room_members.iter().all(|m| m.is_some()) {
1034                    Some(room_members.into_iter().map(|m| m.unwrap()).collect())
1035                } else {
1036                    None
1037                },
1038                thread_members: if thread_members.iter().all(|m| m.is_some()) {
1039                    Some(thread_members.into_iter().map(|m| m.unwrap()).collect())
1040                } else {
1041                    None
1042                },
1043                users,
1044            });
1045        }
1046
1047        self.push_sync(MessageSync::MemberListSync {
1048            user_id,
1049            room_id: if let MemberListTarget::Room(id) = sub.target {
1050                Some(id)
1051            } else {
1052                None
1053            },
1054            channel_id: if let MemberListTarget::Channel(id) = sub.target {
1055                Some(id)
1056            } else {
1057                None
1058            },
1059            ops,
1060            groups,
1061        });
1062
1063        self.member_list_cache = members;
1064
1065        Ok(())
1066    }
1067}
1068
1069impl ConnectionState {
1070    pub fn session(&self) -> Option<&Session> {
1071        match self {
1072            ConnectionState::Unauthed => None,
1073            ConnectionState::Authenticated { session } => Some(session),
1074            ConnectionState::Disconnected { session } => Some(session),
1075        }
1076    }
1077}
1078
1079impl Timeout {
1080    pub fn for_ping() -> Self {
1081        Timeout::Ping(Instant::now() + HEARTBEAT_TIME)
1082    }
1083
1084    pub fn for_close() -> Self {
1085        Timeout::Close(Instant::now() + CLOSE_TIME)
1086    }
1087
1088    pub fn get_instant(&self) -> Instant {
1089        match self {
1090            Timeout::Ping(instant) => *instant,
1091            Timeout::Close(instant) => *instant,
1092        }
1093    }
1094}