Skip to main content

prax_postgres/
types.rs

1//! Type conversions for PostgreSQL.
2
3use prax_query::filter::FilterValue;
4use tokio_postgres::types::{IsNull, ToSql, Type};
5
6use crate::error::{PgError, PgResult};
7
8/// Polymorphic integer binding. `FilterValue::Int` always carries an i64
9/// (the widest scalar variant), but Postgres strictly validates client
10/// bindings against column types: binding an i64 to an `INT4` column
11/// fails with `WrongType { postgres: Int4, rust: "i64" }`. This wrapper
12/// inspects the target column type at bind time and narrows to i16 /
13/// i32 / i64 with a bounds check before forwarding to tokio-postgres'
14/// own impls.
15#[derive(Debug)]
16struct PgInt(i64);
17
18impl ToSql for PgInt {
19    fn to_sql(
20        &self,
21        ty: &Type,
22        out: &mut bytes::BytesMut,
23    ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
24        match *ty {
25            Type::INT2 => {
26                let v: i16 = self
27                    .0
28                    .try_into()
29                    .map_err(|_| format!("value {} overflows INT2", self.0))?;
30                v.to_sql(ty, out)
31            }
32            Type::INT4 => {
33                let v: i32 = self
34                    .0
35                    .try_into()
36                    .map_err(|_| format!("value {} overflows INT4", self.0))?;
37                v.to_sql(ty, out)
38            }
39            Type::INT8 => self.0.to_sql(ty, out),
40            _ => Err(format!("cannot bind integer to postgres type {ty:?}").into()),
41        }
42    }
43
44    fn accepts(ty: &Type) -> bool {
45        matches!(*ty, Type::INT2 | Type::INT4 | Type::INT8)
46    }
47
48    tokio_postgres::types::to_sql_checked!();
49}
50
51/// Convert a FilterValue to a type that can be used as a PostgreSQL parameter.
52pub fn filter_value_to_sql(value: &FilterValue) -> PgResult<Box<dyn ToSql + Sync + Send>> {
53    match value {
54        FilterValue::Null => Ok(Box::new(Option::<String>::None)),
55        FilterValue::Bool(b) => Ok(Box::new(*b)),
56        FilterValue::Int(i) => Ok(Box::new(PgInt(*i))),
57        FilterValue::Float(f) => Ok(Box::new(*f)),
58        FilterValue::String(s) => Ok(Box::new(s.clone())),
59        FilterValue::Json(j) => Ok(Box::new(j.clone())),
60        FilterValue::List(_) => {
61            // Lists need special handling - they should be converted to arrays
62            // For now, return an error and handle lists specially in the engine
63            Err(PgError::type_conversion(
64                "list values should be handled specially",
65            ))
66        }
67    }
68}
69
70/// Convert filter values to PostgreSQL parameters.
71pub fn filter_values_to_params(
72    values: &[FilterValue],
73) -> PgResult<Vec<Box<dyn ToSql + Sync + Send>>> {
74    values.iter().map(filter_value_to_sql).collect()
75}
76
77/// PostgreSQL type mapping utilities.
78pub mod pg_types {
79    use super::*;
80
81    /// Get the PostgreSQL type for a Rust type name.
82    pub fn rust_type_to_pg(rust_type: &str) -> Option<Type> {
83        match rust_type {
84            "i16" => Some(Type::INT2),
85            "i32" => Some(Type::INT4),
86            "i64" => Some(Type::INT8),
87            "f32" => Some(Type::FLOAT4),
88            "f64" => Some(Type::FLOAT8),
89            "bool" => Some(Type::BOOL),
90            "String" | "&str" => Some(Type::TEXT),
91            "Vec<u8>" | "&[u8]" => Some(Type::BYTEA),
92            "chrono::NaiveDate" => Some(Type::DATE),
93            "chrono::NaiveTime" => Some(Type::TIME),
94            "chrono::NaiveDateTime" => Some(Type::TIMESTAMP),
95            "chrono::DateTime<chrono::Utc>" => Some(Type::TIMESTAMPTZ),
96            "uuid::Uuid" => Some(Type::UUID),
97            "serde_json::Value" => Some(Type::JSONB),
98            _ => None,
99        }
100    }
101
102    /// Get the Rust type for a PostgreSQL type.
103    pub fn pg_type_to_rust(pg_type: &Type) -> &'static str {
104        match *pg_type {
105            Type::BOOL => "bool",
106            Type::INT2 => "i16",
107            Type::INT4 => "i32",
108            Type::INT8 => "i64",
109            Type::FLOAT4 => "f32",
110            Type::FLOAT8 => "f64",
111            Type::TEXT | Type::VARCHAR | Type::CHAR | Type::NAME => "String",
112            Type::BYTEA => "Vec<u8>",
113            Type::DATE => "chrono::NaiveDate",
114            Type::TIME => "chrono::NaiveTime",
115            Type::TIMESTAMP => "chrono::NaiveDateTime",
116            Type::TIMESTAMPTZ => "chrono::DateTime<chrono::Utc>",
117            Type::UUID => "uuid::Uuid",
118            Type::JSON | Type::JSONB => "serde_json::Value",
119            _ => "unknown",
120        }
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_filter_value_to_sql() {
130        let result = filter_value_to_sql(&FilterValue::Int(42));
131        assert!(result.is_ok());
132
133        let result = filter_value_to_sql(&FilterValue::String("test".to_string()));
134        assert!(result.is_ok());
135
136        let result = filter_value_to_sql(&FilterValue::Bool(true));
137        assert!(result.is_ok());
138    }
139
140    #[test]
141    fn test_pg_type_mapping() {
142        use pg_types::*;
143
144        assert_eq!(rust_type_to_pg("i32"), Some(Type::INT4));
145        assert_eq!(rust_type_to_pg("String"), Some(Type::TEXT));
146        assert_eq!(rust_type_to_pg("bool"), Some(Type::BOOL));
147
148        assert_eq!(pg_type_to_rust(&Type::INT4), "i32");
149        assert_eq!(pg_type_to_rust(&Type::TEXT), "String");
150    }
151}