Skip to main content

nautilus_connector/
postgres.rs

1//! PostgreSQL executor implementation.
2
3use std::time::Duration;
4
5use crate::error::{ConnectorError as Error, Result};
6use crate::single_row::{fetch_single_row, SingleRowExpectation};
7use crate::{ConnectorPoolOptions, Executor, PgRowStream, Row};
8use futures::future::BoxFuture;
9use nautilus_core::Value;
10use nautilus_dialect::Sql;
11use sqlx::postgres::types::PgHstore;
12use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
13
14/// PostgreSQL executor using sqlx.
15///
16/// Manages a connection pool and executes queries against PostgreSQL databases.
17///
18/// ## Example
19///
20/// ```rust,ignore
21/// use nautilus_connector::PgExecutor;
22///
23/// #[tokio::main]
24/// async fn main() -> nautilus_core::Result<()> {
25///     let executor = PgExecutor::new("postgres://user:pass@localhost/mydb").await?;
26///     // Use executor to run queries...
27///     Ok(())
28/// }
29/// ```
30pub struct PgExecutor {
31    pool: PgPool,
32}
33
34impl PgExecutor {
35    /// Create a new PostgreSQL executor with a connection pool.
36    ///
37    /// ## Parameters
38    ///
39    /// - `url`: PostgreSQL connection URL (e.g., `postgres://user:pass@localhost/dbname`)
40    ///
41    /// ## Errors
42    ///
43    /// Returns `ConnectorError::Connection` if the pool cannot be created or if
44    /// an initial connection test fails.
45    pub async fn new(url: &str) -> Result<Self> {
46        Self::new_with_options(url, ConnectorPoolOptions::default()).await
47    }
48
49    /// Create a new PostgreSQL executor with explicit pool overrides.
50    ///
51    /// Any override not provided keeps the same default used by [`Self::new`].
52    pub async fn new_with_options(url: &str, pool_options: ConnectorPoolOptions) -> Result<Self> {
53        let connect_options = pool_options.apply_to_postgres_connect_options(
54            url.parse::<PgConnectOptions>()
55                .map_err(|e| Error::connection(e, "Invalid PostgreSQL connection options"))?,
56        );
57        let pool = pool_options
58            .apply_to(
59                PgPoolOptions::new()
60                    .max_connections(10)
61                    .min_connections(1)
62                    .acquire_timeout(Duration::from_secs(10))
63                    .idle_timeout(Duration::from_secs(300))
64                    .test_before_acquire(true),
65            )
66            .connect_with(connect_options)
67            .await
68            .map_err(|e| Error::connection(e, "Failed to connect to database"))?;
69
70        Ok(Self { pool })
71    }
72
73    /// Get a reference to the underlying connection pool.
74    pub fn pool(&self) -> &PgPool {
75        &self.pool
76    }
77
78    /// Execute a raw SQL statement with no result rows (e.g., DDL).
79    pub async fn execute_raw(&self, sql: &str) -> Result<()> {
80        sqlx::query(sql)
81            .persistent(false)
82            .execute(&self.pool)
83            .await
84            .map(|_| ())
85            .map_err(|e| Error::database(e, "DDL error"))
86    }
87
88    fn execute_collect_internal_with_persistence<'conn>(
89        &'conn self,
90        sql: &'conn Sql,
91        persistent: bool,
92    ) -> BoxFuture<'conn, Result<Vec<Row>>> {
93        Box::pin(async move {
94            let mut conn = self
95                .pool
96                .acquire()
97                .await
98                .map_err(|e| Error::connection(e, "Failed to acquire connection"))?;
99
100            let mut query = sqlx::query(&sql.text).persistent(persistent);
101            for param in &sql.params {
102                query = bind_value(query, param)?;
103            }
104
105            // Fetch ALL rows at once so the connection completes the full
106            // PostgreSQL extended-query cycle (portal close + ReadyForQuery)
107            // before being returned to the pool. The previous streaming
108            // approach (`query.fetch`) could leave the connection with an
109            // open portal when the stream was dropped mid-iteration, causing
110            // sqlx to discard the "dirty" connection and eventually exhaust
111            // the pool.
112            let pg_rows = query
113                .fetch_all(&mut *conn)
114                .await
115                .map_err(|e| Error::database(e, "Query execution failed"))?;
116
117            drop(conn);
118
119            crate::postgres_stream::decode_rows(&pg_rows)
120        })
121    }
122
123    fn execute_collect_internal<'conn>(
124        &'conn self,
125        sql: &'conn Sql,
126    ) -> BoxFuture<'conn, Result<Vec<Row>>> {
127        self.execute_collect_internal_with_persistence(sql, true)
128    }
129
130    /// Execute a SQL query with sqlx statement persistence disabled.
131    ///
132    /// This is reserved for raw/direct query paths that must stay compatible
133    /// with poolers such as PgBouncer transaction pooling.
134    pub async fn execute_collect_unprepared(&self, sql: &Sql) -> Result<Vec<Row>> {
135        self.execute_collect_internal_with_persistence(sql, false)
136            .await
137    }
138
139    fn execute_and_fetch_collect_internal<'conn>(
140        &'conn self,
141        mutation: &'conn Sql,
142        fetch: &'conn Sql,
143    ) -> BoxFuture<'conn, Result<Vec<Row>>> {
144        Box::pin(async move {
145            use sqlx::Executor as _;
146
147            let mut conn = self
148                .pool
149                .acquire()
150                .await
151                .map_err(|e| Error::connection(e, "Failed to acquire connection"))?;
152
153            let mut mutation_query = sqlx::query(&mutation.text);
154            for param in &mutation.params {
155                mutation_query = bind_value(mutation_query, param)?;
156            }
157
158            (&mut *conn)
159                .execute(mutation_query)
160                .await
161                .map_err(|e| Error::database(e, "Mutation failed"))?;
162
163            let mut fetch_query = sqlx::query(&fetch.text);
164            for param in &fetch.params {
165                fetch_query = bind_value(fetch_query, param)?;
166            }
167
168            let pg_rows = fetch_query
169                .fetch_all(&mut *conn)
170                .await
171                .map_err(|e| Error::database(e, "Fetch failed"))?;
172
173            drop(conn);
174
175            crate::postgres_stream::decode_rows(&pg_rows)
176        })
177    }
178
179    impl_execute_affected!();
180}
181
182/// [`Executor`] implementation backed by a PostgreSQL connection pool.
183impl Executor for PgExecutor {
184    type Row<'conn>
185        = Row
186    where
187        Self: 'conn;
188    type RowStream<'conn>
189        = PgRowStream<'conn>
190    where
191        Self: 'conn;
192
193    fn execute<'conn>(&'conn self, sql: &'conn Sql) -> Self::RowStream<'conn> {
194        crate::streaming::spawn_streaming_query(crate::streaming::StreamingQuery::<
195            sqlx::Postgres,
196            _,
197            _,
198        > {
199            pool: self.pool.clone(),
200            sql_text: sql.text.clone(),
201            params: sql.params.clone(),
202            bind: bind_value,
203            decode: crate::postgres_stream::streaming_decoder(),
204            query_context: "Query execution failed",
205            persistent: true,
206        })
207    }
208
209    fn execute_owned(&self, sql: Sql) -> crate::row_stream::RowStream<'static> {
210        crate::streaming::spawn_streaming_query(crate::streaming::StreamingQuery::<
211            sqlx::Postgres,
212            _,
213            _,
214        > {
215            pool: self.pool.clone(),
216            sql_text: sql.text,
217            params: sql.params,
218            bind: bind_value,
219            decode: crate::postgres_stream::streaming_decoder(),
220            query_context: "Query execution failed",
221            persistent: true,
222        })
223    }
224
225    fn execute_and_fetch<'conn>(
226        &'conn self,
227        mutation: &'conn Sql,
228        fetch: &'conn Sql,
229    ) -> Self::RowStream<'conn> {
230        PgRowStream::from_rows_future(self.execute_and_fetch_collect_internal(mutation, fetch))
231    }
232
233    fn execute_collect<'conn>(
234        &'conn self,
235        sql: &'conn Sql,
236    ) -> BoxFuture<'conn, Result<Vec<Self::Row<'conn>>>>
237    where
238        Self: 'conn,
239    {
240        self.execute_collect_internal(sql)
241    }
242
243    fn execute_one<'conn>(
244        &'conn self,
245        sql: &'conn Sql,
246    ) -> BoxFuture<'conn, Result<Self::Row<'conn>>>
247    where
248        Self: 'conn,
249    {
250        Box::pin(async move {
251            let mut conn = self
252                .pool
253                .acquire()
254                .await
255                .map_err(|e| Error::connection(e, "Failed to acquire connection"))?;
256
257            let row = fetch_single_row::<sqlx::Postgres, _, _, _>(
258                &mut *conn,
259                &sql.text,
260                &sql.params,
261                bind_value,
262                crate::postgres_stream::decode_row_internal,
263                "Query execution failed",
264                SingleRowExpectation::ExactlyOne,
265            )
266            .await?;
267
268            drop(conn);
269            // `ExactlyOne` already validated row_count == 1, so `row` is always
270            // `Some` here; the fallback keeps this a graceful error, never a panic.
271            row.ok_or_else(|| Error::database_msg("Expected exactly one row, got 0"))
272        })
273    }
274
275    fn execute_optional<'conn>(
276        &'conn self,
277        sql: &'conn Sql,
278    ) -> BoxFuture<'conn, Result<Option<Self::Row<'conn>>>>
279    where
280        Self: 'conn,
281    {
282        Box::pin(async move {
283            let mut conn = self
284                .pool
285                .acquire()
286                .await
287                .map_err(|e| Error::connection(e, "Failed to acquire connection"))?;
288
289            let row = fetch_single_row::<sqlx::Postgres, _, _, _>(
290                &mut *conn,
291                &sql.text,
292                &sql.params,
293                bind_value,
294                crate::postgres_stream::decode_row_internal,
295                "Query execution failed",
296                SingleRowExpectation::ZeroOrOne,
297            )
298            .await?;
299
300            drop(conn);
301            Ok(row)
302        })
303    }
304}
305
306#[derive(Debug, Clone, PartialEq)]
307enum PgArrayBinding {
308    Strings(Vec<String>),
309    Hstores(Vec<PgHstore>),
310    Geometries(Vec<String>),
311    Geographies(Vec<String>),
312    I32s(Vec<i32>),
313    I64s(Vec<i64>),
314    F64s(Vec<f64>),
315    Bools(Vec<bool>),
316}
317
318/// Collect a homogeneous slice of [`Value`]s into a typed vector for array binding.
319///
320/// Matches every element against `Value::$variant`, applying `$elem => $map` to
321/// extract the bound element. A `Value::Null` element, or any element of a
322/// different variant, produces a descriptive `expected $expected` error.
323macro_rules! collect_pg_array {
324    ($items:expr, $variant:ident, $elem:pat => $map:expr, $expected:literal) => {{
325        let mut values = Vec::with_capacity($items.len());
326        for (idx, item) in $items.iter().enumerate() {
327            match item {
328                Value::$variant($elem) => values.push($map),
329                Value::Null => {
330                    return Err(Error::database_msg(format!(
331                        "PostgreSQL typed array binding does not support NULL element at index {}",
332                        idx
333                    )));
334                }
335                other => {
336                    return Err(Error::database_msg(format!(
337                        "PostgreSQL array element at index {} has type {:?}; expected {}",
338                        idx, other, $expected
339                    )));
340                }
341            }
342        }
343        values
344    }};
345}
346
347fn bindable_pg_array(items: &[Value]) -> Result<Option<PgArrayBinding>> {
348    let Some(first) = items.first() else {
349        return Ok(Some(PgArrayBinding::Strings(Vec::new())));
350    };
351
352    let binding = match first {
353        Value::String(_) => {
354            PgArrayBinding::Strings(collect_pg_array!(items, String, v => v.clone(), "String"))
355        }
356        Value::Hstore(_) => PgArrayBinding::Hstores(
357            collect_pg_array!(items, Hstore, v => PgHstore(v.clone()), "Hstore"),
358        ),
359        Value::Geometry(_) => PgArrayBinding::Geometries(
360            collect_pg_array!(items, Geometry, v => v.clone(), "Geometry"),
361        ),
362        Value::Geography(_) => PgArrayBinding::Geographies(
363            collect_pg_array!(items, Geography, v => v.clone(), "Geography"),
364        ),
365        Value::I32(_) => PgArrayBinding::I32s(collect_pg_array!(items, I32, v => *v, "I32")),
366        Value::I64(_) => PgArrayBinding::I64s(collect_pg_array!(items, I64, v => *v, "I64")),
367        Value::F64(_) => PgArrayBinding::F64s(collect_pg_array!(items, F64, v => *v, "F64")),
368        Value::Bool(_) => PgArrayBinding::Bools(collect_pg_array!(items, Bool, v => *v, "Bool")),
369        _ => return Ok(None),
370    };
371
372    Ok(Some(binding))
373}
374
375/// Binds a [`Value`] to a PostgreSQL sqlx query as a typed parameter.
376///
377/// Uses native binding for `Decimal`, `DateTime`, and `Uuid` (PG-specific).
378/// Array values are bound as typed slices when the element type is known; unknown
379/// or mixed-type arrays fall back to JSON string serialization.
380pub(crate) fn bind_value<'q>(
381    query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
382    value: &'q Value,
383) -> Result<sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>> {
384    match value {
385        Value::Null => Ok(query.bind(None::<String>)),
386        Value::Bool(b) => Ok(query.bind(b)),
387        Value::I32(i) => Ok(query.bind(i)),
388        Value::I64(i) => Ok(query.bind(i)),
389        Value::F64(f) => Ok(query.bind(f)),
390        Value::Decimal(d) => Ok(query.bind(d)),
391        Value::DateTime(dt) => Ok(query.bind(*dt)),
392        Value::Uuid(u) => Ok(query.bind(*u)),
393        Value::String(s) => Ok(query.bind(s.as_str())),
394        Value::Hstore(map) => Ok(query.bind(PgHstore(map.clone()))),
395        Value::Geometry(raw) | Value::Geography(raw) => Ok(query.bind(raw.as_str())),
396        Value::Vector(values) => Ok(query.bind(format_pg_vector(values)?)),
397        Value::Bytes(b) => Ok(query.bind(b.as_slice())),
398        Value::Json(j) => Ok(query.bind(j.to_string())),
399        Value::Array(items) => match bindable_pg_array(items)? {
400            Some(PgArrayBinding::Strings(values)) => Ok(query.bind(values)),
401            Some(PgArrayBinding::Hstores(values)) => Ok(query.bind(values)),
402            Some(PgArrayBinding::Geometries(values)) => Ok(query.bind(values)),
403            Some(PgArrayBinding::Geographies(values)) => Ok(query.bind(values)),
404            Some(PgArrayBinding::I32s(values)) => Ok(query.bind(values)),
405            Some(PgArrayBinding::I64s(values)) => Ok(query.bind(values)),
406            Some(PgArrayBinding::F64s(values)) => Ok(query.bind(values)),
407            Some(PgArrayBinding::Bools(values)) => Ok(query.bind(values)),
408            None => {
409                let strings: Vec<String> = items
410                    .iter()
411                    .map(|v| crate::utils::value_to_json(v).to_string())
412                    .collect();
413                Ok(query.bind(strings))
414            }
415        },
416        Value::Array2D(_) => {
417            // Bind 2D arrays as a JSON string.
418            // sqlx does not support multi-dimensional PostgreSQL arrays directly,
419            // so we serialize to JSON and let the query cast if necessary.
420            Ok(query.bind(crate::utils::value_to_json(value).to_string()))
421        }
422        // The PG dialect already appends `::type_name` to the placeholder, so
423        // we only need to bind the underlying string value here.
424        Value::Enum { value, .. } => Ok(query.bind(value.as_str())),
425        // The PG dialect appends `::type_name`; we bind the composite as its
426        // record-literal text form and let PostgreSQL parse and cast it.
427        Value::Composite { fields, .. } => Ok(query.bind(encode_pg_composite_literal(fields)?)),
428    }
429}
430
431/// Encode composite-type field values as a PostgreSQL record literal, e.g.
432/// `("0","0","")`. Every non-NULL field is double-quoted (PostgreSQL strips the
433/// quotes and re-parses each field with the target column's input function), and
434/// NULL fields are emitted as an empty slot. This keeps the encoder free of
435/// per-type quoting heuristics.
436fn encode_pg_composite_literal(fields: &[Value]) -> Result<String> {
437    let mut out = String::with_capacity(fields.len().saturating_mul(8) + 2);
438    out.push('(');
439    for (idx, field) in fields.iter().enumerate() {
440        if idx > 0 {
441            out.push(',');
442        }
443        if let Some(text) = composite_field_text(field)? {
444            push_quoted_composite_field(&mut out, &text);
445        }
446        // `None` => SQL NULL => empty slot.
447    }
448    out.push(')');
449    Ok(out)
450}
451
452/// Render a single composite field value to the text PostgreSQL expects inside a
453/// record literal. Returns `None` for NULL fields.
454fn composite_field_text(value: &Value) -> Result<Option<String>> {
455    let text = match value {
456        Value::Null => return Ok(None),
457        Value::Bool(b) => if *b { "t" } else { "f" }.to_string(),
458        Value::I32(i) => i.to_string(),
459        Value::I64(i) => i.to_string(),
460        Value::F64(f) => f.to_string(),
461        Value::Decimal(d) => d.to_string(),
462        Value::DateTime(dt) => dt.format("%Y-%m-%d %H:%M:%S%.f").to_string(),
463        Value::Uuid(u) => u.to_string(),
464        Value::String(s) => s.clone(),
465        Value::Enum { value, .. } => value.clone(),
466        Value::Geometry(raw) | Value::Geography(raw) => raw.clone(),
467        Value::Vector(values) => format_pg_vector(values)?,
468        Value::Json(j) => j.to_string(),
469        Value::Composite { fields, .. } => encode_pg_composite_literal(fields)?,
470        other => crate::utils::value_to_json(other).to_string(),
471    };
472    Ok(Some(text))
473}
474
475/// Append `text` as a double-quoted composite field, escaping `"` and `\`.
476fn push_quoted_composite_field(out: &mut String, text: &str) {
477    out.push('"');
478    for ch in text.chars() {
479        match ch {
480            '"' => out.push_str("\"\""),
481            '\\' => out.push_str("\\\\"),
482            _ => out.push(ch),
483        }
484    }
485    out.push('"');
486}
487
488fn format_pg_vector(values: &[f32]) -> Result<String> {
489    let mut out = String::with_capacity(values.len().saturating_mul(8) + 2);
490    out.push('[');
491    for (idx, value) in values.iter().enumerate() {
492        if !value.is_finite() {
493            return Err(Error::database_msg(format!(
494                "PostgreSQL vector element at index {} is not finite",
495                idx
496            )));
497        }
498        if idx > 0 {
499            out.push(',');
500        }
501        out.push_str(&value.to_string());
502    }
503    out.push(']');
504    Ok(out)
505}
506
507#[cfg(test)]
508mod tests {
509    use super::*;
510
511    #[test]
512    fn bindable_pg_array_keeps_homogeneous_strings() {
513        let binding = bindable_pg_array(&[
514            Value::String("a".to_string()),
515            Value::String("b".to_string()),
516        ])
517        .expect("string array should bind");
518
519        assert_eq!(
520            binding,
521            Some(PgArrayBinding::Strings(vec![
522                "a".to_string(),
523                "b".to_string()
524            ]))
525        );
526    }
527
528    #[test]
529    fn bindable_pg_array_rejects_nulls_in_typed_arrays() {
530        let err = bindable_pg_array(&[Value::I32(1), Value::Null]).unwrap_err();
531        assert!(err.to_string().contains("NULL element"));
532    }
533
534    #[test]
535    fn composite_literal_encodes_scalar_fields() {
536        let literal = encode_pg_composite_literal(&[
537            Value::I32(0),
538            Value::I32(3),
539            Value::F64(1.5),
540            Value::Bool(true),
541        ])
542        .expect("composite should encode");
543
544        assert_eq!(literal, "(\"0\",\"3\",\"1.5\",\"t\")");
545    }
546
547    #[test]
548    fn composite_literal_emits_empty_slot_for_null() {
549        let literal =
550            encode_pg_composite_literal(&[Value::I32(7), Value::Null, Value::String("x".into())])
551                .expect("composite should encode");
552
553        assert_eq!(literal, "(\"7\",,\"x\")");
554    }
555
556    #[test]
557    fn composite_literal_escapes_quotes_and_backslashes() {
558        let literal =
559            encode_pg_composite_literal(&[Value::String("a\"b\\c".into())]).expect("should encode");
560
561        assert_eq!(literal, "(\"a\"\"b\\\\c\")");
562    }
563
564    #[test]
565    fn bindable_pg_array_keeps_homogeneous_hstores() {
566        let binding = bindable_pg_array(&[
567            Value::Hstore(std::collections::BTreeMap::from([(
568                "display_name".to_string(),
569                Some("Bob".to_string()),
570            )])),
571            Value::Hstore(std::collections::BTreeMap::from([(
572                "nickname".to_string(),
573                None,
574            )])),
575        ])
576        .expect("hstore array should bind");
577
578        assert_eq!(
579            binding,
580            Some(PgArrayBinding::Hstores(vec![
581                PgHstore(std::collections::BTreeMap::from([(
582                    "display_name".to_string(),
583                    Some("Bob".to_string()),
584                )])),
585                PgHstore(std::collections::BTreeMap::from([(
586                    "nickname".to_string(),
587                    None,
588                )])),
589            ]))
590        );
591    }
592
593    #[test]
594    fn bindable_pg_array_rejects_mixed_typed_arrays() {
595        let err =
596            bindable_pg_array(&[Value::Bool(true), Value::String("nope".to_string())]).unwrap_err();
597        assert!(err.to_string().contains("expected Bool"));
598    }
599
600    #[test]
601    fn bindable_pg_array_falls_back_for_unsupported_types() {
602        let binding = bindable_pg_array(&[Value::Decimal(rust_decimal::Decimal::new(123, 2))])
603            .expect("unsupported arrays should fall back");
604        assert_eq!(binding, None);
605    }
606
607    #[test]
608    fn format_pg_vector_uses_pgvector_text_literal() {
609        assert_eq!(format_pg_vector(&[1.0, 2.5, 3.25]).unwrap(), "[1,2.5,3.25]");
610    }
611
612    #[test]
613    fn format_pg_vector_rejects_non_finite_values() {
614        let err = format_pg_vector(&[1.0, f32::NAN]).unwrap_err();
615        assert!(err.to_string().contains("not finite"));
616    }
617}