use std::path::PathBuf;
use std::str::FromStr;
use std::time::Duration;
use futures_core::future::BoxFuture;
use log::LevelFilter;
use sqlx_core::connection::{ConnectOptions, LogSettings};
use sqlx_core::error::Error;
use crate::connection::SpgConnection;
#[derive(Debug, Clone)]
pub struct SpgConnectOptions {
pub path: Option<PathBuf>,
pub log_settings: LogSettings,
pub(crate) shared: std::sync::Arc<tokio::sync::OnceCell<spg_embedded_tokio::AsyncDatabase>>,
}
impl Default for SpgConnectOptions {
fn default() -> Self {
Self {
path: None,
log_settings: LogSettings::default(),
shared: std::sync::Arc::new(tokio::sync::OnceCell::new()),
}
}
}
impl SpgConnectOptions {
#[must_use]
pub fn in_memory() -> Self {
Self::default()
}
#[must_use]
pub fn file(path: impl Into<PathBuf>) -> Self {
Self {
path: Some(path.into()),
log_settings: LogSettings::default(),
shared: std::sync::Arc::new(tokio::sync::OnceCell::new()),
}
}
}
impl FromStr for SpgConnectOptions {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Error> {
let rest = s
.strip_prefix("spg://")
.or_else(|| s.strip_prefix("spg:"))
.unwrap_or(s);
if rest.is_empty() || rest.eq_ignore_ascii_case("memory") {
return Ok(Self::in_memory());
}
Ok(Self::file(rest))
}
}
impl ConnectOptions for SpgConnectOptions {
type Connection = SpgConnection;
fn from_url(url: &sqlx_core::Url) -> Result<Self, Error> {
if url.scheme() != "spg" {
return Err(Error::Configuration(
format!("expected spg:// scheme, got {:?}", url.scheme()).into(),
));
}
let host = url.host_str().unwrap_or("");
let path = url.path();
let combined = match (host, path) {
("", "") | ("", "/") => String::new(),
("", p) => p.to_string(),
(h, "") | (h, "/") => h.to_string(),
(h, p) => format!("{h}{p}"),
};
SpgConnectOptions::from_str(&combined)
}
fn connect(&self) -> BoxFuture<'_, Result<SpgConnection, Error>> {
let path = self.path.clone();
let shared = std::sync::Arc::clone(&self.shared);
Box::pin(async move {
let inner = shared
.get_or_try_init(|| async {
match path {
None => Ok::<_, Error>(spg_embedded_tokio::AsyncDatabase::open_in_memory()),
Some(p) => spg_embedded_tokio::AsyncDatabase::open_path(p)
.await
.map_err(crate::error::engine_to_sqlx),
}
})
.await?
.clone();
Ok(SpgConnection::new(inner))
})
}
fn log_statements(mut self, level: LevelFilter) -> Self {
self.log_settings.log_statements(level);
self
}
fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
self.log_settings.log_slow_statements(level, duration);
self
}
}