Skip to main content

prax_postgres/
types.rs

1//! Type conversions for PostgreSQL.
2
3use std::str::FromStr;
4
5use prax_query::filter::FilterValue;
6use tokio_postgres::types::{IsNull, Kind, ToSql, Type};
7
8use crate::error::{PgError, PgResult};
9
10/// Polymorphic integer binding. `FilterValue::Int` always carries an i64
11/// (the widest scalar variant), but Postgres strictly validates client
12/// bindings against column types: binding an i64 to an `INT4` column
13/// fails with `WrongType { postgres: Int4, rust: "i64" }`. This wrapper
14/// inspects the target column type at bind time and narrows to i16 /
15/// i32 / i64 with a bounds check before forwarding to tokio-postgres'
16/// own impls.
17#[derive(Debug)]
18struct PgInt(i64);
19
20impl ToSql for PgInt {
21    fn to_sql(
22        &self,
23        ty: &Type,
24        out: &mut bytes::BytesMut,
25    ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
26        match *ty {
27            Type::INT2 => {
28                let v: i16 = self
29                    .0
30                    .try_into()
31                    .map_err(|_| format!("value {} overflows INT2", self.0))?;
32                v.to_sql(ty, out)
33            }
34            Type::INT4 => {
35                let v: i32 = self
36                    .0
37                    .try_into()
38                    .map_err(|_| format!("value {} overflows INT4", self.0))?;
39                v.to_sql(ty, out)
40            }
41            Type::INT8 => self.0.to_sql(ty, out),
42            _ => Err(format!("cannot bind integer to postgres type {ty:?}").into()),
43        }
44    }
45
46    fn accepts(ty: &Type) -> bool {
47        matches!(*ty, Type::INT2 | Type::INT4 | Type::INT8)
48    }
49
50    tokio_postgres::types::to_sql_checked!();
51}
52
53/// Polymorphic string binding. `FilterValue::String` always carries a
54/// Rust `String`, but Postgres rejects a String bound to a UUID column
55/// (`WrongType { postgres: Uuid, rust: "alloc::string::String" }`).
56/// This wrapper inspects the target column type at bind time and
57/// converts as appropriate; for plain TEXT/VARCHAR/CHAR/NAME and other
58/// types it forwards the String directly.
59#[derive(Debug)]
60struct PgString(String);
61
62impl ToSql for PgString {
63    fn to_sql(
64        &self,
65        ty: &Type,
66        out: &mut bytes::BytesMut,
67    ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
68        match *ty {
69            Type::UUID => {
70                let parsed = ::uuid::Uuid::from_str(&self.0).map_err(|e| {
71                    format!("FilterValue::String('{}') is not a valid UUID: {e}", self.0)
72                })?;
73                parsed.to_sql(ty, out)
74            }
75            // chrono types round-trip through `FilterValue::String`
76            // (see `prax-query/src/filter.rs`). Postgres rejects a String
77            // bound to TIMESTAMPTZ/TIMESTAMP/DATE/TIME with WrongType,
78            // so re-parse and forward through tokio-postgres' chrono
79            // FromSql/ToSql impls.
80            Type::TIMESTAMPTZ => {
81                let parsed: ::chrono::DateTime<::chrono::Utc> =
82                    ::chrono::DateTime::parse_from_rfc3339(&self.0)
83                        .map_err(|e| {
84                            format!(
85                                "FilterValue::String('{}') is not a valid RFC3339 \
86                                 datetime for TIMESTAMPTZ: {e}",
87                                self.0
88                            )
89                        })?
90                        .with_timezone(&::chrono::Utc);
91                parsed.to_sql(ty, out)
92            }
93            Type::TIMESTAMP => {
94                let parsed =
95                    ::chrono::NaiveDateTime::parse_from_str(&self.0, "%Y-%m-%dT%H:%M:%S%.f")
96                        .map_err(|e| {
97                            format!(
98                                "FilterValue::String('{}') is not a valid \
99                         ISO-8601 naive datetime for TIMESTAMP: {e}",
100                                self.0
101                            )
102                        })?;
103                parsed.to_sql(ty, out)
104            }
105            Type::DATE => {
106                let parsed =
107                    ::chrono::NaiveDate::parse_from_str(&self.0, "%Y-%m-%d").map_err(|e| {
108                        format!(
109                            "FilterValue::String('{}') is not a valid \
110                             YYYY-MM-DD date for DATE: {e}",
111                            self.0
112                        )
113                    })?;
114                parsed.to_sql(ty, out)
115            }
116            Type::TIME => {
117                let parsed =
118                    ::chrono::NaiveTime::parse_from_str(&self.0, "%H:%M:%S%.f").map_err(|e| {
119                        format!(
120                            "FilterValue::String('{}') is not a valid \
121                                 HH:MM:SS time for TIME: {e}",
122                            self.0
123                        )
124                    })?;
125                parsed.to_sql(ty, out)
126            }
127            _ => {
128                // User-defined `ENUM` columns reach this arm; their
129                // wire format is just utf-8 bytes, so write the
130                // string body directly. Plain TEXT/VARCHAR/CHAR/NAME
131                // also takes this path through `String: ToSql`.
132                if matches!(ty.kind(), Kind::Enum(_)) {
133                    out.extend_from_slice(self.0.as_bytes());
134                    Ok(IsNull::No)
135                } else {
136                    self.0.to_sql(ty, out)
137                }
138            }
139        }
140    }
141
142    fn accepts(_ty: &Type) -> bool {
143        true
144    }
145
146    tokio_postgres::types::to_sql_checked!();
147}
148
149/// Convert a FilterValue to a type that can be used as a PostgreSQL parameter.
150pub fn filter_value_to_sql(value: &FilterValue) -> PgResult<Box<dyn ToSql + Sync + Send>> {
151    match value {
152        FilterValue::Null => Ok(Box::new(Option::<String>::None)),
153        FilterValue::Bool(b) => Ok(Box::new(*b)),
154        FilterValue::Int(i) => Ok(Box::new(PgInt(*i))),
155        FilterValue::Float(f) => Ok(Box::new(*f)),
156        FilterValue::String(s) => Ok(Box::new(PgString(s.clone()))),
157        FilterValue::Json(j) => Ok(Box::new(j.clone())),
158        FilterValue::List(_) => {
159            // Lists need special handling - they should be converted to arrays
160            // For now, return an error and handle lists specially in the engine
161            Err(PgError::type_conversion(
162                "list values should be handled specially",
163            ))
164        }
165    }
166}
167
168/// Convert filter values to PostgreSQL parameters.
169pub fn filter_values_to_params(
170    values: &[FilterValue],
171) -> PgResult<Vec<Box<dyn ToSql + Sync + Send>>> {
172    values.iter().map(filter_value_to_sql).collect()
173}
174
175/// PostgreSQL type mapping utilities.
176pub mod pg_types {
177    use super::*;
178
179    /// Get the PostgreSQL type for a Rust type name.
180    pub fn rust_type_to_pg(rust_type: &str) -> Option<Type> {
181        match rust_type {
182            "i16" => Some(Type::INT2),
183            "i32" => Some(Type::INT4),
184            "i64" => Some(Type::INT8),
185            "f32" => Some(Type::FLOAT4),
186            "f64" => Some(Type::FLOAT8),
187            "bool" => Some(Type::BOOL),
188            "String" | "&str" => Some(Type::TEXT),
189            "Vec<u8>" | "&[u8]" => Some(Type::BYTEA),
190            "chrono::NaiveDate" => Some(Type::DATE),
191            "chrono::NaiveTime" => Some(Type::TIME),
192            "chrono::NaiveDateTime" => Some(Type::TIMESTAMP),
193            "chrono::DateTime<chrono::Utc>" => Some(Type::TIMESTAMPTZ),
194            "uuid::Uuid" => Some(Type::UUID),
195            "serde_json::Value" => Some(Type::JSONB),
196            _ => None,
197        }
198    }
199
200    /// Get the Rust type for a PostgreSQL type.
201    pub fn pg_type_to_rust(pg_type: &Type) -> &'static str {
202        match *pg_type {
203            Type::BOOL => "bool",
204            Type::INT2 => "i16",
205            Type::INT4 => "i32",
206            Type::INT8 => "i64",
207            Type::FLOAT4 => "f32",
208            Type::FLOAT8 => "f64",
209            Type::TEXT | Type::VARCHAR | Type::CHAR | Type::NAME => "String",
210            Type::BYTEA => "Vec<u8>",
211            Type::DATE => "chrono::NaiveDate",
212            Type::TIME => "chrono::NaiveTime",
213            Type::TIMESTAMP => "chrono::NaiveDateTime",
214            Type::TIMESTAMPTZ => "chrono::DateTime<chrono::Utc>",
215            Type::UUID => "uuid::Uuid",
216            Type::JSON | Type::JSONB => "serde_json::Value",
217            _ => "unknown",
218        }
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_filter_value_to_sql() {
228        let result = filter_value_to_sql(&FilterValue::Int(42));
229        assert!(result.is_ok());
230
231        let result = filter_value_to_sql(&FilterValue::String("test".to_string()));
232        assert!(result.is_ok());
233
234        let result = filter_value_to_sql(&FilterValue::Bool(true));
235        assert!(result.is_ok());
236    }
237
238    #[test]
239    fn test_pg_type_mapping() {
240        use pg_types::*;
241
242        assert_eq!(rust_type_to_pg("i32"), Some(Type::INT4));
243        assert_eq!(rust_type_to_pg("String"), Some(Type::TEXT));
244        assert_eq!(rust_type_to_pg("bool"), Some(Type::BOOL));
245
246        assert_eq!(pg_type_to_rust(&Type::INT4), "i32");
247        assert_eq!(pg_type_to_rust(&Type::TEXT), "String");
248    }
249}