bonsaidb_server/
server.rs

1use std::collections::{hash_map, HashMap};
2use std::fmt::Debug;
3use std::net::SocketAddr;
4use std::ops::Deref;
5use std::path::PathBuf;
6use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
7use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use bonsaidb_core::admin::{Admin, ADMIN_DATABASE_NAME};
12use bonsaidb_core::api;
13use bonsaidb_core::api::ApiName;
14use bonsaidb_core::arc_bytes::serde::Bytes;
15use bonsaidb_core::connection::{
16    self, AsyncConnection, AsyncStorageConnection, HasSession, IdentityReference, Session,
17    SessionId,
18};
19use bonsaidb_core::networking::{self, Payload, CURRENT_PROTOCOL_VERSION};
20use bonsaidb_core::permissions::bonsai::{bonsaidb_resource_name, BonsaiAction, ServerAction};
21use bonsaidb_core::permissions::Permissions;
22use bonsaidb_core::schema::{self, Nameable, NamedCollection, Schema, SchemaSummary};
23use bonsaidb_local::config::Builder;
24use bonsaidb_local::{AsyncStorage, Storage, StorageNonBlocking};
25use bonsaidb_utils::fast_async_lock;
26use derive_where::derive_where;
27use fabruic::{self, CertificateChain, Endpoint, KeyPair, PrivateKey};
28use flume::Sender;
29use futures::{Future, StreamExt};
30use parking_lot::{Mutex, RwLock};
31use rustls::sign::CertifiedKey;
32use schema::SchemaName;
33#[cfg(not(windows))]
34use signal_hook::consts::SIGQUIT;
35use signal_hook::consts::{SIGINT, SIGTERM};
36use tokio::sync::{oneshot, Notify};
37
38use crate::api::{AnyHandler, HandlerSession};
39use crate::backend::ConnectionHandling;
40#[cfg(feature = "acme")]
41use crate::config::AcmeConfiguration;
42use crate::dispatch::{register_api_handlers, ServerDispatcher};
43use crate::error::Error;
44use crate::hosted::{Hosted, SerializablePrivateKey, TlsCertificate, TlsCertificatesByDomain};
45use crate::server::shutdown::{Shutdown, ShutdownState, ShutdownStateWatcher};
46use crate::{Backend, BackendError, BonsaiListenConfig, NoBackend, ServerConfiguration};
47
48#[cfg(feature = "acme")]
49pub mod acme;
50mod connected_client;
51mod database;
52
53mod shutdown;
54mod tcp;
55#[cfg(feature = "websockets")]
56mod websockets;
57
58use self::connected_client::OwnedClient;
59pub use self::connected_client::{ConnectedClient, LockedClientDataGuard, Transport};
60pub use self::database::ServerDatabase;
61pub use self::tcp::{ApplicationProtocols, HttpService, Peer, StandardTcpProtocols, TcpService};
62
63static CONNECTED_CLIENT_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
64
65/// A BonsaiDb server.
66#[derive(Debug)]
67#[derive_where(Clone)]
68pub struct CustomServer<B: Backend = NoBackend> {
69    data: Arc<Data<B>>,
70    pub(crate) storage: AsyncStorage,
71}
72
73impl<'a, B: Backend> From<&'a CustomServer<B>> for Storage {
74    fn from(server: &'a CustomServer<B>) -> Self {
75        Self::from(server.storage.clone())
76    }
77}
78
79impl<B: Backend> From<CustomServer<B>> for Storage {
80    fn from(server: CustomServer<B>) -> Self {
81        Self::from(server.storage)
82    }
83}
84
85/// A BonsaiDb server without a custom backend.
86pub type Server = CustomServer<NoBackend>;
87
88#[derive(Debug)]
89struct Data<B: Backend = NoBackend> {
90    backend: B,
91    clients: RwLock<HashMap<u32, ConnectedClient<B>>>,
92    request_processor: flume::Sender<ClientRequest<B>>,
93    default_session: Session,
94    client_simultaneous_request_limit: usize,
95    primary_tls_key: CachedCertifiedKey,
96    primary_domain: String,
97    custom_apis: RwLock<HashMap<ApiName, Arc<dyn AnyHandler<B>>>>,
98    #[cfg(feature = "acme")]
99    acme: AcmeConfiguration,
100    #[cfg(feature = "acme")]
101    alpn_keys: AlpnKeys,
102    shutdown: Shutdown,
103}
104
105#[derive(Default)]
106struct CachedCertifiedKey(Mutex<Option<Arc<CertifiedKey>>>);
107
108impl Debug for CachedCertifiedKey {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        f.debug_tuple("CachedCertifiedKey").finish()
111    }
112}
113
114impl Deref for CachedCertifiedKey {
115    type Target = Mutex<Option<Arc<CertifiedKey>>>;
116
117    fn deref(&self) -> &Self::Target {
118        &self.0
119    }
120}
121
122impl<B: Backend> CustomServer<B> {
123    /// Opens a server using `directory` for storage.
124    pub async fn open(
125        configuration: ServerConfiguration<B>,
126    ) -> Result<Self, BackendError<B::Error>> {
127        let configuration = register_api_handlers(B::configure(configuration)?)?;
128        let (request_sender, request_receiver) = flume::unbounded::<ClientRequest<B>>();
129        for _ in 0..configuration.request_workers {
130            let request_receiver = request_receiver.clone();
131            tokio::task::spawn(async move {
132                while let Ok(mut client_request) = request_receiver.recv_async().await {
133                    let request = client_request.request.take().unwrap();
134                    let session = client_request.session.clone();
135                    // TODO we should be able to upgrade a session-less Storage to one with a Session.
136                    // The Session needs to be looked up from the client based on the request's session id.
137                    let result = match client_request.server.storage.assume_session(session) {
138                        Ok(storage) => {
139                            let client = HandlerSession {
140                                server: &client_request.server,
141                                client: &client_request.client,
142                                as_client: Self {
143                                    data: client_request.server.data.clone(),
144                                    storage,
145                                },
146                            };
147                            ServerDispatcher::dispatch_api_request(
148                                client,
149                                &request.name,
150                                request.value.unwrap(),
151                            )
152                            .await
153                            .map_err(bonsaidb_core::Error::from)
154                        }
155                        Err(err) => Err(err),
156                    };
157                    drop(client_request.result_sender.send((request.name, result)));
158                }
159            });
160        }
161
162        let storage = AsyncStorage::open(configuration.storage.with_schema::<Hosted>()?).await?;
163
164        storage.create_database::<Hosted>("_hosted", true).await?;
165
166        let default_permissions = Permissions::from(configuration.default_permissions);
167
168        let server = Self {
169            storage,
170            data: Arc::new(Data {
171                backend: configuration.backend,
172                clients: parking_lot::RwLock::default(),
173                request_processor: request_sender,
174                default_session: Session {
175                    permissions: default_permissions,
176                    ..Session::default()
177                },
178                client_simultaneous_request_limit: configuration.client_simultaneous_request_limit,
179                primary_tls_key: CachedCertifiedKey::default(),
180                primary_domain: configuration.server_name,
181                custom_apis: parking_lot::RwLock::new(configuration.custom_apis),
182                #[cfg(feature = "acme")]
183                acme: configuration.acme,
184                #[cfg(feature = "acme")]
185                alpn_keys: AlpnKeys::default(),
186                shutdown: Shutdown::new(),
187            }),
188        };
189
190        server.data.backend.initialize(&server).await?;
191        Ok(server)
192    }
193
194    /// Returns the path to the public pinned certificate, if this server has
195    /// one. Note: this function will always succeed, but the file may not
196    /// exist.
197    #[must_use]
198    pub fn pinned_certificate_path(&self) -> PathBuf {
199        self.path().join("pinned-certificate.der")
200    }
201
202    /// Returns the primary domain configured for this server.
203    #[must_use]
204    pub fn primary_domain(&self) -> &str {
205        &self.data.primary_domain
206    }
207
208    /// Returns the [`Backend`] implementor for this server.
209    #[must_use]
210    pub fn backend(&self) -> &B {
211        &self.data.backend
212    }
213
214    /// Returns the administration database.
215    pub async fn admin(&self) -> ServerDatabase<B> {
216        let db = self.storage.admin().await;
217        ServerDatabase {
218            server: self.clone(),
219            db,
220        }
221    }
222
223    pub(crate) async fn hosted(&self) -> ServerDatabase<B> {
224        let db = self.storage.database::<Hosted>("_hosted").await.unwrap();
225        ServerDatabase {
226            server: self.clone(),
227            db,
228        }
229    }
230
231    pub(crate) fn custom_api_dispatcher(&self, name: &ApiName) -> Option<Arc<dyn AnyHandler<B>>> {
232        let dispatchers = self.data.custom_apis.read();
233        dispatchers.get(name).cloned()
234    }
235
236    /// Installs an X.509 certificate used for general purpose connections.
237    pub async fn install_self_signed_certificate(&self, overwrite: bool) -> Result<(), Error> {
238        let keypair = KeyPair::new_self_signed(&self.data.primary_domain);
239
240        if self.certificate_chain().await.is_ok() && !overwrite {
241            return Err(Error::Core(bonsaidb_core::Error::other("bonsaidb-server config", "Certificate already installed. Enable overwrite if you wish to replace the existing certificate.")));
242        }
243
244        self.install_certificate(keypair.certificate_chain(), keypair.private_key())
245            .await?;
246
247        tokio::fs::write(
248            self.pinned_certificate_path(),
249            keypair.end_entity_certificate().as_ref(),
250        )
251        .await?;
252
253        Ok(())
254    }
255
256    /// Installs a certificate chain and private key used for TLS connections.
257    #[cfg(feature = "pem")]
258    pub async fn install_pem_certificate(
259        &self,
260        certificate_chain: &[u8],
261        private_key: &[u8],
262    ) -> Result<(), Error> {
263        let private_key = match pem::parse(private_key) {
264            Ok(pem) => PrivateKey::unchecked_from_der(pem.contents()),
265            Err(_) => PrivateKey::from_der(private_key)?,
266        };
267        let certificates = pem::parse_many(certificate_chain)?
268            .into_iter()
269            .map(|entry| fabruic::Certificate::unchecked_from_der(entry.contents()))
270            .collect::<Vec<_>>();
271
272        self.install_certificate(
273            &CertificateChain::unchecked_from_certificates(certificates),
274            &private_key,
275        )
276        .await
277    }
278
279    /// Installs a certificate chain and private key used for TLS connections.
280    pub async fn install_certificate(
281        &self,
282        certificate_chain: &CertificateChain,
283        private_key: &PrivateKey,
284    ) -> Result<(), Error> {
285        let db = self.hosted().await;
286
287        TlsCertificate::entry_async(&self.data.primary_domain, &db)
288            .update_with(|cert: &mut TlsCertificate| {
289                cert.certificate_chain = certificate_chain.clone();
290                cert.private_key = SerializablePrivateKey(private_key.clone());
291            })
292            .or_insert_with(|| TlsCertificate {
293                domains: vec![self.data.primary_domain.clone()],
294                private_key: SerializablePrivateKey(private_key.clone()),
295                certificate_chain: certificate_chain.clone(),
296            })
297            .await?;
298
299        self.refresh_certified_key().await?;
300
301        let pinned_certificate_path = self.pinned_certificate_path();
302        if pinned_certificate_path.exists() {
303            tokio::fs::remove_file(&pinned_certificate_path).await?;
304        }
305
306        Ok(())
307    }
308
309    async fn refresh_certified_key(&self) -> Result<(), Error> {
310        let certificate = self.tls_certificate().await?;
311
312        let mut cached_key = self.data.primary_tls_key.lock();
313        let private_key = rustls::PrivateKey(
314            fabruic::dangerous::PrivateKey::as_ref(&certificate.private_key.0).to_vec(),
315        );
316        let private_key = rustls::sign::any_ecdsa_type(&Arc::new(private_key))?;
317
318        let certificates = certificate
319            .certificate_chain
320            .iter()
321            .map(|cert| rustls::Certificate(cert.as_ref().to_vec()))
322            .collect::<Vec<_>>();
323
324        let certified_key = Arc::new(CertifiedKey::new(certificates, private_key));
325        *cached_key = Some(certified_key);
326        Ok(())
327    }
328
329    async fn tls_certificate(&self) -> Result<TlsCertificate, Error> {
330        let db = self.hosted().await;
331        let (_, certificate) = db
332            .view::<TlsCertificatesByDomain>()
333            .with_key(&self.data.primary_domain)
334            .query_with_collection_docs()
335            .await?
336            .documents
337            .into_iter()
338            .next()
339            .ok_or_else(|| {
340                Error::Core(bonsaidb_core::Error::other(
341                    "bonsaidb-server config",
342                    format!("no certificate found for {}", self.data.primary_domain),
343                ))
344            })?;
345        Ok(certificate.contents)
346    }
347
348    /// Returns the current certificate chain.
349    pub async fn certificate_chain(&self) -> Result<CertificateChain, Error> {
350        let db = self.hosted().await;
351        if let Some(mapping) = db
352            .view::<TlsCertificatesByDomain>()
353            .with_key(&self.data.primary_domain)
354            .query()
355            .await?
356            .into_iter()
357            .next()
358        {
359            Ok(mapping.value)
360        } else {
361            Err(Error::Core(bonsaidb_core::Error::other(
362                "bonsaidb-server config",
363                format!("no certificate found for {}", self.data.primary_domain),
364            )))
365        }
366    }
367
368    /// Listens for incoming client connections. Does not return until the
369    /// server shuts down.
370    ///
371    /// ## Listening on a port
372    ///
373    /// When passing a `u16` to this function, the server will begin listening
374    /// on an "unspecified" address. This typically is accessible to other
375    /// machines on the network/internet, so care should be taken to ensure this
376    /// is what is intended.
377    ///
378    /// To ensure that the server only listens for local traffic, specify a
379    /// local IP or localhost in addition to the port number.
380    pub async fn listen_on(&self, config: impl Into<BonsaiListenConfig>) -> Result<(), Error> {
381        let config = config.into();
382        let certificate = self.tls_certificate().await?;
383        let keypair =
384            KeyPair::from_parts(certificate.certificate_chain, certificate.private_key.0)?;
385        let mut builder = Endpoint::builder();
386        builder.set_protocols([CURRENT_PROTOCOL_VERSION.as_bytes().to_vec()]);
387        builder.set_address(config.address);
388        builder.set_max_idle_timeout(None)?;
389        builder.set_server_key_pair(Some(keypair));
390        builder.set_reuse_address(config.reuse_address);
391        let mut server = builder.build()?;
392
393        let mut shutdown_watcher = self
394            .data
395            .shutdown
396            .watcher()
397            .await
398            .expect("server already shut down");
399
400        while let Some(incoming) = tokio::select! {
401            shutdown_state = shutdown_watcher.wait_for_shutdown() => {
402                drop(server.close_incoming());
403                if matches!(shutdown_state, ShutdownState::GracefulShutdown) {
404                    server.wait_idle().await;
405                }
406                None
407            },
408            msg = server.next() => msg
409        } {
410            let address = incoming.remote_address();
411            let connection = match incoming.accept::<()>().await {
412                Ok(connection) => connection,
413                Err(err) => {
414                    log::error!("[server] error on incoming connection from {address}: {err:?}");
415                    continue;
416                }
417            };
418            let task_self = self.clone();
419            tokio::spawn(async move {
420                if let Err(err) = task_self.handle_bonsai_connection(connection).await {
421                    log::error!("[server] closing connection {address}: {err:?}");
422                }
423            });
424        }
425
426        Ok(())
427    }
428
429    /// Returns all of the currently connected clients.
430    #[must_use]
431    pub fn connected_clients(&self) -> Vec<ConnectedClient<B>> {
432        let clients = self.data.clients.read();
433        clients.values().cloned().collect()
434    }
435
436    /// Sends a custom API response to all connected clients.
437    pub fn broadcast<Api: api::Api>(&self, response: &Api::Response) {
438        let clients = self.data.clients.read();
439        for client in clients.values() {
440            // TODO should this broadcast to all sessions too rather than only the global session?
441            drop(client.send::<Api>(None, response));
442        }
443    }
444
445    async fn initialize_client(
446        &self,
447        transport: Transport,
448        address: SocketAddr,
449        sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
450    ) -> Option<OwnedClient<B>> {
451        if !self.data.default_session.allowed_to(
452            bonsaidb_resource_name(),
453            &BonsaiAction::Server(ServerAction::Connect),
454        ) {
455            return None;
456        }
457
458        let client = loop {
459            let next_id = CONNECTED_CLIENT_ID_COUNTER.fetch_add(1, Ordering::SeqCst);
460            let mut clients = self.data.clients.write();
461            if let hash_map::Entry::Vacant(e) = clients.entry(next_id) {
462                let client = OwnedClient::new(
463                    next_id,
464                    address,
465                    transport,
466                    sender,
467                    self.clone(),
468                    self.data.default_session.clone(),
469                );
470                e.insert(client.clone());
471                break client;
472            }
473        };
474
475        match self.data.backend.client_connected(&client, self).await {
476            Ok(ConnectionHandling::Accept) => Some(client),
477            Ok(ConnectionHandling::Reject) => None,
478            Err(err) => {
479                log::error!(
480                    "[server] Rejecting connection due to error in `client_connected`: {err:?}"
481                );
482                None
483            }
484        }
485    }
486
487    async fn disconnect_client(&self, id: u32) {
488        let removed_client = {
489            let mut clients = self.data.clients.write();
490            clients.remove(&id)
491        };
492
493        if let Some(client) = removed_client {
494            client.set_disconnected();
495            for session in client.all_sessions::<Vec<_>>() {
496                if let Err(err) = self
497                    .data
498                    .backend
499                    .client_session_ended(session, &client, true, self)
500                    .await
501                {
502                    log::error!("[server] Error in `client_session_ended`: {err:?}");
503                }
504            }
505
506            if let Err(err) = self.data.backend.client_disconnected(client, self).await {
507                log::error!("[server] Error in `client_disconnected`: {err:?}");
508            }
509        }
510    }
511
512    async fn handle_bonsai_connection(
513        &self,
514        mut connection: fabruic::Connection<()>,
515    ) -> Result<(), Error> {
516        if let Some(incoming) = connection.next().await {
517            let incoming = match incoming {
518                Ok(incoming) => incoming,
519                Err(err) => {
520                    log::error!("[server] Error establishing a stream: {err:?}");
521                    return Ok(());
522                }
523            };
524
525            match incoming
526                .accept::<networking::Payload, networking::Payload>()
527                .await
528            {
529                Ok((sender, receiver)) => {
530                    let (api_response_sender, api_response_receiver) = flume::unbounded();
531                    if let Some(disconnector) = self
532                        .initialize_client(
533                            Transport::Bonsai,
534                            connection.remote_address(),
535                            api_response_sender,
536                        )
537                        .await
538                    {
539                        let task_sender = sender.clone();
540                        tokio::spawn(async move {
541                            while let Ok((session_id, name, bytes)) =
542                                api_response_receiver.recv_async().await
543                            {
544                                if task_sender
545                                    .send(&Payload {
546                                        id: None,
547                                        session_id,
548                                        name,
549                                        value: Ok(bytes),
550                                    })
551                                    .is_err()
552                                {
553                                    break;
554                                }
555                            }
556                        });
557
558                        let task_self = self.clone();
559                        let Some(shutdown) = self.data.shutdown.watcher().await else {
560                            return Ok(());
561                        };
562                        tokio::spawn(async move {
563                            if let Err(err) = task_self
564                                .handle_stream(disconnector, sender, receiver, shutdown)
565                                .await
566                            {
567                                log::error!("[server] Error handling stream: {err:?}");
568                            }
569                        });
570                    } else {
571                        log::error!("[server] Backend rejected connection.");
572                        return Ok(());
573                    }
574                }
575                Err(err) => {
576                    log::error!("[server] Error accepting incoming stream: {err:?}");
577                    return Ok(());
578                }
579            }
580        }
581        Ok(())
582    }
583
584    async fn handle_client_requests(
585        &self,
586        client: ConnectedClient<B>,
587        request_receiver: flume::Receiver<Payload>,
588        response_sender: flume::Sender<Payload>,
589        mut shutdown: ShutdownStateWatcher,
590    ) {
591        let notify = Arc::new(Notify::new());
592        let requests_in_queue = Arc::new(AtomicUsize::new(0));
593        loop {
594            let current_requests = requests_in_queue.load(Ordering::SeqCst);
595            if current_requests == self.data.client_simultaneous_request_limit {
596                // Wait for requests to finish.
597                notify.notified().await;
598            } else if requests_in_queue
599                .compare_exchange(
600                    current_requests,
601                    current_requests + 1,
602                    Ordering::SeqCst,
603                    Ordering::SeqCst,
604                )
605                .is_ok()
606            {
607                let payload = 'payload: loop {
608                    tokio::select! {
609                        payload = request_receiver.recv_async() => {
610                            if let Ok(payload) = payload {
611                                break 'payload payload
612                            }
613
614                            return
615                        },
616                        state = shutdown.wait_for_shutdown() => {
617                            if matches!(state, ShutdownState::Shutdown | ShutdownState::GracefulShutdown) {
618                                return
619                            }
620                        }
621                    }
622                };
623                let session_id = payload.session_id;
624                let id = payload.id;
625                let task_sender = response_sender.clone();
626
627                let notify = notify.clone();
628                let requests_in_queue = requests_in_queue.clone();
629                self.handle_request_through_worker(
630                    payload,
631                    move |name, value| async move {
632                        drop(task_sender.send(Payload {
633                            session_id,
634                            id,
635                            name,
636                            value,
637                        }));
638
639                        requests_in_queue.fetch_sub(1, Ordering::SeqCst);
640
641                        notify.notify_one();
642
643                        Ok(())
644                    },
645                    client.clone(),
646                )
647                .unwrap();
648            }
649        }
650    }
651
652    fn handle_request_through_worker<
653        F: FnOnce(ApiName, Result<Bytes, bonsaidb_core::Error>) -> R + Send + 'static,
654        R: Future<Output = Result<(), Error>> + Send,
655    >(
656        &self,
657        request: Payload,
658        callback: F,
659        client: ConnectedClient<B>,
660    ) -> Result<(), Error> {
661        let (result_sender, result_receiver) = oneshot::channel();
662        let session = client
663            .session(request.session_id)
664            .unwrap_or_else(|| self.data.default_session.clone());
665        self.data
666            .request_processor
667            .send(ClientRequest::<B>::new(
668                request,
669                self.clone(),
670                client,
671                session,
672                result_sender,
673            ))
674            .map_err(|_| Error::InternalCommunication)?;
675        tokio::spawn(async move {
676            let (name, result) = result_receiver.await?;
677            // Map the error into a Response::Error. The jobs system supports
678            // multiple receivers receiving output, and wraps Err to avoid
679            // requiring the error to be cloneable. As such, we have to unwrap
680            // it. Thankfully, we can guarantee nothing else is waiting on a
681            // response to a request than the original requestor, so this can be
682            // safely unwrapped.
683            callback(name, result).await?;
684            Result::<(), Error>::Ok(())
685        });
686        Ok(())
687    }
688
689    async fn handle_stream(
690        &self,
691        client: OwnedClient<B>,
692        sender: fabruic::Sender<Payload>,
693        mut receiver: fabruic::Receiver<Payload>,
694        mut shutdown: ShutdownStateWatcher,
695    ) -> Result<(), Error> {
696        let (payload_sender, payload_receiver) = flume::unbounded();
697        tokio::spawn({
698            let mut shutdown = shutdown.clone();
699            async move {
700                'stream: loop {
701                    let payload = loop {
702                        tokio::select! {
703                            payload = payload_receiver.recv_async() => {
704                                if let Ok(payload) = payload {
705                                    break payload
706                                }
707                                break 'stream
708                            }
709                            shutdown = shutdown.wait_for_shutdown() => {
710                                if matches!(shutdown, ShutdownState::Shutdown | ShutdownState::GracefulShutdown) {
711                                    break 'stream
712                                }
713                            }
714                        }
715                    };
716                    if sender.send(&payload).is_err() {
717                        break;
718                    }
719                }
720            }
721        });
722
723        let (request_sender, request_receiver) =
724            flume::bounded::<Payload>(self.data.client_simultaneous_request_limit);
725        let task_self = self.clone();
726        tokio::spawn({
727            let shutdown = shutdown.clone();
728            async move {
729                task_self
730                    .handle_client_requests(
731                        client.clone(),
732                        request_receiver,
733                        payload_sender,
734                        shutdown,
735                    )
736                    .await;
737            }
738        });
739
740        loop {
741            let payload = loop {
742                tokio::select! {
743                    payload = receiver.next() => {
744                        if let Some(payload) = payload {
745                            break payload
746                        }
747
748                        receiver.finish().await?;
749
750                        return Ok(());
751                    }
752                    shutdown = shutdown.wait_for_shutdown() => {
753                        if matches!(shutdown, ShutdownState::Shutdown | ShutdownState::GracefulShutdown) {
754                            return Ok(());
755                        }
756                    }
757                }
758            };
759            drop(request_sender.send_async(payload?).await);
760        }
761    }
762
763    /// Shuts the server down. If a `timeout` is provided, the server will stop
764    /// accepting new connections and attempt to respond to any outstanding
765    /// requests already being processed. After the `timeout` has elapsed or if
766    /// no `timeout` was provided, the server is forcefully shut down.
767    pub async fn shutdown(&self, timeout: Option<Duration>) -> Result<(), Error> {
768        if let Some(timeout) = timeout {
769            self.data.shutdown.graceful_shutdown(timeout).await;
770        } else {
771            self.data.shutdown.shutdown().await;
772        }
773
774        Ok(())
775    }
776
777    /// Listens for signals from the operating system that the server should
778    /// shut down and attempts to gracefully shut down.
779    pub async fn listen_for_shutdown(&self) -> Result<(), Error> {
780        const GRACEFUL_SHUTDOWN: usize = 1;
781        const TERMINATE: usize = 2;
782
783        enum SignalShutdownState {
784            Running,
785            ShuttingDown(flume::Receiver<()>),
786        }
787
788        let shutdown_state = Arc::new(async_lock::Mutex::new(SignalShutdownState::Running));
789        let flag = Arc::new(AtomicUsize::default());
790        let register_hook = |flag: &Arc<AtomicUsize>| {
791            signal_hook::flag::register_usize(SIGINT, flag.clone(), GRACEFUL_SHUTDOWN)?;
792            signal_hook::flag::register_usize(SIGTERM, flag.clone(), TERMINATE)?;
793            #[cfg(not(windows))]
794            signal_hook::flag::register_usize(SIGQUIT, flag.clone(), TERMINATE)?;
795            Result::<(), std::io::Error>::Ok(())
796        };
797        if let Err(err) = register_hook(&flag) {
798            log::error!("Error installing signals for graceful shutdown: {err:?}");
799            tokio::time::sleep(Duration::MAX).await;
800        } else {
801            loop {
802                match flag.load(Ordering::Relaxed) {
803                    0 => {
804                        // No signal
805                    }
806                    GRACEFUL_SHUTDOWN => {
807                        let mut state = fast_async_lock!(shutdown_state);
808                        match *state {
809                            SignalShutdownState::Running => {
810                                log::error!("Interrupt signal received. Shutting down gracefully.");
811                                let task_server = self.clone();
812                                let (shutdown_sender, shutdown_receiver) = flume::bounded(1);
813                                tokio::task::spawn(async move {
814                                    task_server.shutdown(Some(Duration::from_secs(30))).await?;
815                                    let _: Result<_, _> = shutdown_sender.send(());
816                                    Result::<(), Error>::Ok(())
817                                });
818                                *state = SignalShutdownState::ShuttingDown(shutdown_receiver);
819                            }
820                            SignalShutdownState::ShuttingDown(_) => {
821                                // Two interrupts, go ahead and force the shutdown
822                                break;
823                            }
824                        }
825                    }
826                    TERMINATE => {
827                        log::error!("Quit signal received. Shutting down.");
828                        break;
829                    }
830                    _ => unreachable!(),
831                }
832
833                let state = fast_async_lock!(shutdown_state);
834                if let SignalShutdownState::ShuttingDown(receiver) = &*state {
835                    if receiver.try_recv().is_ok() {
836                        // Fully shut down.
837                        return Ok(());
838                    }
839                } else if self.data.shutdown.should_shutdown() {
840                    return Ok(());
841                }
842
843                tokio::time::sleep(Duration::from_millis(300)).await;
844            }
845            self.shutdown(None).await?;
846        }
847
848        Ok(())
849    }
850}
851
852impl<B: Backend> Deref for CustomServer<B> {
853    type Target = AsyncStorage;
854
855    fn deref(&self) -> &Self::Target {
856        &self.storage
857    }
858}
859
860#[derive(Debug)]
861struct ClientRequest<B: Backend> {
862    request: Option<Payload>,
863    client: ConnectedClient<B>,
864    session: Session,
865    server: CustomServer<B>,
866    result_sender: oneshot::Sender<(ApiName, Result<Bytes, bonsaidb_core::Error>)>,
867}
868
869impl<B: Backend> ClientRequest<B> {
870    pub fn new(
871        request: Payload,
872        server: CustomServer<B>,
873        client: ConnectedClient<B>,
874        session: Session,
875        result_sender: oneshot::Sender<(ApiName, Result<Bytes, bonsaidb_core::Error>)>,
876    ) -> Self {
877        Self {
878            request: Some(request),
879            server,
880            client,
881            session,
882            result_sender,
883        }
884    }
885}
886
887impl<B: Backend> HasSession for CustomServer<B> {
888    fn session(&self) -> Option<&Session> {
889        self.storage.session()
890    }
891}
892
893#[async_trait]
894impl<B: Backend> AsyncStorageConnection for CustomServer<B> {
895    type Authenticated = Self;
896    type Database = ServerDatabase<B>;
897
898    async fn admin(&self) -> Self::Database {
899        self.database::<Admin>(ADMIN_DATABASE_NAME).await.unwrap()
900    }
901
902    async fn create_database_with_schema(
903        &self,
904        name: &str,
905        schema: SchemaName,
906        only_if_needed: bool,
907    ) -> Result<(), bonsaidb_core::Error> {
908        self.storage
909            .create_database_with_schema(name, schema, only_if_needed)
910            .await
911    }
912
913    async fn database<DB: Schema>(
914        &self,
915        name: &str,
916    ) -> Result<Self::Database, bonsaidb_core::Error> {
917        let db = self.storage.database::<DB>(name).await?;
918        Ok(ServerDatabase {
919            server: self.clone(),
920            db,
921        })
922    }
923
924    async fn delete_database(&self, name: &str) -> Result<(), bonsaidb_core::Error> {
925        self.storage.delete_database(name).await
926    }
927
928    async fn list_databases(&self) -> Result<Vec<connection::Database>, bonsaidb_core::Error> {
929        self.storage.list_databases().await
930    }
931
932    async fn list_available_schemas(&self) -> Result<Vec<SchemaSummary>, bonsaidb_core::Error> {
933        self.storage.list_available_schemas().await
934    }
935
936    async fn create_user(&self, username: &str) -> Result<u64, bonsaidb_core::Error> {
937        self.storage.create_user(username).await
938    }
939
940    async fn delete_user<'user, U: Nameable<'user, u64> + Send + Sync>(
941        &self,
942        user: U,
943    ) -> Result<(), bonsaidb_core::Error> {
944        self.storage.delete_user(user).await
945    }
946
947    #[cfg(feature = "password-hashing")]
948    async fn set_user_password<'user, U: Nameable<'user, u64> + Send + Sync>(
949        &self,
950        user: U,
951        password: bonsaidb_core::connection::SensitiveString,
952    ) -> Result<(), bonsaidb_core::Error> {
953        self.storage.set_user_password(user, password).await
954    }
955
956    #[cfg(any(feature = "token-authentication", feature = "password-hashing"))]
957    async fn authenticate(
958        &self,
959        authentication: bonsaidb_core::connection::Authentication,
960    ) -> Result<Self::Authenticated, bonsaidb_core::Error> {
961        let storage = self.storage.authenticate(authentication).await?;
962        Ok(Self {
963            data: self.data.clone(),
964            storage,
965        })
966    }
967
968    async fn assume_identity(
969        &self,
970        identity: IdentityReference<'_>,
971    ) -> Result<Self::Authenticated, bonsaidb_core::Error> {
972        let storage = self.storage.assume_identity(identity).await?;
973        Ok(Self {
974            data: self.data.clone(),
975            storage,
976        })
977    }
978
979    async fn add_permission_group_to_user<
980        'user,
981        'group,
982        U: Nameable<'user, u64> + Send + Sync,
983        G: Nameable<'group, u64> + Send + Sync,
984    >(
985        &self,
986        user: U,
987        permission_group: G,
988    ) -> Result<(), bonsaidb_core::Error> {
989        self.storage
990            .add_permission_group_to_user(user, permission_group)
991            .await
992    }
993
994    async fn remove_permission_group_from_user<
995        'user,
996        'group,
997        U: Nameable<'user, u64> + Send + Sync,
998        G: Nameable<'group, u64> + Send + Sync,
999    >(
1000        &self,
1001        user: U,
1002        permission_group: G,
1003    ) -> Result<(), bonsaidb_core::Error> {
1004        self.storage
1005            .remove_permission_group_from_user(user, permission_group)
1006            .await
1007    }
1008
1009    async fn add_role_to_user<
1010        'user,
1011        'group,
1012        U: Nameable<'user, u64> + Send + Sync,
1013        G: Nameable<'group, u64> + Send + Sync,
1014    >(
1015        &self,
1016        user: U,
1017        role: G,
1018    ) -> Result<(), bonsaidb_core::Error> {
1019        self.storage.add_role_to_user(user, role).await
1020    }
1021
1022    async fn remove_role_from_user<
1023        'user,
1024        'group,
1025        U: Nameable<'user, u64> + Send + Sync,
1026        G: Nameable<'group, u64> + Send + Sync,
1027    >(
1028        &self,
1029        user: U,
1030        role: G,
1031    ) -> Result<(), bonsaidb_core::Error> {
1032        self.storage.remove_role_from_user(user, role).await
1033    }
1034}
1035
1036#[derive(Default)]
1037struct AlpnKeys(Arc<Mutex<HashMap<String, Arc<rustls::sign::CertifiedKey>>>>);
1038
1039impl Debug for AlpnKeys {
1040    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1041        f.debug_tuple("AlpnKeys").finish()
1042    }
1043}
1044
1045impl Deref for AlpnKeys {
1046    type Target = Arc<Mutex<HashMap<String, Arc<rustls::sign::CertifiedKey>>>>;
1047
1048    fn deref(&self) -> &Self::Target {
1049        &self.0
1050    }
1051}