1use prax_query::filter::FilterValue;
4use tokio_postgres::types::{IsNull, ToSql, Type};
5
6use crate::error::{PgError, PgResult};
7
8#[derive(Debug)]
16struct PgInt(i64);
17
18impl ToSql for PgInt {
19 fn to_sql(
20 &self,
21 ty: &Type,
22 out: &mut bytes::BytesMut,
23 ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
24 match *ty {
25 Type::INT2 => {
26 let v: i16 = self
27 .0
28 .try_into()
29 .map_err(|_| format!("value {} overflows INT2", self.0))?;
30 v.to_sql(ty, out)
31 }
32 Type::INT4 => {
33 let v: i32 = self
34 .0
35 .try_into()
36 .map_err(|_| format!("value {} overflows INT4", self.0))?;
37 v.to_sql(ty, out)
38 }
39 Type::INT8 => self.0.to_sql(ty, out),
40 _ => Err(format!("cannot bind integer to postgres type {ty:?}").into()),
41 }
42 }
43
44 fn accepts(ty: &Type) -> bool {
45 matches!(*ty, Type::INT2 | Type::INT4 | Type::INT8)
46 }
47
48 tokio_postgres::types::to_sql_checked!();
49}
50
51pub fn filter_value_to_sql(value: &FilterValue) -> PgResult<Box<dyn ToSql + Sync + Send>> {
53 match value {
54 FilterValue::Null => Ok(Box::new(Option::<String>::None)),
55 FilterValue::Bool(b) => Ok(Box::new(*b)),
56 FilterValue::Int(i) => Ok(Box::new(PgInt(*i))),
57 FilterValue::Float(f) => Ok(Box::new(*f)),
58 FilterValue::String(s) => Ok(Box::new(s.clone())),
59 FilterValue::Json(j) => Ok(Box::new(j.clone())),
60 FilterValue::List(_) => {
61 Err(PgError::type_conversion(
64 "list values should be handled specially",
65 ))
66 }
67 }
68}
69
70pub fn filter_values_to_params(
72 values: &[FilterValue],
73) -> PgResult<Vec<Box<dyn ToSql + Sync + Send>>> {
74 values.iter().map(filter_value_to_sql).collect()
75}
76
77pub mod pg_types {
79 use super::*;
80
81 pub fn rust_type_to_pg(rust_type: &str) -> Option<Type> {
83 match rust_type {
84 "i16" => Some(Type::INT2),
85 "i32" => Some(Type::INT4),
86 "i64" => Some(Type::INT8),
87 "f32" => Some(Type::FLOAT4),
88 "f64" => Some(Type::FLOAT8),
89 "bool" => Some(Type::BOOL),
90 "String" | "&str" => Some(Type::TEXT),
91 "Vec<u8>" | "&[u8]" => Some(Type::BYTEA),
92 "chrono::NaiveDate" => Some(Type::DATE),
93 "chrono::NaiveTime" => Some(Type::TIME),
94 "chrono::NaiveDateTime" => Some(Type::TIMESTAMP),
95 "chrono::DateTime<chrono::Utc>" => Some(Type::TIMESTAMPTZ),
96 "uuid::Uuid" => Some(Type::UUID),
97 "serde_json::Value" => Some(Type::JSONB),
98 _ => None,
99 }
100 }
101
102 pub fn pg_type_to_rust(pg_type: &Type) -> &'static str {
104 match *pg_type {
105 Type::BOOL => "bool",
106 Type::INT2 => "i16",
107 Type::INT4 => "i32",
108 Type::INT8 => "i64",
109 Type::FLOAT4 => "f32",
110 Type::FLOAT8 => "f64",
111 Type::TEXT | Type::VARCHAR | Type::CHAR | Type::NAME => "String",
112 Type::BYTEA => "Vec<u8>",
113 Type::DATE => "chrono::NaiveDate",
114 Type::TIME => "chrono::NaiveTime",
115 Type::TIMESTAMP => "chrono::NaiveDateTime",
116 Type::TIMESTAMPTZ => "chrono::DateTime<chrono::Utc>",
117 Type::UUID => "uuid::Uuid",
118 Type::JSON | Type::JSONB => "serde_json::Value",
119 _ => "unknown",
120 }
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn test_filter_value_to_sql() {
130 let result = filter_value_to_sql(&FilterValue::Int(42));
131 assert!(result.is_ok());
132
133 let result = filter_value_to_sql(&FilterValue::String("test".to_string()));
134 assert!(result.is_ok());
135
136 let result = filter_value_to_sql(&FilterValue::Bool(true));
137 assert!(result.is_ok());
138 }
139
140 #[test]
141 fn test_pg_type_mapping() {
142 use pg_types::*;
143
144 assert_eq!(rust_type_to_pg("i32"), Some(Type::INT4));
145 assert_eq!(rust_type_to_pg("String"), Some(Type::TEXT));
146 assert_eq!(rust_type_to_pg("bool"), Some(Type::BOOL));
147
148 assert_eq!(pg_type_to_rust(&Type::INT4), "i32");
149 assert_eq!(pg_type_to_rust(&Type::TEXT), "String");
150 }
151}