1use std::str::FromStr;
4
5use prax_query::filter::FilterValue;
6use tokio_postgres::types::{IsNull, Kind, ToSql, Type};
7
8use crate::error::{PgError, PgResult};
9
10#[derive(Debug)]
18struct PgInt(i64);
19
20impl ToSql for PgInt {
21 fn to_sql(
22 &self,
23 ty: &Type,
24 out: &mut bytes::BytesMut,
25 ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
26 match *ty {
27 Type::INT2 => {
28 let v: i16 = self
29 .0
30 .try_into()
31 .map_err(|_| format!("value {} overflows INT2", self.0))?;
32 v.to_sql(ty, out)
33 }
34 Type::INT4 => {
35 let v: i32 = self
36 .0
37 .try_into()
38 .map_err(|_| format!("value {} overflows INT4", self.0))?;
39 v.to_sql(ty, out)
40 }
41 Type::INT8 => self.0.to_sql(ty, out),
42 _ => Err(format!("cannot bind integer to postgres type {ty:?}").into()),
43 }
44 }
45
46 fn accepts(ty: &Type) -> bool {
47 matches!(*ty, Type::INT2 | Type::INT4 | Type::INT8)
48 }
49
50 tokio_postgres::types::to_sql_checked!();
51}
52
53#[derive(Debug)]
60struct PgString(String);
61
62impl ToSql for PgString {
63 fn to_sql(
64 &self,
65 ty: &Type,
66 out: &mut bytes::BytesMut,
67 ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
68 match *ty {
69 Type::UUID => {
70 let parsed = ::uuid::Uuid::from_str(&self.0).map_err(|e| {
71 format!("FilterValue::String('{}') is not a valid UUID: {e}", self.0)
72 })?;
73 parsed.to_sql(ty, out)
74 }
75 Type::TIMESTAMPTZ => {
81 let parsed: ::chrono::DateTime<::chrono::Utc> =
82 ::chrono::DateTime::parse_from_rfc3339(&self.0)
83 .map_err(|e| {
84 format!(
85 "FilterValue::String('{}') is not a valid RFC3339 \
86 datetime for TIMESTAMPTZ: {e}",
87 self.0
88 )
89 })?
90 .with_timezone(&::chrono::Utc);
91 parsed.to_sql(ty, out)
92 }
93 Type::TIMESTAMP => {
94 let parsed =
95 ::chrono::NaiveDateTime::parse_from_str(&self.0, "%Y-%m-%dT%H:%M:%S%.f")
96 .map_err(|e| {
97 format!(
98 "FilterValue::String('{}') is not a valid \
99 ISO-8601 naive datetime for TIMESTAMP: {e}",
100 self.0
101 )
102 })?;
103 parsed.to_sql(ty, out)
104 }
105 Type::DATE => {
106 let parsed =
107 ::chrono::NaiveDate::parse_from_str(&self.0, "%Y-%m-%d").map_err(|e| {
108 format!(
109 "FilterValue::String('{}') is not a valid \
110 YYYY-MM-DD date for DATE: {e}",
111 self.0
112 )
113 })?;
114 parsed.to_sql(ty, out)
115 }
116 Type::TIME => {
117 let parsed =
118 ::chrono::NaiveTime::parse_from_str(&self.0, "%H:%M:%S%.f").map_err(|e| {
119 format!(
120 "FilterValue::String('{}') is not a valid \
121 HH:MM:SS time for TIME: {e}",
122 self.0
123 )
124 })?;
125 parsed.to_sql(ty, out)
126 }
127 _ => {
128 if matches!(ty.kind(), Kind::Enum(_)) {
133 out.extend_from_slice(self.0.as_bytes());
134 Ok(IsNull::No)
135 } else {
136 self.0.to_sql(ty, out)
137 }
138 }
139 }
140 }
141
142 fn accepts(_ty: &Type) -> bool {
143 true
144 }
145
146 tokio_postgres::types::to_sql_checked!();
147}
148
149pub fn filter_value_to_sql(value: &FilterValue) -> PgResult<Box<dyn ToSql + Sync + Send>> {
151 match value {
152 FilterValue::Null => Ok(Box::new(Option::<String>::None)),
153 FilterValue::Bool(b) => Ok(Box::new(*b)),
154 FilterValue::Int(i) => Ok(Box::new(PgInt(*i))),
155 FilterValue::Float(f) => Ok(Box::new(*f)),
156 FilterValue::String(s) => Ok(Box::new(PgString(s.clone()))),
157 FilterValue::Json(j) => Ok(Box::new(j.clone())),
158 FilterValue::List(_) => {
159 Err(PgError::type_conversion(
162 "list values should be handled specially",
163 ))
164 }
165 }
166}
167
168pub fn filter_values_to_params(
170 values: &[FilterValue],
171) -> PgResult<Vec<Box<dyn ToSql + Sync + Send>>> {
172 values.iter().map(filter_value_to_sql).collect()
173}
174
175pub mod pg_types {
177 use super::*;
178
179 pub fn rust_type_to_pg(rust_type: &str) -> Option<Type> {
181 match rust_type {
182 "i16" => Some(Type::INT2),
183 "i32" => Some(Type::INT4),
184 "i64" => Some(Type::INT8),
185 "f32" => Some(Type::FLOAT4),
186 "f64" => Some(Type::FLOAT8),
187 "bool" => Some(Type::BOOL),
188 "String" | "&str" => Some(Type::TEXT),
189 "Vec<u8>" | "&[u8]" => Some(Type::BYTEA),
190 "chrono::NaiveDate" => Some(Type::DATE),
191 "chrono::NaiveTime" => Some(Type::TIME),
192 "chrono::NaiveDateTime" => Some(Type::TIMESTAMP),
193 "chrono::DateTime<chrono::Utc>" => Some(Type::TIMESTAMPTZ),
194 "uuid::Uuid" => Some(Type::UUID),
195 "serde_json::Value" => Some(Type::JSONB),
196 _ => None,
197 }
198 }
199
200 pub fn pg_type_to_rust(pg_type: &Type) -> &'static str {
202 match *pg_type {
203 Type::BOOL => "bool",
204 Type::INT2 => "i16",
205 Type::INT4 => "i32",
206 Type::INT8 => "i64",
207 Type::FLOAT4 => "f32",
208 Type::FLOAT8 => "f64",
209 Type::TEXT | Type::VARCHAR | Type::CHAR | Type::NAME => "String",
210 Type::BYTEA => "Vec<u8>",
211 Type::DATE => "chrono::NaiveDate",
212 Type::TIME => "chrono::NaiveTime",
213 Type::TIMESTAMP => "chrono::NaiveDateTime",
214 Type::TIMESTAMPTZ => "chrono::DateTime<chrono::Utc>",
215 Type::UUID => "uuid::Uuid",
216 Type::JSON | Type::JSONB => "serde_json::Value",
217 _ => "unknown",
218 }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_filter_value_to_sql() {
228 let result = filter_value_to_sql(&FilterValue::Int(42));
229 assert!(result.is_ok());
230
231 let result = filter_value_to_sql(&FilterValue::String("test".to_string()));
232 assert!(result.is_ok());
233
234 let result = filter_value_to_sql(&FilterValue::Bool(true));
235 assert!(result.is_ok());
236 }
237
238 #[test]
239 fn test_pg_type_mapping() {
240 use pg_types::*;
241
242 assert_eq!(rust_type_to_pg("i32"), Some(Type::INT4));
243 assert_eq!(rust_type_to_pg("String"), Some(Type::TEXT));
244 assert_eq!(rust_type_to_pg("bool"), Some(Type::BOOL));
245
246 assert_eq!(pg_type_to_rust(&Type::INT4), "i32");
247 assert_eq!(pg_type_to_rust(&Type::TEXT), "String");
248 }
249}