Skip to main content

nextsql_backend_rust_runtime/
tokio_postgres_impl.rs

1use crate::{Client, QueryExecutor, Row, ToSqlParam, Transaction};
2
3// ---- Row implementation wrapping tokio_postgres::Row ----
4
5pub struct PgRow(pub tokio_postgres::Row);
6
7impl Row for PgRow {
8    fn get_i16(&self, idx: usize) -> i16 { self.0.get(idx) }
9    fn get_i32(&self, idx: usize) -> i32 { self.0.get(idx) }
10    fn get_i64(&self, idx: usize) -> i64 { self.0.get(idx) }
11    fn get_f32(&self, idx: usize) -> f32 { self.0.get(idx) }
12    fn get_f64(&self, idx: usize) -> f64 { self.0.get(idx) }
13    fn get_string(&self, idx: usize) -> String { self.0.get(idx) }
14    fn get_bool(&self, idx: usize) -> bool { self.0.get(idx) }
15    fn get_uuid(&self, idx: usize) -> uuid::Uuid { self.0.get(idx) }
16    fn get_timestamp(&self, idx: usize) -> chrono::NaiveDateTime { self.0.get(idx) }
17    fn get_timestamptz(&self, idx: usize) -> chrono::DateTime<chrono::Utc> { self.0.get(idx) }
18    fn get_date(&self, idx: usize) -> chrono::NaiveDate { self.0.get(idx) }
19    fn get_decimal(&self, idx: usize) -> rust_decimal::Decimal { self.0.get(idx) }
20    fn get_json(&self, idx: usize) -> serde_json::Value { self.0.get(idx) }
21
22    fn get_opt_i16(&self, idx: usize) -> Option<i16> { self.0.get(idx) }
23    fn get_opt_i32(&self, idx: usize) -> Option<i32> { self.0.get(idx) }
24    fn get_opt_i64(&self, idx: usize) -> Option<i64> { self.0.get(idx) }
25    fn get_opt_f32(&self, idx: usize) -> Option<f32> { self.0.get(idx) }
26    fn get_opt_f64(&self, idx: usize) -> Option<f64> { self.0.get(idx) }
27    fn get_opt_string(&self, idx: usize) -> Option<String> { self.0.get(idx) }
28    fn get_opt_bool(&self, idx: usize) -> Option<bool> { self.0.get(idx) }
29    fn get_opt_uuid(&self, idx: usize) -> Option<uuid::Uuid> { self.0.get(idx) }
30    fn get_opt_timestamp(&self, idx: usize) -> Option<chrono::NaiveDateTime> { self.0.get(idx) }
31    fn get_opt_timestamptz(&self, idx: usize) -> Option<chrono::DateTime<chrono::Utc>> { self.0.get(idx) }
32    fn get_opt_date(&self, idx: usize) -> Option<chrono::NaiveDate> { self.0.get(idx) }
33    fn get_opt_decimal(&self, idx: usize) -> Option<rust_decimal::Decimal> { self.0.get(idx) }
34    fn get_opt_json(&self, idx: usize) -> Option<serde_json::Value> { self.0.get(idx) }
35
36    fn get_vec_i16(&self, idx: usize) -> Vec<i16> { self.0.get(idx) }
37    fn get_vec_i32(&self, idx: usize) -> Vec<i32> { self.0.get(idx) }
38    fn get_vec_i64(&self, idx: usize) -> Vec<i64> { self.0.get(idx) }
39    fn get_vec_f32(&self, idx: usize) -> Vec<f32> { self.0.get(idx) }
40    fn get_vec_f64(&self, idx: usize) -> Vec<f64> { self.0.get(idx) }
41    fn get_vec_string(&self, idx: usize) -> Vec<String> { self.0.get(idx) }
42    fn get_vec_bool(&self, idx: usize) -> Vec<bool> { self.0.get(idx) }
43    fn get_vec_uuid(&self, idx: usize) -> Vec<uuid::Uuid> { self.0.get(idx) }
44}
45
46// ---- Owned parameter enum for crossing async boundaries ----
47
48/// An owned SQL parameter value. Used internally to convert borrowed `&dyn ToSqlParam`
49/// into owned values that can be moved into async futures.
50#[derive(Debug)]
51pub enum OwnedParam {
52    I16(i16),
53    I32(i32),
54    I64(i64),
55    F32(f32),
56    F64(f64),
57    Bool(bool),
58    String(String),
59    Uuid(uuid::Uuid),
60    NaiveDateTime(chrono::NaiveDateTime),
61    DateTimeUtc(chrono::DateTime<chrono::Utc>),
62    NaiveDate(chrono::NaiveDate),
63    Decimal(rust_decimal::Decimal),
64    Json(serde_json::Value),
65    OptI16(Option<i16>),
66    OptI32(Option<i32>),
67    OptI64(Option<i64>),
68    OptF32(Option<f32>),
69    OptF64(Option<f64>),
70    OptBool(Option<bool>),
71    OptString(Option<String>),
72    OptUuid(Option<uuid::Uuid>),
73    OptNaiveDateTime(Option<chrono::NaiveDateTime>),
74    OptDateTimeUtc(Option<chrono::DateTime<chrono::Utc>>),
75    OptNaiveDate(Option<chrono::NaiveDate>),
76    OptDecimal(Option<rust_decimal::Decimal>),
77    OptJson(Option<serde_json::Value>),
78    VecI16(Vec<i16>),
79    VecI32(Vec<i32>),
80    VecI64(Vec<i64>),
81    VecF32(Vec<f32>),
82    VecF64(Vec<f64>),
83    VecBool(Vec<bool>),
84    VecString(Vec<String>),
85    VecUuid(Vec<uuid::Uuid>),
86}
87
88impl tokio_postgres::types::ToSql for OwnedParam {
89    fn to_sql(
90        &self,
91        ty: &tokio_postgres::types::Type,
92        out: &mut bytes::BytesMut,
93    ) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
94        match self {
95            OwnedParam::I16(v) => v.to_sql(ty, out),
96            OwnedParam::I32(v) => v.to_sql(ty, out),
97            OwnedParam::I64(v) => v.to_sql(ty, out),
98            OwnedParam::F32(v) => v.to_sql(ty, out),
99            OwnedParam::F64(v) => v.to_sql(ty, out),
100            OwnedParam::Bool(v) => v.to_sql(ty, out),
101            OwnedParam::String(v) => v.to_sql(ty, out),
102            OwnedParam::Uuid(v) => v.to_sql(ty, out),
103            OwnedParam::NaiveDateTime(v) => v.to_sql(ty, out),
104            OwnedParam::DateTimeUtc(v) => v.to_sql(ty, out),
105            OwnedParam::NaiveDate(v) => v.to_sql(ty, out),
106            OwnedParam::Decimal(v) => v.to_sql(ty, out),
107            OwnedParam::Json(v) => v.to_sql(ty, out),
108            OwnedParam::OptI16(v) => v.to_sql(ty, out),
109            OwnedParam::OptI32(v) => v.to_sql(ty, out),
110            OwnedParam::OptI64(v) => v.to_sql(ty, out),
111            OwnedParam::OptF32(v) => v.to_sql(ty, out),
112            OwnedParam::OptF64(v) => v.to_sql(ty, out),
113            OwnedParam::OptBool(v) => v.to_sql(ty, out),
114            OwnedParam::OptString(v) => v.to_sql(ty, out),
115            OwnedParam::OptUuid(v) => v.to_sql(ty, out),
116            OwnedParam::OptNaiveDateTime(v) => v.to_sql(ty, out),
117            OwnedParam::OptDateTimeUtc(v) => v.to_sql(ty, out),
118            OwnedParam::OptNaiveDate(v) => v.to_sql(ty, out),
119            OwnedParam::OptDecimal(v) => v.to_sql(ty, out),
120            OwnedParam::OptJson(v) => v.to_sql(ty, out),
121            OwnedParam::VecI16(v) => v.to_sql(ty, out),
122            OwnedParam::VecI32(v) => v.to_sql(ty, out),
123            OwnedParam::VecI64(v) => v.to_sql(ty, out),
124            OwnedParam::VecF32(v) => v.to_sql(ty, out),
125            OwnedParam::VecF64(v) => v.to_sql(ty, out),
126            OwnedParam::VecBool(v) => v.to_sql(ty, out),
127            OwnedParam::VecString(v) => v.to_sql(ty, out),
128            OwnedParam::VecUuid(v) => v.to_sql(ty, out),
129        }
130    }
131
132    fn accepts(ty: &tokio_postgres::types::Type) -> bool {
133        i16::accepts(ty)
134            || i32::accepts(ty)
135            || i64::accepts(ty)
136            || f32::accepts(ty)
137            || f64::accepts(ty)
138            || bool::accepts(ty)
139            || String::accepts(ty)
140            || uuid::Uuid::accepts(ty)
141            || chrono::NaiveDateTime::accepts(ty)
142            || chrono::DateTime::<chrono::Utc>::accepts(ty)
143            || chrono::NaiveDate::accepts(ty)
144            || rust_decimal::Decimal::accepts(ty)
145            || serde_json::Value::accepts(ty)
146            || <Option<i16>>::accepts(ty)
147            || <Option<i32>>::accepts(ty)
148            || <Option<i64>>::accepts(ty)
149            || <Option<f32>>::accepts(ty)
150            || <Option<f64>>::accepts(ty)
151            || <Option<bool>>::accepts(ty)
152            || <Option<String>>::accepts(ty)
153            || <Option<uuid::Uuid>>::accepts(ty)
154            || <Option<chrono::NaiveDateTime>>::accepts(ty)
155            || <Option<chrono::DateTime<chrono::Utc>>>::accepts(ty)
156            || <Option<chrono::NaiveDate>>::accepts(ty)
157            || <Option<rust_decimal::Decimal>>::accepts(ty)
158            || <Option<serde_json::Value>>::accepts(ty)
159            || <Vec<i16>>::accepts(ty)
160            || <Vec<i32>>::accepts(ty)
161            || <Vec<i64>>::accepts(ty)
162            || <Vec<f32>>::accepts(ty)
163            || <Vec<f64>>::accepts(ty)
164            || <Vec<bool>>::accepts(ty)
165            || <Vec<String>>::accepts(ty)
166            || <Vec<uuid::Uuid>>::accepts(ty)
167    }
168
169    tokio_postgres::types::to_sql_checked!();
170}
171
172/// Convert a `&dyn ToSqlParam` to an owned parameter value by downcasting.
173pub fn to_owned_param(param: &dyn ToSqlParam) -> OwnedParam {
174    let any = param.as_any();
175    // Non-optional types
176    if let Some(v) = any.downcast_ref::<i16>() { return OwnedParam::I16(*v); }
177    if let Some(v) = any.downcast_ref::<i32>() { return OwnedParam::I32(*v); }
178    if let Some(v) = any.downcast_ref::<i64>() { return OwnedParam::I64(*v); }
179    if let Some(v) = any.downcast_ref::<f32>() { return OwnedParam::F32(*v); }
180    if let Some(v) = any.downcast_ref::<f64>() { return OwnedParam::F64(*v); }
181    if let Some(v) = any.downcast_ref::<bool>() { return OwnedParam::Bool(*v); }
182    if let Some(v) = any.downcast_ref::<String>() { return OwnedParam::String(v.clone()); }
183    if let Some(v) = any.downcast_ref::<uuid::Uuid>() { return OwnedParam::Uuid(*v); }
184    if let Some(v) = any.downcast_ref::<chrono::NaiveDateTime>() { return OwnedParam::NaiveDateTime(*v); }
185    if let Some(v) = any.downcast_ref::<chrono::DateTime<chrono::Utc>>() { return OwnedParam::DateTimeUtc(*v); }
186    if let Some(v) = any.downcast_ref::<chrono::NaiveDate>() { return OwnedParam::NaiveDate(*v); }
187    if let Some(v) = any.downcast_ref::<rust_decimal::Decimal>() { return OwnedParam::Decimal(*v); }
188    if let Some(v) = any.downcast_ref::<serde_json::Value>() { return OwnedParam::Json(v.clone()); }
189    // Optional types
190    if let Some(v) = any.downcast_ref::<Option<i16>>() { return OwnedParam::OptI16(*v); }
191    if let Some(v) = any.downcast_ref::<Option<i32>>() { return OwnedParam::OptI32(*v); }
192    if let Some(v) = any.downcast_ref::<Option<i64>>() { return OwnedParam::OptI64(*v); }
193    if let Some(v) = any.downcast_ref::<Option<f32>>() { return OwnedParam::OptF32(*v); }
194    if let Some(v) = any.downcast_ref::<Option<f64>>() { return OwnedParam::OptF64(*v); }
195    if let Some(v) = any.downcast_ref::<Option<bool>>() { return OwnedParam::OptBool(*v); }
196    if let Some(v) = any.downcast_ref::<Option<String>>() { return OwnedParam::OptString(v.clone()); }
197    if let Some(v) = any.downcast_ref::<Option<uuid::Uuid>>() { return OwnedParam::OptUuid(*v); }
198    if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDateTime>>() { return OwnedParam::OptNaiveDateTime(*v); }
199    if let Some(v) = any.downcast_ref::<Option<chrono::DateTime<chrono::Utc>>>() { return OwnedParam::OptDateTimeUtc(*v); }
200    if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDate>>() { return OwnedParam::OptNaiveDate(*v); }
201    if let Some(v) = any.downcast_ref::<Option<rust_decimal::Decimal>>() { return OwnedParam::OptDecimal(*v); }
202    if let Some(v) = any.downcast_ref::<Option<serde_json::Value>>() { return OwnedParam::OptJson(v.clone()); }
203    // Vec types
204    if let Some(v) = any.downcast_ref::<Vec<i16>>() { return OwnedParam::VecI16(v.clone()); }
205    if let Some(v) = any.downcast_ref::<Vec<i32>>() { return OwnedParam::VecI32(v.clone()); }
206    if let Some(v) = any.downcast_ref::<Vec<i64>>() { return OwnedParam::VecI64(v.clone()); }
207    if let Some(v) = any.downcast_ref::<Vec<f32>>() { return OwnedParam::VecF32(v.clone()); }
208    if let Some(v) = any.downcast_ref::<Vec<f64>>() { return OwnedParam::VecF64(v.clone()); }
209    if let Some(v) = any.downcast_ref::<Vec<bool>>() { return OwnedParam::VecBool(v.clone()); }
210    if let Some(v) = any.downcast_ref::<Vec<String>>() { return OwnedParam::VecString(v.clone()); }
211    if let Some(v) = any.downcast_ref::<Vec<uuid::Uuid>>() { return OwnedParam::VecUuid(v.clone()); }
212    panic!("Unsupported parameter type for tokio-postgres backend");
213}
214
215/// Convert a slice of borrowed params to a vec of owned params.
216pub fn convert_params(params: &[&dyn ToSqlParam]) -> Vec<OwnedParam> {
217    params.iter().map(|p| to_owned_param(*p)).collect()
218}
219
220// ---- Client wrapper ----
221
222/// A wrapper around `tokio_postgres::Client` implementing the NextSQL `Client` trait.
223pub struct PgClient {
224    inner: tokio_postgres::Client,
225}
226
227impl PgClient {
228    pub fn new(client: tokio_postgres::Client) -> Self {
229        Self { inner: client }
230    }
231
232    /// Get a reference to the underlying `tokio_postgres::Client`.
233    pub fn inner(&self) -> &tokio_postgres::Client {
234        &self.inner
235    }
236
237    /// Consume this wrapper and return the underlying `tokio_postgres::Client`.
238    pub fn into_inner(self) -> tokio_postgres::Client {
239        self.inner
240    }
241}
242
243impl QueryExecutor for PgClient {
244    type Error = tokio_postgres::Error;
245    type Row = PgRow;
246
247    fn query(
248        &self,
249        sql: &str,
250        params: &[&dyn ToSqlParam],
251    ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
252        let owned_params = convert_params(params);
253        let sql = sql.to_owned();
254        let client = &self.inner;
255        async move {
256            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
257                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
258            let rows = client.query(&sql, &param_refs).await?;
259            Ok(rows.into_iter().map(PgRow).collect())
260        }
261    }
262
263    fn execute(
264        &self,
265        sql: &str,
266        params: &[&dyn ToSqlParam],
267    ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
268        let owned_params = convert_params(params);
269        let sql = sql.to_owned();
270        let client = &self.inner;
271        async move {
272            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
273                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
274            client.execute(&sql, &param_refs).await
275        }
276    }
277}
278
279impl Client for PgClient {
280    type Transaction<'a> = PgTransaction<'a>;
281
282    fn transaction(&mut self) -> impl std::future::Future<Output = Result<Self::Transaction<'_>, Self::Error>> + Send {
283        async move {
284            let tx = self.inner.transaction().await?;
285            Ok(PgTransaction { inner: tx })
286        }
287    }
288}
289
290// ---- Transaction wrapper ----
291
292/// A wrapper around `tokio_postgres::Transaction` implementing the NextSQL `Transaction` trait.
293/// Supports nested transactions (savepoints) and commit/rollback.
294/// Drop without calling `commit()` will rollback the transaction.
295pub struct PgTransaction<'a> {
296    inner: tokio_postgres::Transaction<'a>,
297}
298
299impl<'a> PgTransaction<'a> {
300    pub fn new(tx: tokio_postgres::Transaction<'a>) -> Self {
301        Self { inner: tx }
302    }
303}
304
305impl QueryExecutor for PgTransaction<'_> {
306    type Error = tokio_postgres::Error;
307    type Row = PgRow;
308
309    fn query(
310        &self,
311        sql: &str,
312        params: &[&dyn ToSqlParam],
313    ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
314        let owned_params = convert_params(params);
315        let sql = sql.to_owned();
316        let client = &self.inner;
317        async move {
318            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
319                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
320            let rows = client.query(&sql, &param_refs).await?;
321            Ok(rows.into_iter().map(PgRow).collect())
322        }
323    }
324
325    fn execute(
326        &self,
327        sql: &str,
328        params: &[&dyn ToSqlParam],
329    ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
330        let owned_params = convert_params(params);
331        let sql = sql.to_owned();
332        let client = &self.inner;
333        async move {
334            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
335                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
336            client.execute(&sql, &param_refs).await
337        }
338    }
339}
340
341impl Transaction for PgTransaction<'_> {
342    type Nested<'a> = PgTransaction<'a> where Self: 'a;
343
344    fn transaction(&mut self) -> impl std::future::Future<Output = Result<Self::Nested<'_>, Self::Error>> + Send {
345        async move {
346            let tx = self.inner.transaction().await?;
347            Ok(PgTransaction { inner: tx })
348        }
349    }
350
351    fn commit(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
352        async move { self.inner.commit().await }
353    }
354
355    fn rollback(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send {
356        async move { self.inner.rollback().await }
357    }
358}