1use std::collections::{hash_map, HashMap};
2use std::fmt::Debug;
3use std::net::SocketAddr;
4use std::ops::Deref;
5use std::path::PathBuf;
6use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
7use std::sync::Arc;
8use std::time::Duration;
9
10use async_trait::async_trait;
11use bonsaidb_core::admin::{Admin, ADMIN_DATABASE_NAME};
12use bonsaidb_core::api;
13use bonsaidb_core::api::ApiName;
14use bonsaidb_core::arc_bytes::serde::Bytes;
15use bonsaidb_core::connection::{
16 self, AsyncConnection, AsyncStorageConnection, HasSession, IdentityReference, Session,
17 SessionId,
18};
19use bonsaidb_core::networking::{self, Payload, CURRENT_PROTOCOL_VERSION};
20use bonsaidb_core::permissions::bonsai::{bonsaidb_resource_name, BonsaiAction, ServerAction};
21use bonsaidb_core::permissions::Permissions;
22use bonsaidb_core::schema::{self, Nameable, NamedCollection, Schema, SchemaSummary};
23use bonsaidb_local::config::Builder;
24use bonsaidb_local::{AsyncStorage, Storage, StorageNonBlocking};
25use bonsaidb_utils::fast_async_lock;
26use derive_where::derive_where;
27use fabruic::{self, CertificateChain, Endpoint, KeyPair, PrivateKey};
28use flume::Sender;
29use futures::{Future, StreamExt};
30use parking_lot::{Mutex, RwLock};
31use rustls::sign::CertifiedKey;
32use schema::SchemaName;
33#[cfg(not(windows))]
34use signal_hook::consts::SIGQUIT;
35use signal_hook::consts::{SIGINT, SIGTERM};
36use tokio::sync::{oneshot, Notify};
37
38use crate::api::{AnyHandler, HandlerSession};
39use crate::backend::ConnectionHandling;
40#[cfg(feature = "acme")]
41use crate::config::AcmeConfiguration;
42use crate::dispatch::{register_api_handlers, ServerDispatcher};
43use crate::error::Error;
44use crate::hosted::{Hosted, SerializablePrivateKey, TlsCertificate, TlsCertificatesByDomain};
45use crate::server::shutdown::{Shutdown, ShutdownState, ShutdownStateWatcher};
46use crate::{Backend, BackendError, BonsaiListenConfig, NoBackend, ServerConfiguration};
47
48#[cfg(feature = "acme")]
49pub mod acme;
50mod connected_client;
51mod database;
52
53mod shutdown;
54mod tcp;
55#[cfg(feature = "websockets")]
56mod websockets;
57
58use self::connected_client::OwnedClient;
59pub use self::connected_client::{ConnectedClient, LockedClientDataGuard, Transport};
60pub use self::database::ServerDatabase;
61pub use self::tcp::{ApplicationProtocols, HttpService, Peer, StandardTcpProtocols, TcpService};
62
63static CONNECTED_CLIENT_ID_COUNTER: AtomicU32 = AtomicU32::new(0);
64
65#[derive(Debug)]
67#[derive_where(Clone)]
68pub struct CustomServer<B: Backend = NoBackend> {
69 data: Arc<Data<B>>,
70 pub(crate) storage: AsyncStorage,
71}
72
73impl<'a, B: Backend> From<&'a CustomServer<B>> for Storage {
74 fn from(server: &'a CustomServer<B>) -> Self {
75 Self::from(server.storage.clone())
76 }
77}
78
79impl<B: Backend> From<CustomServer<B>> for Storage {
80 fn from(server: CustomServer<B>) -> Self {
81 Self::from(server.storage)
82 }
83}
84
85pub type Server = CustomServer<NoBackend>;
87
88#[derive(Debug)]
89struct Data<B: Backend = NoBackend> {
90 backend: B,
91 clients: RwLock<HashMap<u32, ConnectedClient<B>>>,
92 request_processor: flume::Sender<ClientRequest<B>>,
93 default_session: Session,
94 client_simultaneous_request_limit: usize,
95 primary_tls_key: CachedCertifiedKey,
96 primary_domain: String,
97 custom_apis: RwLock<HashMap<ApiName, Arc<dyn AnyHandler<B>>>>,
98 #[cfg(feature = "acme")]
99 acme: AcmeConfiguration,
100 #[cfg(feature = "acme")]
101 alpn_keys: AlpnKeys,
102 shutdown: Shutdown,
103}
104
105#[derive(Default)]
106struct CachedCertifiedKey(Mutex<Option<Arc<CertifiedKey>>>);
107
108impl Debug for CachedCertifiedKey {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 f.debug_tuple("CachedCertifiedKey").finish()
111 }
112}
113
114impl Deref for CachedCertifiedKey {
115 type Target = Mutex<Option<Arc<CertifiedKey>>>;
116
117 fn deref(&self) -> &Self::Target {
118 &self.0
119 }
120}
121
122impl<B: Backend> CustomServer<B> {
123 pub async fn open(
125 configuration: ServerConfiguration<B>,
126 ) -> Result<Self, BackendError<B::Error>> {
127 let configuration = register_api_handlers(B::configure(configuration)?)?;
128 let (request_sender, request_receiver) = flume::unbounded::<ClientRequest<B>>();
129 for _ in 0..configuration.request_workers {
130 let request_receiver = request_receiver.clone();
131 tokio::task::spawn(async move {
132 while let Ok(mut client_request) = request_receiver.recv_async().await {
133 let request = client_request.request.take().unwrap();
134 let session = client_request.session.clone();
135 let result = match client_request.server.storage.assume_session(session) {
138 Ok(storage) => {
139 let client = HandlerSession {
140 server: &client_request.server,
141 client: &client_request.client,
142 as_client: Self {
143 data: client_request.server.data.clone(),
144 storage,
145 },
146 };
147 ServerDispatcher::dispatch_api_request(
148 client,
149 &request.name,
150 request.value.unwrap(),
151 )
152 .await
153 .map_err(bonsaidb_core::Error::from)
154 }
155 Err(err) => Err(err),
156 };
157 drop(client_request.result_sender.send((request.name, result)));
158 }
159 });
160 }
161
162 let storage = AsyncStorage::open(configuration.storage.with_schema::<Hosted>()?).await?;
163
164 storage.create_database::<Hosted>("_hosted", true).await?;
165
166 let default_permissions = Permissions::from(configuration.default_permissions);
167
168 let server = Self {
169 storage,
170 data: Arc::new(Data {
171 backend: configuration.backend,
172 clients: parking_lot::RwLock::default(),
173 request_processor: request_sender,
174 default_session: Session {
175 permissions: default_permissions,
176 ..Session::default()
177 },
178 client_simultaneous_request_limit: configuration.client_simultaneous_request_limit,
179 primary_tls_key: CachedCertifiedKey::default(),
180 primary_domain: configuration.server_name,
181 custom_apis: parking_lot::RwLock::new(configuration.custom_apis),
182 #[cfg(feature = "acme")]
183 acme: configuration.acme,
184 #[cfg(feature = "acme")]
185 alpn_keys: AlpnKeys::default(),
186 shutdown: Shutdown::new(),
187 }),
188 };
189
190 server.data.backend.initialize(&server).await?;
191 Ok(server)
192 }
193
194 #[must_use]
198 pub fn pinned_certificate_path(&self) -> PathBuf {
199 self.path().join("pinned-certificate.der")
200 }
201
202 #[must_use]
204 pub fn primary_domain(&self) -> &str {
205 &self.data.primary_domain
206 }
207
208 #[must_use]
210 pub fn backend(&self) -> &B {
211 &self.data.backend
212 }
213
214 pub async fn admin(&self) -> ServerDatabase<B> {
216 let db = self.storage.admin().await;
217 ServerDatabase {
218 server: self.clone(),
219 db,
220 }
221 }
222
223 pub(crate) async fn hosted(&self) -> ServerDatabase<B> {
224 let db = self.storage.database::<Hosted>("_hosted").await.unwrap();
225 ServerDatabase {
226 server: self.clone(),
227 db,
228 }
229 }
230
231 pub(crate) fn custom_api_dispatcher(&self, name: &ApiName) -> Option<Arc<dyn AnyHandler<B>>> {
232 let dispatchers = self.data.custom_apis.read();
233 dispatchers.get(name).cloned()
234 }
235
236 pub async fn install_self_signed_certificate(&self, overwrite: bool) -> Result<(), Error> {
238 let keypair = KeyPair::new_self_signed(&self.data.primary_domain);
239
240 if self.certificate_chain().await.is_ok() && !overwrite {
241 return Err(Error::Core(bonsaidb_core::Error::other("bonsaidb-server config", "Certificate already installed. Enable overwrite if you wish to replace the existing certificate.")));
242 }
243
244 self.install_certificate(keypair.certificate_chain(), keypair.private_key())
245 .await?;
246
247 tokio::fs::write(
248 self.pinned_certificate_path(),
249 keypair.end_entity_certificate().as_ref(),
250 )
251 .await?;
252
253 Ok(())
254 }
255
256 #[cfg(feature = "pem")]
258 pub async fn install_pem_certificate(
259 &self,
260 certificate_chain: &[u8],
261 private_key: &[u8],
262 ) -> Result<(), Error> {
263 let private_key = match pem::parse(private_key) {
264 Ok(pem) => PrivateKey::unchecked_from_der(pem.contents()),
265 Err(_) => PrivateKey::from_der(private_key)?,
266 };
267 let certificates = pem::parse_many(certificate_chain)?
268 .into_iter()
269 .map(|entry| fabruic::Certificate::unchecked_from_der(entry.contents()))
270 .collect::<Vec<_>>();
271
272 self.install_certificate(
273 &CertificateChain::unchecked_from_certificates(certificates),
274 &private_key,
275 )
276 .await
277 }
278
279 pub async fn install_certificate(
281 &self,
282 certificate_chain: &CertificateChain,
283 private_key: &PrivateKey,
284 ) -> Result<(), Error> {
285 let db = self.hosted().await;
286
287 TlsCertificate::entry_async(&self.data.primary_domain, &db)
288 .update_with(|cert: &mut TlsCertificate| {
289 cert.certificate_chain = certificate_chain.clone();
290 cert.private_key = SerializablePrivateKey(private_key.clone());
291 })
292 .or_insert_with(|| TlsCertificate {
293 domains: vec![self.data.primary_domain.clone()],
294 private_key: SerializablePrivateKey(private_key.clone()),
295 certificate_chain: certificate_chain.clone(),
296 })
297 .await?;
298
299 self.refresh_certified_key().await?;
300
301 let pinned_certificate_path = self.pinned_certificate_path();
302 if pinned_certificate_path.exists() {
303 tokio::fs::remove_file(&pinned_certificate_path).await?;
304 }
305
306 Ok(())
307 }
308
309 async fn refresh_certified_key(&self) -> Result<(), Error> {
310 let certificate = self.tls_certificate().await?;
311
312 let mut cached_key = self.data.primary_tls_key.lock();
313 let private_key = rustls::PrivateKey(
314 fabruic::dangerous::PrivateKey::as_ref(&certificate.private_key.0).to_vec(),
315 );
316 let private_key = rustls::sign::any_ecdsa_type(&Arc::new(private_key))?;
317
318 let certificates = certificate
319 .certificate_chain
320 .iter()
321 .map(|cert| rustls::Certificate(cert.as_ref().to_vec()))
322 .collect::<Vec<_>>();
323
324 let certified_key = Arc::new(CertifiedKey::new(certificates, private_key));
325 *cached_key = Some(certified_key);
326 Ok(())
327 }
328
329 async fn tls_certificate(&self) -> Result<TlsCertificate, Error> {
330 let db = self.hosted().await;
331 let (_, certificate) = db
332 .view::<TlsCertificatesByDomain>()
333 .with_key(&self.data.primary_domain)
334 .query_with_collection_docs()
335 .await?
336 .documents
337 .into_iter()
338 .next()
339 .ok_or_else(|| {
340 Error::Core(bonsaidb_core::Error::other(
341 "bonsaidb-server config",
342 format!("no certificate found for {}", self.data.primary_domain),
343 ))
344 })?;
345 Ok(certificate.contents)
346 }
347
348 pub async fn certificate_chain(&self) -> Result<CertificateChain, Error> {
350 let db = self.hosted().await;
351 if let Some(mapping) = db
352 .view::<TlsCertificatesByDomain>()
353 .with_key(&self.data.primary_domain)
354 .query()
355 .await?
356 .into_iter()
357 .next()
358 {
359 Ok(mapping.value)
360 } else {
361 Err(Error::Core(bonsaidb_core::Error::other(
362 "bonsaidb-server config",
363 format!("no certificate found for {}", self.data.primary_domain),
364 )))
365 }
366 }
367
368 pub async fn listen_on(&self, config: impl Into<BonsaiListenConfig>) -> Result<(), Error> {
381 let config = config.into();
382 let certificate = self.tls_certificate().await?;
383 let keypair =
384 KeyPair::from_parts(certificate.certificate_chain, certificate.private_key.0)?;
385 let mut builder = Endpoint::builder();
386 builder.set_protocols([CURRENT_PROTOCOL_VERSION.as_bytes().to_vec()]);
387 builder.set_address(config.address);
388 builder.set_max_idle_timeout(None)?;
389 builder.set_server_key_pair(Some(keypair));
390 builder.set_reuse_address(config.reuse_address);
391 let mut server = builder.build()?;
392
393 let mut shutdown_watcher = self
394 .data
395 .shutdown
396 .watcher()
397 .await
398 .expect("server already shut down");
399
400 while let Some(incoming) = tokio::select! {
401 shutdown_state = shutdown_watcher.wait_for_shutdown() => {
402 drop(server.close_incoming());
403 if matches!(shutdown_state, ShutdownState::GracefulShutdown) {
404 server.wait_idle().await;
405 }
406 None
407 },
408 msg = server.next() => msg
409 } {
410 let address = incoming.remote_address();
411 let connection = match incoming.accept::<()>().await {
412 Ok(connection) => connection,
413 Err(err) => {
414 log::error!("[server] error on incoming connection from {address}: {err:?}");
415 continue;
416 }
417 };
418 let task_self = self.clone();
419 tokio::spawn(async move {
420 if let Err(err) = task_self.handle_bonsai_connection(connection).await {
421 log::error!("[server] closing connection {address}: {err:?}");
422 }
423 });
424 }
425
426 Ok(())
427 }
428
429 #[must_use]
431 pub fn connected_clients(&self) -> Vec<ConnectedClient<B>> {
432 let clients = self.data.clients.read();
433 clients.values().cloned().collect()
434 }
435
436 pub fn broadcast<Api: api::Api>(&self, response: &Api::Response) {
438 let clients = self.data.clients.read();
439 for client in clients.values() {
440 drop(client.send::<Api>(None, response));
442 }
443 }
444
445 async fn initialize_client(
446 &self,
447 transport: Transport,
448 address: SocketAddr,
449 sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
450 ) -> Option<OwnedClient<B>> {
451 if !self.data.default_session.allowed_to(
452 bonsaidb_resource_name(),
453 &BonsaiAction::Server(ServerAction::Connect),
454 ) {
455 return None;
456 }
457
458 let client = loop {
459 let next_id = CONNECTED_CLIENT_ID_COUNTER.fetch_add(1, Ordering::SeqCst);
460 let mut clients = self.data.clients.write();
461 if let hash_map::Entry::Vacant(e) = clients.entry(next_id) {
462 let client = OwnedClient::new(
463 next_id,
464 address,
465 transport,
466 sender,
467 self.clone(),
468 self.data.default_session.clone(),
469 );
470 e.insert(client.clone());
471 break client;
472 }
473 };
474
475 match self.data.backend.client_connected(&client, self).await {
476 Ok(ConnectionHandling::Accept) => Some(client),
477 Ok(ConnectionHandling::Reject) => None,
478 Err(err) => {
479 log::error!(
480 "[server] Rejecting connection due to error in `client_connected`: {err:?}"
481 );
482 None
483 }
484 }
485 }
486
487 async fn disconnect_client(&self, id: u32) {
488 let removed_client = {
489 let mut clients = self.data.clients.write();
490 clients.remove(&id)
491 };
492
493 if let Some(client) = removed_client {
494 client.set_disconnected();
495 for session in client.all_sessions::<Vec<_>>() {
496 if let Err(err) = self
497 .data
498 .backend
499 .client_session_ended(session, &client, true, self)
500 .await
501 {
502 log::error!("[server] Error in `client_session_ended`: {err:?}");
503 }
504 }
505
506 if let Err(err) = self.data.backend.client_disconnected(client, self).await {
507 log::error!("[server] Error in `client_disconnected`: {err:?}");
508 }
509 }
510 }
511
512 async fn handle_bonsai_connection(
513 &self,
514 mut connection: fabruic::Connection<()>,
515 ) -> Result<(), Error> {
516 if let Some(incoming) = connection.next().await {
517 let incoming = match incoming {
518 Ok(incoming) => incoming,
519 Err(err) => {
520 log::error!("[server] Error establishing a stream: {err:?}");
521 return Ok(());
522 }
523 };
524
525 match incoming
526 .accept::<networking::Payload, networking::Payload>()
527 .await
528 {
529 Ok((sender, receiver)) => {
530 let (api_response_sender, api_response_receiver) = flume::unbounded();
531 if let Some(disconnector) = self
532 .initialize_client(
533 Transport::Bonsai,
534 connection.remote_address(),
535 api_response_sender,
536 )
537 .await
538 {
539 let task_sender = sender.clone();
540 tokio::spawn(async move {
541 while let Ok((session_id, name, bytes)) =
542 api_response_receiver.recv_async().await
543 {
544 if task_sender
545 .send(&Payload {
546 id: None,
547 session_id,
548 name,
549 value: Ok(bytes),
550 })
551 .is_err()
552 {
553 break;
554 }
555 }
556 });
557
558 let task_self = self.clone();
559 let Some(shutdown) = self.data.shutdown.watcher().await else {
560 return Ok(());
561 };
562 tokio::spawn(async move {
563 if let Err(err) = task_self
564 .handle_stream(disconnector, sender, receiver, shutdown)
565 .await
566 {
567 log::error!("[server] Error handling stream: {err:?}");
568 }
569 });
570 } else {
571 log::error!("[server] Backend rejected connection.");
572 return Ok(());
573 }
574 }
575 Err(err) => {
576 log::error!("[server] Error accepting incoming stream: {err:?}");
577 return Ok(());
578 }
579 }
580 }
581 Ok(())
582 }
583
584 async fn handle_client_requests(
585 &self,
586 client: ConnectedClient<B>,
587 request_receiver: flume::Receiver<Payload>,
588 response_sender: flume::Sender<Payload>,
589 mut shutdown: ShutdownStateWatcher,
590 ) {
591 let notify = Arc::new(Notify::new());
592 let requests_in_queue = Arc::new(AtomicUsize::new(0));
593 loop {
594 let current_requests = requests_in_queue.load(Ordering::SeqCst);
595 if current_requests == self.data.client_simultaneous_request_limit {
596 notify.notified().await;
598 } else if requests_in_queue
599 .compare_exchange(
600 current_requests,
601 current_requests + 1,
602 Ordering::SeqCst,
603 Ordering::SeqCst,
604 )
605 .is_ok()
606 {
607 let payload = 'payload: loop {
608 tokio::select! {
609 payload = request_receiver.recv_async() => {
610 if let Ok(payload) = payload {
611 break 'payload payload
612 }
613
614 return
615 },
616 state = shutdown.wait_for_shutdown() => {
617 if matches!(state, ShutdownState::Shutdown | ShutdownState::GracefulShutdown) {
618 return
619 }
620 }
621 }
622 };
623 let session_id = payload.session_id;
624 let id = payload.id;
625 let task_sender = response_sender.clone();
626
627 let notify = notify.clone();
628 let requests_in_queue = requests_in_queue.clone();
629 self.handle_request_through_worker(
630 payload,
631 move |name, value| async move {
632 drop(task_sender.send(Payload {
633 session_id,
634 id,
635 name,
636 value,
637 }));
638
639 requests_in_queue.fetch_sub(1, Ordering::SeqCst);
640
641 notify.notify_one();
642
643 Ok(())
644 },
645 client.clone(),
646 )
647 .unwrap();
648 }
649 }
650 }
651
652 fn handle_request_through_worker<
653 F: FnOnce(ApiName, Result<Bytes, bonsaidb_core::Error>) -> R + Send + 'static,
654 R: Future<Output = Result<(), Error>> + Send,
655 >(
656 &self,
657 request: Payload,
658 callback: F,
659 client: ConnectedClient<B>,
660 ) -> Result<(), Error> {
661 let (result_sender, result_receiver) = oneshot::channel();
662 let session = client
663 .session(request.session_id)
664 .unwrap_or_else(|| self.data.default_session.clone());
665 self.data
666 .request_processor
667 .send(ClientRequest::<B>::new(
668 request,
669 self.clone(),
670 client,
671 session,
672 result_sender,
673 ))
674 .map_err(|_| Error::InternalCommunication)?;
675 tokio::spawn(async move {
676 let (name, result) = result_receiver.await?;
677 callback(name, result).await?;
684 Result::<(), Error>::Ok(())
685 });
686 Ok(())
687 }
688
689 async fn handle_stream(
690 &self,
691 client: OwnedClient<B>,
692 sender: fabruic::Sender<Payload>,
693 mut receiver: fabruic::Receiver<Payload>,
694 mut shutdown: ShutdownStateWatcher,
695 ) -> Result<(), Error> {
696 let (payload_sender, payload_receiver) = flume::unbounded();
697 tokio::spawn({
698 let mut shutdown = shutdown.clone();
699 async move {
700 'stream: loop {
701 let payload = loop {
702 tokio::select! {
703 payload = payload_receiver.recv_async() => {
704 if let Ok(payload) = payload {
705 break payload
706 }
707 break 'stream
708 }
709 shutdown = shutdown.wait_for_shutdown() => {
710 if matches!(shutdown, ShutdownState::Shutdown | ShutdownState::GracefulShutdown) {
711 break 'stream
712 }
713 }
714 }
715 };
716 if sender.send(&payload).is_err() {
717 break;
718 }
719 }
720 }
721 });
722
723 let (request_sender, request_receiver) =
724 flume::bounded::<Payload>(self.data.client_simultaneous_request_limit);
725 let task_self = self.clone();
726 tokio::spawn({
727 let shutdown = shutdown.clone();
728 async move {
729 task_self
730 .handle_client_requests(
731 client.clone(),
732 request_receiver,
733 payload_sender,
734 shutdown,
735 )
736 .await;
737 }
738 });
739
740 loop {
741 let payload = loop {
742 tokio::select! {
743 payload = receiver.next() => {
744 if let Some(payload) = payload {
745 break payload
746 }
747
748 receiver.finish().await?;
749
750 return Ok(());
751 }
752 shutdown = shutdown.wait_for_shutdown() => {
753 if matches!(shutdown, ShutdownState::Shutdown | ShutdownState::GracefulShutdown) {
754 return Ok(());
755 }
756 }
757 }
758 };
759 drop(request_sender.send_async(payload?).await);
760 }
761 }
762
763 pub async fn shutdown(&self, timeout: Option<Duration>) -> Result<(), Error> {
768 if let Some(timeout) = timeout {
769 self.data.shutdown.graceful_shutdown(timeout).await;
770 } else {
771 self.data.shutdown.shutdown().await;
772 }
773
774 Ok(())
775 }
776
777 pub async fn listen_for_shutdown(&self) -> Result<(), Error> {
780 const GRACEFUL_SHUTDOWN: usize = 1;
781 const TERMINATE: usize = 2;
782
783 enum SignalShutdownState {
784 Running,
785 ShuttingDown(flume::Receiver<()>),
786 }
787
788 let shutdown_state = Arc::new(async_lock::Mutex::new(SignalShutdownState::Running));
789 let flag = Arc::new(AtomicUsize::default());
790 let register_hook = |flag: &Arc<AtomicUsize>| {
791 signal_hook::flag::register_usize(SIGINT, flag.clone(), GRACEFUL_SHUTDOWN)?;
792 signal_hook::flag::register_usize(SIGTERM, flag.clone(), TERMINATE)?;
793 #[cfg(not(windows))]
794 signal_hook::flag::register_usize(SIGQUIT, flag.clone(), TERMINATE)?;
795 Result::<(), std::io::Error>::Ok(())
796 };
797 if let Err(err) = register_hook(&flag) {
798 log::error!("Error installing signals for graceful shutdown: {err:?}");
799 tokio::time::sleep(Duration::MAX).await;
800 } else {
801 loop {
802 match flag.load(Ordering::Relaxed) {
803 0 => {
804 }
806 GRACEFUL_SHUTDOWN => {
807 let mut state = fast_async_lock!(shutdown_state);
808 match *state {
809 SignalShutdownState::Running => {
810 log::error!("Interrupt signal received. Shutting down gracefully.");
811 let task_server = self.clone();
812 let (shutdown_sender, shutdown_receiver) = flume::bounded(1);
813 tokio::task::spawn(async move {
814 task_server.shutdown(Some(Duration::from_secs(30))).await?;
815 let _: Result<_, _> = shutdown_sender.send(());
816 Result::<(), Error>::Ok(())
817 });
818 *state = SignalShutdownState::ShuttingDown(shutdown_receiver);
819 }
820 SignalShutdownState::ShuttingDown(_) => {
821 break;
823 }
824 }
825 }
826 TERMINATE => {
827 log::error!("Quit signal received. Shutting down.");
828 break;
829 }
830 _ => unreachable!(),
831 }
832
833 let state = fast_async_lock!(shutdown_state);
834 if let SignalShutdownState::ShuttingDown(receiver) = &*state {
835 if receiver.try_recv().is_ok() {
836 return Ok(());
838 }
839 } else if self.data.shutdown.should_shutdown() {
840 return Ok(());
841 }
842
843 tokio::time::sleep(Duration::from_millis(300)).await;
844 }
845 self.shutdown(None).await?;
846 }
847
848 Ok(())
849 }
850}
851
852impl<B: Backend> Deref for CustomServer<B> {
853 type Target = AsyncStorage;
854
855 fn deref(&self) -> &Self::Target {
856 &self.storage
857 }
858}
859
860#[derive(Debug)]
861struct ClientRequest<B: Backend> {
862 request: Option<Payload>,
863 client: ConnectedClient<B>,
864 session: Session,
865 server: CustomServer<B>,
866 result_sender: oneshot::Sender<(ApiName, Result<Bytes, bonsaidb_core::Error>)>,
867}
868
869impl<B: Backend> ClientRequest<B> {
870 pub fn new(
871 request: Payload,
872 server: CustomServer<B>,
873 client: ConnectedClient<B>,
874 session: Session,
875 result_sender: oneshot::Sender<(ApiName, Result<Bytes, bonsaidb_core::Error>)>,
876 ) -> Self {
877 Self {
878 request: Some(request),
879 server,
880 client,
881 session,
882 result_sender,
883 }
884 }
885}
886
887impl<B: Backend> HasSession for CustomServer<B> {
888 fn session(&self) -> Option<&Session> {
889 self.storage.session()
890 }
891}
892
893#[async_trait]
894impl<B: Backend> AsyncStorageConnection for CustomServer<B> {
895 type Authenticated = Self;
896 type Database = ServerDatabase<B>;
897
898 async fn admin(&self) -> Self::Database {
899 self.database::<Admin>(ADMIN_DATABASE_NAME).await.unwrap()
900 }
901
902 async fn create_database_with_schema(
903 &self,
904 name: &str,
905 schema: SchemaName,
906 only_if_needed: bool,
907 ) -> Result<(), bonsaidb_core::Error> {
908 self.storage
909 .create_database_with_schema(name, schema, only_if_needed)
910 .await
911 }
912
913 async fn database<DB: Schema>(
914 &self,
915 name: &str,
916 ) -> Result<Self::Database, bonsaidb_core::Error> {
917 let db = self.storage.database::<DB>(name).await?;
918 Ok(ServerDatabase {
919 server: self.clone(),
920 db,
921 })
922 }
923
924 async fn delete_database(&self, name: &str) -> Result<(), bonsaidb_core::Error> {
925 self.storage.delete_database(name).await
926 }
927
928 async fn list_databases(&self) -> Result<Vec<connection::Database>, bonsaidb_core::Error> {
929 self.storage.list_databases().await
930 }
931
932 async fn list_available_schemas(&self) -> Result<Vec<SchemaSummary>, bonsaidb_core::Error> {
933 self.storage.list_available_schemas().await
934 }
935
936 async fn create_user(&self, username: &str) -> Result<u64, bonsaidb_core::Error> {
937 self.storage.create_user(username).await
938 }
939
940 async fn delete_user<'user, U: Nameable<'user, u64> + Send + Sync>(
941 &self,
942 user: U,
943 ) -> Result<(), bonsaidb_core::Error> {
944 self.storage.delete_user(user).await
945 }
946
947 #[cfg(feature = "password-hashing")]
948 async fn set_user_password<'user, U: Nameable<'user, u64> + Send + Sync>(
949 &self,
950 user: U,
951 password: bonsaidb_core::connection::SensitiveString,
952 ) -> Result<(), bonsaidb_core::Error> {
953 self.storage.set_user_password(user, password).await
954 }
955
956 #[cfg(any(feature = "token-authentication", feature = "password-hashing"))]
957 async fn authenticate(
958 &self,
959 authentication: bonsaidb_core::connection::Authentication,
960 ) -> Result<Self::Authenticated, bonsaidb_core::Error> {
961 let storage = self.storage.authenticate(authentication).await?;
962 Ok(Self {
963 data: self.data.clone(),
964 storage,
965 })
966 }
967
968 async fn assume_identity(
969 &self,
970 identity: IdentityReference<'_>,
971 ) -> Result<Self::Authenticated, bonsaidb_core::Error> {
972 let storage = self.storage.assume_identity(identity).await?;
973 Ok(Self {
974 data: self.data.clone(),
975 storage,
976 })
977 }
978
979 async fn add_permission_group_to_user<
980 'user,
981 'group,
982 U: Nameable<'user, u64> + Send + Sync,
983 G: Nameable<'group, u64> + Send + Sync,
984 >(
985 &self,
986 user: U,
987 permission_group: G,
988 ) -> Result<(), bonsaidb_core::Error> {
989 self.storage
990 .add_permission_group_to_user(user, permission_group)
991 .await
992 }
993
994 async fn remove_permission_group_from_user<
995 'user,
996 'group,
997 U: Nameable<'user, u64> + Send + Sync,
998 G: Nameable<'group, u64> + Send + Sync,
999 >(
1000 &self,
1001 user: U,
1002 permission_group: G,
1003 ) -> Result<(), bonsaidb_core::Error> {
1004 self.storage
1005 .remove_permission_group_from_user(user, permission_group)
1006 .await
1007 }
1008
1009 async fn add_role_to_user<
1010 'user,
1011 'group,
1012 U: Nameable<'user, u64> + Send + Sync,
1013 G: Nameable<'group, u64> + Send + Sync,
1014 >(
1015 &self,
1016 user: U,
1017 role: G,
1018 ) -> Result<(), bonsaidb_core::Error> {
1019 self.storage.add_role_to_user(user, role).await
1020 }
1021
1022 async fn remove_role_from_user<
1023 'user,
1024 'group,
1025 U: Nameable<'user, u64> + Send + Sync,
1026 G: Nameable<'group, u64> + Send + Sync,
1027 >(
1028 &self,
1029 user: U,
1030 role: G,
1031 ) -> Result<(), bonsaidb_core::Error> {
1032 self.storage.remove_role_from_user(user, role).await
1033 }
1034}
1035
1036#[derive(Default)]
1037struct AlpnKeys(Arc<Mutex<HashMap<String, Arc<rustls::sign::CertifiedKey>>>>);
1038
1039impl Debug for AlpnKeys {
1040 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1041 f.debug_tuple("AlpnKeys").finish()
1042 }
1043}
1044
1045impl Deref for AlpnKeys {
1046 type Target = Arc<Mutex<HashMap<String, Arc<rustls::sign::CertifiedKey>>>>;
1047
1048 fn deref(&self) -> &Self::Target {
1049 &self.0
1050 }
1051}