Skip to main content

prax_sqlx/
row_ref.rs

1//! Bridge between SqlxRow and prax_query::row::RowRef.
2//!
3//! Decodes each column to a string-keyed snapshot so the prax-query
4//! `FromRow` pipeline works uniformly across SQLx's three backends
5//! (Postgres, MySQL, SQLite). Strings are materialized eagerly so
6//! `get_str` can hand back a borrowed slice.
7
8use std::collections::HashMap;
9
10use prax_query::row::{RowError, RowRef};
11use sqlx::{Column, Row};
12
13use crate::row::SqlxRow;
14
15enum Value {
16    Null,
17    Bool(bool),
18    I64(i64),
19    F64(f64),
20    Text(String),
21    Bytes(Vec<u8>),
22}
23
24/// A driver-agnostic decoded row produced by the sqlx engine. Holds owned
25/// values keyed by column name so callers can access them after the row
26/// itself has been dropped.
27pub struct SqlxRowRef {
28    values: HashMap<String, Value>,
29}
30
31impl SqlxRowRef {
32    /// Decode a raw sqlx row into a driver-agnostic [`SqlxRowRef`].
33    pub fn from_sqlx(row: &SqlxRow) -> Result<Self, RowError> {
34        let mut values = HashMap::new();
35        match row {
36            #[cfg(feature = "postgres")]
37            SqlxRow::Postgres(r) => {
38                for (i, col) in r.columns().iter().enumerate() {
39                    let name = col.name().to_string();
40                    let v = decode_pg_cell(r, i);
41                    values.insert(name, v);
42                }
43            }
44            #[cfg(feature = "mysql")]
45            SqlxRow::MySql(r) => {
46                for (i, col) in r.columns().iter().enumerate() {
47                    let name = col.name().to_string();
48                    let v = decode_generic_cell_mysql(r, i);
49                    values.insert(name, v);
50                }
51            }
52            #[cfg(feature = "sqlite")]
53            SqlxRow::Sqlite(r) => {
54                for (i, col) in r.columns().iter().enumerate() {
55                    let name = col.name().to_string();
56                    let v = decode_generic_cell_sqlite(r, i);
57                    values.insert(name, v);
58                }
59            }
60        }
61        Ok(Self { values })
62    }
63}
64
65fn tc(column: &str, msg: impl Into<String>) -> RowError {
66    RowError::TypeConversion {
67        column: column.into(),
68        message: msg.into(),
69    }
70}
71
72/// Probe a Postgres cell in width order (text → bool → i64 → i32 → f64
73/// → bytes), falling back to Null for everything we don't recognise.
74#[cfg(feature = "postgres")]
75fn decode_pg_cell(r: &sqlx::postgres::PgRow, i: usize) -> Value {
76    if let Ok(Some(s)) = r.try_get::<Option<String>, _>(i) {
77        return Value::Text(s);
78    }
79    if let Ok(Some(b)) = r.try_get::<Option<bool>, _>(i) {
80        return Value::Bool(b);
81    }
82    if let Ok(Some(n)) = r.try_get::<Option<i64>, _>(i) {
83        return Value::I64(n);
84    }
85    if let Ok(Some(n)) = r.try_get::<Option<i32>, _>(i) {
86        return Value::I64(n as i64);
87    }
88    if let Ok(Some(n)) = r.try_get::<Option<i16>, _>(i) {
89        return Value::I64(n as i64);
90    }
91    if let Ok(Some(f)) = r.try_get::<Option<f64>, _>(i) {
92        return Value::F64(f);
93    }
94    if let Ok(Some(f)) = r.try_get::<Option<f32>, _>(i) {
95        return Value::F64(f as f64);
96    }
97    if let Ok(Some(b)) = r.try_get::<Option<Vec<u8>>, _>(i) {
98        return Value::Bytes(b);
99    }
100    Value::Null
101}
102
103#[cfg(feature = "mysql")]
104fn decode_generic_cell_mysql(r: &sqlx::mysql::MySqlRow, i: usize) -> Value {
105    if let Ok(Some(s)) = r.try_get::<Option<String>, _>(i) {
106        return Value::Text(s);
107    }
108    if let Ok(Some(b)) = r.try_get::<Option<bool>, _>(i) {
109        return Value::Bool(b);
110    }
111    if let Ok(Some(n)) = r.try_get::<Option<i64>, _>(i) {
112        return Value::I64(n);
113    }
114    if let Ok(Some(f)) = r.try_get::<Option<f64>, _>(i) {
115        return Value::F64(f);
116    }
117    if let Ok(Some(b)) = r.try_get::<Option<Vec<u8>>, _>(i) {
118        return Value::Bytes(b);
119    }
120    Value::Null
121}
122
123#[cfg(feature = "sqlite")]
124fn decode_generic_cell_sqlite(r: &sqlx::sqlite::SqliteRow, i: usize) -> Value {
125    if let Ok(Some(s)) = r.try_get::<Option<String>, _>(i) {
126        return Value::Text(s);
127    }
128    if let Ok(Some(n)) = r.try_get::<Option<i64>, _>(i) {
129        return Value::I64(n);
130    }
131    if let Ok(Some(f)) = r.try_get::<Option<f64>, _>(i) {
132        return Value::F64(f);
133    }
134    if let Ok(Some(b)) = r.try_get::<Option<Vec<u8>>, _>(i) {
135        return Value::Bytes(b);
136    }
137    Value::Null
138}
139
140impl RowRef for SqlxRowRef {
141    fn get_i32(&self, c: &str) -> Result<i32, RowError> {
142        match self
143            .values
144            .get(c)
145            .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
146        {
147            Value::I64(i) => i32::try_from(*i).map_err(|_| tc(c, "i64 overflow")),
148            Value::Null => Err(RowError::UnexpectedNull(c.into())),
149            _ => Err(tc(c, "not an integer")),
150        }
151    }
152    fn get_i32_opt(&self, c: &str) -> Result<Option<i32>, RowError> {
153        match self.values.get(c) {
154            None => Err(RowError::ColumnNotFound(c.into())),
155            Some(Value::Null) => Ok(None),
156            Some(Value::I64(i)) => i32::try_from(*i)
157                .map(Some)
158                .map_err(|_| tc(c, "i64 overflow")),
159            Some(_) => Err(tc(c, "not an integer")),
160        }
161    }
162    fn get_i64(&self, c: &str) -> Result<i64, RowError> {
163        match self
164            .values
165            .get(c)
166            .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
167        {
168            Value::I64(i) => Ok(*i),
169            Value::Null => Err(RowError::UnexpectedNull(c.into())),
170            _ => Err(tc(c, "not an integer")),
171        }
172    }
173    fn get_i64_opt(&self, c: &str) -> Result<Option<i64>, RowError> {
174        match self.values.get(c) {
175            None => Err(RowError::ColumnNotFound(c.into())),
176            Some(Value::Null) => Ok(None),
177            Some(Value::I64(i)) => Ok(Some(*i)),
178            Some(_) => Err(tc(c, "not an integer")),
179        }
180    }
181    fn get_f64(&self, c: &str) -> Result<f64, RowError> {
182        match self
183            .values
184            .get(c)
185            .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
186        {
187            Value::F64(f) => Ok(*f),
188            Value::I64(i) => Ok(*i as f64),
189            Value::Null => Err(RowError::UnexpectedNull(c.into())),
190            _ => Err(tc(c, "not a number")),
191        }
192    }
193    fn get_f64_opt(&self, c: &str) -> Result<Option<f64>, RowError> {
194        match self.values.get(c) {
195            None => Err(RowError::ColumnNotFound(c.into())),
196            Some(Value::Null) => Ok(None),
197            Some(Value::F64(f)) => Ok(Some(*f)),
198            Some(Value::I64(i)) => Ok(Some(*i as f64)),
199            Some(_) => Err(tc(c, "not a number")),
200        }
201    }
202    fn get_bool(&self, c: &str) -> Result<bool, RowError> {
203        match self
204            .values
205            .get(c)
206            .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
207        {
208            Value::Bool(b) => Ok(*b),
209            Value::I64(i) => Ok(*i != 0),
210            Value::Null => Err(RowError::UnexpectedNull(c.into())),
211            _ => Err(tc(c, "not a boolean")),
212        }
213    }
214    fn get_bool_opt(&self, c: &str) -> Result<Option<bool>, RowError> {
215        match self.values.get(c) {
216            None => Err(RowError::ColumnNotFound(c.into())),
217            Some(Value::Null) => Ok(None),
218            Some(Value::Bool(b)) => Ok(Some(*b)),
219            Some(Value::I64(i)) => Ok(Some(*i != 0)),
220            Some(_) => Err(tc(c, "not a boolean")),
221        }
222    }
223    fn get_str(&self, c: &str) -> Result<&str, RowError> {
224        match self
225            .values
226            .get(c)
227            .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
228        {
229            Value::Text(s) => Ok(s.as_str()),
230            Value::Null => Err(RowError::UnexpectedNull(c.into())),
231            _ => Err(tc(c, "not text")),
232        }
233    }
234    fn get_str_opt(&self, c: &str) -> Result<Option<&str>, RowError> {
235        match self.values.get(c) {
236            None => Err(RowError::ColumnNotFound(c.into())),
237            Some(Value::Null) => Ok(None),
238            Some(Value::Text(s)) => Ok(Some(s.as_str())),
239            Some(_) => Err(tc(c, "not text")),
240        }
241    }
242    fn get_bytes(&self, c: &str) -> Result<&[u8], RowError> {
243        match self
244            .values
245            .get(c)
246            .ok_or_else(|| RowError::ColumnNotFound(c.into()))?
247        {
248            Value::Bytes(b) => Ok(b.as_slice()),
249            Value::Text(s) => Ok(s.as_bytes()),
250            Value::Null => Err(RowError::UnexpectedNull(c.into())),
251            _ => Err(tc(c, "not bytes")),
252        }
253    }
254    fn get_bytes_opt(&self, c: &str) -> Result<Option<&[u8]>, RowError> {
255        match self.values.get(c) {
256            None => Err(RowError::ColumnNotFound(c.into())),
257            Some(Value::Null) => Ok(None),
258            Some(Value::Bytes(b)) => Ok(Some(b.as_slice())),
259            Some(Value::Text(s)) => Ok(Some(s.as_bytes())),
260            Some(_) => Err(tc(c, "not bytes")),
261        }
262    }
263    fn get_datetime_utc(&self, c: &str) -> Result<chrono::DateTime<chrono::Utc>, RowError> {
264        let s = self.get_str(c)?;
265        chrono::DateTime::parse_from_rfc3339(s)
266            .map(|d| d.with_timezone(&chrono::Utc))
267            .map_err(|e| tc(c, e.to_string()))
268    }
269    fn get_datetime_utc_opt(
270        &self,
271        c: &str,
272    ) -> Result<Option<chrono::DateTime<chrono::Utc>>, RowError> {
273        match self.get_str_opt(c)? {
274            None => Ok(None),
275            Some(s) => chrono::DateTime::parse_from_rfc3339(s)
276                .map(|d| Some(d.with_timezone(&chrono::Utc)))
277                .map_err(|e| tc(c, e.to_string())),
278        }
279    }
280    fn get_uuid(&self, c: &str) -> Result<uuid::Uuid, RowError> {
281        uuid::Uuid::parse_str(self.get_str(c)?).map_err(|e| tc(c, e.to_string()))
282    }
283    fn get_uuid_opt(&self, c: &str) -> Result<Option<uuid::Uuid>, RowError> {
284        match self.get_str_opt(c)? {
285            None => Ok(None),
286            Some(s) => uuid::Uuid::parse_str(s)
287                .map(Some)
288                .map_err(|e| tc(c, e.to_string())),
289        }
290    }
291    fn get_json(&self, c: &str) -> Result<serde_json::Value, RowError> {
292        serde_json::from_str(self.get_str(c)?).map_err(|e| tc(c, e.to_string()))
293    }
294    fn get_json_opt(&self, c: &str) -> Result<Option<serde_json::Value>, RowError> {
295        match self.get_str_opt(c)? {
296            None => Ok(None),
297            Some(s) => serde_json::from_str(s)
298                .map(Some)
299                .map_err(|e| tc(c, e.to_string())),
300        }
301    }
302}