bonsaidb-client 0.5.0

Client for accessing BonsaiDb servers.
Documentation
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Duration;

use bonsaidb_core::api::ApiName;
use bonsaidb_core::networking::Payload;
use bonsaidb_utils::fast_async_lock;
use fabruic::{self, Certificate, Endpoint};
use flume::Receiver;
use futures::StreamExt;
use url::Url;

use super::PendingRequest;
use crate::client::{
    disconnect_pending_requests, AnyApiCallback, ConnectionInfo, OutstandingRequestMapHandle,
};
use crate::Error;

/// This function will establish a connection and try to keep it active. If an
/// error occurs, any queries that come in while reconnecting will have the
/// error replayed to them.
pub(super) async fn reconnecting_client_loop(
    mut server: ConnectionInfo,
    protocol_version: &'static str,
    certificate: Option<Certificate>,
    request_receiver: Receiver<PendingRequest>,
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
    connection_counter: Arc<AtomicU32>,
) -> Result<(), Error> {
    if server.url.port().is_none() && server.url.scheme() == "bonsaidb" {
        let _: Result<_, _> = server.url.set_port(Some(5645));
    }

    server.subscribers.clear();
    let mut pending_error = None;
    while let Ok(request) = request_receiver.recv_async().await {
        if let Some(pending_error) = pending_error.take() {
            drop(request.responder.send(Err(pending_error)));
            continue;
        }
        connection_counter.fetch_add(1, Ordering::SeqCst);
        if let Err((failed_request, Some(err))) = connect_and_process(
            &server.url,
            protocol_version,
            certificate.as_ref(),
            request,
            &request_receiver,
            custom_apis.clone(),
            server.connect_timeout,
        )
        .await
        {
            if let Some(failed_request) = failed_request {
                drop(failed_request.responder.send(Err(err)));
            } else {
                pending_error = Some(err);
            }
        }
    }

    Ok(())
}

async fn connect_and_process(
    url: &Url,
    protocol_version: &str,
    certificate: Option<&Certificate>,
    initial_request: PendingRequest,
    request_receiver: &Receiver<PendingRequest>,
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
    connect_timeout: Duration,
) -> Result<(), (Option<PendingRequest>, Option<Error>)> {
    let (_connection, payload_sender, payload_receiver) =
        match tokio::time::timeout(connect_timeout, connect(url, certificate, protocol_version))
            .await
        {
            Ok(Ok(result)) => result,
            Ok(Err(err)) => return Err((Some(initial_request), Some(err))),
            Err(_) => return Err((Some(initial_request), Some(Error::connect_timeout()))),
        };

    let outstanding_requests = OutstandingRequestMapHandle::default();
    let request_processor = tokio::spawn(process(
        outstanding_requests.clone(),
        payload_receiver,
        custom_apis,
    ));

    if let Err(err) = payload_sender.send(&initial_request.request) {
        return Err((Some(initial_request), Some(Error::from(err))));
    }

    {
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
        outstanding_requests.insert(
            initial_request
                .request
                .id
                .expect("all requests require ids"),
            initial_request,
        );
    }

    if let Err(err) = futures::try_join!(
        process_requests(
            outstanding_requests.clone(),
            request_receiver,
            payload_sender
        ),
        async { request_processor.await.map_err(|_| Error::disconnected())? }
    ) {
        let mut pending_error = Some(err);
        // Our socket was disconnected, clear the outstanding requests before returning.
        disconnect_pending_requests(&outstanding_requests, &mut pending_error).await;
        return Err((None, pending_error));
    }

    Ok(())
}

async fn process_requests(
    outstanding_requests: OutstandingRequestMapHandle,
    request_receiver: &Receiver<PendingRequest>,
    payload_sender: fabruic::Sender<Payload>,
) -> Result<(), Error> {
    while let Ok(client_request) = request_receiver.recv_async().await {
        let mut outstanding_requests = fast_async_lock!(outstanding_requests);
        payload_sender.send(&client_request.request)?;
        outstanding_requests.insert(
            client_request.request.id.expect("all requests require ids"),
            client_request,
        );
    }

    drop(payload_sender.finish());

    // Return an error to make sure try_join returns.
    Err(Error::disconnected())
}

pub async fn process(
    outstanding_requests: OutstandingRequestMapHandle,
    mut payload_receiver: fabruic::Receiver<Payload>,
    custom_apis: Arc<HashMap<ApiName, Option<Arc<dyn AnyApiCallback>>>>,
) -> Result<(), Error> {
    while let Some(payload) = payload_receiver.next().await {
        let payload = payload?;
        super::process_response_payload(payload, &outstanding_requests, &custom_apis).await;
    }

    Err(Error::disconnected())
}

async fn connect(
    url: &Url,
    certificate: Option<&Certificate>,
    protocol_version: &str,
) -> Result<
    (
        fabruic::Connection<()>,
        fabruic::Sender<Payload>,
        fabruic::Receiver<Payload>,
    ),
    Error,
> {
    let mut endpoint = Endpoint::builder();
    endpoint
        .set_max_idle_timeout(None)
        .map_err(|err| Error::Core(bonsaidb_core::Error::other("quic", err)))?;
    endpoint.set_protocols([protocol_version.as_bytes().to_vec()]);
    let endpoint = endpoint
        .build()
        .map_err(|err| Error::Core(bonsaidb_core::Error::other("quic", err)))?;
    let connecting = if let Some(certificate) = certificate {
        endpoint.connect_pinned(url, certificate, None).await?
    } else {
        endpoint.connect(url).await?
    };

    let connection = connecting.accept::<()>().await.map_err(|err| {
        if matches!(err, fabruic::error::Connecting::ProtocolMismatch) {
            Error::ProtocolVersionMismatch
        } else {
            Error::from(err)
        }
    })?;
    let (sender, receiver) = connection.open_stream(&()).await?;

    Ok((connection, sender, receiver))
}