Skip to main content

nextsql_backend_rust_runtime/
tokio_postgres_impl.rs

1use crate::{Client, QueryExecutor, Row, ToSqlParam, Transaction};
2
3// ---- Helper: String wrapper that accepts any type (including custom enums) ----
4
5/// A String wrapper that implements `FromSql` for any PostgreSQL type.
6/// PostgreSQL enum values are transmitted as text on the wire, so reading
7/// them as String is safe. Standard `String::from_sql` only accepts TEXT/VARCHAR,
8/// rejecting custom enum types.
9struct AnyString(String);
10
11impl<'a> tokio_postgres::types::FromSql<'a> for AnyString {
12    fn from_sql(
13        _ty: &tokio_postgres::types::Type,
14        raw: &'a [u8],
15    ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
16        Ok(AnyString(String::from_utf8(raw.to_vec())?))
17    }
18
19    fn accepts(_ty: &tokio_postgres::types::Type) -> bool {
20        true
21    }
22}
23
24// ---- Row implementation wrapping tokio_postgres::Row ----
25
26pub struct PgRow(pub tokio_postgres::Row);
27
28impl Row for PgRow {
29    fn get_i16(&self, idx: usize) -> i16 { self.0.get(idx) }
30    fn get_i32(&self, idx: usize) -> i32 { self.0.get(idx) }
31    fn get_i64(&self, idx: usize) -> i64 { self.0.get(idx) }
32    fn get_f32(&self, idx: usize) -> f32 { self.0.get(idx) }
33    fn get_f64(&self, idx: usize) -> f64 { self.0.get(idx) }
34    fn get_string(&self, idx: usize) -> String { self.0.get::<_, AnyString>(idx).0 }
35    fn get_bool(&self, idx: usize) -> bool { self.0.get(idx) }
36    fn get_uuid(&self, idx: usize) -> uuid::Uuid { self.0.get(idx) }
37    fn get_timestamp(&self, idx: usize) -> chrono::NaiveDateTime { self.0.get(idx) }
38    fn get_timestamptz(&self, idx: usize) -> chrono::DateTime<chrono::Utc> { self.0.get(idx) }
39    fn get_date(&self, idx: usize) -> chrono::NaiveDate { self.0.get(idx) }
40    fn get_decimal(&self, idx: usize) -> rust_decimal::Decimal { self.0.get(idx) }
41    fn get_json(&self, idx: usize) -> serde_json::Value { self.0.get(idx) }
42
43    fn get_opt_i16(&self, idx: usize) -> Option<i16> { self.0.get(idx) }
44    fn get_opt_i32(&self, idx: usize) -> Option<i32> { self.0.get(idx) }
45    fn get_opt_i64(&self, idx: usize) -> Option<i64> { self.0.get(idx) }
46    fn get_opt_f32(&self, idx: usize) -> Option<f32> { self.0.get(idx) }
47    fn get_opt_f64(&self, idx: usize) -> Option<f64> { self.0.get(idx) }
48    fn get_opt_string(&self, idx: usize) -> Option<String> {
49        self.0.get::<_, Option<AnyString>>(idx).map(|s| s.0)
50    }
51    fn get_opt_bool(&self, idx: usize) -> Option<bool> { self.0.get(idx) }
52    fn get_opt_uuid(&self, idx: usize) -> Option<uuid::Uuid> { self.0.get(idx) }
53    fn get_opt_timestamp(&self, idx: usize) -> Option<chrono::NaiveDateTime> { self.0.get(idx) }
54    fn get_opt_timestamptz(&self, idx: usize) -> Option<chrono::DateTime<chrono::Utc>> { self.0.get(idx) }
55    fn get_opt_date(&self, idx: usize) -> Option<chrono::NaiveDate> { self.0.get(idx) }
56    fn get_opt_decimal(&self, idx: usize) -> Option<rust_decimal::Decimal> { self.0.get(idx) }
57    fn get_opt_json(&self, idx: usize) -> Option<serde_json::Value> { self.0.get(idx) }
58
59    fn get_vec_i16(&self, idx: usize) -> Vec<i16> { self.0.get(idx) }
60    fn get_vec_i32(&self, idx: usize) -> Vec<i32> { self.0.get(idx) }
61    fn get_vec_i64(&self, idx: usize) -> Vec<i64> { self.0.get(idx) }
62    fn get_vec_f32(&self, idx: usize) -> Vec<f32> { self.0.get(idx) }
63    fn get_vec_f64(&self, idx: usize) -> Vec<f64> { self.0.get(idx) }
64    fn get_vec_string(&self, idx: usize) -> Vec<String> { self.0.get(idx) }
65    fn get_vec_bool(&self, idx: usize) -> Vec<bool> { self.0.get(idx) }
66    fn get_vec_uuid(&self, idx: usize) -> Vec<uuid::Uuid> { self.0.get(idx) }
67}
68
69// ---- Owned parameter enum for crossing async boundaries ----
70
71/// An owned SQL parameter value. Used internally to convert borrowed `&dyn ToSqlParam`
72/// into owned values that can be moved into async futures.
73#[derive(Debug)]
74pub enum OwnedParam {
75    I16(i16),
76    I32(i32),
77    I64(i64),
78    F32(f32),
79    F64(f64),
80    Bool(bool),
81    String(String),
82    Uuid(uuid::Uuid),
83    NaiveDateTime(chrono::NaiveDateTime),
84    DateTimeUtc(chrono::DateTime<chrono::Utc>),
85    NaiveDate(chrono::NaiveDate),
86    Decimal(rust_decimal::Decimal),
87    Json(serde_json::Value),
88    OptI16(Option<i16>),
89    OptI32(Option<i32>),
90    OptI64(Option<i64>),
91    OptF32(Option<f32>),
92    OptF64(Option<f64>),
93    OptBool(Option<bool>),
94    OptString(Option<String>),
95    OptUuid(Option<uuid::Uuid>),
96    OptNaiveDateTime(Option<chrono::NaiveDateTime>),
97    OptDateTimeUtc(Option<chrono::DateTime<chrono::Utc>>),
98    OptNaiveDate(Option<chrono::NaiveDate>),
99    OptDecimal(Option<rust_decimal::Decimal>),
100    OptJson(Option<serde_json::Value>),
101    VecI16(Vec<i16>),
102    VecI32(Vec<i32>),
103    VecI64(Vec<i64>),
104    VecF32(Vec<f32>),
105    VecF64(Vec<f64>),
106    VecBool(Vec<bool>),
107    VecString(Vec<String>),
108    VecUuid(Vec<uuid::Uuid>),
109}
110
111impl tokio_postgres::types::ToSql for OwnedParam {
112    fn to_sql(
113        &self,
114        ty: &tokio_postgres::types::Type,
115        out: &mut bytes::BytesMut,
116    ) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
117        match self {
118            OwnedParam::I16(v) => v.to_sql(ty, out),
119            OwnedParam::I32(v) => v.to_sql(ty, out),
120            OwnedParam::I64(v) => v.to_sql(ty, out),
121            OwnedParam::F32(v) => v.to_sql(ty, out),
122            OwnedParam::F64(v) => v.to_sql(ty, out),
123            OwnedParam::Bool(v) => v.to_sql(ty, out),
124            OwnedParam::String(v) => v.to_sql(ty, out),
125            OwnedParam::Uuid(v) => v.to_sql(ty, out),
126            OwnedParam::NaiveDateTime(v) => v.to_sql(ty, out),
127            OwnedParam::DateTimeUtc(v) => v.to_sql(ty, out),
128            OwnedParam::NaiveDate(v) => v.to_sql(ty, out),
129            OwnedParam::Decimal(v) => v.to_sql(ty, out),
130            OwnedParam::Json(v) => v.to_sql(ty, out),
131            OwnedParam::OptI16(v) => v.to_sql(ty, out),
132            OwnedParam::OptI32(v) => v.to_sql(ty, out),
133            OwnedParam::OptI64(v) => v.to_sql(ty, out),
134            OwnedParam::OptF32(v) => v.to_sql(ty, out),
135            OwnedParam::OptF64(v) => v.to_sql(ty, out),
136            OwnedParam::OptBool(v) => v.to_sql(ty, out),
137            OwnedParam::OptString(v) => v.to_sql(ty, out),
138            OwnedParam::OptUuid(v) => v.to_sql(ty, out),
139            OwnedParam::OptNaiveDateTime(v) => v.to_sql(ty, out),
140            OwnedParam::OptDateTimeUtc(v) => v.to_sql(ty, out),
141            OwnedParam::OptNaiveDate(v) => v.to_sql(ty, out),
142            OwnedParam::OptDecimal(v) => v.to_sql(ty, out),
143            OwnedParam::OptJson(v) => v.to_sql(ty, out),
144            OwnedParam::VecI16(v) => v.to_sql(ty, out),
145            OwnedParam::VecI32(v) => v.to_sql(ty, out),
146            OwnedParam::VecI64(v) => v.to_sql(ty, out),
147            OwnedParam::VecF32(v) => v.to_sql(ty, out),
148            OwnedParam::VecF64(v) => v.to_sql(ty, out),
149            OwnedParam::VecBool(v) => v.to_sql(ty, out),
150            OwnedParam::VecString(v) => v.to_sql(ty, out),
151            OwnedParam::VecUuid(v) => v.to_sql(ty, out),
152        }
153    }
154
155    fn accepts(_ty: &tokio_postgres::types::Type) -> bool {
156        // Accept all types unconditionally. OwnedParam is a dynamic wrapper that
157        // delegates serialization to the inner value's to_sql(). PostgreSQL handles
158        // type coercion (e.g. text -> enum) server-side when the value is sent as text.
159        // Rejecting custom/enum types here would break INSERT/UPDATE for enum columns.
160        true
161    }
162
163    tokio_postgres::types::to_sql_checked!();
164}
165
166/// Convert a `&dyn ToSqlParam` to an owned parameter value by downcasting.
167pub fn to_owned_param(param: &dyn ToSqlParam) -> OwnedParam {
168    let any = param.as_any();
169    // Non-optional types
170    if let Some(v) = any.downcast_ref::<i16>() { return OwnedParam::I16(*v); }
171    if let Some(v) = any.downcast_ref::<i32>() { return OwnedParam::I32(*v); }
172    if let Some(v) = any.downcast_ref::<i64>() { return OwnedParam::I64(*v); }
173    if let Some(v) = any.downcast_ref::<f32>() { return OwnedParam::F32(*v); }
174    if let Some(v) = any.downcast_ref::<f64>() { return OwnedParam::F64(*v); }
175    if let Some(v) = any.downcast_ref::<bool>() { return OwnedParam::Bool(*v); }
176    if let Some(v) = any.downcast_ref::<String>() { return OwnedParam::String(v.clone()); }
177    if let Some(v) = any.downcast_ref::<uuid::Uuid>() { return OwnedParam::Uuid(*v); }
178    if let Some(v) = any.downcast_ref::<chrono::NaiveDateTime>() { return OwnedParam::NaiveDateTime(*v); }
179    if let Some(v) = any.downcast_ref::<chrono::DateTime<chrono::Utc>>() { return OwnedParam::DateTimeUtc(*v); }
180    if let Some(v) = any.downcast_ref::<chrono::NaiveDate>() { return OwnedParam::NaiveDate(*v); }
181    if let Some(v) = any.downcast_ref::<rust_decimal::Decimal>() { return OwnedParam::Decimal(*v); }
182    if let Some(v) = any.downcast_ref::<serde_json::Value>() { return OwnedParam::Json(v.clone()); }
183    // Optional types
184    if let Some(v) = any.downcast_ref::<Option<i16>>() { return OwnedParam::OptI16(*v); }
185    if let Some(v) = any.downcast_ref::<Option<i32>>() { return OwnedParam::OptI32(*v); }
186    if let Some(v) = any.downcast_ref::<Option<i64>>() { return OwnedParam::OptI64(*v); }
187    if let Some(v) = any.downcast_ref::<Option<f32>>() { return OwnedParam::OptF32(*v); }
188    if let Some(v) = any.downcast_ref::<Option<f64>>() { return OwnedParam::OptF64(*v); }
189    if let Some(v) = any.downcast_ref::<Option<bool>>() { return OwnedParam::OptBool(*v); }
190    if let Some(v) = any.downcast_ref::<Option<String>>() { return OwnedParam::OptString(v.clone()); }
191    if let Some(v) = any.downcast_ref::<Option<uuid::Uuid>>() { return OwnedParam::OptUuid(*v); }
192    if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDateTime>>() { return OwnedParam::OptNaiveDateTime(*v); }
193    if let Some(v) = any.downcast_ref::<Option<chrono::DateTime<chrono::Utc>>>() { return OwnedParam::OptDateTimeUtc(*v); }
194    if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDate>>() { return OwnedParam::OptNaiveDate(*v); }
195    if let Some(v) = any.downcast_ref::<Option<rust_decimal::Decimal>>() { return OwnedParam::OptDecimal(*v); }
196    if let Some(v) = any.downcast_ref::<Option<serde_json::Value>>() { return OwnedParam::OptJson(v.clone()); }
197    // Vec types
198    if let Some(v) = any.downcast_ref::<Vec<i16>>() { return OwnedParam::VecI16(v.clone()); }
199    if let Some(v) = any.downcast_ref::<Vec<i32>>() { return OwnedParam::VecI32(v.clone()); }
200    if let Some(v) = any.downcast_ref::<Vec<i64>>() { return OwnedParam::VecI64(v.clone()); }
201    if let Some(v) = any.downcast_ref::<Vec<f32>>() { return OwnedParam::VecF32(v.clone()); }
202    if let Some(v) = any.downcast_ref::<Vec<f64>>() { return OwnedParam::VecF64(v.clone()); }
203    if let Some(v) = any.downcast_ref::<Vec<bool>>() { return OwnedParam::VecBool(v.clone()); }
204    if let Some(v) = any.downcast_ref::<Vec<String>>() { return OwnedParam::VecString(v.clone()); }
205    if let Some(v) = any.downcast_ref::<Vec<uuid::Uuid>>() { return OwnedParam::VecUuid(v.clone()); }
206    panic!("Unsupported parameter type for tokio-postgres backend");
207}
208
209/// Convert a slice of borrowed params to a vec of owned params.
210pub fn convert_params(params: &[&dyn ToSqlParam]) -> Vec<OwnedParam> {
211    params.iter().map(|p| to_owned_param(*p)).collect()
212}
213
214// ---- Client wrapper ----
215
216/// A wrapper around `tokio_postgres::Client` implementing the NextSQL `Client` trait.
217pub struct PgClient {
218    inner: tokio_postgres::Client,
219}
220
221impl PgClient {
222    pub fn new(client: tokio_postgres::Client) -> Self {
223        Self { inner: client }
224    }
225
226    /// Get a reference to the underlying `tokio_postgres::Client`.
227    pub fn inner(&self) -> &tokio_postgres::Client {
228        &self.inner
229    }
230
231    /// Consume this wrapper and return the underlying `tokio_postgres::Client`.
232    pub fn into_inner(self) -> tokio_postgres::Client {
233        self.inner
234    }
235}
236
237impl QueryExecutor for PgClient {
238    type Error = tokio_postgres::Error;
239    type Row = PgRow;
240
241    fn query(
242        &self,
243        sql: &str,
244        params: &[&dyn ToSqlParam],
245    ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
246        let owned_params = convert_params(params);
247        let sql = sql.to_owned();
248        let client = &self.inner;
249        async move {
250            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
251                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
252            let rows = client.query(&sql, &param_refs).await?;
253            Ok(rows.into_iter().map(PgRow).collect())
254        }
255    }
256
257    fn execute(
258        &self,
259        sql: &str,
260        params: &[&dyn ToSqlParam],
261    ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
262        let owned_params = convert_params(params);
263        let sql = sql.to_owned();
264        let client = &self.inner;
265        async move {
266            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
267                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
268            client.execute(&sql, &param_refs).await
269        }
270    }
271}
272
273impl Client for PgClient {
274    type Transaction<'a> = PgTransaction<'a>;
275
276    fn transaction(&mut self) -> impl std::future::Future<Output = Result<Self::Transaction<'_>, Self::Error>> + Send {
277        async move {
278            let tx = self.inner.transaction().await?;
279            Ok(PgTransaction { inner: tx })
280        }
281    }
282}
283
284// ---- Transaction wrapper ----
285
286/// A wrapper around `tokio_postgres::Transaction` implementing the NextSQL `Transaction` trait.
287/// Supports nested transactions (savepoints) and commit/rollback.
288/// Drop without calling `commit()` will rollback the transaction.
289pub struct PgTransaction<'a> {
290    inner: tokio_postgres::Transaction<'a>,
291}
292
293impl<'a> PgTransaction<'a> {
294    pub fn new(tx: tokio_postgres::Transaction<'a>) -> Self {
295        Self { inner: tx }
296    }
297}
298
299impl QueryExecutor for PgTransaction<'_> {
300    type Error = tokio_postgres::Error;
301    type Row = PgRow;
302
303    fn query(
304        &self,
305        sql: &str,
306        params: &[&dyn ToSqlParam],
307    ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
308        let owned_params = convert_params(params);
309        let sql = sql.to_owned();
310        let client = &self.inner;
311        async move {
312            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
313                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
314            let rows = client.query(&sql, &param_refs).await?;
315            Ok(rows.into_iter().map(PgRow).collect())
316        }
317    }
318
319    fn execute(
320        &self,
321        sql: &str,
322        params: &[&dyn ToSqlParam],
323    ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
324        let owned_params = convert_params(params);
325        let sql = sql.to_owned();
326        let client = &self.inner;
327        async move {
328            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
329                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
330            client.execute(&sql, &param_refs).await
331        }
332    }
333}
334
335impl Transaction for PgTransaction<'_> {
336    type Nested<'a> = PgTransaction<'a> where Self: 'a;
337
338    fn transaction(&mut self) -> impl std::future::Future<Output = Result<Self::Nested<'_>, Self::Error>> + Send {
339        async move {
340            let tx = self.inner.transaction().await?;
341            Ok(PgTransaction { inner: tx })
342        }
343    }
344
345    fn commit(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
346        async move { self.inner.commit().await }
347    }
348
349    fn rollback(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
350        async move { self.inner.rollback().await }
351    }
352}