use std::time::Duration;
use crate::error::{ConnectorError as Error, Result};
use crate::{Executor, PgRowStream, Row};
use nautilus_core::Value;
use nautilus_dialect::Sql;
use sqlx::postgres::{PgPool, PgPoolOptions};
pub struct PgExecutor {
pool: PgPool,
}
impl PgExecutor {
pub async fn new(url: &str) -> Result<Self> {
let pool = PgPoolOptions::new()
.max_connections(10)
.min_connections(1)
.acquire_timeout(Duration::from_secs(10))
.idle_timeout(Duration::from_secs(300))
.test_before_acquire(true)
.connect(url)
.await
.map_err(|e| Error::connection(e, "Failed to connect to database"))?;
Ok(Self { pool })
}
pub fn pool(&self) -> &PgPool {
&self.pool
}
pub async fn execute_raw(&self, sql: &str) -> Result<()> {
sqlx::query(sql)
.execute(&self.pool)
.await
.map(|_| ())
.map_err(|e| Error::database(e, "DDL error"))
}
impl_execute_affected!();
}
impl Executor for PgExecutor {
type Row<'conn>
= Row
where
Self: 'conn;
type RowStream<'conn>
= PgRowStream
where
Self: 'conn;
fn execute<'conn>(&'conn self, sql: &'conn Sql) -> Self::RowStream<'conn> {
let pool = self.pool.clone();
let sql_text = sql.text.clone();
let params = sql.params.clone();
let stream = async_stream::stream! {
let mut conn = match pool.acquire().await {
Ok(c) => c,
Err(e) => {
yield Err(Error::connection(e, "Failed to acquire connection"));
return;
}
};
let mut query = sqlx::query(&sql_text);
for param in ¶ms {
query = match bind_value(query, param) {
Ok(q) => q,
Err(e) => {
yield Err(e);
return;
}
};
}
let pg_rows = match query.fetch_all(&mut *conn).await {
Ok(rows) => rows,
Err(e) => {
yield Err(Error::database(e, "Query execution failed"));
return;
}
};
drop(conn);
for pg_row in pg_rows {
match crate::postgres_stream::decode_row_internal(pg_row) {
Ok(row) => yield Ok(row),
Err(e) => yield Err(e),
}
}
};
PgRowStream::new_from_stream(Box::pin(stream))
}
fn execute_and_fetch<'conn>(
&'conn self,
mutation: &'conn Sql,
fetch: &'conn Sql,
) -> Self::RowStream<'conn> {
let pool = self.pool.clone();
let mutation_text = mutation.text.clone();
let mutation_params = mutation.params.clone();
let fetch_text = fetch.text.clone();
let fetch_params = fetch.params.clone();
let stream = async_stream::stream! {
use sqlx::Executor as _;
let mut conn = match pool.acquire().await {
Ok(c) => c,
Err(e) => {
yield Err(Error::connection(e, "Failed to acquire connection"));
return;
}
};
let mut mutation_query = sqlx::query(&mutation_text);
for param in &mutation_params {
mutation_query = match bind_value(mutation_query, param) {
Ok(q) => q,
Err(e) => {
yield Err(e);
return;
}
};
}
if let Err(e) = (&mut *conn).execute(mutation_query).await {
yield Err(Error::database(e, "Mutation failed"));
return;
}
let mut fetch_query = sqlx::query(&fetch_text);
for param in &fetch_params {
fetch_query = match bind_value(fetch_query, param) {
Ok(q) => q,
Err(e) => {
yield Err(e);
return;
}
};
}
let pg_rows = match fetch_query.fetch_all(&mut *conn).await {
Ok(rows) => rows,
Err(e) => {
yield Err(Error::database(e, "Fetch failed"));
return;
}
};
drop(conn);
for pg_row in pg_rows {
match crate::postgres_stream::decode_row_internal(pg_row) {
Ok(row) => yield Ok(row),
Err(e) => yield Err(e),
}
}
};
PgRowStream::new_from_stream(Box::pin(stream))
}
}
#[derive(Debug, Clone, PartialEq)]
enum PgArrayBinding {
Strings(Vec<String>),
I32s(Vec<i32>),
I64s(Vec<i64>),
F64s(Vec<f64>),
Bools(Vec<bool>),
}
fn bindable_pg_array(items: &[Value]) -> Result<Option<PgArrayBinding>> {
let Some(first) = items.first() else {
return Ok(Some(PgArrayBinding::Strings(Vec::new())));
};
match first {
Value::String(_) => {
let mut values = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
match item {
Value::String(value) => values.push(value.clone()),
Value::Null => {
return Err(Error::database_msg(format!(
"PostgreSQL typed array binding does not support NULL element at index {}",
idx
)));
}
other => {
return Err(Error::database_msg(format!(
"PostgreSQL array element at index {} has type {:?}; expected String",
idx, other
)));
}
}
}
Ok(Some(PgArrayBinding::Strings(values)))
}
Value::I32(_) => {
let mut values = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
match item {
Value::I32(value) => values.push(*value),
Value::Null => {
return Err(Error::database_msg(format!(
"PostgreSQL typed array binding does not support NULL element at index {}",
idx
)));
}
other => {
return Err(Error::database_msg(format!(
"PostgreSQL array element at index {} has type {:?}; expected I32",
idx, other
)));
}
}
}
Ok(Some(PgArrayBinding::I32s(values)))
}
Value::I64(_) => {
let mut values = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
match item {
Value::I64(value) => values.push(*value),
Value::Null => {
return Err(Error::database_msg(format!(
"PostgreSQL typed array binding does not support NULL element at index {}",
idx
)));
}
other => {
return Err(Error::database_msg(format!(
"PostgreSQL array element at index {} has type {:?}; expected I64",
idx, other
)));
}
}
}
Ok(Some(PgArrayBinding::I64s(values)))
}
Value::F64(_) => {
let mut values = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
match item {
Value::F64(value) => values.push(*value),
Value::Null => {
return Err(Error::database_msg(format!(
"PostgreSQL typed array binding does not support NULL element at index {}",
idx
)));
}
other => {
return Err(Error::database_msg(format!(
"PostgreSQL array element at index {} has type {:?}; expected F64",
idx, other
)));
}
}
}
Ok(Some(PgArrayBinding::F64s(values)))
}
Value::Bool(_) => {
let mut values = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
match item {
Value::Bool(value) => values.push(*value),
Value::Null => {
return Err(Error::database_msg(format!(
"PostgreSQL typed array binding does not support NULL element at index {}",
idx
)));
}
other => {
return Err(Error::database_msg(format!(
"PostgreSQL array element at index {} has type {:?}; expected Bool",
idx, other
)));
}
}
}
Ok(Some(PgArrayBinding::Bools(values)))
}
_ => Ok(None),
}
}
pub(crate) fn bind_value<'q>(
query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
value: &'q Value,
) -> Result<sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>> {
match value {
Value::Null => Ok(query.bind(None::<String>)),
Value::Bool(b) => Ok(query.bind(b)),
Value::I32(i) => Ok(query.bind(i)),
Value::I64(i) => Ok(query.bind(i)),
Value::F64(f) => Ok(query.bind(f)),
Value::Decimal(d) => Ok(query.bind(d)),
Value::DateTime(dt) => Ok(query.bind(*dt)),
Value::Uuid(u) => Ok(query.bind(*u)),
Value::String(s) => Ok(query.bind(s.as_str())),
Value::Bytes(b) => Ok(query.bind(b.as_slice())),
Value::Json(j) => Ok(query.bind(j.to_string())),
Value::Array(items) => match bindable_pg_array(items)? {
Some(PgArrayBinding::Strings(values)) => Ok(query.bind(values)),
Some(PgArrayBinding::I32s(values)) => Ok(query.bind(values)),
Some(PgArrayBinding::I64s(values)) => Ok(query.bind(values)),
Some(PgArrayBinding::F64s(values)) => Ok(query.bind(values)),
Some(PgArrayBinding::Bools(values)) => Ok(query.bind(values)),
None => {
let strings: Vec<String> = items
.iter()
.map(|v| crate::utils::value_to_json(v).to_string())
.collect();
Ok(query.bind(strings))
}
},
Value::Array2D(_) => {
Ok(query.bind(crate::utils::value_to_json(value).to_string()))
}
Value::Enum { value, .. } => Ok(query.bind(value.as_str())),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bindable_pg_array_keeps_homogeneous_strings() {
let binding = bindable_pg_array(&[
Value::String("a".to_string()),
Value::String("b".to_string()),
])
.expect("string array should bind");
assert_eq!(
binding,
Some(PgArrayBinding::Strings(vec![
"a".to_string(),
"b".to_string()
]))
);
}
#[test]
fn bindable_pg_array_rejects_nulls_in_typed_arrays() {
let err = bindable_pg_array(&[Value::I32(1), Value::Null]).unwrap_err();
assert!(err.to_string().contains("NULL element"));
}
#[test]
fn bindable_pg_array_rejects_mixed_typed_arrays() {
let err =
bindable_pg_array(&[Value::Bool(true), Value::String("nope".to_string())]).unwrap_err();
assert!(err.to_string().contains("expected Bool"));
}
#[test]
fn bindable_pg_array_falls_back_for_unsupported_types() {
let binding = bindable_pg_array(&[Value::Decimal(rust_decimal::Decimal::new(123, 2))])
.expect("unsupported arrays should fall back");
assert_eq!(binding, None);
}
}