Skip to main content

nextsql_backend_rust_runtime/
tokio_postgres_impl.rs

1use crate::{Client, Row, ToSqlParam};
2
3// ---- ToSqlParam implementations for primitive types ----
4
5macro_rules! impl_to_sql_param {
6    ($($ty:ty),*) => {
7        $(
8            impl ToSqlParam for $ty {
9                fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
10                    self
11                }
12            }
13        )*
14    };
15}
16
17impl_to_sql_param!(
18    i16, i32, i64, f32, f64, bool, String,
19    uuid::Uuid, chrono::NaiveDateTime, chrono::NaiveDate
20);
21
22impl ToSqlParam for chrono::DateTime<chrono::Utc> {
23    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
24        self
25    }
26}
27
28impl ToSqlParam for &'static str {
29    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
30        self
31    }
32}
33
34impl<T: ToSqlParam + 'static> ToSqlParam for Option<T> {
35    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
36        self
37    }
38}
39
40impl<T: ToSqlParam + 'static> ToSqlParam for Vec<T> {
41    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
42        self
43    }
44}
45
46// ---- Row implementation wrapping tokio_postgres::Row ----
47
48pub struct PgRow(pub tokio_postgres::Row);
49
50impl Row for PgRow {
51    fn get_i16(&self, idx: usize) -> i16 { self.0.get(idx) }
52    fn get_i32(&self, idx: usize) -> i32 { self.0.get(idx) }
53    fn get_i64(&self, idx: usize) -> i64 { self.0.get(idx) }
54    fn get_f32(&self, idx: usize) -> f32 { self.0.get(idx) }
55    fn get_f64(&self, idx: usize) -> f64 { self.0.get(idx) }
56    fn get_string(&self, idx: usize) -> String { self.0.get(idx) }
57    fn get_bool(&self, idx: usize) -> bool { self.0.get(idx) }
58    fn get_uuid(&self, idx: usize) -> uuid::Uuid { self.0.get(idx) }
59    fn get_timestamp(&self, idx: usize) -> chrono::NaiveDateTime { self.0.get(idx) }
60    fn get_timestamptz(&self, idx: usize) -> chrono::DateTime<chrono::Utc> { self.0.get(idx) }
61    fn get_date(&self, idx: usize) -> chrono::NaiveDate { self.0.get(idx) }
62
63    fn get_opt_i16(&self, idx: usize) -> Option<i16> { self.0.get(idx) }
64    fn get_opt_i32(&self, idx: usize) -> Option<i32> { self.0.get(idx) }
65    fn get_opt_i64(&self, idx: usize) -> Option<i64> { self.0.get(idx) }
66    fn get_opt_f32(&self, idx: usize) -> Option<f32> { self.0.get(idx) }
67    fn get_opt_f64(&self, idx: usize) -> Option<f64> { self.0.get(idx) }
68    fn get_opt_string(&self, idx: usize) -> Option<String> { self.0.get(idx) }
69    fn get_opt_bool(&self, idx: usize) -> Option<bool> { self.0.get(idx) }
70    fn get_opt_uuid(&self, idx: usize) -> Option<uuid::Uuid> { self.0.get(idx) }
71    fn get_opt_timestamp(&self, idx: usize) -> Option<chrono::NaiveDateTime> { self.0.get(idx) }
72    fn get_opt_timestamptz(&self, idx: usize) -> Option<chrono::DateTime<chrono::Utc>> { self.0.get(idx) }
73    fn get_opt_date(&self, idx: usize) -> Option<chrono::NaiveDate> { self.0.get(idx) }
74
75    fn get_vec_i16(&self, idx: usize) -> Vec<i16> { self.0.get(idx) }
76    fn get_vec_i32(&self, idx: usize) -> Vec<i32> { self.0.get(idx) }
77    fn get_vec_i64(&self, idx: usize) -> Vec<i64> { self.0.get(idx) }
78    fn get_vec_f32(&self, idx: usize) -> Vec<f32> { self.0.get(idx) }
79    fn get_vec_f64(&self, idx: usize) -> Vec<f64> { self.0.get(idx) }
80    fn get_vec_string(&self, idx: usize) -> Vec<String> { self.0.get(idx) }
81    fn get_vec_bool(&self, idx: usize) -> Vec<bool> { self.0.get(idx) }
82    fn get_vec_uuid(&self, idx: usize) -> Vec<uuid::Uuid> { self.0.get(idx) }
83}
84
85// ---- Owned parameter enum for crossing async boundaries ----
86
87/// An owned SQL parameter value. Used internally to convert borrowed `&dyn ToSqlParam`
88/// into owned values that can be moved into async futures.
89#[derive(Debug)]
90enum OwnedParam {
91    I16(i16),
92    I32(i32),
93    I64(i64),
94    F32(f32),
95    F64(f64),
96    Bool(bool),
97    String(String),
98    Uuid(uuid::Uuid),
99    NaiveDateTime(chrono::NaiveDateTime),
100    DateTimeUtc(chrono::DateTime<chrono::Utc>),
101    NaiveDate(chrono::NaiveDate),
102    OptI16(Option<i16>),
103    OptI32(Option<i32>),
104    OptI64(Option<i64>),
105    OptF32(Option<f32>),
106    OptF64(Option<f64>),
107    OptBool(Option<bool>),
108    OptString(Option<String>),
109    OptUuid(Option<uuid::Uuid>),
110    OptNaiveDateTime(Option<chrono::NaiveDateTime>),
111    OptDateTimeUtc(Option<chrono::DateTime<chrono::Utc>>),
112    OptNaiveDate(Option<chrono::NaiveDate>),
113    VecI16(Vec<i16>),
114    VecI32(Vec<i32>),
115    VecI64(Vec<i64>),
116    VecF32(Vec<f32>),
117    VecF64(Vec<f64>),
118    VecBool(Vec<bool>),
119    VecString(Vec<String>),
120    VecUuid(Vec<uuid::Uuid>),
121}
122
123impl tokio_postgres::types::ToSql for OwnedParam {
124    fn to_sql(
125        &self,
126        ty: &tokio_postgres::types::Type,
127        out: &mut bytes::BytesMut,
128    ) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
129        match self {
130            OwnedParam::I16(v) => v.to_sql(ty, out),
131            OwnedParam::I32(v) => v.to_sql(ty, out),
132            OwnedParam::I64(v) => v.to_sql(ty, out),
133            OwnedParam::F32(v) => v.to_sql(ty, out),
134            OwnedParam::F64(v) => v.to_sql(ty, out),
135            OwnedParam::Bool(v) => v.to_sql(ty, out),
136            OwnedParam::String(v) => v.to_sql(ty, out),
137            OwnedParam::Uuid(v) => v.to_sql(ty, out),
138            OwnedParam::NaiveDateTime(v) => v.to_sql(ty, out),
139            OwnedParam::DateTimeUtc(v) => v.to_sql(ty, out),
140            OwnedParam::NaiveDate(v) => v.to_sql(ty, out),
141            OwnedParam::OptI16(v) => v.to_sql(ty, out),
142            OwnedParam::OptI32(v) => v.to_sql(ty, out),
143            OwnedParam::OptI64(v) => v.to_sql(ty, out),
144            OwnedParam::OptF32(v) => v.to_sql(ty, out),
145            OwnedParam::OptF64(v) => v.to_sql(ty, out),
146            OwnedParam::OptBool(v) => v.to_sql(ty, out),
147            OwnedParam::OptString(v) => v.to_sql(ty, out),
148            OwnedParam::OptUuid(v) => v.to_sql(ty, out),
149            OwnedParam::OptNaiveDateTime(v) => v.to_sql(ty, out),
150            OwnedParam::OptDateTimeUtc(v) => v.to_sql(ty, out),
151            OwnedParam::OptNaiveDate(v) => v.to_sql(ty, out),
152            OwnedParam::VecI16(v) => v.to_sql(ty, out),
153            OwnedParam::VecI32(v) => v.to_sql(ty, out),
154            OwnedParam::VecI64(v) => v.to_sql(ty, out),
155            OwnedParam::VecF32(v) => v.to_sql(ty, out),
156            OwnedParam::VecF64(v) => v.to_sql(ty, out),
157            OwnedParam::VecBool(v) => v.to_sql(ty, out),
158            OwnedParam::VecString(v) => v.to_sql(ty, out),
159            OwnedParam::VecUuid(v) => v.to_sql(ty, out),
160        }
161    }
162
163    fn accepts(ty: &tokio_postgres::types::Type) -> bool {
164        i16::accepts(ty)
165            || i32::accepts(ty)
166            || i64::accepts(ty)
167            || f32::accepts(ty)
168            || f64::accepts(ty)
169            || bool::accepts(ty)
170            || String::accepts(ty)
171            || uuid::Uuid::accepts(ty)
172            || chrono::NaiveDateTime::accepts(ty)
173            || chrono::DateTime::<chrono::Utc>::accepts(ty)
174            || chrono::NaiveDate::accepts(ty)
175            || <Option<i16>>::accepts(ty)
176            || <Option<i32>>::accepts(ty)
177            || <Option<i64>>::accepts(ty)
178            || <Option<f32>>::accepts(ty)
179            || <Option<f64>>::accepts(ty)
180            || <Option<bool>>::accepts(ty)
181            || <Option<String>>::accepts(ty)
182            || <Option<uuid::Uuid>>::accepts(ty)
183            || <Option<chrono::NaiveDateTime>>::accepts(ty)
184            || <Option<chrono::DateTime<chrono::Utc>>>::accepts(ty)
185            || <Option<chrono::NaiveDate>>::accepts(ty)
186            || <Vec<i16>>::accepts(ty)
187            || <Vec<i32>>::accepts(ty)
188            || <Vec<i64>>::accepts(ty)
189            || <Vec<f32>>::accepts(ty)
190            || <Vec<f64>>::accepts(ty)
191            || <Vec<bool>>::accepts(ty)
192            || <Vec<String>>::accepts(ty)
193            || <Vec<uuid::Uuid>>::accepts(ty)
194    }
195
196    tokio_postgres::types::to_sql_checked!();
197}
198
199/// Convert a `&dyn ToSqlParam` to an owned parameter value by downcasting.
200fn to_owned_param(param: &dyn ToSqlParam) -> OwnedParam {
201    let any = param.as_any();
202    // Non-optional types
203    if let Some(v) = any.downcast_ref::<i16>() { return OwnedParam::I16(*v); }
204    if let Some(v) = any.downcast_ref::<i32>() { return OwnedParam::I32(*v); }
205    if let Some(v) = any.downcast_ref::<i64>() { return OwnedParam::I64(*v); }
206    if let Some(v) = any.downcast_ref::<f32>() { return OwnedParam::F32(*v); }
207    if let Some(v) = any.downcast_ref::<f64>() { return OwnedParam::F64(*v); }
208    if let Some(v) = any.downcast_ref::<bool>() { return OwnedParam::Bool(*v); }
209    if let Some(v) = any.downcast_ref::<String>() { return OwnedParam::String(v.clone()); }
210    if let Some(v) = any.downcast_ref::<uuid::Uuid>() { return OwnedParam::Uuid(*v); }
211    if let Some(v) = any.downcast_ref::<chrono::NaiveDateTime>() { return OwnedParam::NaiveDateTime(*v); }
212    if let Some(v) = any.downcast_ref::<chrono::DateTime<chrono::Utc>>() { return OwnedParam::DateTimeUtc(*v); }
213    if let Some(v) = any.downcast_ref::<chrono::NaiveDate>() { return OwnedParam::NaiveDate(*v); }
214    // Optional types
215    if let Some(v) = any.downcast_ref::<Option<i16>>() { return OwnedParam::OptI16(*v); }
216    if let Some(v) = any.downcast_ref::<Option<i32>>() { return OwnedParam::OptI32(*v); }
217    if let Some(v) = any.downcast_ref::<Option<i64>>() { return OwnedParam::OptI64(*v); }
218    if let Some(v) = any.downcast_ref::<Option<f32>>() { return OwnedParam::OptF32(*v); }
219    if let Some(v) = any.downcast_ref::<Option<f64>>() { return OwnedParam::OptF64(*v); }
220    if let Some(v) = any.downcast_ref::<Option<bool>>() { return OwnedParam::OptBool(*v); }
221    if let Some(v) = any.downcast_ref::<Option<String>>() { return OwnedParam::OptString(v.clone()); }
222    if let Some(v) = any.downcast_ref::<Option<uuid::Uuid>>() { return OwnedParam::OptUuid(*v); }
223    if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDateTime>>() { return OwnedParam::OptNaiveDateTime(*v); }
224    if let Some(v) = any.downcast_ref::<Option<chrono::DateTime<chrono::Utc>>>() { return OwnedParam::OptDateTimeUtc(*v); }
225    if let Some(v) = any.downcast_ref::<Option<chrono::NaiveDate>>() { return OwnedParam::OptNaiveDate(*v); }
226    // Vec types
227    if let Some(v) = any.downcast_ref::<Vec<i16>>() { return OwnedParam::VecI16(v.clone()); }
228    if let Some(v) = any.downcast_ref::<Vec<i32>>() { return OwnedParam::VecI32(v.clone()); }
229    if let Some(v) = any.downcast_ref::<Vec<i64>>() { return OwnedParam::VecI64(v.clone()); }
230    if let Some(v) = any.downcast_ref::<Vec<f32>>() { return OwnedParam::VecF32(v.clone()); }
231    if let Some(v) = any.downcast_ref::<Vec<f64>>() { return OwnedParam::VecF64(v.clone()); }
232    if let Some(v) = any.downcast_ref::<Vec<bool>>() { return OwnedParam::VecBool(v.clone()); }
233    if let Some(v) = any.downcast_ref::<Vec<String>>() { return OwnedParam::VecString(v.clone()); }
234    if let Some(v) = any.downcast_ref::<Vec<uuid::Uuid>>() { return OwnedParam::VecUuid(v.clone()); }
235    panic!("Unsupported parameter type for tokio-postgres backend");
236}
237
238/// Convert a slice of borrowed params to a vec of owned params.
239fn convert_params(params: &[&dyn ToSqlParam]) -> Vec<OwnedParam> {
240    params.iter().map(|p| to_owned_param(*p)).collect()
241}
242
243// ---- Client wrapper ----
244
245/// A wrapper around `tokio_postgres::Client` implementing the NextSQL `Client` trait.
246pub struct PgClient {
247    inner: tokio_postgres::Client,
248}
249
250impl PgClient {
251    pub fn new(client: tokio_postgres::Client) -> Self {
252        Self { inner: client }
253    }
254
255    /// Get a reference to the underlying `tokio_postgres::Client`.
256    pub fn inner(&self) -> &tokio_postgres::Client {
257        &self.inner
258    }
259
260    /// Consume this wrapper and return the underlying `tokio_postgres::Client`.
261    pub fn into_inner(self) -> tokio_postgres::Client {
262        self.inner
263    }
264}
265
266impl Client for PgClient {
267    type Error = tokio_postgres::Error;
268    type Row = PgRow;
269
270    fn query(
271        &self,
272        sql: &str,
273        params: &[&dyn ToSqlParam],
274    ) -> impl std::future::Future<Output = Result<Vec<Self::Row>, Self::Error>> + Send {
275        let owned_params = convert_params(params);
276        let sql = sql.to_owned();
277        let client = &self.inner;
278        async move {
279            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
280                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
281            let rows = client.query(&sql, &param_refs).await?;
282            Ok(rows.into_iter().map(PgRow).collect())
283        }
284    }
285
286    fn execute(
287        &self,
288        sql: &str,
289        params: &[&dyn ToSqlParam],
290    ) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send {
291        let owned_params = convert_params(params);
292        let sql = sql.to_owned();
293        let client = &self.inner;
294        async move {
295            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
296                owned_params.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
297            client.execute(&sql, &param_refs).await
298        }
299    }
300}