1use std::collections::HashMap;
9
10use prax_query::row::{RowError, RowRef};
11use sqlx::{Column, Row};
12
13use crate::row::SqlxRow;
14
15enum Value {
16 Null,
17 Bool(bool),
18 I64(i64),
19 F64(f64),
20 Text(String),
21 Bytes(Vec<u8>),
22}
23
24pub struct SqlxRowRef {
28 values: HashMap<String, Value>,
29}
30
31impl SqlxRowRef {
32 pub fn from_sqlx(row: &SqlxRow) -> Result<Self, RowError> {
34 let mut values = HashMap::new();
35 match row {
36 #[cfg(feature = "postgres")]
37 SqlxRow::Postgres(r) => {
38 for (i, col) in r.columns().iter().enumerate() {
39 let name = col.name().to_string();
40 let v = decode_pg_cell(r, i);
41 values.insert(name, v);
42 }
43 }
44 #[cfg(feature = "mysql")]
45 SqlxRow::MySql(r) => {
46 for (i, col) in r.columns().iter().enumerate() {
47 let name = col.name().to_string();
48 let v = decode_generic_cell_mysql(r, i);
49 values.insert(name, v);
50 }
51 }
52 #[cfg(feature = "sqlite")]
53 SqlxRow::Sqlite(r) => {
54 for (i, col) in r.columns().iter().enumerate() {
55 let name = col.name().to_string();
56 let v = decode_generic_cell_sqlite(r, i);
57 values.insert(name, v);
58 }
59 }
60 }
61 Ok(Self { values })
62 }
63}
64
65fn tc(column: &str, msg: impl Into<String>) -> RowError {
66 RowError::TypeConversion {
67 column: column.into(),
68 message: msg.into(),
69 }
70}
71
72#[cfg(feature = "postgres")]
75fn decode_pg_cell(r: &sqlx::postgres::PgRow, i: usize) -> Value {
76 if let Ok(Some(s)) = r.try_get::<Option<String>, _>(i) {
77 return Value::Text(s);
78 }
79 if let Ok(Some(b)) = r.try_get::<Option<bool>, _>(i) {
80 return Value::Bool(b);
81 }
82 if let Ok(Some(n)) = r.try_get::<Option<i64>, _>(i) {
83 return Value::I64(n);
84 }
85 if let Ok(Some(n)) = r.try_get::<Option<i32>, _>(i) {
86 return Value::I64(n as i64);
87 }
88 if let Ok(Some(n)) = r.try_get::<Option<i16>, _>(i) {
89 return Value::I64(n as i64);
90 }
91 if let Ok(Some(f)) = r.try_get::<Option<f64>, _>(i) {
92 return Value::F64(f);
93 }
94 if let Ok(Some(f)) = r.try_get::<Option<f32>, _>(i) {
95 return Value::F64(f as f64);
96 }
97 if let Ok(Some(b)) = r.try_get::<Option<Vec<u8>>, _>(i) {
98 return Value::Bytes(b);
99 }
100 Value::Null
101}
102
103#[cfg(feature = "mysql")]
104fn decode_generic_cell_mysql(r: &sqlx::mysql::MySqlRow, i: usize) -> Value {
105 if let Ok(Some(s)) = r.try_get::<Option<String>, _>(i) {
106 return Value::Text(s);
107 }
108 if let Ok(Some(b)) = r.try_get::<Option<bool>, _>(i) {
109 return Value::Bool(b);
110 }
111 if let Ok(Some(n)) = r.try_get::<Option<i64>, _>(i) {
112 return Value::I64(n);
113 }
114 if let Ok(Some(f)) = r.try_get::<Option<f64>, _>(i) {
115 return Value::F64(f);
116 }
117 if let Ok(Some(b)) = r.try_get::<Option<Vec<u8>>, _>(i) {
118 return Value::Bytes(b);
119 }
120 Value::Null
121}
122
123#[cfg(feature = "sqlite")]
124fn decode_generic_cell_sqlite(r: &sqlx::sqlite::SqliteRow, i: usize) -> Value {
125 if let Ok(Some(s)) = r.try_get::<Option<String>, _>(i) {
126 return Value::Text(s);
127 }
128 if let Ok(Some(n)) = r.try_get::<Option<i64>, _>(i) {
129 return Value::I64(n);
130 }
131 if let Ok(Some(f)) = r.try_get::<Option<f64>, _>(i) {
132 return Value::F64(f);
133 }
134 if let Ok(Some(b)) = r.try_get::<Option<Vec<u8>>, _>(i) {
135 return Value::Bytes(b);
136 }
137 Value::Null
138}
139
140impl RowRef for SqlxRowRef {
141 fn get_i32(&self, c: &str) -> Result<i32, RowError> {
142 match self
143 .values
144 .get(c)
145 .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
146 {
147 Value::I64(i) => i32::try_from(*i).map_err(|_| tc(c, "i64 overflow")),
148 Value::Null => Err(RowError::UnexpectedNull(c.into())),
149 _ => Err(tc(c, "not an integer")),
150 }
151 }
152 fn get_i32_opt(&self, c: &str) -> Result<Option<i32>, RowError> {
153 match self.values.get(c) {
154 None => Err(RowError::ColumnNotFound(c.into())),
155 Some(Value::Null) => Ok(None),
156 Some(Value::I64(i)) => i32::try_from(*i)
157 .map(Some)
158 .map_err(|_| tc(c, "i64 overflow")),
159 Some(_) => Err(tc(c, "not an integer")),
160 }
161 }
162 fn get_i64(&self, c: &str) -> Result<i64, RowError> {
163 match self
164 .values
165 .get(c)
166 .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
167 {
168 Value::I64(i) => Ok(*i),
169 Value::Null => Err(RowError::UnexpectedNull(c.into())),
170 _ => Err(tc(c, "not an integer")),
171 }
172 }
173 fn get_i64_opt(&self, c: &str) -> Result<Option<i64>, RowError> {
174 match self.values.get(c) {
175 None => Err(RowError::ColumnNotFound(c.into())),
176 Some(Value::Null) => Ok(None),
177 Some(Value::I64(i)) => Ok(Some(*i)),
178 Some(_) => Err(tc(c, "not an integer")),
179 }
180 }
181 fn get_f64(&self, c: &str) -> Result<f64, RowError> {
182 match self
183 .values
184 .get(c)
185 .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
186 {
187 Value::F64(f) => Ok(*f),
188 Value::I64(i) => Ok(*i as f64),
189 Value::Null => Err(RowError::UnexpectedNull(c.into())),
190 _ => Err(tc(c, "not a number")),
191 }
192 }
193 fn get_f64_opt(&self, c: &str) -> Result<Option<f64>, RowError> {
194 match self.values.get(c) {
195 None => Err(RowError::ColumnNotFound(c.into())),
196 Some(Value::Null) => Ok(None),
197 Some(Value::F64(f)) => Ok(Some(*f)),
198 Some(Value::I64(i)) => Ok(Some(*i as f64)),
199 Some(_) => Err(tc(c, "not a number")),
200 }
201 }
202 fn get_bool(&self, c: &str) -> Result<bool, RowError> {
203 match self
204 .values
205 .get(c)
206 .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
207 {
208 Value::Bool(b) => Ok(*b),
209 Value::I64(i) => Ok(*i != 0),
210 Value::Null => Err(RowError::UnexpectedNull(c.into())),
211 _ => Err(tc(c, "not a boolean")),
212 }
213 }
214 fn get_bool_opt(&self, c: &str) -> Result<Option<bool>, RowError> {
215 match self.values.get(c) {
216 None => Err(RowError::ColumnNotFound(c.into())),
217 Some(Value::Null) => Ok(None),
218 Some(Value::Bool(b)) => Ok(Some(*b)),
219 Some(Value::I64(i)) => Ok(Some(*i != 0)),
220 Some(_) => Err(tc(c, "not a boolean")),
221 }
222 }
223 fn get_str(&self, c: &str) -> Result<&str, RowError> {
224 match self
225 .values
226 .get(c)
227 .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
228 {
229 Value::Text(s) => Ok(s.as_str()),
230 Value::Null => Err(RowError::UnexpectedNull(c.into())),
231 _ => Err(tc(c, "not text")),
232 }
233 }
234 fn get_str_opt(&self, c: &str) -> Result<Option<&str>, RowError> {
235 match self.values.get(c) {
236 None => Err(RowError::ColumnNotFound(c.into())),
237 Some(Value::Null) => Ok(None),
238 Some(Value::Text(s)) => Ok(Some(s.as_str())),
239 Some(_) => Err(tc(c, "not text")),
240 }
241 }
242 fn get_bytes(&self, c: &str) -> Result<&[u8], RowError> {
243 match self
244 .values
245 .get(c)
246 .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
247 {
248 Value::Bytes(b) => Ok(b.as_slice()),
249 Value::Text(s) => Ok(s.as_bytes()),
250 Value::Null => Err(RowError::UnexpectedNull(c.into())),
251 _ => Err(tc(c, "not bytes")),
252 }
253 }
254 fn get_bytes_opt(&self, c: &str) -> Result<Option<&[u8]>, RowError> {
255 match self.values.get(c) {
256 None => Err(RowError::ColumnNotFound(c.into())),
257 Some(Value::Null) => Ok(None),
258 Some(Value::Bytes(b)) => Ok(Some(b.as_slice())),
259 Some(Value::Text(s)) => Ok(Some(s.as_bytes())),
260 Some(_) => Err(tc(c, "not bytes")),
261 }
262 }
263 fn get_datetime_utc(&self, c: &str) -> Result<chrono::DateTime<chrono::Utc>, RowError> {
264 let s = self.get_str(c)?;
265 chrono::DateTime::parse_from_rfc3339(s)
266 .map(|d| d.with_timezone(&chrono::Utc))
267 .map_err(|e| tc(c, e.to_string()))
268 }
269 fn get_datetime_utc_opt(
270 &self,
271 c: &str,
272 ) -> Result<Option<chrono::DateTime<chrono::Utc>>, RowError> {
273 match self.get_str_opt(c)? {
274 None => Ok(None),
275 Some(s) => chrono::DateTime::parse_from_rfc3339(s)
276 .map(|d| Some(d.with_timezone(&chrono::Utc)))
277 .map_err(|e| tc(c, e.to_string())),
278 }
279 }
280 fn get_uuid(&self, c: &str) -> Result<uuid::Uuid, RowError> {
281 uuid::Uuid::parse_str(self.get_str(c)?).map_err(|e| tc(c, e.to_string()))
282 }
283 fn get_uuid_opt(&self, c: &str) -> Result<Option<uuid::Uuid>, RowError> {
284 match self.get_str_opt(c)? {
285 None => Ok(None),
286 Some(s) => uuid::Uuid::parse_str(s)
287 .map(Some)
288 .map_err(|e| tc(c, e.to_string())),
289 }
290 }
291 fn get_json(&self, c: &str) -> Result<serde_json::Value, RowError> {
292 serde_json::from_str(self.get_str(c)?).map_err(|e| tc(c, e.to_string()))
293 }
294 fn get_json_opt(&self, c: &str) -> Result<Option<serde_json::Value>, RowError> {
295 match self.get_str_opt(c)? {
296 None => Ok(None),
297 Some(s) => serde_json::from_str(s)
298 .map(Some)
299 .map_err(|e| tc(c, e.to_string())),
300 }
301 }
302}