1use std::time::Duration;
4
5use crate::error::{ConnectorError as Error, Result};
6use crate::single_row::{fetch_single_row, SingleRowExpectation};
7use crate::{ConnectorPoolOptions, Executor, PgRowStream, Row};
8use futures::future::BoxFuture;
9use nautilus_core::Value;
10use nautilus_dialect::Sql;
11use sqlx::postgres::types::PgHstore;
12use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
13
14pub struct PgExecutor {
31 pool: PgPool,
32}
33
34impl PgExecutor {
35 pub async fn new(url: &str) -> Result<Self> {
46 Self::new_with_options(url, ConnectorPoolOptions::default()).await
47 }
48
49 pub async fn new_with_options(url: &str, pool_options: ConnectorPoolOptions) -> Result<Self> {
53 let connect_options = pool_options.apply_to_postgres_connect_options(
54 url.parse::<PgConnectOptions>()
55 .map_err(|e| Error::connection(e, "Invalid PostgreSQL connection options"))?,
56 );
57 let pool = pool_options
58 .apply_to(
59 PgPoolOptions::new()
60 .max_connections(10)
61 .min_connections(1)
62 .acquire_timeout(Duration::from_secs(10))
63 .idle_timeout(Duration::from_secs(300))
64 .test_before_acquire(true),
65 )
66 .connect_with(connect_options)
67 .await
68 .map_err(|e| Error::connection(e, "Failed to connect to database"))?;
69
70 Ok(Self { pool })
71 }
72
73 pub fn pool(&self) -> &PgPool {
75 &self.pool
76 }
77
78 pub async fn execute_raw(&self, sql: &str) -> Result<()> {
80 sqlx::query(sql)
81 .persistent(false)
82 .execute(&self.pool)
83 .await
84 .map(|_| ())
85 .map_err(|e| Error::database(e, "DDL error"))
86 }
87
88 fn execute_collect_internal_with_persistence<'conn>(
89 &'conn self,
90 sql: &'conn Sql,
91 persistent: bool,
92 ) -> BoxFuture<'conn, Result<Vec<Row>>> {
93 Box::pin(async move {
94 let mut conn = self
95 .pool
96 .acquire()
97 .await
98 .map_err(|e| Error::connection(e, "Failed to acquire connection"))?;
99
100 let mut query = sqlx::query(&sql.text).persistent(persistent);
101 for param in &sql.params {
102 query = bind_value(query, param)?;
103 }
104
105 let pg_rows = query
113 .fetch_all(&mut *conn)
114 .await
115 .map_err(|e| Error::database(e, "Query execution failed"))?;
116
117 drop(conn);
118
119 crate::postgres_stream::decode_rows(&pg_rows)
120 })
121 }
122
123 fn execute_collect_internal<'conn>(
124 &'conn self,
125 sql: &'conn Sql,
126 ) -> BoxFuture<'conn, Result<Vec<Row>>> {
127 self.execute_collect_internal_with_persistence(sql, true)
128 }
129
130 pub async fn execute_collect_unprepared(&self, sql: &Sql) -> Result<Vec<Row>> {
135 self.execute_collect_internal_with_persistence(sql, false)
136 .await
137 }
138
139 fn execute_and_fetch_collect_internal<'conn>(
140 &'conn self,
141 mutation: &'conn Sql,
142 fetch: &'conn Sql,
143 ) -> BoxFuture<'conn, Result<Vec<Row>>> {
144 Box::pin(async move {
145 use sqlx::Executor as _;
146
147 let mut conn = self
148 .pool
149 .acquire()
150 .await
151 .map_err(|e| Error::connection(e, "Failed to acquire connection"))?;
152
153 let mut mutation_query = sqlx::query(&mutation.text);
154 for param in &mutation.params {
155 mutation_query = bind_value(mutation_query, param)?;
156 }
157
158 (&mut *conn)
159 .execute(mutation_query)
160 .await
161 .map_err(|e| Error::database(e, "Mutation failed"))?;
162
163 let mut fetch_query = sqlx::query(&fetch.text);
164 for param in &fetch.params {
165 fetch_query = bind_value(fetch_query, param)?;
166 }
167
168 let pg_rows = fetch_query
169 .fetch_all(&mut *conn)
170 .await
171 .map_err(|e| Error::database(e, "Fetch failed"))?;
172
173 drop(conn);
174
175 crate::postgres_stream::decode_rows(&pg_rows)
176 })
177 }
178
179 impl_execute_affected!();
180}
181
182impl Executor for PgExecutor {
184 type Row<'conn>
185 = Row
186 where
187 Self: 'conn;
188 type RowStream<'conn>
189 = PgRowStream<'conn>
190 where
191 Self: 'conn;
192
193 fn execute<'conn>(&'conn self, sql: &'conn Sql) -> Self::RowStream<'conn> {
194 crate::streaming::spawn_streaming_query(crate::streaming::StreamingQuery::<
195 sqlx::Postgres,
196 _,
197 _,
198 > {
199 pool: self.pool.clone(),
200 sql_text: sql.text.clone(),
201 params: sql.params.clone(),
202 bind: bind_value,
203 decode: crate::postgres_stream::streaming_decoder(),
204 query_context: "Query execution failed",
205 persistent: true,
206 })
207 }
208
209 fn execute_owned(&self, sql: Sql) -> crate::row_stream::RowStream<'static> {
210 crate::streaming::spawn_streaming_query(crate::streaming::StreamingQuery::<
211 sqlx::Postgres,
212 _,
213 _,
214 > {
215 pool: self.pool.clone(),
216 sql_text: sql.text,
217 params: sql.params,
218 bind: bind_value,
219 decode: crate::postgres_stream::streaming_decoder(),
220 query_context: "Query execution failed",
221 persistent: true,
222 })
223 }
224
225 fn execute_and_fetch<'conn>(
226 &'conn self,
227 mutation: &'conn Sql,
228 fetch: &'conn Sql,
229 ) -> Self::RowStream<'conn> {
230 PgRowStream::from_rows_future(self.execute_and_fetch_collect_internal(mutation, fetch))
231 }
232
233 fn execute_collect<'conn>(
234 &'conn self,
235 sql: &'conn Sql,
236 ) -> BoxFuture<'conn, Result<Vec<Self::Row<'conn>>>>
237 where
238 Self: 'conn,
239 {
240 self.execute_collect_internal(sql)
241 }
242
243 fn execute_one<'conn>(
244 &'conn self,
245 sql: &'conn Sql,
246 ) -> BoxFuture<'conn, Result<Self::Row<'conn>>>
247 where
248 Self: 'conn,
249 {
250 Box::pin(async move {
251 let mut conn = self
252 .pool
253 .acquire()
254 .await
255 .map_err(|e| Error::connection(e, "Failed to acquire connection"))?;
256
257 let row = fetch_single_row::<sqlx::Postgres, _, _, _>(
258 &mut *conn,
259 &sql.text,
260 &sql.params,
261 bind_value,
262 crate::postgres_stream::decode_row_internal,
263 "Query execution failed",
264 SingleRowExpectation::ExactlyOne,
265 )
266 .await?;
267
268 drop(conn);
269 row.ok_or_else(|| Error::database_msg("Expected exactly one row, got 0"))
272 })
273 }
274
275 fn execute_optional<'conn>(
276 &'conn self,
277 sql: &'conn Sql,
278 ) -> BoxFuture<'conn, Result<Option<Self::Row<'conn>>>>
279 where
280 Self: 'conn,
281 {
282 Box::pin(async move {
283 let mut conn = self
284 .pool
285 .acquire()
286 .await
287 .map_err(|e| Error::connection(e, "Failed to acquire connection"))?;
288
289 let row = fetch_single_row::<sqlx::Postgres, _, _, _>(
290 &mut *conn,
291 &sql.text,
292 &sql.params,
293 bind_value,
294 crate::postgres_stream::decode_row_internal,
295 "Query execution failed",
296 SingleRowExpectation::ZeroOrOne,
297 )
298 .await?;
299
300 drop(conn);
301 Ok(row)
302 })
303 }
304}
305
306#[derive(Debug, Clone, PartialEq)]
307enum PgArrayBinding {
308 Strings(Vec<String>),
309 Hstores(Vec<PgHstore>),
310 Geometries(Vec<String>),
311 Geographies(Vec<String>),
312 I32s(Vec<i32>),
313 I64s(Vec<i64>),
314 F64s(Vec<f64>),
315 Bools(Vec<bool>),
316}
317
318macro_rules! collect_pg_array {
324 ($items:expr, $variant:ident, $elem:pat => $map:expr, $expected:literal) => {{
325 let mut values = Vec::with_capacity($items.len());
326 for (idx, item) in $items.iter().enumerate() {
327 match item {
328 Value::$variant($elem) => values.push($map),
329 Value::Null => {
330 return Err(Error::database_msg(format!(
331 "PostgreSQL typed array binding does not support NULL element at index {}",
332 idx
333 )));
334 }
335 other => {
336 return Err(Error::database_msg(format!(
337 "PostgreSQL array element at index {} has type {:?}; expected {}",
338 idx, other, $expected
339 )));
340 }
341 }
342 }
343 values
344 }};
345}
346
347fn bindable_pg_array(items: &[Value]) -> Result<Option<PgArrayBinding>> {
348 let Some(first) = items.first() else {
349 return Ok(Some(PgArrayBinding::Strings(Vec::new())));
350 };
351
352 let binding = match first {
353 Value::String(_) => {
354 PgArrayBinding::Strings(collect_pg_array!(items, String, v => v.clone(), "String"))
355 }
356 Value::Hstore(_) => PgArrayBinding::Hstores(
357 collect_pg_array!(items, Hstore, v => PgHstore(v.clone()), "Hstore"),
358 ),
359 Value::Geometry(_) => PgArrayBinding::Geometries(
360 collect_pg_array!(items, Geometry, v => v.clone(), "Geometry"),
361 ),
362 Value::Geography(_) => PgArrayBinding::Geographies(
363 collect_pg_array!(items, Geography, v => v.clone(), "Geography"),
364 ),
365 Value::I32(_) => PgArrayBinding::I32s(collect_pg_array!(items, I32, v => *v, "I32")),
366 Value::I64(_) => PgArrayBinding::I64s(collect_pg_array!(items, I64, v => *v, "I64")),
367 Value::F64(_) => PgArrayBinding::F64s(collect_pg_array!(items, F64, v => *v, "F64")),
368 Value::Bool(_) => PgArrayBinding::Bools(collect_pg_array!(items, Bool, v => *v, "Bool")),
369 _ => return Ok(None),
370 };
371
372 Ok(Some(binding))
373}
374
375pub(crate) fn bind_value<'q>(
381 query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
382 value: &'q Value,
383) -> Result<sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>> {
384 match value {
385 Value::Null => Ok(query.bind(None::<String>)),
386 Value::Bool(b) => Ok(query.bind(b)),
387 Value::I32(i) => Ok(query.bind(i)),
388 Value::I64(i) => Ok(query.bind(i)),
389 Value::F64(f) => Ok(query.bind(f)),
390 Value::Decimal(d) => Ok(query.bind(d)),
391 Value::DateTime(dt) => Ok(query.bind(*dt)),
392 Value::Uuid(u) => Ok(query.bind(*u)),
393 Value::String(s) => Ok(query.bind(s.as_str())),
394 Value::Hstore(map) => Ok(query.bind(PgHstore(map.clone()))),
395 Value::Geometry(raw) | Value::Geography(raw) => Ok(query.bind(raw.as_str())),
396 Value::Vector(values) => Ok(query.bind(format_pg_vector(values)?)),
397 Value::Bytes(b) => Ok(query.bind(b.as_slice())),
398 Value::Json(j) => Ok(query.bind(j.to_string())),
399 Value::Array(items) => match bindable_pg_array(items)? {
400 Some(PgArrayBinding::Strings(values)) => Ok(query.bind(values)),
401 Some(PgArrayBinding::Hstores(values)) => Ok(query.bind(values)),
402 Some(PgArrayBinding::Geometries(values)) => Ok(query.bind(values)),
403 Some(PgArrayBinding::Geographies(values)) => Ok(query.bind(values)),
404 Some(PgArrayBinding::I32s(values)) => Ok(query.bind(values)),
405 Some(PgArrayBinding::I64s(values)) => Ok(query.bind(values)),
406 Some(PgArrayBinding::F64s(values)) => Ok(query.bind(values)),
407 Some(PgArrayBinding::Bools(values)) => Ok(query.bind(values)),
408 None => {
409 let strings: Vec<String> = items
410 .iter()
411 .map(|v| crate::utils::value_to_json(v).to_string())
412 .collect();
413 Ok(query.bind(strings))
414 }
415 },
416 Value::Array2D(_) => {
417 Ok(query.bind(crate::utils::value_to_json(value).to_string()))
421 }
422 Value::Enum { value, .. } => Ok(query.bind(value.as_str())),
425 Value::Composite { fields, .. } => Ok(query.bind(encode_pg_composite_literal(fields)?)),
428 }
429}
430
431fn encode_pg_composite_literal(fields: &[Value]) -> Result<String> {
437 let mut out = String::with_capacity(fields.len().saturating_mul(8) + 2);
438 out.push('(');
439 for (idx, field) in fields.iter().enumerate() {
440 if idx > 0 {
441 out.push(',');
442 }
443 if let Some(text) = composite_field_text(field)? {
444 push_quoted_composite_field(&mut out, &text);
445 }
446 }
448 out.push(')');
449 Ok(out)
450}
451
452fn composite_field_text(value: &Value) -> Result<Option<String>> {
455 let text = match value {
456 Value::Null => return Ok(None),
457 Value::Bool(b) => if *b { "t" } else { "f" }.to_string(),
458 Value::I32(i) => i.to_string(),
459 Value::I64(i) => i.to_string(),
460 Value::F64(f) => f.to_string(),
461 Value::Decimal(d) => d.to_string(),
462 Value::DateTime(dt) => dt.format("%Y-%m-%d %H:%M:%S%.f").to_string(),
463 Value::Uuid(u) => u.to_string(),
464 Value::String(s) => s.clone(),
465 Value::Enum { value, .. } => value.clone(),
466 Value::Geometry(raw) | Value::Geography(raw) => raw.clone(),
467 Value::Vector(values) => format_pg_vector(values)?,
468 Value::Json(j) => j.to_string(),
469 Value::Composite { fields, .. } => encode_pg_composite_literal(fields)?,
470 other => crate::utils::value_to_json(other).to_string(),
471 };
472 Ok(Some(text))
473}
474
475fn push_quoted_composite_field(out: &mut String, text: &str) {
477 out.push('"');
478 for ch in text.chars() {
479 match ch {
480 '"' => out.push_str("\"\""),
481 '\\' => out.push_str("\\\\"),
482 _ => out.push(ch),
483 }
484 }
485 out.push('"');
486}
487
488fn format_pg_vector(values: &[f32]) -> Result<String> {
489 let mut out = String::with_capacity(values.len().saturating_mul(8) + 2);
490 out.push('[');
491 for (idx, value) in values.iter().enumerate() {
492 if !value.is_finite() {
493 return Err(Error::database_msg(format!(
494 "PostgreSQL vector element at index {} is not finite",
495 idx
496 )));
497 }
498 if idx > 0 {
499 out.push(',');
500 }
501 out.push_str(&value.to_string());
502 }
503 out.push(']');
504 Ok(out)
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510
511 #[test]
512 fn bindable_pg_array_keeps_homogeneous_strings() {
513 let binding = bindable_pg_array(&[
514 Value::String("a".to_string()),
515 Value::String("b".to_string()),
516 ])
517 .expect("string array should bind");
518
519 assert_eq!(
520 binding,
521 Some(PgArrayBinding::Strings(vec![
522 "a".to_string(),
523 "b".to_string()
524 ]))
525 );
526 }
527
528 #[test]
529 fn bindable_pg_array_rejects_nulls_in_typed_arrays() {
530 let err = bindable_pg_array(&[Value::I32(1), Value::Null]).unwrap_err();
531 assert!(err.to_string().contains("NULL element"));
532 }
533
534 #[test]
535 fn composite_literal_encodes_scalar_fields() {
536 let literal = encode_pg_composite_literal(&[
537 Value::I32(0),
538 Value::I32(3),
539 Value::F64(1.5),
540 Value::Bool(true),
541 ])
542 .expect("composite should encode");
543
544 assert_eq!(literal, "(\"0\",\"3\",\"1.5\",\"t\")");
545 }
546
547 #[test]
548 fn composite_literal_emits_empty_slot_for_null() {
549 let literal =
550 encode_pg_composite_literal(&[Value::I32(7), Value::Null, Value::String("x".into())])
551 .expect("composite should encode");
552
553 assert_eq!(literal, "(\"7\",,\"x\")");
554 }
555
556 #[test]
557 fn composite_literal_escapes_quotes_and_backslashes() {
558 let literal =
559 encode_pg_composite_literal(&[Value::String("a\"b\\c".into())]).expect("should encode");
560
561 assert_eq!(literal, "(\"a\"\"b\\\\c\")");
562 }
563
564 #[test]
565 fn bindable_pg_array_keeps_homogeneous_hstores() {
566 let binding = bindable_pg_array(&[
567 Value::Hstore(std::collections::BTreeMap::from([(
568 "display_name".to_string(),
569 Some("Bob".to_string()),
570 )])),
571 Value::Hstore(std::collections::BTreeMap::from([(
572 "nickname".to_string(),
573 None,
574 )])),
575 ])
576 .expect("hstore array should bind");
577
578 assert_eq!(
579 binding,
580 Some(PgArrayBinding::Hstores(vec![
581 PgHstore(std::collections::BTreeMap::from([(
582 "display_name".to_string(),
583 Some("Bob".to_string()),
584 )])),
585 PgHstore(std::collections::BTreeMap::from([(
586 "nickname".to_string(),
587 None,
588 )])),
589 ]))
590 );
591 }
592
593 #[test]
594 fn bindable_pg_array_rejects_mixed_typed_arrays() {
595 let err =
596 bindable_pg_array(&[Value::Bool(true), Value::String("nope".to_string())]).unwrap_err();
597 assert!(err.to_string().contains("expected Bool"));
598 }
599
600 #[test]
601 fn bindable_pg_array_falls_back_for_unsupported_types() {
602 let binding = bindable_pg_array(&[Value::Decimal(rust_decimal::Decimal::new(123, 2))])
603 .expect("unsupported arrays should fall back");
604 assert_eq!(binding, None);
605 }
606
607 #[test]
608 fn format_pg_vector_uses_pgvector_text_literal() {
609 assert_eq!(format_pg_vector(&[1.0, 2.5, 3.25]).unwrap(), "[1,2.5,3.25]");
610 }
611
612 #[test]
613 fn format_pg_vector_rejects_non_finite_values() {
614 let err = format_pg_vector(&[1.0, f32::NAN]).unwrap_err();
615 assert!(err.to_string().contains("not finite"));
616 }
617}