nautilus-orm-connector 0.1.4

Database executors and connection management for Nautilus ORM
Documentation
//! PostgreSQL executor implementation.

use std::time::Duration;

use crate::error::{ConnectorError as Error, Result};
use crate::{Executor, PgRowStream, Row};
use nautilus_core::Value;
use nautilus_dialect::Sql;
use sqlx::postgres::{PgPool, PgPoolOptions};

/// PostgreSQL executor using sqlx.
///
/// Manages a connection pool and executes queries against PostgreSQL databases.
///
/// ## Example
///
/// ```rust,ignore
/// use nautilus_connector::PgExecutor;
///
/// #[tokio::main]
/// async fn main() -> nautilus_core::Result<()> {
///     let executor = PgExecutor::new("postgres://user:pass@localhost/mydb").await?;
///     // Use executor to run queries...
///     Ok(())
/// }
/// ```
pub struct PgExecutor {
    pool: PgPool,
}

impl PgExecutor {
    /// Create a new PostgreSQL executor with a connection pool.
    ///
    /// ## Parameters
    ///
    /// - `url`: PostgreSQL connection URL (e.g., `postgres://user:pass@localhost/dbname`)
    ///
    /// ## Errors
    ///
    /// Returns `ConnectorError::Connection` if the pool cannot be created or if
    /// an initial connection test fails.
    pub async fn new(url: &str) -> Result<Self> {
        let pool = PgPoolOptions::new()
            .max_connections(10)
            .min_connections(1)
            .acquire_timeout(Duration::from_secs(10))
            .idle_timeout(Duration::from_secs(300))
            .test_before_acquire(true)
            .connect(url)
            .await
            .map_err(|e| Error::connection(e, "Failed to connect to database"))?;

        Ok(Self { pool })
    }

    /// Get a reference to the underlying connection pool.
    pub fn pool(&self) -> &PgPool {
        &self.pool
    }

    /// Execute a raw SQL statement with no result rows (e.g., DDL).
    pub async fn execute_raw(&self, sql: &str) -> Result<()> {
        sqlx::query(sql)
            .execute(&self.pool)
            .await
            .map(|_| ())
            .map_err(|e| Error::database(e, "DDL error"))
    }

    impl_execute_affected!();
}

/// [`Executor`] implementation backed by a PostgreSQL connection pool.
impl Executor for PgExecutor {
    type Row<'conn>
        = Row
    where
        Self: 'conn;
    type RowStream<'conn>
        = PgRowStream
    where
        Self: 'conn;

    fn execute<'conn>(&'conn self, sql: &'conn Sql) -> Self::RowStream<'conn> {
        let pool = self.pool.clone();
        let sql_text = sql.text.clone();
        let params = sql.params.clone();

        let stream = async_stream::stream! {
            let mut conn = match pool.acquire().await {
                Ok(c) => c,
                Err(e) => {
                    yield Err(Error::connection(e, "Failed to acquire connection"));
                    return;
                }
            };

            let mut query = sqlx::query(&sql_text);
            for param in &params {
                query = match bind_value(query, param) {
                    Ok(q) => q,
                    Err(e) => {
                        yield Err(e);
                        return;
                    }
                };
            }

            // Fetch ALL rows at once so the connection completes the full
            // PostgreSQL extended-query cycle (portal close + ReadyForQuery)
            // before being returned to the pool.  The previous streaming
            // approach (`query.fetch`) could leave the connection with an
            // open portal when the async_stream generator was dropped
            // mid-iteration, causing sqlx to discard the "dirty" connection
            // and eventually exhaust the pool.
            let pg_rows = match query.fetch_all(&mut *conn).await {
                Ok(rows) => rows,
                Err(e) => {
                    yield Err(Error::database(e, "Query execution failed"));
                    return;
                }
            };

            drop(conn);

            for pg_row in pg_rows {
                match crate::postgres_stream::decode_row_internal(pg_row) {
                    Ok(row) => yield Ok(row),
                    Err(e) => yield Err(e),
                }
            }
        };

        PgRowStream::new_from_stream(Box::pin(stream))
    }

    fn execute_and_fetch<'conn>(
        &'conn self,
        mutation: &'conn Sql,
        fetch: &'conn Sql,
    ) -> Self::RowStream<'conn> {
        let pool = self.pool.clone();
        let mutation_text = mutation.text.clone();
        let mutation_params = mutation.params.clone();
        let fetch_text = fetch.text.clone();
        let fetch_params = fetch.params.clone();

        let stream = async_stream::stream! {
            use sqlx::Executor as _;

            let mut conn = match pool.acquire().await {
                Ok(c) => c,
                Err(e) => {
                    yield Err(Error::connection(e, "Failed to acquire connection"));
                    return;
                }
            };

            let mut mutation_query = sqlx::query(&mutation_text);
            for param in &mutation_params {
                mutation_query = match bind_value(mutation_query, param) {
                    Ok(q) => q,
                    Err(e) => {
                        yield Err(e);
                        return;
                    }
                };
            }

            if let Err(e) = (&mut *conn).execute(mutation_query).await {
                yield Err(Error::database(e, "Mutation failed"));
                return;
            }

            let mut fetch_query = sqlx::query(&fetch_text);
            for param in &fetch_params {
                fetch_query = match bind_value(fetch_query, param) {
                    Ok(q) => q,
                    Err(e) => {
                        yield Err(e);
                        return;
                    }
                };
            }

            let pg_rows = match fetch_query.fetch_all(&mut *conn).await {
                Ok(rows) => rows,
                Err(e) => {
                    yield Err(Error::database(e, "Fetch failed"));
                    return;
                }
            };

            drop(conn);

            for pg_row in pg_rows {
                match crate::postgres_stream::decode_row_internal(pg_row) {
                    Ok(row) => yield Ok(row),
                    Err(e) => yield Err(e),
                }
            }
        };

        PgRowStream::new_from_stream(Box::pin(stream))
    }
}

/// Binds a [`Value`] to a PostgreSQL sqlx query as a typed parameter.
///
/// Uses native binding for `Decimal`, `DateTime`, and `Uuid` (PG-specific).
/// Array values are bound as typed slices when the element type is known; unknown
/// or mixed-type arrays fall back to JSON string serialization.
pub(crate) fn bind_value<'q>(
    query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
    value: &'q Value,
) -> Result<sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>> {
    match value {
        Value::Null => Ok(query.bind(None::<String>)),
        Value::Bool(b) => Ok(query.bind(b)),
        Value::I32(i) => Ok(query.bind(i)),
        Value::I64(i) => Ok(query.bind(i)),
        Value::F64(f) => Ok(query.bind(f)),
        Value::Decimal(d) => Ok(query.bind(d)),
        Value::DateTime(dt) => Ok(query.bind(*dt)),
        Value::Uuid(u) => Ok(query.bind(*u)),
        Value::String(s) => Ok(query.bind(s.as_str())),
        Value::Bytes(b) => Ok(query.bind(b.as_slice())),
        Value::Json(j) => Ok(query.bind(j.to_string())),
        Value::Array(items) => {
            if items.is_empty() {
                Ok(query.bind(Vec::<String>::new()))
            } else {
                match &items[0] {
                    Value::String(_) => {
                        let strings: Vec<String> = items
                            .iter()
                            .filter_map(|v| {
                                if let Value::String(s) = v {
                                    Some(s.clone())
                                } else {
                                    None
                                }
                            })
                            .collect();
                        Ok(query.bind(strings))
                    }
                    Value::I32(_) => {
                        let ints: Vec<i32> = items
                            .iter()
                            .filter_map(|v| {
                                if let Value::I32(i) = v {
                                    Some(*i)
                                } else {
                                    None
                                }
                            })
                            .collect();
                        Ok(query.bind(ints))
                    }
                    Value::I64(_) => {
                        let bigints: Vec<i64> = items
                            .iter()
                            .filter_map(|v| {
                                if let Value::I64(i) = v {
                                    Some(*i)
                                } else {
                                    None
                                }
                            })
                            .collect();
                        Ok(query.bind(bigints))
                    }
                    Value::F64(_) => {
                        let floats: Vec<f64> = items
                            .iter()
                            .filter_map(|v| {
                                if let Value::F64(f) = v {
                                    Some(*f)
                                } else {
                                    None
                                }
                            })
                            .collect();
                        Ok(query.bind(floats))
                    }
                    Value::Bool(_) => {
                        let bools: Vec<bool> = items
                            .iter()
                            .filter_map(|v| {
                                if let Value::Bool(b) = v {
                                    Some(*b)
                                } else {
                                    None
                                }
                            })
                            .collect();
                        Ok(query.bind(bools))
                    }
                    _ => {
                        // For mixed/unknown element types, serialize each element as JSON text
                        let strings: Vec<String> = items
                            .iter()
                            .map(|v| crate::utils::value_to_json(v).to_string())
                            .collect();
                        Ok(query.bind(strings))
                    }
                }
            }
        }
        Value::Array2D(_) => {
            // Bind 2D arrays as a JSON string.
            // sqlx does not support multi-dimensional PostgreSQL arrays directly,
            // so we serialize to JSON and let the query cast if necessary.
            Ok(query.bind(crate::utils::value_to_json(value).to_string()))
        }
        // The PG dialect already appends `::type_name` to the placeholder, so
        // we only need to bind the underlying string value here.
        Value::Enum { value, .. } => Ok(query.bind(value.as_str())),
    }
}