use std::collections::HashSet;
use std::sync::Arc;
use async_channel::Receiver;
use reqwest::ClientBuilder;
use surrealdb_core::cnf::SURREALDB_USER_AGENT;
use tokio::sync::watch;
use url::Url;
use super::{Client, RouterState};
use crate::conn::{Route, Router};
use crate::engine::{SessionError, session_error_to_error};
use crate::method::BoxFuture;
#[cfg(any(feature = "native-tls", feature = "rustls"))]
use crate::opt::Tls;
use crate::opt::{Endpoint, WaitFor};
use crate::{Error, ExtraFeatures, Result, SessionClone, SessionId, Surreal, conn};
pub(crate) async fn create_client(
base_url: &Url,
#[cfg(any(feature = "native-tls", feature = "rustls"))] tls_config: Option<&Tls>,
) -> Result<reqwest::Client> {
let headers = super::default_headers();
let hostname = base_url.domain().unwrap_or("localhost");
let port = base_url.port_or_known_default().unwrap_or(8000);
let addrs = tokio::net::lookup_host((hostname, port)).await.map_err(|error| {
Error::internal(format!("DNS resolution failed for {hostname}:{port}; {error}"))
})?;
let mut last_error = None;
for addr in addrs {
#[cfg_attr(not(any(feature = "native-tls", feature = "rustls")), expect(unused_mut))]
let mut builder = ClientBuilder::new().default_headers(headers.clone()).resolve(hostname, addr);
#[cfg(any(feature = "native-tls", feature = "rustls"))]
if let Some(tls) = tls_config {
builder = match tls {
#[cfg(feature = "native-tls")]
Tls::Native(config) => builder.use_preconfigured_tls(config.clone()),
#[cfg(feature = "rustls")]
Tls::Rust(config) => builder.use_preconfigured_tls(config.clone()),
};
}
let client = match builder.build() {
Ok(client) => client,
Err(error) => {
last_error = Some(Error::internal(error.to_string()));
continue;
}
};
let req = client
.get(base_url.join("health").map_err(crate::std_error_to_types_error)?)
.header(reqwest::header::USER_AGENT, &*SURREALDB_USER_AGENT);
match super::health(req).await {
Ok(()) => return Ok(client),
Err(e) => {
last_error = Some(e);
continue;
}
}
}
Err(last_error.unwrap_or_else(|| Error::internal("No addresses available".to_string())))
}
impl crate::Connection for Client {}
impl conn::Sealed for Client {
#[allow(private_interfaces)]
fn connect(
address: Endpoint,
capacity: usize,
session_clone: Option<crate::SessionClone>,
) -> BoxFuture<'static, Result<Surreal<Self>>> {
Box::pin(async move {
let config = address.config.clone();
let base_url = address.url;
#[cfg(any(feature = "native-tls", feature = "rustls"))]
let client = create_client(&base_url, address.config.tls_config.as_ref()).await?;
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
let client = create_client(&base_url).await?;
let (route_tx, route_rx) = match capacity {
0 => async_channel::unbounded(),
capacity => async_channel::bounded(capacity),
};
let session_clone = session_clone.unwrap_or_else(SessionClone::new);
tokio::spawn(run_router(client, base_url, route_rx, session_clone.receiver.clone()));
let mut features = HashSet::new();
features.insert(ExtraFeatures::Backup);
let waiter = watch::channel(Some(WaitFor::Connection));
let router = Router {
features,
config,
sender: route_tx,
};
Ok((router, waiter, session_clone).into())
})
}
}
pub(crate) async fn run_router(
client: reqwest::Client,
base_url: url::Url,
route_rx: Receiver<Route>,
session_rx: Receiver<SessionId>,
) {
let state = Arc::new(RouterState::new(client, base_url));
loop {
tokio::select! {
biased;
session = session_rx.recv() => {
let Ok(session_id) = session else {
break
};
match session_id {
SessionId::Initial(session_id) => {
state.handle_session_initial(session_id).await;
}
SessionId::Clone { old, new } => {
state.handle_session_clone(old, new).await;
}
SessionId::Drop(session_id) => {
state.handle_session_drop(session_id).await;
}
}
}
route = route_rx.recv() => {
let Ok(route) = route else {
break
};
let session_id = route.request.session_id;
let command = route.request.command.clone();
let session_state = match state.sessions.get(&session_id) {
Some(Ok(state)) => state,
Some(Err(error)) => {
route.response.send(Err(session_error_to_error(error))).await.ok();
continue;
}
None => {
let error = session_error_to_error(SessionError::NotFound(session_id));
route.response.send(Err(error)).await.ok();
continue;
}
};
let router_state = state.clone();
tokio::spawn(async move {
let result = super::router(
route.request,
&router_state.base_url,
&router_state.client,
&session_state,
)
.await;
if result.is_ok() && command.replayable() {
session_state.replay.push(command);
}
let db_result = result;
route.response.send(db_result).await.ok();
});
}
}
}
}