1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::ops::{Deref, DerefMut};
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6
7use async_lock::{Mutex, MutexGuard};
8use bonsaidb_core::api;
9use bonsaidb_core::api::ApiName;
10use bonsaidb_core::arc_bytes::serde::Bytes;
11use bonsaidb_core::connection::{Session, SessionId};
12use bonsaidb_core::networking::MessageReceived;
13use bonsaidb_core::pubsub::{Receiver, Subscriber as _};
14use bonsaidb_local::Subscriber;
15use bonsaidb_utils::fast_async_lock;
16use derive_where::derive_where;
17use flume::Sender;
18use parking_lot::RwLock;
19
20use crate::{Backend, CustomServer, Error, NoBackend};
21
22#[derive(Debug, PartialEq, Eq)]
24pub enum Transport {
25 Bonsai,
27 #[cfg(feature = "websockets")]
29 WebSocket,
30}
31
32#[derive(Debug)]
34#[derive_where(Clone)]
35pub struct ConnectedClient<B: Backend = NoBackend> {
36 data: Arc<Data<B>>,
37}
38
39#[derive(Debug)]
40struct Data<B: Backend = NoBackend> {
41 id: u32,
42 sessions: RwLock<HashMap<Option<SessionId>, ClientSession>>,
43 address: SocketAddr,
44 transport: Transport,
45 response_sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
46 client_data: Mutex<Option<B::ClientData>>,
47 connected: AtomicBool,
48}
49
50#[derive(Debug)]
51struct ClientSession {
52 session: Session,
53 subscribers: HashMap<u64, Subscriber>,
54}
55
56impl<B: Backend> ConnectedClient<B> {
57 #[must_use]
59 pub fn address(&self) -> &SocketAddr {
60 &self.data.address
61 }
62
63 #[must_use]
65 pub fn transport(&self) -> &Transport {
66 &self.data.transport
67 }
68
69 #[must_use]
71 pub fn connected(&self) -> bool {
72 self.data.connected.load(Ordering::Relaxed)
73 }
74
75 pub(crate) fn set_disconnected(&self) {
76 self.data.connected.store(false, Ordering::Relaxed);
77 }
78
79 pub(crate) fn logged_in_as(&self, session: Session) {
80 let mut sessions = self.data.sessions.write();
81 sessions.insert(
82 session.id,
83 ClientSession {
84 session,
85 subscribers: HashMap::default(),
86 },
87 );
88 }
89
90 pub(crate) fn log_out(&self, session: SessionId) -> Option<Session> {
91 let mut sessions = self.data.sessions.write();
92 sessions.remove(&Some(session)).map(|cs| cs.session)
93 }
94
95 pub fn send<Api: api::Api>(
97 &self,
98 session: Option<&Session>,
99 response: &Api::Response,
100 ) -> Result<(), Error> {
101 let encoded = pot::to_vec(&Result::<&Api::Response, Api::Error>::Ok(response))?;
102 self.data.response_sender.send((
103 session.and_then(|session| session.id),
104 Api::name(),
105 Bytes::from(encoded),
106 ))?;
107 Ok(())
108 }
109
110 pub async fn client_data(&self) -> LockedClientDataGuard<'_, B::ClientData> {
112 LockedClientDataGuard(fast_async_lock!(self.data.client_data))
113 }
114
115 #[must_use]
119 pub fn session(&self, session_id: Option<SessionId>) -> Option<Session> {
120 let sessions = self.data.sessions.read();
121 sessions.get(&session_id).map(|data| data.session.clone())
122 }
123
124 #[must_use]
126 pub fn all_sessions<C: FromIterator<Session>>(&self) -> C {
127 let sessions = self.data.sessions.read();
128 sessions.values().map(|s| s.session.clone()).collect()
129 }
130
131 pub(crate) fn register_subscriber(
132 &self,
133 subscriber: Subscriber,
134 session_id: Option<SessionId>,
135 ) {
136 let subscriber_id = subscriber.id();
137 let receiver = subscriber.receiver().clone();
138 {
139 let mut sessions = self.data.sessions.write();
140 if let Some(client_session) = sessions.get_mut(&session_id) {
141 client_session
142 .subscribers
143 .insert(subscriber.id(), subscriber);
144 } else {
145 return;
147 }
148 }
149 let task_self = self.clone();
150 tokio::task::spawn(async move {
151 task_self
152 .forward_notifications_for(session_id, subscriber_id, receiver)
153 .await;
154 });
155 }
156
157 pub async fn set_client_data(&self, data: B::ClientData) {
159 let mut client_data = fast_async_lock!(self.data.client_data);
160 *client_data = Some(data);
161 }
162
163 async fn forward_notifications_for(
164 &self,
165 session_id: Option<SessionId>,
166 subscriber_id: u64,
167 receiver: Receiver,
168 ) {
169 let session = self.session(session_id);
170 while let Ok(message) = receiver.receive_async().await {
171 if self
172 .send::<MessageReceived>(
173 session.as_ref(),
174 &MessageReceived {
175 subscriber_id,
176 topic: Bytes::from(message.topic.0.into_vec()),
177 payload: Bytes::from(&message.payload[..]),
178 },
179 )
180 .is_err()
181 {
182 break;
183 }
184 }
185 }
186
187 pub(crate) fn subscribe_by_id(
188 &self,
189 subscriber_id: u64,
190 topic: Bytes,
191 check_session_id: Option<SessionId>,
192 ) -> Result<(), crate::Error> {
193 let mut sessions = self.data.sessions.write();
194 if let Some(client_session) = sessions.get_mut(&check_session_id) {
195 if let Some(subscriber) = client_session.subscribers.get(&subscriber_id) {
196 subscriber.subscribe_to_bytes(topic.0)?;
197 Ok(())
198 } else {
199 Err(Error::other(
200 "bonsaidb-server pubsub",
201 "invalid subscriber id",
202 ))
203 }
204 } else {
205 Err(Error::other("bonsaidb-server auth", "invalid session id"))
206 }
207 }
208
209 pub(crate) fn unsubscribe_by_id(
210 &self,
211 subscriber_id: u64,
212 topic: &[u8],
213 check_session_id: Option<SessionId>,
214 ) -> Result<(), crate::Error> {
215 let mut sessions = self.data.sessions.write();
216 if let Some(client_session) = sessions.get_mut(&check_session_id) {
217 if let Some(subscriber) = client_session.subscribers.get(&subscriber_id) {
218 subscriber.unsubscribe_from_bytes(topic)?;
219 Ok(())
220 } else {
221 Err(Error::other(
222 "bonsaidb-server pubsub",
223 "invalid subscriber id",
224 ))
225 }
226 } else {
227 Err(Error::other("bonsaidb-server auth", "invalid session id"))
228 }
229 }
230
231 pub(crate) fn unregister_subscriber_by_id(
232 &self,
233 subscriber_id: u64,
234 check_session_id: Option<SessionId>,
235 ) -> Result<(), crate::Error> {
236 let mut sessions = self.data.sessions.write();
237 if let Some(client_session) = sessions.get_mut(&check_session_id) {
238 if client_session.subscribers.remove(&subscriber_id).is_some() {
239 Ok(())
240 } else {
241 Err(Error::other(
242 "bonsaidb-server pubsub",
243 "invalid subscriber id",
244 ))
245 }
246 } else {
247 Err(Error::other("bonsaidb-server auth", "invalid session id"))
248 }
249 }
250}
251
252pub struct LockedClientDataGuard<'client, ClientData>(MutexGuard<'client, Option<ClientData>>);
254
255impl<'client, ClientData> Deref for LockedClientDataGuard<'client, ClientData> {
256 type Target = Option<ClientData>;
257
258 fn deref(&self) -> &Self::Target {
259 &self.0
260 }
261}
262
263impl<'client, ClientData> DerefMut for LockedClientDataGuard<'client, ClientData> {
264 fn deref_mut(&mut self) -> &mut Self::Target {
265 &mut self.0
266 }
267}
268
269#[derive(Debug)]
270pub struct OwnedClient<B: Backend> {
271 client: ConnectedClient<B>,
272 runtime: Arc<tokio::runtime::Handle>,
273 server: Option<CustomServer<B>>,
274}
275
276impl<B: Backend> OwnedClient<B> {
277 pub(crate) fn new(
278 id: u32,
279 address: SocketAddr,
280 transport: Transport,
281 response_sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
282 server: CustomServer<B>,
283 default_session: Session,
284 ) -> Self {
285 let mut session = HashMap::new();
286 session.insert(
287 None,
288 ClientSession {
289 session: default_session,
290 subscribers: HashMap::default(),
291 },
292 );
293 Self {
294 client: ConnectedClient {
295 data: Arc::new(Data {
296 id,
297 address,
298 transport,
299 response_sender,
300 sessions: RwLock::new(session),
301 client_data: Mutex::default(),
302 connected: AtomicBool::new(true),
303 }),
304 },
305 runtime: Arc::new(tokio::runtime::Handle::current()),
306 server: Some(server),
307 }
308 }
309
310 pub fn clone(&self) -> ConnectedClient<B> {
311 self.client.clone()
312 }
313}
314
315impl<B: Backend> Drop for OwnedClient<B> {
316 fn drop(&mut self) {
317 let id = self.client.data.id;
318 let server = self.server.take().unwrap();
319 self.runtime
320 .spawn(async move { server.disconnect_client(id).await });
321 }
322}
323
324impl<B: Backend> Deref for OwnedClient<B> {
325 type Target = ConnectedClient<B>;
326
327 fn deref(&self) -> &Self::Target {
328 &self.client
329 }
330}