use std::sync::Arc;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use sqlx_core::HashMap;
use sqlx_core::connection::Connection;
use sqlx_core::error::Error;
use sqlx_core::executor::Executor;
use sqlx_core::transaction::Transaction;
use spg_embedded::QueryResult as EngineQueryResult;
use spg_embedded_tokio::AsyncDatabase;
use crate::column::SpgColumn;
use crate::database::Spg;
use crate::error::engine_to_sqlx;
use crate::options::SpgConnectOptions;
use crate::query_result::SpgQueryResult;
use crate::row::SpgRow;
use crate::type_info::SpgTypeInfo;
#[derive(Debug, Clone)]
pub(crate) struct CachedStmt {
pub(crate) readonly: bool,
pub(crate) stmt: std::sync::Arc<spg_embedded::Statement>,
}
const STMT_CACHE_CAP: usize = 256;
#[derive(Debug)]
pub struct SpgConnection {
pub(crate) inner: AsyncDatabase,
pub(crate) stmt_cache: HashMap<String, CachedStmt>,
pub(crate) tx_depth: usize,
pub(crate) pending_rollback: bool,
}
impl Clone for SpgConnection {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
stmt_cache: HashMap::new(),
tx_depth: self.tx_depth,
pending_rollback: self.pending_rollback,
}
}
}
impl SpgConnection {
pub fn new(inner: AsyncDatabase) -> Self {
Self {
inner,
stmt_cache: HashMap::new(),
tx_depth: 0,
pending_rollback: false,
}
}
#[must_use]
pub const fn engine(&self) -> &AsyncDatabase {
&self.inner
}
pub(crate) async fn cached_stmt(
&mut self,
sql: &str,
) -> Result<CachedStmt, spg_embedded::EngineError> {
if let Some(c) = self.stmt_cache.get(sql) {
return Ok(c.clone());
}
let readonly = spg_embedded::Engine::is_readonly_sql(sql);
let snap = self.inner.clone_snapshot_inline().await;
let stmt = spg_embedded::Database::prepare_on_snapshot(&snap, sql)?;
let cached = CachedStmt {
readonly,
stmt: std::sync::Arc::new(stmt),
};
if self.stmt_cache.len() >= STMT_CACHE_CAP {
self.stmt_cache.clear();
}
self.stmt_cache.insert(sql.to_string(), cached.clone());
Ok(cached)
}
}
impl Connection for SpgConnection {
type Database = Spg;
type Options = SpgConnectOptions;
fn close(self) -> BoxFuture<'static, Result<(), Error>> {
Box::pin(async move { Ok(()) })
}
fn close_hard(self) -> BoxFuture<'static, Result<(), Error>> {
Box::pin(async move { Ok(()) })
}
fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
self.inner
.execute("SELECT 1")
.await
.map_err(engine_to_sqlx)?;
Ok(())
})
}
fn begin(&mut self) -> BoxFuture<'_, Result<Transaction<'_, Self::Database>, Error>>
where
Self: Sized,
{
Transaction::begin(self, None)
}
fn shrink_buffers(&mut self) {
}
fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move { Ok(()) })
}
fn should_flush(&self) -> bool {
false
}
}
impl<'c> Executor<'c> for &'c mut SpgConnection {
type Database = Spg;
fn fetch_many<'e, 'q: 'e, E>(
self,
mut query: E,
) -> BoxStream<
'e,
Result<
either::Either<
<Self::Database as sqlx_core::database::Database>::QueryResult,
crate::SpgRow,
>,
Error,
>,
>
where
'c: 'e,
E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
{
use futures_util::stream::{self, StreamExt};
let sql = query.sql().to_string();
let arguments = match query.take_arguments() {
Ok(args) => args,
Err(e) => {
return Box::pin(stream::iter(std::iter::once(Err(Error::Encode(e)))));
}
};
let outcome_fut = async move {
match arguments {
Some(args) => run_one(self, &sql, Some(args)).await.map(|o| vec![o]),
None => Ok(self
.inner
.execute_script(&sql)
.await
.map_err(engine_to_sqlx)?
.into_iter()
.map(outcome_from)
.collect()),
}
};
Box::pin(stream::once(outcome_fut).flat_map(|outcome| {
let items: Vec<Result<either::Either<SpgQueryResult, SpgRow>, Error>> = match outcome {
Ok(outcomes) => outcomes
.into_iter()
.flat_map(|o| match o {
Outcome::Affected(qr) => vec![Ok(either::Either::Left(qr))],
Outcome::Rows(rows) => rows
.into_iter()
.map(|r| Ok(either::Either::Right(r)))
.collect::<Vec<_>>(),
})
.collect(),
Err(e) => vec![Err(e)],
};
stream::iter(items)
}))
}
fn fetch_optional<'e, 'q: 'e, E>(
self,
mut query: E,
) -> BoxFuture<'e, Result<Option<crate::SpgRow>, Error>>
where
'c: 'e,
E: 'q + sqlx_core::executor::Execute<'q, Self::Database>,
{
let sql = query.sql().to_string();
let arguments = query.take_arguments();
Box::pin(async move {
let args = arguments.map_err(Error::Encode)?;
match run_one(self, &sql, args).await? {
Outcome::Rows(mut rows) => Ok(rows.drain(..).next()),
Outcome::Affected(_) => Ok(None),
}
})
}
fn prepare_with<'e, 'q: 'e>(
self,
sql: &'q str,
_parameters: &'e [<Self::Database as sqlx_core::database::Database>::TypeInfo],
) -> BoxFuture<
'e,
Result<<Self::Database as sqlx_core::database::Database>::Statement<'q>, Error>,
>
where
'c: 'e,
{
let inner = self.inner.clone();
let sql_str = sql.to_string();
Box::pin(async move {
let stmt = inner.prepare(&sql_str).await.map_err(engine_to_sqlx)?;
let inner_stmt = spg_embedded_tokio::async_statement_inner(&stmt);
Ok(crate::SpgStatement {
sql: std::borrow::Cow::Owned(sql_str),
inner: Some(inner_stmt),
columns: std::sync::Arc::new(Vec::new()),
by_name: std::sync::Arc::new(sqlx_core::HashMap::new()),
})
})
}
fn describe<'e, 'q: 'e>(
self,
sql: &'q str,
) -> BoxFuture<'e, Result<sqlx_core::describe::Describe<Self::Database>, Error>>
where
'c: 'e,
{
let inner = self.inner.clone();
let sql_str = sql.to_string();
Box::pin(async move {
let (params, cols) = inner.describe(&sql_str).await.map_err(engine_to_sqlx)?;
let nullable: Vec<Option<bool>> = cols.iter().map(|c| Some(c.nullable)).collect();
let columns: Vec<SpgColumn> = cols
.iter()
.enumerate()
.map(|(i, c)| {
let ti = SpgTypeInfo::from_data_type(c.ty);
SpgColumn::new(i, c.name.clone(), ti)
})
.collect();
let parameters = if params.is_empty() {
None
} else {
Some(either::Either::Right(params.len()))
};
Ok(sqlx_core::describe::Describe {
columns,
parameters,
nullable,
})
})
}
}
enum Outcome {
Affected(SpgQueryResult),
Rows(Vec<SpgRow>),
}
async fn run_one(
conn: &mut SpgConnection,
sql: &str,
arguments: Option<crate::SpgArguments<'_>>,
) -> Result<Outcome, Error> {
let in_tx = conn.tx_depth > 0;
let cached = if in_tx {
None
} else {
conn.cached_stmt(sql).await.ok()
};
let result: EngineQueryResult = if let Some(c) = cached.filter(|c| c.readonly) {
let snap = conn.inner.clone_snapshot_inline().await;
let params = arguments.map(crate::SpgArguments::into_engine_values);
let budget_ms: u64 = std::env::var("SPG_SQLX_INLINE_BUDGET_MS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(1000);
let started = std::time::Instant::now();
let inline = spg_embedded::Database::execute_prepared_on_snapshot_with_budget(
&snap,
&c.stmt,
params.as_deref().unwrap_or(&[]),
budget_ms.saturating_mul(1_000),
);
match inline {
Ok(r) => r,
Err(spg_embedded::EngineError::Cancelled) => {
let stmt = c.stmt.clone();
let params_owned: Vec<spg_embedded::Value> =
params.as_deref().unwrap_or(&[]).to_vec();
let result = tokio::task::spawn_blocking(move || {
spg_embedded::Database::execute_prepared_on_snapshot(
&snap,
&stmt,
¶ms_owned,
)
})
.await
.map_err(|e| Error::Protocol(format!("blocking-pool join: {e}")))?
.map_err(engine_to_sqlx)?;
let elapsed_ms = started.elapsed().as_millis();
eprintln!(
"spg-sqlx: readonly query exceeded the {budget_ms} ms inline budget; \
continuing on the blocking pool: elapsed_ms={elapsed_ms} sql={}",
&sql[..sql.len().min(120)]
);
result
}
Err(e) => return Err(engine_to_sqlx(e)),
}
} else {
let db = &conn.inner;
if let Some(args) = arguments {
let stmt = db.prepare(sql).await.map_err(engine_to_sqlx)?;
db.execute_prepared(&stmt, args.into_engine_values())
.await
.map_err(engine_to_sqlx)?
} else {
db.execute(sql).await.map_err(engine_to_sqlx)?
}
};
Ok(outcome_from(result))
}
fn outcome_from(result: EngineQueryResult) -> Outcome {
match result {
EngineQueryResult::Rows { columns, rows } => {
let row_values: Vec<Vec<spg_embedded::Value>> =
rows.into_iter().map(|r| r.values).collect();
Outcome::Rows(build_rows(&columns, row_values))
}
EngineQueryResult::CommandOk { affected, .. } => {
Outcome::Affected(SpgQueryResult::new(u64::try_from(affected).unwrap_or(0)))
}
_ => Outcome::Affected(SpgQueryResult::default()),
}
}
#[allow(dead_code)]
fn affected_from(qr: &EngineQueryResult) -> u64 {
match qr {
EngineQueryResult::CommandOk { affected, .. } => u64::try_from(*affected).unwrap_or(0),
EngineQueryResult::Rows { rows, .. } => u64::try_from(rows.len()).unwrap_or(0),
_ => 0,
}
}
fn build_rows(
cols: &[spg_embedded::ColumnSchema],
rows: Vec<Vec<spg_embedded::Value>>,
) -> Vec<SpgRow> {
let columns: Arc<Vec<SpgColumn>> = Arc::new(
cols.iter()
.enumerate()
.map(|(i, c)| SpgColumn::new(i, c.name.clone(), SpgTypeInfo::from_data_type(c.ty)))
.collect(),
);
let mut by_name: HashMap<String, usize> = HashMap::new();
for (i, c) in cols.iter().enumerate() {
by_name.insert(c.name.clone(), i);
}
let by_name = Arc::new(by_name);
rows.into_iter()
.map(|values| SpgRow::new(Arc::clone(&columns), Arc::clone(&by_name), values))
.collect()
}