1use std::convert::TryFrom;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use anyhow::{Context, anyhow, bail};
6use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
7use deadpool_postgres::Object;
8use futures::future::FutureExt;
9use omnia_wasi_sql::{Connection, DataType, Field, FutureResult, Row, WasiSqlCtx};
10use tokio_postgres::row::Row as PgRow;
11
12use crate::Client;
13use crate::types::{Param, ParamRef, PgType};
14
15impl WasiSqlCtx for Client {
17 fn open(&self, name: String) -> FutureResult<Arc<dyn Connection>> {
18 tracing::debug!("getting connection {name}");
19
20 let pool = match self.0.get(&name.to_ascii_uppercase()) {
21 Some(p) => p.clone(),
22 None => {
23 return futures::future::ready(Err(anyhow!("unknown postgres pool '{name}'")))
24 .boxed();
25 }
26 };
27 async move {
28 let cnn = pool.get().await.context("issue getting connection")?;
29 Ok(Arc::new(PostgresConnection(Arc::new(cnn))) as Arc<dyn Connection>)
30 }
31 .boxed()
32 }
33}
34
35#[derive(Debug)]
37pub struct PostgresConnection(Arc<Object>);
38
39impl Connection for PostgresConnection {
40 fn query(&self, query: String, params: Vec<DataType>) -> FutureResult<Vec<Row>> {
41 tracing::debug!("query: {query}, params: {params:?}");
42 let cnn = Arc::clone(&self.0);
43
44 async move {
45 let mut pg_params: Vec<Param> = Vec::new();
46 for p in ¶ms {
47 pg_params.push(into_param(p)?);
48 }
49 let param_refs: Vec<ParamRef> =
50 pg_params.iter().map(|b| b.as_ref() as ParamRef).collect();
51
52 let pg_rows = cnn
53 .query(&query, ¶m_refs)
54 .await
55 .inspect_err(|e| {
56 dbg!(e);
57 })
58 .context("query failed")?;
59 tracing::debug!("query returned {} rows", pg_rows.len());
60
61 let mut wasi_rows = Vec::new();
62 for (idx, r) in pg_rows.iter().enumerate() {
63 let row = match into_wasi_row(r, idx) {
64 Ok(row) => row,
65 Err(e) => {
66 tracing::error!("failed to convert row: {e:?}");
67 return Err(anyhow!("failed to convert row: {e:?}"));
68 }
69 };
70 wasi_rows.push(row);
71 }
72
73 Ok(wasi_rows)
74 }
75 .boxed()
76 }
77
78 fn exec(&self, query: String, params: Vec<DataType>) -> FutureResult<u32> {
79 tracing::debug!("exec: {query}, params: {params:?}");
80 let cnn = Arc::clone(&self.0);
81
82 async move {
83 let mut pg_params: Vec<Param> = Vec::new();
84 for p in ¶ms {
85 pg_params.push(into_param(p)?);
86 }
87 let param_refs: Vec<ParamRef> =
88 pg_params.iter().map(|b| b.as_ref() as ParamRef).collect();
89
90 let affected = match cnn.execute(&query, ¶m_refs).await {
91 Ok(count) => count,
92 Err(e) => {
93 tracing::error!("exec failed: {e}");
94 return Err(anyhow!("exec failed: {e}"));
95 }
96 };
97 #[allow(clippy::cast_possible_truncation)]
98 Ok(affected as u32)
99 }
100 .boxed()
101 }
102}
103
104fn parse_date(source: Option<&str>) -> anyhow::Result<Option<NaiveDate>> {
105 source.map(|s| NaiveDate::from_str(s).context("invalid date format")).transpose()
106}
107
108fn parse_time(source: Option<&str>) -> anyhow::Result<Option<NaiveTime>> {
109 source.map(|s| NaiveTime::from_str(s).context("invalid time format")).transpose()
110}
111
112fn parse_timestamp_naive(source: Option<&str>) -> anyhow::Result<Option<NaiveDateTime>> {
113 source
114 .map(|s| {
115 NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f")
116 .context("invalid naive timestamp format")
117 })
118 .transpose()
119}
120
121#[cfg(test)]
122fn parse_timestamp_tz(source: Option<&str>) -> anyhow::Result<Option<DateTime<Utc>>> {
123 source
124 .map(|s| {
125 DateTime::parse_from_rfc3339(s)
126 .map(|ts| ts.with_timezone(&Utc))
127 .context("invalid RFC3339 timestamp")
128 })
129 .transpose()
130}
131
132fn into_param(value: &DataType) -> anyhow::Result<Param> {
133 let pg_value = match value {
134 DataType::Int32(v) => PgType::Int32(*v),
135 DataType::Int64(v) => PgType::Int64(*v),
136 DataType::Uint32(v) => PgType::Uint32(*v),
137 DataType::Uint64(v) => {
138 let converted = match v {
140 Some(raw) => {
141 let clamped = i64::try_from(*raw).map_err(|err| {
142 anyhow!("uint64 value {raw} exceeds i64::MAX and cannot be stored: {err}")
143 })?;
144 Some(clamped)
145 }
146 None => None,
147 };
148 PgType::Int64(converted)
149 }
150 DataType::Float(v) => PgType::Float(*v),
151 DataType::Double(v) => PgType::Double(*v),
152 DataType::Str(v) => PgType::Text(v.clone()),
153 DataType::Boolean(v) => PgType::Bool(*v),
154 DataType::Date(v) => PgType::Date(parse_date(v.as_deref())?),
155 DataType::Time(v) => PgType::Time(parse_time(v.as_deref())?),
156 DataType::Timestamp(v) => {
157 if let Some(s) = v.as_deref() {
159 if let Ok(ts) = DateTime::parse_from_rfc3339(s) {
160 PgType::TimestampTz(Some(ts.with_timezone(&Utc)))
161 } else {
162 PgType::Timestamp(parse_timestamp_naive(v.as_deref())?)
164 }
165 } else {
166 PgType::Timestamp(None)
167 }
168 }
169 DataType::Binary(v) => PgType::Binary(v.clone()),
170 };
171
172 Ok(Box::new(pg_value) as Param)
173}
174
175fn into_wasi_row(pg_row: &PgRow, idx: usize) -> anyhow::Result<Row> {
181 let mut fields = Vec::new();
182 for (i, col) in pg_row.columns().iter().enumerate() {
183 let name = col.name().to_string();
184 tracing::debug!("attempting to convert column '{name}' with type '{:?}'", col.type_());
185 tracing::debug!("column type name: {}", col.type_().name());
186 let value = match col.type_().name() {
187 "int4" => {
188 let v: Option<i32> = pg_row.try_get(i)?;
189 DataType::Int32(v)
190 }
191 "int8" => {
192 let v: Option<i64> = pg_row.try_get(i)?;
193 DataType::Int64(v)
194 }
195 "oid" => {
196 let v: Option<u32> = pg_row.try_get(i)?;
197 DataType::Uint32(v)
198 }
199 "float4" => {
200 let v: Option<f32> = pg_row.try_get(i)?;
201 DataType::Float(v)
202 }
203 "float8" => {
204 let v: Option<f64> = pg_row.try_get(i)?;
205 DataType::Double(v)
206 }
207 "text" | "varchar" | "name" | "_char" => {
208 let v: Option<String> = pg_row.try_get(i)?;
209 DataType::Str(v)
210 }
211 "bool" => {
212 let v: Option<bool> = pg_row.try_get(i)?;
213 DataType::Boolean(v)
214 }
215 "date" => {
216 let v: Option<NaiveDate> = pg_row.try_get(i)?;
217 let formatted = v.map(|date| date.to_string());
218 DataType::Date(formatted)
219 }
220 "time" => {
221 let v: Option<NaiveTime> = pg_row.try_get(i)?;
222 let formatted = v.map(|time| time.to_string());
223 DataType::Time(formatted)
224 }
225 "timestamp" => {
226 let v: Option<NaiveDateTime> = pg_row.try_get(i)?;
227 let formatted = v.map(|dt| dt.to_string());
228 DataType::Timestamp(formatted)
229 }
230 "timestamptz" => {
231 let v: Option<DateTime<Utc>> = pg_row.try_get(i)?;
232 let formatted = v.map(|dtz| dtz.to_rfc3339());
233 DataType::Timestamp(formatted)
234 }
235 "json" | "jsonb" => {
236 let v: Option<tokio_postgres::types::Json<serde_json::Value>> =
237 pg_row.try_get(i)?;
238 let as_str = v.map(|json| json.0.to_string());
239 DataType::Str(as_str)
240 }
241 "bytea" => {
242 let v: Option<Vec<u8>> = pg_row.try_get(i)?;
243 DataType::Binary(v)
244 }
245 other => {
246 bail!("unsupported column type: {other}");
247 }
248 };
249 tracing::debug!("converted column '{name}' to value '{:?}'", value);
250 fields.push(Field { name, value });
251 }
252
253 Ok(Row {
254 index: idx.to_string(),
255 fields,
256 })
257}
258
259#[cfg(test)]
260mod tests {
261 use chrono::{Datelike, Timelike};
262
263 use super::*;
264
265 #[test]
266 fn parse_date_valid() {
267 let valid = parse_date(Some("2024-12-25")).unwrap().unwrap();
268 assert_eq!(valid.year(), 2024);
269 assert_eq!(valid.month(), 12);
270 assert_eq!(valid.day(), 25);
271
272 assert!(parse_date(None).unwrap().is_none());
273 parse_date(Some("invalid-date")).unwrap_err();
274 }
275
276 #[test]
277 fn parse_time_valid() {
278 let valid = parse_time(Some("14:30:45.123456")).unwrap().unwrap();
279 assert_eq!(valid.hour(), 14);
280 assert_eq!(valid.minute(), 30);
281 assert_eq!(valid.second(), 45);
282
283 assert!(parse_time(None).unwrap().is_none());
284 parse_time(Some("25:00:00")).unwrap_err();
285 }
286
287 #[test]
288 fn parse_timestamp_tz_valid() {
289 let valid = parse_timestamp_tz(Some("2024-01-20T15:30:45Z")).unwrap().unwrap();
290 assert_eq!(valid.year(), 2024);
291 assert_eq!(valid.month(), 1);
292 assert_eq!(valid.day(), 20);
293
294 let with_offset = parse_timestamp_tz(Some("2024-01-20T15:30:45+05:00")).unwrap().unwrap();
295 assert_eq!(with_offset.hour(), 10);
296
297 assert!(parse_timestamp_tz(None).unwrap().is_none());
298 parse_timestamp_tz(Some("not a timestamp")).unwrap_err();
299 }
300
301 #[test]
302 fn parse_timestamp_naive_valid() {
303 let valid = parse_timestamp_naive(Some("2024-01-20 15:30:45.123")).unwrap().unwrap();
304 assert_eq!(valid.year(), 2024);
305 assert_eq!(valid.month(), 1);
306 assert_eq!(valid.day(), 20);
307 assert_eq!(valid.hour(), 15);
308
309 assert!(parse_timestamp_naive(None).unwrap().is_none());
310 parse_timestamp_naive(Some("2024/01/20 15:30:45")).unwrap_err();
311 }
312
313 #[test]
314 fn into_param_valid_conversions() {
315 into_param(&DataType::Int32(Some(42))).unwrap();
317 into_param(&DataType::Int64(Some(i64::MAX))).unwrap();
318 into_param(&DataType::Uint32(Some(u32::MAX))).unwrap();
319 into_param(&DataType::Uint64(Some(100))).unwrap(); into_param(&DataType::Float(Some(std::f32::consts::PI))).unwrap();
321 into_param(&DataType::Double(Some(std::f64::consts::E))).unwrap();
322 into_param(&DataType::Str(Some("test".to_string()))).unwrap();
323 into_param(&DataType::Boolean(Some(true))).unwrap();
324 into_param(&DataType::Binary(Some(vec![0x01, 0x02]))).unwrap();
325
326 into_param(&DataType::Int32(None)).unwrap();
328 into_param(&DataType::Str(None)).unwrap();
329 }
330
331 #[test]
332 fn into_param_invalid_conversions() {
333 let uint64_overflow = DataType::Uint64(Some(u64::MAX));
334 let result = into_param(&uint64_overflow);
335 assert!(result.is_err());
336 assert!(result.unwrap_err().to_string().contains("exceeds i64::MAX"));
337
338 into_param(&DataType::Date(Some("invalid".to_string()))).unwrap_err();
339 into_param(&DataType::Time(Some("25:00:00".to_string()))).unwrap_err();
340 into_param(&DataType::Timestamp(Some("not a timestamp".to_string()))).unwrap_err();
341 }
342
343 #[test]
344 fn into_param_timestamp_format_detection() {
345 let with_tz = DataType::Timestamp(Some("2024-01-20T15:30:45Z".to_string()));
346 into_param(&with_tz).unwrap();
347
348 let naive = DataType::Timestamp(Some("2024-01-20 15:30:45".to_string()));
349 into_param(&naive).unwrap();
350 }
351}