use std::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use crate::config::QueryConfig;
use crate::entity::SqlEntity;
use crate::error::SqlError;
use crate::params::{OdbcParam, ParamValue};
use crate::pool::{Pool, PooledConn};
use crate::pool::metrics::MetricsSnapshot;
use crate::row::OdbcRow;
#[derive(Debug, Clone)]
pub struct Single<T>(pub Option<T>);
#[derive(Debug, Clone)]
pub struct Required<T>(pub T);
#[derive(Debug, Clone)]
pub struct Scalar<T>(pub Option<T>);
pub trait FromResultSet: Sized {
fn from_result_set(rows: Vec<OdbcRow>) -> Result<Self, SqlError>;
}
impl<T: SqlEntity> FromResultSet for Vec<T> {
#[inline]
fn from_result_set(rows: Vec<OdbcRow>) -> Result<Self, SqlError> {
rows.iter().map(T::from_row).collect()
}
}
impl<T: SqlEntity> FromResultSet for Single<T> {
#[inline]
fn from_result_set(rows: Vec<OdbcRow>) -> Result<Self, SqlError> {
let val = rows
.into_iter()
.next()
.map(|r| T::from_row(&r))
.transpose()?;
Ok(Single(val))
}
}
impl<T: SqlEntity> FromResultSet for Required<T> {
#[inline]
fn from_result_set(rows: Vec<OdbcRow>) -> Result<Self, SqlError> {
let row = rows
.into_iter()
.next()
.ok_or_else(|| SqlError::config("query_required: sproc returned no rows"))?;
Ok(Required(T::from_row(&row)?))
}
}
impl<S> FromResultSet for Scalar<S>
where
S: std::str::FromStr,
S::Err: std::fmt::Display,
{
#[inline]
fn from_result_set(rows: Vec<OdbcRow>) -> Result<Self, SqlError> {
let val = rows
.into_iter()
.next()
.and_then(|r| r.get_first_string().ok())
.map(|s| {
s.trim()
.parse::<S>()
.map_err(|e| SqlError::config(e.to_string()))
})
.transpose()?;
Ok(Scalar(val))
}
}
#[derive(Debug, Default, Clone)]
pub struct SprocParams {
params: Vec<(String, ParamValue)>,
}
impl SprocParams {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn add(mut self, name: &str, value: impl Into<ParamValue>) -> Self {
let key = strip_at(name);
self.params.push((key, value.into()));
self
}
#[inline]
pub fn add_nullable<V: Into<ParamValue>>(mut self, name: &str, value: Option<V>) -> Self {
let key = strip_at(name);
let pv = match value {
Some(v) => v.into(),
None => ParamValue::Null,
};
self.params.push((key, pv));
self
}
pub(crate) fn into_exec(self, sproc_name: &str) -> (String, Vec<OdbcParam>) {
let extra = self.params.iter().map(|(n, _)| n.len() * 2 + 8).sum::<usize>();
let mut sql = String::with_capacity(5 + sproc_name.len() + extra);
sql.push_str("EXEC ");
sql.push_str(sproc_name);
for (i, (n, _)) in self.params.iter().enumerate() {
if i == 0 { sql.push(' '); } else { sql.push_str(", "); }
sql.push('@');
sql.push_str(n);
sql.push_str(" = @");
sql.push_str(n);
}
let odbc_params = self
.params
.into_iter()
.map(|(name, value)| {
let static_name: &'static str =
Box::leak(name.into_boxed_str());
OdbcParam::new(static_name, value)
})
.collect();
(sql, odbc_params)
}
}
fn strip_at(name: &str) -> String {
name.strip_prefix('@').unwrap_or(name).to_owned()
}
#[derive(Debug)]
pub struct MultiReader {
sets: Vec<Vec<OdbcRow>>,
idx: usize,
}
impl MultiReader {
pub(crate) fn new(sets: Vec<Vec<OdbcRow>>) -> Self {
Self { sets, idx: 0 }
}
fn next_set(&mut self) -> Vec<OdbcRow> {
if self.idx < self.sets.len() {
let set = std::mem::take(&mut self.sets[self.idx]);
self.idx += 1;
set
} else {
Vec::new()
}
}
pub fn read_list<T: SqlEntity>(&mut self) -> Result<Vec<T>, SqlError> {
let rows = self.next_set();
rows.iter().map(T::from_row).collect()
}
pub fn read_single<T: SqlEntity>(&mut self) -> Result<Option<T>, SqlError> {
let rows = self.next_set();
rows.into_iter().next().map(|r| T::from_row(&r)).transpose()
}
pub fn read_required<T: SqlEntity>(&mut self) -> Result<T, SqlError> {
let rows = self.next_set();
let row = rows
.into_iter()
.next()
.ok_or_else(|| SqlError::config("read_required: result set is empty"))?;
T::from_row(&row)
}
pub fn read_scalar<S>(&mut self) -> Result<Option<S>, SqlError>
where
S: std::str::FromStr,
S::Err: std::fmt::Display,
{
let rows = self.next_set();
rows.into_iter()
.next()
.and_then(|r| r.get_first_string().ok())
.map(|s| {
s.trim()
.parse::<S>()
.map_err(|e| SqlError::config(e.to_string()))
})
.transpose()
}
pub fn read_raw(&mut self) -> Vec<OdbcRow> {
self.next_set()
}
}
#[derive(Debug, Clone)]
pub struct SprocResult<T = ()> {
pub success: bool,
pub error_code: Option<String>,
pub error_message: Option<String>,
pub data: Option<T>,
}
impl<T> SprocResult<T> {
pub fn ok(data: T) -> Self {
Self {
success: true,
error_code: None,
error_message: None,
data: Some(data),
}
}
pub fn fail(
error_code: Option<String>,
error_message: Option<String>,
) -> Self {
Self {
success: false,
error_code,
error_message,
data: None,
}
}
pub fn is_success(&self) -> bool {
self.success
}
}
impl SprocResult<()> {
pub fn ok_unit() -> Self {
Self {
success: true,
error_code: None,
error_message: None,
data: None,
}
}
}
#[derive(Debug, Clone)]
pub struct SprocPagedResult<T> {
pub items: Vec<T>,
pub total_count: i64,
pub page_number: i32,
pub page_size: i32,
}
pub struct SprocService {
pool: Pool,
query_cfg: QueryConfig,
}
impl SprocService {
pub fn new(pool: Pool, query_cfg: QueryConfig) -> Self {
Self { pool, query_cfg }
}
pub fn pool_metrics(&self) -> MetricsSnapshot {
self.pool.metrics()
}
async fn checkout(&self, token: &CancellationToken) -> Result<PooledConn, SqlError> {
self.pool.checkout(token).await
}
fn slow_warn(&self, sql: &str, elapsed_ms: u64) {
if elapsed_ms >= self.query_cfg.slow_query_threshold_ms {
warn!(
elapsed_ms,
sql = &sql[..sql.len().min(120)],
"Slow sproc"
);
}
}
async fn run_multiple(
&self,
sql: &str,
params: &[OdbcParam],
token: &CancellationToken,
) -> Result<Vec<Vec<OdbcRow>>, SqlError> {
let mut conn = self.checkout(token).await?;
let start = Instant::now();
let sql_owned = sql.to_owned();
let params_owned: Vec<OdbcParam> = params.to_vec();
let max_text_bytes = self.query_cfg.max_text_bytes;
let result = tokio::select! {
biased;
_ = token.cancelled() => Err(SqlError::Cancelled),
res = tokio::task::spawn_blocking(move || {
conn.execute_multiple_query_sync(&sql_owned, ¶ms_owned, max_text_bytes)
}) => res.map_err(|e| SqlError::config(e.to_string()))?,
};
let elapsed = start.elapsed().as_millis() as u64;
self.slow_warn(sql, elapsed);
debug!(elapsed_ms = elapsed, sql = &sql[..sql.len().min(80)], "Sproc executed");
result
}
pub async fn query<R: FromResultSet>(
&self,
name: &str,
params: SprocParams,
token: &CancellationToken,
) -> Result<R, SqlError> {
let (sql, odbc_params) = params.into_exec(name);
let sets = self.run_multiple(&sql, &odbc_params, token).await?;
let first = sets.into_iter().next().unwrap_or_default();
R::from_result_set(first)
}
pub async fn query2<R1: FromResultSet, R2: FromResultSet>(
&self,
name: &str,
params: SprocParams,
token: &CancellationToken,
) -> Result<(R1, R2), SqlError> {
let (sql, odbc_params) = params.into_exec(name);
let sets = self.run_multiple(&sql, &odbc_params, token).await?;
let mut it = sets.into_iter();
let s0 = it.next().unwrap_or_default();
let s1 = it.next().unwrap_or_default();
Ok((R1::from_result_set(s0)?, R2::from_result_set(s1)?))
}
pub async fn query3<R1: FromResultSet, R2: FromResultSet, R3: FromResultSet>(
&self,
name: &str,
params: SprocParams,
token: &CancellationToken,
) -> Result<(R1, R2, R3), SqlError> {
let (sql, odbc_params) = params.into_exec(name);
let sets = self.run_multiple(&sql, &odbc_params, token).await?;
let mut it = sets.into_iter();
let s0 = it.next().unwrap_or_default();
let s1 = it.next().unwrap_or_default();
let s2 = it.next().unwrap_or_default();
Ok((
R1::from_result_set(s0)?,
R2::from_result_set(s1)?,
R3::from_result_set(s2)?,
))
}
pub async fn query4<
R1: FromResultSet,
R2: FromResultSet,
R3: FromResultSet,
R4: FromResultSet,
>(
&self,
name: &str,
params: SprocParams,
token: &CancellationToken,
) -> Result<(R1, R2, R3, R4), SqlError> {
let (sql, odbc_params) = params.into_exec(name);
let sets = self.run_multiple(&sql, &odbc_params, token).await?;
let mut it = sets.into_iter();
let s0 = it.next().unwrap_or_default();
let s1 = it.next().unwrap_or_default();
let s2 = it.next().unwrap_or_default();
let s3 = it.next().unwrap_or_default();
Ok((
R1::from_result_set(s0)?,
R2::from_result_set(s1)?,
R3::from_result_set(s2)?,
R4::from_result_set(s3)?,
))
}
pub async fn query_multiple(
&self,
name: &str,
params: SprocParams,
token: &CancellationToken,
) -> Result<MultiReader, SqlError> {
let (sql, odbc_params) = params.into_exec(name);
let sets = self.run_multiple(&sql, &odbc_params, token).await?;
Ok(MultiReader::new(sets))
}
pub async fn execute(
&self,
name: &str,
params: SprocParams,
token: &CancellationToken,
) -> Result<usize, SqlError> {
let mut conn = self.checkout(token).await?;
let (sql, odbc_params) = params.into_exec(name);
let sql_owned = sql.clone();
let result = tokio::select! {
biased;
_ = token.cancelled() => Err(SqlError::Cancelled),
res = tokio::task::spawn_blocking(move || {
conn.execute_non_query_sync(&sql_owned, &odbc_params)
}) => res.map_err(|e| SqlError::config(e.to_string()))?,
};
self.slow_warn(&sql, 0);
result
}
}