Skip to main content

prax_sqlx/
row_ref.rs

1//! Bridge between SqlxRow and prax_query::row::RowRef.
2//!
3//! Decodes each column to a string-keyed snapshot so the prax-query
4//! `FromRow` pipeline works uniformly across SQLx's three backends
5//! (Postgres, MySQL, SQLite). Strings are materialized eagerly so
6//! `get_str` can hand back a borrowed slice.
7
8use std::collections::HashMap;
9
10use prax_query::row::{RowError, RowRef};
11use sqlx::{Column, Row};
12
13use crate::row::SqlxRow;
14
15enum Value {
16    Null,
17    Bool(bool),
18    I64(i64),
19    F64(f64),
20    Text(String),
21    Bytes(Vec<u8>),
22}
23
24pub struct SqlxRowRef {
25    values: HashMap<String, Value>,
26}
27
28impl SqlxRowRef {
29    pub fn from_sqlx(row: &SqlxRow) -> Result<Self, RowError> {
30        let mut values = HashMap::new();
31        match row {
32            #[cfg(feature = "postgres")]
33            SqlxRow::Postgres(r) => {
34                for (i, col) in r.columns().iter().enumerate() {
35                    let name = col.name().to_string();
36                    let v = decode_pg_cell(r, i);
37                    values.insert(name, v);
38                }
39            }
40            #[cfg(feature = "mysql")]
41            SqlxRow::MySql(r) => {
42                for (i, col) in r.columns().iter().enumerate() {
43                    let name = col.name().to_string();
44                    let v = decode_generic_cell_mysql(r, i);
45                    values.insert(name, v);
46                }
47            }
48            #[cfg(feature = "sqlite")]
49            SqlxRow::Sqlite(r) => {
50                for (i, col) in r.columns().iter().enumerate() {
51                    let name = col.name().to_string();
52                    let v = decode_generic_cell_sqlite(r, i);
53                    values.insert(name, v);
54                }
55            }
56        }
57        Ok(Self { values })
58    }
59}
60
61fn tc(column: &str, msg: impl Into<String>) -> RowError {
62    RowError::TypeConversion {
63        column: column.into(),
64        message: msg.into(),
65    }
66}
67
68/// Probe a Postgres cell in width order (text → bool → i64 → i32 → f64
69/// → bytes), falling back to Null for everything we don't recognise.
70#[cfg(feature = "postgres")]
71fn decode_pg_cell(r: &sqlx::postgres::PgRow, i: usize) -> Value {
72    if let Ok(Some(s)) = r.try_get::<Option<String>, _>(i) {
73        return Value::Text(s);
74    }
75    if let Ok(Some(b)) = r.try_get::<Option<bool>, _>(i) {
76        return Value::Bool(b);
77    }
78    if let Ok(Some(n)) = r.try_get::<Option<i64>, _>(i) {
79        return Value::I64(n);
80    }
81    if let Ok(Some(n)) = r.try_get::<Option<i32>, _>(i) {
82        return Value::I64(n as i64);
83    }
84    if let Ok(Some(n)) = r.try_get::<Option<i16>, _>(i) {
85        return Value::I64(n as i64);
86    }
87    if let Ok(Some(f)) = r.try_get::<Option<f64>, _>(i) {
88        return Value::F64(f);
89    }
90    if let Ok(Some(f)) = r.try_get::<Option<f32>, _>(i) {
91        return Value::F64(f as f64);
92    }
93    if let Ok(Some(b)) = r.try_get::<Option<Vec<u8>>, _>(i) {
94        return Value::Bytes(b);
95    }
96    Value::Null
97}
98
99#[cfg(feature = "mysql")]
100fn decode_generic_cell_mysql(r: &sqlx::mysql::MySqlRow, i: usize) -> Value {
101    if let Ok(Some(s)) = r.try_get::<Option<String>, _>(i) {
102        return Value::Text(s);
103    }
104    if let Ok(Some(b)) = r.try_get::<Option<bool>, _>(i) {
105        return Value::Bool(b);
106    }
107    if let Ok(Some(n)) = r.try_get::<Option<i64>, _>(i) {
108        return Value::I64(n);
109    }
110    if let Ok(Some(f)) = r.try_get::<Option<f64>, _>(i) {
111        return Value::F64(f);
112    }
113    if let Ok(Some(b)) = r.try_get::<Option<Vec<u8>>, _>(i) {
114        return Value::Bytes(b);
115    }
116    Value::Null
117}
118
119#[cfg(feature = "sqlite")]
120fn decode_generic_cell_sqlite(r: &sqlx::sqlite::SqliteRow, i: usize) -> Value {
121    if let Ok(Some(s)) = r.try_get::<Option<String>, _>(i) {
122        return Value::Text(s);
123    }
124    if let Ok(Some(n)) = r.try_get::<Option<i64>, _>(i) {
125        return Value::I64(n);
126    }
127    if let Ok(Some(f)) = r.try_get::<Option<f64>, _>(i) {
128        return Value::F64(f);
129    }
130    if let Ok(Some(b)) = r.try_get::<Option<Vec<u8>>, _>(i) {
131        return Value::Bytes(b);
132    }
133    Value::Null
134}
135
136impl RowRef for SqlxRowRef {
137    fn get_i32(&self, c: &str) -> Result<i32, RowError> {
138        match self
139            .values
140            .get(c)
141            .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
142        {
143            Value::I64(i) => i32::try_from(*i).map_err(|_| tc(c, "i64 overflow")),
144            Value::Null => Err(RowError::UnexpectedNull(c.into())),
145            _ => Err(tc(c, "not an integer")),
146        }
147    }
148    fn get_i32_opt(&self, c: &str) -> Result<Option<i32>, RowError> {
149        match self.values.get(c) {
150            None => Err(RowError::ColumnNotFound(c.into())),
151            Some(Value::Null) => Ok(None),
152            Some(Value::I64(i)) => i32::try_from(*i)
153                .map(Some)
154                .map_err(|_| tc(c, "i64 overflow")),
155            Some(_) => Err(tc(c, "not an integer")),
156        }
157    }
158    fn get_i64(&self, c: &str) -> Result<i64, RowError> {
159        match self
160            .values
161            .get(c)
162            .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
163        {
164            Value::I64(i) => Ok(*i),
165            Value::Null => Err(RowError::UnexpectedNull(c.into())),
166            _ => Err(tc(c, "not an integer")),
167        }
168    }
169    fn get_i64_opt(&self, c: &str) -> Result<Option<i64>, RowError> {
170        match self.values.get(c) {
171            None => Err(RowError::ColumnNotFound(c.into())),
172            Some(Value::Null) => Ok(None),
173            Some(Value::I64(i)) => Ok(Some(*i)),
174            Some(_) => Err(tc(c, "not an integer")),
175        }
176    }
177    fn get_f64(&self, c: &str) -> Result<f64, RowError> {
178        match self
179            .values
180            .get(c)
181            .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
182        {
183            Value::F64(f) => Ok(*f),
184            Value::I64(i) => Ok(*i as f64),
185            Value::Null => Err(RowError::UnexpectedNull(c.into())),
186            _ => Err(tc(c, "not a number")),
187        }
188    }
189    fn get_f64_opt(&self, c: &str) -> Result<Option<f64>, RowError> {
190        match self.values.get(c) {
191            None => Err(RowError::ColumnNotFound(c.into())),
192            Some(Value::Null) => Ok(None),
193            Some(Value::F64(f)) => Ok(Some(*f)),
194            Some(Value::I64(i)) => Ok(Some(*i as f64)),
195            Some(_) => Err(tc(c, "not a number")),
196        }
197    }
198    fn get_bool(&self, c: &str) -> Result<bool, RowError> {
199        match self
200            .values
201            .get(c)
202            .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
203        {
204            Value::Bool(b) => Ok(*b),
205            Value::I64(i) => Ok(*i != 0),
206            Value::Null => Err(RowError::UnexpectedNull(c.into())),
207            _ => Err(tc(c, "not a boolean")),
208        }
209    }
210    fn get_bool_opt(&self, c: &str) -> Result<Option<bool>, RowError> {
211        match self.values.get(c) {
212            None => Err(RowError::ColumnNotFound(c.into())),
213            Some(Value::Null) => Ok(None),
214            Some(Value::Bool(b)) => Ok(Some(*b)),
215            Some(Value::I64(i)) => Ok(Some(*i != 0)),
216            Some(_) => Err(tc(c, "not a boolean")),
217        }
218    }
219    fn get_str(&self, c: &str) -> Result<&str, RowError> {
220        match self
221            .values
222            .get(c)
223            .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
224        {
225            Value::Text(s) => Ok(s.as_str()),
226            Value::Null => Err(RowError::UnexpectedNull(c.into())),
227            _ => Err(tc(c, "not text")),
228        }
229    }
230    fn get_str_opt(&self, c: &str) -> Result<Option<&str>, RowError> {
231        match self.values.get(c) {
232            None => Err(RowError::ColumnNotFound(c.into())),
233            Some(Value::Null) => Ok(None),
234            Some(Value::Text(s)) => Ok(Some(s.as_str())),
235            Some(_) => Err(tc(c, "not text")),
236        }
237    }
238    fn get_bytes(&self, c: &str) -> Result<&[u8], RowError> {
239        match self
240            .values
241            .get(c)
242            .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
243        {
244            Value::Bytes(b) => Ok(b.as_slice()),
245            Value::Text(s) => Ok(s.as_bytes()),
246            Value::Null => Err(RowError::UnexpectedNull(c.into())),
247            _ => Err(tc(c, "not bytes")),
248        }
249    }
250    fn get_bytes_opt(&self, c: &str) -> Result<Option<&[u8]>, RowError> {
251        match self.values.get(c) {
252            None => Err(RowError::ColumnNotFound(c.into())),
253            Some(Value::Null) => Ok(None),
254            Some(Value::Bytes(b)) => Ok(Some(b.as_slice())),
255            Some(Value::Text(s)) => Ok(Some(s.as_bytes())),
256            Some(_) => Err(tc(c, "not bytes")),
257        }
258    }
259    fn get_datetime_utc(&self, c: &str) -> Result<chrono::DateTime<chrono::Utc>, RowError> {
260        let s = self.get_str(c)?;
261        chrono::DateTime::parse_from_rfc3339(s)
262            .map(|d| d.with_timezone(&chrono::Utc))
263            .map_err(|e| tc(c, e.to_string()))
264    }
265    fn get_datetime_utc_opt(
266        &self,
267        c: &str,
268    ) -> Result<Option<chrono::DateTime<chrono::Utc>>, RowError> {
269        match self.get_str_opt(c)? {
270            None => Ok(None),
271            Some(s) => chrono::DateTime::parse_from_rfc3339(s)
272                .map(|d| Some(d.with_timezone(&chrono::Utc)))
273                .map_err(|e| tc(c, e.to_string())),
274        }
275    }
276    fn get_uuid(&self, c: &str) -> Result<uuid::Uuid, RowError> {
277        uuid::Uuid::parse_str(self.get_str(c)?).map_err(|e| tc(c, e.to_string()))
278    }
279    fn get_uuid_opt(&self, c: &str) -> Result<Option<uuid::Uuid>, RowError> {
280        match self.get_str_opt(c)? {
281            None => Ok(None),
282            Some(s) => uuid::Uuid::parse_str(s)
283                .map(Some)
284                .map_err(|e| tc(c, e.to_string())),
285        }
286    }
287    fn get_json(&self, c: &str) -> Result<serde_json::Value, RowError> {
288        serde_json::from_str(self.get_str(c)?).map_err(|e| tc(c, e.to_string()))
289    }
290    fn get_json_opt(&self, c: &str) -> Result<Option<serde_json::Value>, RowError> {
291        match self.get_str_opt(c)? {
292            None => Ok(None),
293            Some(s) => serde_json::from_str(s)
294                .map(Some)
295                .map_err(|e| tc(c, e.to_string())),
296        }
297    }
298}