bonsaidb_server/server/
connected_client.rs

1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::ops::{Deref, DerefMut};
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6
7use async_lock::{Mutex, MutexGuard};
8use bonsaidb_core::api;
9use bonsaidb_core::api::ApiName;
10use bonsaidb_core::arc_bytes::serde::Bytes;
11use bonsaidb_core::connection::{Session, SessionId};
12use bonsaidb_core::networking::MessageReceived;
13use bonsaidb_core::pubsub::{Receiver, Subscriber as _};
14use bonsaidb_local::Subscriber;
15use bonsaidb_utils::fast_async_lock;
16use derive_where::derive_where;
17use flume::Sender;
18use parking_lot::RwLock;
19
20use crate::{Backend, CustomServer, Error, NoBackend};
21
22/// The ways a client can be connected to the server.
23#[derive(Debug, PartialEq, Eq)]
24pub enum Transport {
25    /// A connection over BonsaiDb's QUIC-based protocol.
26    Bonsai,
27    /// A connection over WebSockets.
28    #[cfg(feature = "websockets")]
29    WebSocket,
30}
31
32/// A connected database client.
33#[derive(Debug)]
34#[derive_where(Clone)]
35pub struct ConnectedClient<B: Backend = NoBackend> {
36    data: Arc<Data<B>>,
37}
38
39#[derive(Debug)]
40struct Data<B: Backend = NoBackend> {
41    id: u32,
42    sessions: RwLock<HashMap<Option<SessionId>, ClientSession>>,
43    address: SocketAddr,
44    transport: Transport,
45    response_sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
46    client_data: Mutex<Option<B::ClientData>>,
47    connected: AtomicBool,
48}
49
50#[derive(Debug)]
51struct ClientSession {
52    session: Session,
53    subscribers: HashMap<u64, Subscriber>,
54}
55
56impl<B: Backend> ConnectedClient<B> {
57    /// Returns the address of the connected client.
58    #[must_use]
59    pub fn address(&self) -> &SocketAddr {
60        &self.data.address
61    }
62
63    /// Returns the transport method the client is connected via.
64    #[must_use]
65    pub fn transport(&self) -> &Transport {
66        &self.data.transport
67    }
68
69    /// Returns true if the server still believes the client is connected.
70    #[must_use]
71    pub fn connected(&self) -> bool {
72        self.data.connected.load(Ordering::Relaxed)
73    }
74
75    pub(crate) fn set_disconnected(&self) {
76        self.data.connected.store(false, Ordering::Relaxed);
77    }
78
79    pub(crate) fn logged_in_as(&self, session: Session) {
80        let mut sessions = self.data.sessions.write();
81        sessions.insert(
82            session.id,
83            ClientSession {
84                session,
85                subscribers: HashMap::default(),
86            },
87        );
88    }
89
90    pub(crate) fn log_out(&self, session: SessionId) -> Option<Session> {
91        let mut sessions = self.data.sessions.write();
92        sessions.remove(&Some(session)).map(|cs| cs.session)
93    }
94
95    /// Sends a custom API response to the client.
96    pub fn send<Api: api::Api>(
97        &self,
98        session: Option<&Session>,
99        response: &Api::Response,
100    ) -> Result<(), Error> {
101        let encoded = pot::to_vec(&Result::<&Api::Response, Api::Error>::Ok(response))?;
102        self.data.response_sender.send((
103            session.and_then(|session| session.id),
104            Api::name(),
105            Bytes::from(encoded),
106        ))?;
107        Ok(())
108    }
109
110    /// Returns a locked reference to the stored client data.
111    pub async fn client_data(&self) -> LockedClientDataGuard<'_, B::ClientData> {
112        LockedClientDataGuard(fast_async_lock!(self.data.client_data))
113    }
114
115    /// Looks up an active authentication session by its unique id. `None`
116    /// represents the unauthenticated session, and the result can be used to
117    /// check what permissions are allowed by default.
118    #[must_use]
119    pub fn session(&self, session_id: Option<SessionId>) -> Option<Session> {
120        let sessions = self.data.sessions.read();
121        sessions.get(&session_id).map(|data| data.session.clone())
122    }
123
124    /// Returns a collection of all active [`Session`]s for this client.
125    #[must_use]
126    pub fn all_sessions<C: FromIterator<Session>>(&self) -> C {
127        let sessions = self.data.sessions.read();
128        sessions.values().map(|s| s.session.clone()).collect()
129    }
130
131    pub(crate) fn register_subscriber(
132        &self,
133        subscriber: Subscriber,
134        session_id: Option<SessionId>,
135    ) {
136        let subscriber_id = subscriber.id();
137        let receiver = subscriber.receiver().clone();
138        {
139            let mut sessions = self.data.sessions.write();
140            if let Some(client_session) = sessions.get_mut(&session_id) {
141                client_session
142                    .subscribers
143                    .insert(subscriber.id(), subscriber);
144            } else {
145                // TODO return error for session not found.
146                return;
147            }
148        }
149        let task_self = self.clone();
150        tokio::task::spawn(async move {
151            task_self
152                .forward_notifications_for(session_id, subscriber_id, receiver)
153                .await;
154        });
155    }
156
157    /// Sets the associated data for this client.
158    pub async fn set_client_data(&self, data: B::ClientData) {
159        let mut client_data = fast_async_lock!(self.data.client_data);
160        *client_data = Some(data);
161    }
162
163    async fn forward_notifications_for(
164        &self,
165        session_id: Option<SessionId>,
166        subscriber_id: u64,
167        receiver: Receiver,
168    ) {
169        let session = self.session(session_id);
170        while let Ok(message) = receiver.receive_async().await {
171            if self
172                .send::<MessageReceived>(
173                    session.as_ref(),
174                    &MessageReceived {
175                        subscriber_id,
176                        topic: Bytes::from(message.topic.0.into_vec()),
177                        payload: Bytes::from(&message.payload[..]),
178                    },
179                )
180                .is_err()
181            {
182                break;
183            }
184        }
185    }
186
187    pub(crate) fn subscribe_by_id(
188        &self,
189        subscriber_id: u64,
190        topic: Bytes,
191        check_session_id: Option<SessionId>,
192    ) -> Result<(), crate::Error> {
193        let mut sessions = self.data.sessions.write();
194        if let Some(client_session) = sessions.get_mut(&check_session_id) {
195            if let Some(subscriber) = client_session.subscribers.get(&subscriber_id) {
196                subscriber.subscribe_to_bytes(topic.0)?;
197                Ok(())
198            } else {
199                Err(Error::other(
200                    "bonsaidb-server pubsub",
201                    "invalid subscriber id",
202                ))
203            }
204        } else {
205            Err(Error::other("bonsaidb-server auth", "invalid session id"))
206        }
207    }
208
209    pub(crate) fn unsubscribe_by_id(
210        &self,
211        subscriber_id: u64,
212        topic: &[u8],
213        check_session_id: Option<SessionId>,
214    ) -> Result<(), crate::Error> {
215        let mut sessions = self.data.sessions.write();
216        if let Some(client_session) = sessions.get_mut(&check_session_id) {
217            if let Some(subscriber) = client_session.subscribers.get(&subscriber_id) {
218                subscriber.unsubscribe_from_bytes(topic)?;
219                Ok(())
220            } else {
221                Err(Error::other(
222                    "bonsaidb-server pubsub",
223                    "invalid subscriber id",
224                ))
225            }
226        } else {
227            Err(Error::other("bonsaidb-server auth", "invalid session id"))
228        }
229    }
230
231    pub(crate) fn unregister_subscriber_by_id(
232        &self,
233        subscriber_id: u64,
234        check_session_id: Option<SessionId>,
235    ) -> Result<(), crate::Error> {
236        let mut sessions = self.data.sessions.write();
237        if let Some(client_session) = sessions.get_mut(&check_session_id) {
238            if client_session.subscribers.remove(&subscriber_id).is_some() {
239                Ok(())
240            } else {
241                Err(Error::other(
242                    "bonsaidb-server pubsub",
243                    "invalid subscriber id",
244                ))
245            }
246        } else {
247            Err(Error::other("bonsaidb-server auth", "invalid session id"))
248        }
249    }
250}
251
252/// A locked reference to associated client data.
253pub struct LockedClientDataGuard<'client, ClientData>(MutexGuard<'client, Option<ClientData>>);
254
255impl<'client, ClientData> Deref for LockedClientDataGuard<'client, ClientData> {
256    type Target = Option<ClientData>;
257
258    fn deref(&self) -> &Self::Target {
259        &self.0
260    }
261}
262
263impl<'client, ClientData> DerefMut for LockedClientDataGuard<'client, ClientData> {
264    fn deref_mut(&mut self) -> &mut Self::Target {
265        &mut self.0
266    }
267}
268
269#[derive(Debug)]
270pub struct OwnedClient<B: Backend> {
271    client: ConnectedClient<B>,
272    runtime: Arc<tokio::runtime::Handle>,
273    server: Option<CustomServer<B>>,
274}
275
276impl<B: Backend> OwnedClient<B> {
277    pub(crate) fn new(
278        id: u32,
279        address: SocketAddr,
280        transport: Transport,
281        response_sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
282        server: CustomServer<B>,
283        default_session: Session,
284    ) -> Self {
285        let mut session = HashMap::new();
286        session.insert(
287            None,
288            ClientSession {
289                session: default_session,
290                subscribers: HashMap::default(),
291            },
292        );
293        Self {
294            client: ConnectedClient {
295                data: Arc::new(Data {
296                    id,
297                    address,
298                    transport,
299                    response_sender,
300                    sessions: RwLock::new(session),
301                    client_data: Mutex::default(),
302                    connected: AtomicBool::new(true),
303                }),
304            },
305            runtime: Arc::new(tokio::runtime::Handle::current()),
306            server: Some(server),
307        }
308    }
309
310    pub fn clone(&self) -> ConnectedClient<B> {
311        self.client.clone()
312    }
313}
314
315impl<B: Backend> Drop for OwnedClient<B> {
316    fn drop(&mut self) {
317        let id = self.client.data.id;
318        let server = self.server.take().unwrap();
319        self.runtime
320            .spawn(async move { server.disconnect_client(id).await });
321    }
322}
323
324impl<B: Backend> Deref for OwnedClient<B> {
325    type Target = ConnectedClient<B>;
326
327    fn deref(&self) -> &Self::Target {
328        &self.client
329    }
330}