Skip to main content

omnia_postgres/
sql.rs

1use std::convert::TryFrom;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use anyhow::{Context, anyhow, bail};
6use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
7use deadpool_postgres::Object;
8use futures::future::FutureExt;
9use omnia_wasi_sql::{Connection, DataType, Field, FutureResult, Row, WasiSqlCtx};
10use tokio_postgres::row::Row as PgRow;
11
12use crate::Client;
13use crate::types::{Param, ParamRef, PgType};
14
15/// `wasi-sql` implementation backed by `deadpool-postgres` connection pools.
16impl WasiSqlCtx for Client {
17    fn open(&self, name: String) -> FutureResult<Arc<dyn Connection>> {
18        tracing::debug!("getting connection {name}");
19
20        let pool = match self.0.get(&name.to_ascii_uppercase()) {
21            Some(p) => p.clone(),
22            None => {
23                return futures::future::ready(Err(anyhow!("unknown postgres pool '{name}'")))
24                    .boxed();
25            }
26        };
27        async move {
28            let cnn = pool.get().await.context("issue getting connection")?;
29            Ok(Arc::new(PostgresConnection(Arc::new(cnn))) as Arc<dyn Connection>)
30        }
31        .boxed()
32    }
33}
34
35/// A pooled `PostgreSQL` connection implementing the `wasi-sql` `Connection` trait.
36#[derive(Debug)]
37pub struct PostgresConnection(Arc<Object>);
38
39impl Connection for PostgresConnection {
40    fn query(&self, query: String, params: Vec<DataType>) -> FutureResult<Vec<Row>> {
41        tracing::debug!("query: {query}, params: {params:?}");
42        let cnn = Arc::clone(&self.0);
43
44        async move {
45            let mut pg_params: Vec<Param> = Vec::new();
46            for p in &params {
47                pg_params.push(into_param(p)?);
48            }
49            let param_refs: Vec<ParamRef> =
50                pg_params.iter().map(|b| b.as_ref() as ParamRef).collect();
51
52            let pg_rows = cnn
53                .query(&query, &param_refs)
54                .await
55                .inspect_err(|e| {
56                    dbg!(e);
57                })
58                .context("query failed")?;
59            tracing::debug!("query returned {} rows", pg_rows.len());
60
61            let mut wasi_rows = Vec::new();
62            for (idx, r) in pg_rows.iter().enumerate() {
63                let row = match into_wasi_row(r, idx) {
64                    Ok(row) => row,
65                    Err(e) => {
66                        tracing::error!("failed to convert row: {e:?}");
67                        return Err(anyhow!("failed to convert row: {e:?}"));
68                    }
69                };
70                wasi_rows.push(row);
71            }
72
73            Ok(wasi_rows)
74        }
75        .boxed()
76    }
77
78    fn exec(&self, query: String, params: Vec<DataType>) -> FutureResult<u32> {
79        tracing::debug!("exec: {query}, params: {params:?}");
80        let cnn = Arc::clone(&self.0);
81
82        async move {
83            let mut pg_params: Vec<Param> = Vec::new();
84            for p in &params {
85                pg_params.push(into_param(p)?);
86            }
87            let param_refs: Vec<ParamRef> =
88                pg_params.iter().map(|b| b.as_ref() as ParamRef).collect();
89
90            let affected = match cnn.execute(&query, &param_refs).await {
91                Ok(count) => count,
92                Err(e) => {
93                    tracing::error!("exec failed: {e}");
94                    return Err(anyhow!("exec failed: {e}"));
95                }
96            };
97            #[allow(clippy::cast_possible_truncation)]
98            Ok(affected as u32)
99        }
100        .boxed()
101    }
102}
103
104fn parse_date(source: Option<&str>) -> anyhow::Result<Option<NaiveDate>> {
105    source.map(|s| NaiveDate::from_str(s).context("invalid date format")).transpose()
106}
107
108fn parse_time(source: Option<&str>) -> anyhow::Result<Option<NaiveTime>> {
109    source.map(|s| NaiveTime::from_str(s).context("invalid time format")).transpose()
110}
111
112fn parse_timestamp_naive(source: Option<&str>) -> anyhow::Result<Option<NaiveDateTime>> {
113    source
114        .map(|s| {
115            NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f")
116                .context("invalid naive timestamp format")
117        })
118        .transpose()
119}
120
121#[cfg(test)]
122fn parse_timestamp_tz(source: Option<&str>) -> anyhow::Result<Option<DateTime<Utc>>> {
123    source
124        .map(|s| {
125            DateTime::parse_from_rfc3339(s)
126                .map(|ts| ts.with_timezone(&Utc))
127                .context("invalid RFC3339 timestamp")
128        })
129        .transpose()
130}
131
132fn into_param(value: &DataType) -> anyhow::Result<Param> {
133    let pg_value = match value {
134        DataType::Int32(v) => PgType::Int32(*v),
135        DataType::Int64(v) => PgType::Int64(*v),
136        DataType::Uint32(v) => PgType::Uint32(*v),
137        DataType::Uint64(v) => {
138            // Postgres doesn't support u64, so clamping it to i64.
139            let converted = match v {
140                Some(raw) => {
141                    let clamped = i64::try_from(*raw).map_err(|err| {
142                        anyhow!("uint64 value {raw} exceeds i64::MAX and cannot be stored: {err}")
143                    })?;
144                    Some(clamped)
145                }
146                None => None,
147            };
148            PgType::Int64(converted)
149        }
150        DataType::Float(v) => PgType::Float(*v),
151        DataType::Double(v) => PgType::Double(*v),
152        DataType::Str(v) => PgType::Text(v.clone()),
153        DataType::Boolean(v) => PgType::Bool(*v),
154        DataType::Date(v) => PgType::Date(parse_date(v.as_deref())?),
155        DataType::Time(v) => PgType::Time(parse_time(v.as_deref())?),
156        DataType::Timestamp(v) => {
157            // Try RFC3339 format first (with timezone)
158            if let Some(s) = v.as_deref() {
159                if let Ok(ts) = DateTime::parse_from_rfc3339(s) {
160                    PgType::TimestampTz(Some(ts.with_timezone(&Utc)))
161                } else {
162                    // Fall back to naive timestamp format
163                    PgType::Timestamp(parse_timestamp_naive(v.as_deref())?)
164                }
165            } else {
166                PgType::Timestamp(None)
167            }
168        }
169        DataType::Binary(v) => PgType::Binary(v.clone()),
170    };
171
172    Ok(Box::new(pg_value) as Param)
173}
174
175/// Converts a ``PostgreSQL`` row to WASI SQL format.
176///
177/// # Testing
178/// This function will have to tested via integration tests with a real database
179/// due to the difficulty of mocking `tokio_postgres::Row`.
180fn into_wasi_row(pg_row: &PgRow, idx: usize) -> anyhow::Result<Row> {
181    let mut fields = Vec::new();
182    for (i, col) in pg_row.columns().iter().enumerate() {
183        let name = col.name().to_string();
184        tracing::debug!("attempting to convert column '{name}' with type '{:?}'", col.type_());
185        tracing::debug!("column type name: {}", col.type_().name());
186        let value = match col.type_().name() {
187            "int4" => {
188                let v: Option<i32> = pg_row.try_get(i)?;
189                DataType::Int32(v)
190            }
191            "int8" => {
192                let v: Option<i64> = pg_row.try_get(i)?;
193                DataType::Int64(v)
194            }
195            "oid" => {
196                let v: Option<u32> = pg_row.try_get(i)?;
197                DataType::Uint32(v)
198            }
199            "float4" => {
200                let v: Option<f32> = pg_row.try_get(i)?;
201                DataType::Float(v)
202            }
203            "float8" => {
204                let v: Option<f64> = pg_row.try_get(i)?;
205                DataType::Double(v)
206            }
207            "text" | "varchar" | "name" | "_char" => {
208                let v: Option<String> = pg_row.try_get(i)?;
209                DataType::Str(v)
210            }
211            "bool" => {
212                let v: Option<bool> = pg_row.try_get(i)?;
213                DataType::Boolean(v)
214            }
215            "date" => {
216                let v: Option<NaiveDate> = pg_row.try_get(i)?;
217                let formatted = v.map(|date| date.to_string());
218                DataType::Date(formatted)
219            }
220            "time" => {
221                let v: Option<NaiveTime> = pg_row.try_get(i)?;
222                let formatted = v.map(|time| time.to_string());
223                DataType::Time(formatted)
224            }
225            "timestamp" => {
226                let v: Option<NaiveDateTime> = pg_row.try_get(i)?;
227                let formatted = v.map(|dt| dt.to_string());
228                DataType::Timestamp(formatted)
229            }
230            "timestamptz" => {
231                let v: Option<DateTime<Utc>> = pg_row.try_get(i)?;
232                let formatted = v.map(|dtz| dtz.to_rfc3339());
233                DataType::Timestamp(formatted)
234            }
235            "json" | "jsonb" => {
236                let v: Option<tokio_postgres::types::Json<serde_json::Value>> =
237                    pg_row.try_get(i)?;
238                let as_str = v.map(|json| json.0.to_string());
239                DataType::Str(as_str)
240            }
241            "bytea" => {
242                let v: Option<Vec<u8>> = pg_row.try_get(i)?;
243                DataType::Binary(v)
244            }
245            other => {
246                bail!("unsupported column type: {other}");
247            }
248        };
249        tracing::debug!("converted column '{name}' to value '{:?}'", value);
250        fields.push(Field { name, value });
251    }
252
253    Ok(Row {
254        index: idx.to_string(),
255        fields,
256    })
257}
258
259#[cfg(test)]
260mod tests {
261    use chrono::{Datelike, Timelike};
262
263    use super::*;
264
265    #[test]
266    fn parse_date_valid() {
267        let valid = parse_date(Some("2024-12-25")).unwrap().unwrap();
268        assert_eq!(valid.year(), 2024);
269        assert_eq!(valid.month(), 12);
270        assert_eq!(valid.day(), 25);
271
272        assert!(parse_date(None).unwrap().is_none());
273        parse_date(Some("invalid-date")).unwrap_err();
274    }
275
276    #[test]
277    fn parse_time_valid() {
278        let valid = parse_time(Some("14:30:45.123456")).unwrap().unwrap();
279        assert_eq!(valid.hour(), 14);
280        assert_eq!(valid.minute(), 30);
281        assert_eq!(valid.second(), 45);
282
283        assert!(parse_time(None).unwrap().is_none());
284        parse_time(Some("25:00:00")).unwrap_err();
285    }
286
287    #[test]
288    fn parse_timestamp_tz_valid() {
289        let valid = parse_timestamp_tz(Some("2024-01-20T15:30:45Z")).unwrap().unwrap();
290        assert_eq!(valid.year(), 2024);
291        assert_eq!(valid.month(), 1);
292        assert_eq!(valid.day(), 20);
293
294        let with_offset = parse_timestamp_tz(Some("2024-01-20T15:30:45+05:00")).unwrap().unwrap();
295        assert_eq!(with_offset.hour(), 10);
296
297        assert!(parse_timestamp_tz(None).unwrap().is_none());
298        parse_timestamp_tz(Some("not a timestamp")).unwrap_err();
299    }
300
301    #[test]
302    fn parse_timestamp_naive_valid() {
303        let valid = parse_timestamp_naive(Some("2024-01-20 15:30:45.123")).unwrap().unwrap();
304        assert_eq!(valid.year(), 2024);
305        assert_eq!(valid.month(), 1);
306        assert_eq!(valid.day(), 20);
307        assert_eq!(valid.hour(), 15);
308
309        assert!(parse_timestamp_naive(None).unwrap().is_none());
310        parse_timestamp_naive(Some("2024/01/20 15:30:45")).unwrap_err();
311    }
312
313    #[test]
314    fn into_param_valid_conversions() {
315        // Basic types
316        into_param(&DataType::Int32(Some(42))).unwrap();
317        into_param(&DataType::Int64(Some(i64::MAX))).unwrap();
318        into_param(&DataType::Uint32(Some(u32::MAX))).unwrap();
319        into_param(&DataType::Uint64(Some(100))).unwrap(); // within i64 range
320        into_param(&DataType::Float(Some(std::f32::consts::PI))).unwrap();
321        into_param(&DataType::Double(Some(std::f64::consts::E))).unwrap();
322        into_param(&DataType::Str(Some("test".to_string()))).unwrap();
323        into_param(&DataType::Boolean(Some(true))).unwrap();
324        into_param(&DataType::Binary(Some(vec![0x01, 0x02]))).unwrap();
325
326        // None values
327        into_param(&DataType::Int32(None)).unwrap();
328        into_param(&DataType::Str(None)).unwrap();
329    }
330
331    #[test]
332    fn into_param_invalid_conversions() {
333        let uint64_overflow = DataType::Uint64(Some(u64::MAX));
334        let result = into_param(&uint64_overflow);
335        assert!(result.is_err());
336        assert!(result.unwrap_err().to_string().contains("exceeds i64::MAX"));
337
338        into_param(&DataType::Date(Some("invalid".to_string()))).unwrap_err();
339        into_param(&DataType::Time(Some("25:00:00".to_string()))).unwrap_err();
340        into_param(&DataType::Timestamp(Some("not a timestamp".to_string()))).unwrap_err();
341    }
342
343    #[test]
344    fn into_param_timestamp_format_detection() {
345        let with_tz = DataType::Timestamp(Some("2024-01-20T15:30:45Z".to_string()));
346        into_param(&with_tz).unwrap();
347
348        let naive = DataType::Timestamp(Some("2024-01-20 15:30:45".to_string()));
349        into_param(&naive).unwrap();
350    }
351}