use std::collections::HashMap;
use std::fmt::{self, Debug, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use bytes::Bytes;
use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_util::{stream, TryStreamExt};
use futures_util::lock::Mutex;
use crate::column::ColumnOrigin;
use crate::error::{firebird_err, Error};
use crate::options::FirebirdConnectOptions;
use crate::statement::FirebirdStatementMetadata;
use crate::type_info::FirebirdSqlType;
use crate::{
Firebird, FirebirdArguments, FirebirdColumn, FirebirdQueryResult, FirebirdRow,
FirebirdStatement, FirebirdTypeInfo,
};
use sqlx_core::connection::{Connection, LogSettings};
#[cfg(feature = "offline")]
use sqlx_core::describe::Describe;
use sqlx_core::executor::{Execute, Executor};
use sqlx_core::sql_str::SqlStr;
use sqlx_core::transaction::Transaction;
pub(crate) struct SendConn(firebirust::ConnectionAsync);
unsafe impl Send for SendConn {}
unsafe impl Sync for SendConn {}
impl std::ops::Deref for SendConn {
type Target = firebirust::ConnectionAsync;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::DerefMut for SendConn {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
pub(crate) struct AssertSend<F>(pub(crate) F);
unsafe impl<F: Future> Send for AssertSend<F> {}
impl<F: Future> Future for AssertSend<F> {
type Output = F::Output;
fn poll(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
unsafe { self.map_unchecked_mut(|s| &mut s.0).poll(cx) }
}
}
pub struct FirebirdConnection {
pub(crate) inner: Arc<Mutex<SendConn>>,
pub(crate) transaction_depth: usize,
#[allow(dead_code)]
log_settings: LogSettings,
}
impl Debug for FirebirdConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("FirebirdConnection")
.field("transaction_depth", &self.transaction_depth)
.finish()
}
}
impl FirebirdConnection {
pub(crate) async fn establish(options: &FirebirdConnectOptions) -> Result<Self, Error> {
let db_name = options
.database
.as_deref()
.ok_or_else(|| Error::Configuration("database is required".into()))?;
let conn = firebirust::ConnectionAsync::connect(
&options.host,
options.port,
db_name,
&options.username,
&options.password,
&options.conn_options,
)
.await
.map_err(firebird_err)?;
Ok(Self {
inner: Arc::new(Mutex::new(SendConn(conn))),
transaction_depth: 0,
log_settings: options.log_settings.clone(),
})
}
#[allow(dead_code)]
pub(crate) fn run(
&mut self,
sql: SqlStr,
arguments: Option<FirebirdArguments>,
_persistent: bool,
) -> BoxFuture<
'_,
Result<BoxStream<'_, Result<Either<FirebirdQueryResult, FirebirdRow>, Error>>, Error>,
> {
let inner = self.inner.clone();
Box::pin(AssertSend(async move {
let items = run_query_inner(&inner, sql.as_str(), arguments).await?;
Ok(
Box::pin(stream::iter(items.into_iter().map(Ok)))
as BoxStream<'_, Result<Either<FirebirdQueryResult, FirebirdRow>, Error>>,
)
}))
}
}
impl Connection for FirebirdConnection {
type Database = Firebird;
type Options = FirebirdConnectOptions;
fn close(self) -> impl Future<Output = Result<(), Error>> + Send + 'static {
async move {
drop(self);
Ok(())
}
}
fn close_hard(self) -> impl Future<Output = Result<(), Error>> + Send + 'static {
self.close()
}
fn ping(&mut self) -> impl Future<Output = Result<(), Error>> + Send + '_ {
let inner = self.inner.clone();
AssertSend(async move {
let mut conn = inner.lock().await;
let mut stmt = conn
.prepare("SELECT 1 FROM RDB$DATABASE")
.await
.map_err(firebird_err)?;
let _ = stmt.query(()).await.map_err(firebird_err)?;
Ok(())
})
}
fn begin(
&mut self,
) -> impl Future<Output = Result<Transaction<'_, Firebird>, Error>> + Send + '_ {
Transaction::begin(self, None)
}
fn cached_statements_size(&self) -> usize {
0
}
fn clear_cached_statements(
&mut self,
) -> impl Future<Output = Result<(), Error>> + Send + '_ {
async { Ok(()) }
}
fn shrink_buffers(&mut self) {
}
fn flush(&mut self) -> impl Future<Output = Result<(), Error>> + Send + '_ {
async { Ok(()) }
}
fn should_flush(&self) -> bool {
false
}
}
fn map_sqltype(sqltype: u32, sqlscale: i32, sqlsubtype: i32) -> FirebirdTypeInfo {
let fb_type = FirebirdSqlType::from_sqltype(sqltype).unwrap_or(FirebirdSqlType::Varying);
FirebirdTypeInfo::with_scale(fb_type, sqlscale, sqlsubtype)
}
impl<'c> Executor<'c> for &'c mut FirebirdConnection {
type Database = Firebird;
fn fetch_many<'e, 'q: 'e, E>(
self,
mut query: E,
) -> BoxStream<'e, Result<Either<FirebirdQueryResult, FirebirdRow>, Error>>
where
'c: 'e,
E: 'q + Execute<'q, Self::Database>,
{
let arguments = query.take_arguments().map_err(Error::Encode);
let sql = query.sql().as_str().to_string();
let inner = self.inner.clone();
Box::pin(
stream::once(AssertSend(async move {
let arguments = arguments?;
let items = run_query_inner(&inner, &sql, arguments).await?;
Ok::<_, Error>(stream::iter(items.into_iter().map(Ok)))
}))
.try_flatten(),
)
}
fn fetch_optional<'e, 'q: 'e, E>(
self,
query: E,
) -> BoxFuture<'e, Result<Option<FirebirdRow>, Error>>
where
'c: 'e,
E: 'q + Execute<'q, Self::Database>,
{
let mut s = self.fetch_many(query);
Box::pin(async move {
while let Some(v) = s.try_next().await? {
if let Either::Right(r) = v {
return Ok(Some(r));
}
}
Ok(None)
})
}
fn prepare_with<'e>(
self,
sql: SqlStr,
_parameters: &'e [FirebirdTypeInfo],
) -> BoxFuture<'e, Result<FirebirdStatement, Error>>
where
'c: 'e,
{
let inner = self.inner.clone();
Box::pin(AssertSend(async move {
let mut conn = inner.lock().await;
let stmt = conn.prepare(sql.as_str()).await.map_err(firebird_err)?;
let col_count = stmt.column_count();
let names = stmt.column_names();
let mut columns = Vec::with_capacity(col_count);
let mut column_names = HashMap::with_capacity(col_count);
let mut nullable = Vec::with_capacity(col_count);
for i in 0..col_count {
if let Some((
sqltype, sqlscale, sqlsubtype, _sqllen, null_ok, _fieldname, _relname, _ownname,
)) = stmt.column_metadata(i)
{
let name = names.get(i).map(|s| s.to_string()).unwrap_or_default();
column_names.insert(name.clone(), i);
columns.push(FirebirdColumn {
ordinal: i,
name,
type_info: map_sqltype(sqltype, sqlscale, sqlsubtype),
origin: ColumnOrigin::default(),
});
nullable.push(Some(null_ok));
}
}
let parameters = sql.as_str().chars().filter(|&c| c == '?').count();
Ok(FirebirdStatement {
sql,
metadata: FirebirdStatementMetadata {
columns: Arc::new(columns),
column_names: Arc::new(column_names),
parameters,
nullable,
},
})
}))
}
#[cfg(feature = "offline")]
fn describe<'e>(
self,
sql: SqlStr,
) -> BoxFuture<'e, Result<Describe<Firebird>, Error>>
where
'c: 'e,
{
Box::pin(async move {
let statement = Executor::prepare_with(self, sql, &[]).await?;
Ok(Describe {
columns: (*statement.metadata.columns).clone(),
parameters: Some(Either::Right(statement.metadata.parameters)),
nullable: statement.metadata.nullable,
})
})
}
}
async fn run_query_inner(
conn: &Arc<Mutex<SendConn>>,
sql: &str,
arguments: Option<FirebirdArguments>,
) -> Result<Vec<Either<FirebirdQueryResult, FirebirdRow>>, Error> {
let mut conn = conn.lock().await;
let mut stmt = conn.prepare(sql).await.map_err(firebird_err)?;
let col_count = stmt.column_count();
let names = stmt.column_names();
let mut columns = Vec::with_capacity(col_count);
let mut column_names = HashMap::with_capacity(col_count);
for i in 0..col_count {
if let Some((
sqltype, sqlscale, sqlsubtype, _sqllen, _null_ok, _fieldname, _relname, _ownname,
)) = stmt.column_metadata(i)
{
let name = names.get(i).map(|s| s.to_string()).unwrap_or_default();
column_names.insert(name.clone(), i);
columns.push(FirebirdColumn {
ordinal: i,
name,
type_info: map_sqltype(sqltype, sqlscale, sqlsubtype),
origin: ColumnOrigin::default(),
});
}
}
let columns = Arc::new(columns);
let column_names = Arc::new(column_names);
let params: Vec<firebirust::Param> = match arguments {
Some(args) => args.params,
None => vec![],
};
let param_refs: Vec<&dyn firebirust::ToSqlParam> = params
.iter()
.map(|p| p as &dyn firebirust::ToSqlParam)
.collect();
let result = stmt
.query(param_refs.as_slice())
.await
.map_err(firebird_err)?;
let mut items: Vec<Either<FirebirdQueryResult, FirebirdRow>> = Vec::new();
let mut row_count: u64 = 0;
for fb_row in result {
let mut values = Vec::with_capacity(col_count);
for i in 0..col_count {
let bytes: Option<Bytes> = match columns[i].type_info.r#type {
FirebirdSqlType::Short => fb_row
.get::<Option<i16>>(i)
.ok()
.flatten()
.map(|v| Bytes::copy_from_slice(&v.to_le_bytes())),
FirebirdSqlType::Long => fb_row
.get::<Option<i32>>(i)
.ok()
.flatten()
.map(|v| Bytes::copy_from_slice(&v.to_le_bytes())),
FirebirdSqlType::Int64 => {
#[cfg(feature = "rust_decimal")]
if columns[i].type_info.sqlscale < 0 {
fb_row
.get::<Option<rust_decimal::Decimal>>(i)
.ok()
.flatten()
.map(|v: rust_decimal::Decimal| Bytes::from(v.to_string()))
} else {
fb_row
.get::<Option<i64>>(i)
.ok()
.flatten()
.map(|v| Bytes::copy_from_slice(&v.to_le_bytes()))
}
#[cfg(not(feature = "rust_decimal"))]
{
fb_row
.get::<Option<i64>>(i)
.ok()
.flatten()
.map(|v| Bytes::copy_from_slice(&v.to_le_bytes()))
}
}
FirebirdSqlType::Int128 => {
fb_row
.get::<Option<String>>(i)
.ok()
.flatten()
.map(Bytes::from)
}
FirebirdSqlType::Float => fb_row
.get::<Option<f32>>(i)
.ok()
.flatten()
.map(|v| Bytes::copy_from_slice(&v.to_le_bytes())),
FirebirdSqlType::Double => fb_row
.get::<Option<f64>>(i)
.ok()
.flatten()
.map(|v| Bytes::copy_from_slice(&v.to_le_bytes())),
FirebirdSqlType::Boolean => fb_row
.get::<Option<bool>>(i)
.ok()
.flatten()
.map(|v| Bytes::copy_from_slice(&[v as u8])),
FirebirdSqlType::Text | FirebirdSqlType::Varying => fb_row
.get::<Option<String>>(i)
.ok()
.flatten()
.map(Bytes::from),
FirebirdSqlType::Blob => fb_row
.get::<Option<Vec<u8>>>(i)
.ok()
.flatten()
.map(Bytes::from),
#[cfg(feature = "chrono")]
FirebirdSqlType::Date => fb_row
.get::<Option<chrono::NaiveDate>>(i)
.ok()
.flatten()
.map(|v: chrono::NaiveDate| Bytes::from(v.to_string())),
#[cfg(not(feature = "chrono"))]
FirebirdSqlType::Date => fb_row
.get::<Option<String>>(i)
.ok()
.flatten()
.map(Bytes::from),
#[cfg(feature = "chrono")]
FirebirdSqlType::Time => fb_row
.get::<Option<chrono::NaiveTime>>(i)
.ok()
.flatten()
.map(|v: chrono::NaiveTime| Bytes::from(v.to_string())),
#[cfg(not(feature = "chrono"))]
FirebirdSqlType::Time => fb_row
.get::<Option<String>>(i)
.ok()
.flatten()
.map(Bytes::from),
#[cfg(feature = "chrono")]
FirebirdSqlType::Timestamp => fb_row
.get::<Option<chrono::NaiveDateTime>>(i)
.ok()
.flatten()
.map(|v: chrono::NaiveDateTime| Bytes::from(v.to_string())),
#[cfg(not(feature = "chrono"))]
FirebirdSqlType::Timestamp => fb_row
.get::<Option<String>>(i)
.ok()
.flatten()
.map(Bytes::from),
#[cfg(feature = "chrono")]
FirebirdSqlType::TimestampTz => fb_row
.get::<Option<chrono::NaiveDateTime>>(i)
.ok()
.flatten()
.map(|v: chrono::NaiveDateTime| Bytes::from(v.to_string())),
#[cfg(not(feature = "chrono"))]
FirebirdSqlType::TimestampTz => fb_row
.get::<Option<String>>(i)
.ok()
.flatten()
.map(Bytes::from),
#[cfg(feature = "rust_decimal")]
FirebirdSqlType::DecFixed | FirebirdSqlType::Dec64 | FirebirdSqlType::Dec128 => fb_row
.get::<Option<rust_decimal::Decimal>>(i)
.ok()
.flatten()
.map(|v: rust_decimal::Decimal| Bytes::from(v.to_string())),
#[cfg(not(feature = "rust_decimal"))]
FirebirdSqlType::DecFixed | FirebirdSqlType::Dec64 | FirebirdSqlType::Dec128 => fb_row
.get::<Option<String>>(i)
.ok()
.flatten()
.map(Bytes::from),
_ => fb_row
.get::<Option<String>>(i)
.ok()
.flatten()
.map(Bytes::from),
};
values.push(bytes);
}
items.push(Either::Right(FirebirdRow {
values,
columns: columns.clone(),
column_names: column_names.clone(),
}));
row_count += 1;
}
items.insert(
0,
Either::Left(FirebirdQueryResult {
rows_affected: row_count,
}),
);
Ok(items)
}