Skip to main content

spin_sdk/
pg.rs

1//! Postgres relational database storage.
2//!
3//! You can use the [`into()`](std::convert::Into) method to convert
4//! a Rust value into a [`ParameterValue`]. You can use the
5//! [`Decode`] trait to convert a [`DbValue`] to a suitable Rust type.
6//! The following table shows available conversions.
7//!
8//! # Types
9//!
10//! | Rust type               | WIT (db-value)                                | Postgres type(s)             |
11//! |-------------------------|-----------------------------------------------|----------------------------- |
12//! | `bool`                  | boolean(bool)                                 | BOOL                         |
13//! | `i16`                   | int16(s16)                                    | SMALLINT, SMALLSERIAL, INT2  |
14//! | `i32`                   | int32(s32)                                    | INT, SERIAL, INT4            |
15//! | `i64`                   | int64(s64)                                    | BIGINT, BIGSERIAL, INT8      |
16//! | `f32`                   | floating32(float32)                           | REAL, FLOAT4                 |
17//! | `f64`                   | floating64(float64)                           | DOUBLE PRECISION, FLOAT8     |
18//! | `String`                | str(string)                                   | VARCHAR, CHAR(N), TEXT       |
19//! | `Vec<u8>`               | binary(list\<u8\>)                            | BYTEA                        |
20//! | `chrono::NaiveDate`     | date(tuple<s32, u8, u8>)                      | DATE                         |
21//! | `chrono::NaiveTime`     | time(tuple<u8, u8, u8, u32>)                  | TIME                         |
22//! | `chrono::NaiveDateTime` | datetime(tuple<s32, u8, u8, u8, u8, u8, u32>) | TIMESTAMP                    |
23//! | `chrono::Duration`      | timestamp(s64)                                | BIGINT                       |
24//! | `uuid::Uuid`            | uuid(string)                                  | UUID                         |
25//! | `serde_json::Value`     | jsonb(list\<u8\>)                             | JSONB                        |
26//! | `serde::De/Serialize`   | jsonb(list\<u8\>)                             | JSONB                        |
27//! | `rust_decimal::Decimal` | decimal(string)                               | NUMERIC                      |
28//! | `postgres_range`        | range-int32(...), range-int64(...)            | INT4RANGE, INT8RANGE         |
29//! | lower/upper tuple       | range-decimal(...)                            | NUMERICRANGE                 |
30//! | `Vec<Option<...>>`      | array-int32(...), array-int64(...), array-str(...), array-decimal(...) | INT4[], INT8[], TEXT[], NUMERIC[] |
31//! | `pg4::Interval`         | interval(interval)                            | INTERVAL                     |
32
33// pg4 errors can be large, because they now include a breakdown of the PostgreSQL
34// error fields instead of just a string
35#![allow(clippy::result_large_err)]
36
37use crate::wit_bindgen;
38use std::sync::Arc;
39
40#[doc(hidden)]
41/// Module containing wit bindgen generated code.
42///
43/// This is only meant for internal consumption.
44pub mod wit {
45    #![allow(missing_docs)]
46    use crate::wit_bindgen;
47
48    wit_bindgen::generate!({
49        runtime_path: "crate::wit_bindgen::rt",
50        world: "spin-sdk-pg",
51        path: "wit",
52        generate_all,
53    });
54
55    pub use spin::postgres::postgres;
56}
57
58#[doc(inline)]
59pub use wit::postgres::{
60    Column, DbDataType, DbError, DbValue, Error as PgError, ParameterValue, QueryError,
61    RangeBoundKind,
62};
63
64/// The PostgreSQL INTERVAL data type.
65pub use wit::postgres::Interval;
66
67use chrono::{Datelike, Timelike};
68
69/// An open connection to a PostgreSQL database.
70///
71/// # Examples
72///
73/// Load a set of rows from a local PostgreSQL database, and iterate over them.
74///
75/// ```no_run
76/// use spin_sdk::pg::{Connection, Decode};
77///
78/// # async fn run() -> anyhow::Result<()> {
79/// # let min_age = 0;
80/// let db = Connection::open("host=localhost user=postgres password=my_password dbname=mydb").await?;
81///
82/// let mut query_result = db.query(
83///     "SELECT * FROM users WHERE age >= $1",
84///     &[min_age.into()]
85/// ).await?;
86///
87/// while let Some(row) = query_result.next().await {
88///     let name = row.get::<String>("name").unwrap();
89///     println!("Found user {name}");
90/// }
91///
92/// query_result.result().await?;
93/// # Ok(())
94/// # }
95/// ```
96///
97/// Perform an aggregate (scalar) operation over a table. The result set
98/// contains a single column, with a single row.
99///
100/// ```no_run
101/// use spin_sdk::pg::{Connection, Decode};
102///
103/// # async fn run() -> anyhow::Result<()> {
104/// let db = Connection::open("host=localhost user=postgres password=my_password dbname=mydb").await?;
105///
106/// let query_result = db.query("SELECT COUNT(*) FROM users", &[]).await?;
107///
108/// assert_eq!(1, query_result.columns().len());
109/// assert_eq!("count", query_result.columns()[0].name);
110///
111/// let rows = query_result.collect().await?;
112///
113/// assert_eq!(1, rows.len());
114///
115/// let count = &rows[0][0];
116/// # Ok(())
117/// # }
118/// ```
119///
120/// Delete rows from a PostgreSQL table. This uses [Connection::execute()]
121/// instead of the `query` method.
122///
123/// ```no_run
124/// use spin_sdk::pg::Connection;
125///
126/// # async fn run() -> anyhow::Result<()> {
127/// let db = Connection::open("host=localhost user=postgres password=my_password dbname=mydb").await?;
128///
129/// let rows_affected = db.execute(
130///     "DELETE FROM users WHERE name = $1",
131///     &["Baldrick".to_owned().into()]
132/// ).await?;
133/// # Ok(())
134/// # }
135/// ```
136pub struct Connection(wit::postgres::Connection);
137
138/// Options for opening a [`Connection`].
139#[derive(Default)]
140pub struct OpenOptions {
141    /// A certificate for the root certificate authority to use in TLS
142    /// for the connection.
143    pub ca_root: Option<Certificate>,
144}
145
146/// A TLS certificate. This is a text document (starting with `-----BEGIN CERTIFICATE-----`).
147pub enum Certificate {
148    /// The certificate is a file mounted in the guest at the given path.
149    FilePath(String),
150    /// The certificate text is the given string.
151    Text(String),
152}
153
154impl Certificate {
155    fn load(self) -> Result<String, Error> {
156        match self {
157            Certificate::FilePath(path) => std::fs::read_to_string(path)
158                .map_err(|e| Error::PgError(PgError::Other(e.to_string()))),
159            Certificate::Text(text) => Ok(text),
160        }
161    }
162}
163
164impl Connection {
165    /// Open a connection to a PostgreSQL database.
166    ///
167    /// The address may be in connection string form (`"host=... dbname=..."`)
168    /// or in URL form (`"postgres://<host>/<dbname>?..."`).
169    ///
170    /// This constructor does not support options such as custom CA roots.
171    /// See [`Connection::open_with_options`] for a more flexible constructor.
172    pub async fn open(address: impl Into<String>) -> Result<Self, Error> {
173        let inner = wit::postgres::Connection::open_async(address.into()).await?;
174        Ok(Self(inner))
175    }
176
177    /// Open a connection to a PostgreSQL database.
178    ///
179    /// The address may be in connection string form (`"host=... dbname=..."`)
180    /// or in URL form (`"postgres://<host>/<dbname>?..."`).
181    ///
182    /// The `options` parameter allows for passing options not available in the address string.
183    pub async fn open_with_options(
184        address: impl AsRef<str>,
185        options: OpenOptions,
186    ) -> Result<Self, Error> {
187        let builder = wit::postgres::ConnectionBuilder::new(address.as_ref());
188        let OpenOptions { ca_root } = options;
189
190        if let Some(ca_root) = ca_root {
191            let ca_root_text = ca_root.load()?;
192            builder.set_ca_root(&ca_root_text)?;
193        }
194
195        let inner = builder.build_async().await?;
196        Ok(Self(inner))
197    }
198
199    /// Query the database.
200    ///
201    /// Use this function for queries that return rows (typically `SELECT` queries).
202    /// For side-effectful queries, see [`Connection::execute`].
203    pub async fn query(
204        &self,
205        statement: impl Into<String>,
206        params: impl Into<Vec<ParameterValue>>,
207    ) -> Result<QueryResult, Error> {
208        let (columns, rows, result) = self.0.query_async(statement.into(), params.into()).await?;
209        // let result = result.into_future();
210        Ok(QueryResult {
211            columns: Arc::new(columns),
212            rows,
213            result,
214        })
215    }
216
217    /// Execute a command against the database.
218    ///
219    /// Use this function for side-effectful queries (such as `INSERT` or `DELETE` queries).
220    /// For queries that return row data, see [`Connection::query`].
221    pub async fn execute(
222        &self,
223        statement: impl Into<String>,
224        params: impl Into<Vec<ParameterValue>>,
225    ) -> Result<u64, Error> {
226        self.0
227            .execute_async(statement.into(), params.into())
228            .await
229            .map_err(Error::PgError)
230    }
231
232    /// Extracts the underlying Wasm Component Model resource for the connection.
233    pub fn into_inner(self) -> wit::postgres::Connection {
234        self.0
235    }
236}
237
238/// The result of a [`Connection::query`] operation.
239pub struct QueryResult {
240    columns: Arc<Vec<Column>>,
241    rows: wit_bindgen::StreamReader<Vec<DbValue>>,
242    result: wit_bindgen::FutureReader<Result<(), PgError>>,
243}
244
245impl QueryResult {
246    /// The columns in the query result.
247    pub fn columns(&self) -> &[Column] {
248        &self.columns
249    }
250
251    // TODO: should this return Result<Option<Row>> so users
252    // could write `q.next().await?` instead of checking `result`
253    // separately???
254
255    /// Gets the next row in the result set.
256    ///
257    /// If this is `None`, there are no more rows available. You _must_
258    /// await [`QueryResult::result()`] to determine if all rows
259    /// were read successfully.
260    pub async fn next(&mut self) -> Option<Row> {
261        self.rows.next().await.map(|r| Row {
262            columns: self.columns.clone(),
263            result: r,
264        })
265    }
266
267    /// Whether the query completed successfully or with an error.
268    pub async fn result(self) -> Result<(), Error> {
269        self.result.await.map_err(Error::PgError)
270    }
271
272    /// Collect all rows in the result set.
273    ///
274    /// This is provided for when the result set is small enough to fit in
275    /// memory and you do not require streaming behaviour.
276    pub async fn collect(mut self) -> Result<Vec<Row>, Error> {
277        let mut rows = vec![];
278        while let Some(row) = self.next().await {
279            rows.push(row);
280        }
281        self.result.await.map_err(Error::PgError)?;
282        Ok(rows)
283    }
284
285    /// An asynchronous reader for the rows of the query result. Call
286    /// `.next().await` to iterate over the rows. When this returns `None`,
287    /// you have read all available rows. At this point you _must_ check
288    /// [`QueryResult::result()`] to determine if the read completed
289    /// successfully.
290    ///
291    /// This provides each row as a plain vector of database values.
292    /// [`QueryResult::next()`] provides a more ergonomic wrapper.
293    ///
294    /// To collect all rows into a vector, see [`QueryResult::collect`].
295    pub fn rows(&mut self) -> &mut wit_bindgen::StreamReader<Vec<DbValue>> {
296        &mut self.rows
297    }
298
299    /// Extracts the underlying Wasm Component Model results of the query.
300    #[allow(
301        clippy::type_complexity,
302        reason = "sorry clippy that's just what the inner bits are"
303    )]
304    pub fn into_inner(
305        self,
306    ) -> (
307        Vec<Column>,
308        wit_bindgen::StreamReader<Vec<DbValue>>,
309        wit_bindgen::FutureReader<Result<(), PgError>>,
310    ) {
311        ((*self.columns).clone(), self.rows, self.result)
312    }
313}
314
315/// A database row result.
316///
317/// There are two representations of a SQLite row in the SDK.  This type is useful for
318/// addressing elements by column name, and is obtained from the [QueryResult::next()] function.
319/// The [DbValue] vector representation is obtained from the [QueryResult::rows()] function, and provides
320/// index-based lookup or low-level access to row values via a vector.
321pub struct Row {
322    columns: Arc<Vec<wit::postgres::Column>>,
323    result: Vec<DbValue>,
324}
325
326impl Row {
327    /// Get a value by its column name. The value is converted to the target type as per the
328    /// conversion table shown in the module documentation.
329    ///
330    /// This function returns None for both no such column _and_ failed conversion. You should use
331    /// it only if you do not need to address errors (that is, if you know that conversion should
332    /// never fail). If your code does not know the type in advance, use the raw [QueryResult::rows()] function
333    /// instead of the [`QueryResult::next()`] or [`QueryResult::collect()`] wrappers to access
334    /// the underlying [DbValue] enum: this will allow you to
335    /// determine the type and process it accordingly.
336    ///
337    /// Additionally, this function performs a name lookup each time it is called. If you are iterating
338    /// over a large number of rows, it's more efficient to use column indexes, either calculated or
339    /// statically known from the column order in the SQL.
340    ///
341    /// # Examples
342    ///
343    /// ```no_run
344    /// use spin_sdk::pg::{Connection, DbValue};
345    ///
346    /// # async fn run() -> anyhow::Result<()> {
347    /// # let user_id = 0;
348    /// let db = Connection::open("host=localhost user=postgres password=my_password dbname=mydb").await?;
349    /// let mut query_result = db.query(
350    ///     "SELECT * FROM users WHERE id = $1",
351    ///     &[user_id.into()]
352    /// ).await?;
353    /// let user_row = query_result.next().await.unwrap();
354    ///
355    /// let name = user_row.get::<String>("name").unwrap();
356    /// let age = user_row.get::<i16>("age").unwrap();
357    /// # Ok(())
358    /// # }
359    /// ```
360    pub fn get<T: Decode>(&self, column: &str) -> Option<T> {
361        let i = self.columns.iter().position(|c| c.name == column)?;
362        let db_value = self.result.get(i)?;
363        Decode::decode(db_value).ok()
364    }
365}
366
367impl std::ops::Index<usize> for Row {
368    type Output = DbValue;
369
370    fn index(&self, index: usize) -> &Self::Output {
371        &self.result[index]
372    }
373}
374
375/// A Postgres error
376#[derive(Debug, thiserror::Error)]
377pub enum Error {
378    /// Failed to deserialize [`DbValue`]
379    #[error("error value decoding: {0}")]
380    Decode(String),
381    /// Postgres query failed with an error
382    #[error(transparent)]
383    PgError(#[from] PgError),
384}
385
386/// A type that can be decoded from the database.
387pub trait Decode: Sized {
388    /// Decode a new value of this type using a [`DbValue`].
389    fn decode(value: &DbValue) -> Result<Self, Error>;
390}
391
392impl<T> Decode for Option<T>
393where
394    T: Decode,
395{
396    fn decode(value: &DbValue) -> Result<Self, Error> {
397        match value {
398            DbValue::DbNull => Ok(None),
399            v => Ok(Some(T::decode(v)?)),
400        }
401    }
402}
403
404impl Decode for bool {
405    fn decode(value: &DbValue) -> Result<Self, Error> {
406        match value {
407            DbValue::Boolean(boolean) => Ok(*boolean),
408            _ => Err(Error::Decode(format_decode_err("BOOL", value))),
409        }
410    }
411}
412
413impl Decode for i16 {
414    fn decode(value: &DbValue) -> Result<Self, Error> {
415        match value {
416            DbValue::Int16(n) => Ok(*n),
417            _ => Err(Error::Decode(format_decode_err("SMALLINT", value))),
418        }
419    }
420}
421
422impl Decode for i32 {
423    fn decode(value: &DbValue) -> Result<Self, Error> {
424        match value {
425            DbValue::Int32(n) => Ok(*n),
426            _ => Err(Error::Decode(format_decode_err("INT", value))),
427        }
428    }
429}
430
431impl Decode for i64 {
432    fn decode(value: &DbValue) -> Result<Self, Error> {
433        match value {
434            DbValue::Int64(n) => Ok(*n),
435            _ => Err(Error::Decode(format_decode_err("BIGINT", value))),
436        }
437    }
438}
439
440impl Decode for f32 {
441    fn decode(value: &DbValue) -> Result<Self, Error> {
442        match value {
443            DbValue::Floating32(n) => Ok(*n),
444            _ => Err(Error::Decode(format_decode_err("REAL", value))),
445        }
446    }
447}
448
449impl Decode for f64 {
450    fn decode(value: &DbValue) -> Result<Self, Error> {
451        match value {
452            DbValue::Floating64(n) => Ok(*n),
453            _ => Err(Error::Decode(format_decode_err("DOUBLE PRECISION", value))),
454        }
455    }
456}
457
458impl Decode for Vec<u8> {
459    fn decode(value: &DbValue) -> Result<Self, Error> {
460        match value {
461            DbValue::Binary(n) => Ok(n.to_owned()),
462            _ => Err(Error::Decode(format_decode_err("BYTEA", value))),
463        }
464    }
465}
466
467impl Decode for String {
468    fn decode(value: &DbValue) -> Result<Self, Error> {
469        match value {
470            DbValue::Str(s) => Ok(s.to_owned()),
471            _ => Err(Error::Decode(format_decode_err(
472                "CHAR, VARCHAR, TEXT",
473                value,
474            ))),
475        }
476    }
477}
478
479impl Decode for chrono::NaiveDate {
480    fn decode(value: &DbValue) -> Result<Self, Error> {
481        match value {
482            DbValue::Date((year, month, day)) => {
483                let naive_date =
484                    chrono::NaiveDate::from_ymd_opt(*year, (*month).into(), (*day).into())
485                        .ok_or_else(|| {
486                            Error::Decode(format!(
487                                "invalid date y={}, m={}, d={}",
488                                year, month, day
489                            ))
490                        })?;
491                Ok(naive_date)
492            }
493            _ => Err(Error::Decode(format_decode_err("DATE", value))),
494        }
495    }
496}
497
498impl Decode for chrono::NaiveTime {
499    fn decode(value: &DbValue) -> Result<Self, Error> {
500        match value {
501            DbValue::Time((hour, minute, second, nanosecond)) => {
502                let naive_time = chrono::NaiveTime::from_hms_nano_opt(
503                    (*hour).into(),
504                    (*minute).into(),
505                    (*second).into(),
506                    *nanosecond,
507                )
508                .ok_or_else(|| {
509                    Error::Decode(format!(
510                        "invalid time {}:{}:{}:{}",
511                        hour, minute, second, nanosecond
512                    ))
513                })?;
514                Ok(naive_time)
515            }
516            _ => Err(Error::Decode(format_decode_err("TIME", value))),
517        }
518    }
519}
520
521impl Decode for chrono::NaiveDateTime {
522    fn decode(value: &DbValue) -> Result<Self, Error> {
523        match value {
524            DbValue::Datetime((year, month, day, hour, minute, second, nanosecond)) => {
525                let naive_date =
526                    chrono::NaiveDate::from_ymd_opt(*year, (*month).into(), (*day).into())
527                        .ok_or_else(|| {
528                            Error::Decode(format!(
529                                "invalid date y={}, m={}, d={}",
530                                year, month, day
531                            ))
532                        })?;
533                let naive_time = chrono::NaiveTime::from_hms_nano_opt(
534                    (*hour).into(),
535                    (*minute).into(),
536                    (*second).into(),
537                    *nanosecond,
538                )
539                .ok_or_else(|| {
540                    Error::Decode(format!(
541                        "invalid time {}:{}:{}:{}",
542                        hour, minute, second, nanosecond
543                    ))
544                })?;
545                let dt = chrono::NaiveDateTime::new(naive_date, naive_time);
546                Ok(dt)
547            }
548            _ => Err(Error::Decode(format_decode_err("DATETIME", value))),
549        }
550    }
551}
552
553impl Decode for chrono::Duration {
554    fn decode(value: &DbValue) -> Result<Self, Error> {
555        match value {
556            DbValue::Timestamp(n) => Ok(chrono::Duration::seconds(*n)),
557            _ => Err(Error::Decode(format_decode_err("BIGINT", value))),
558        }
559    }
560}
561
562#[cfg(feature = "postgres4-types")]
563impl Decode for uuid::Uuid {
564    fn decode(value: &DbValue) -> Result<Self, Error> {
565        match value {
566            DbValue::Uuid(s) => uuid::Uuid::parse_str(s).map_err(|e| Error::Decode(e.to_string())),
567            _ => Err(Error::Decode(format_decode_err("UUID", value))),
568        }
569    }
570}
571
572#[cfg(feature = "json")]
573impl Decode for serde_json::Value {
574    fn decode(value: &DbValue) -> Result<Self, Error> {
575        from_jsonb(value)
576    }
577}
578
579/// Convert a Postgres JSONB value to a `Deserialize`-able type.
580#[cfg(feature = "json")]
581pub fn from_jsonb<'a, T: serde::Deserialize<'a>>(value: &'a DbValue) -> Result<T, Error> {
582    match value {
583        DbValue::Jsonb(j) => serde_json::from_slice(j).map_err(|e| Error::Decode(e.to_string())),
584        _ => Err(Error::Decode(format_decode_err("JSONB", value))),
585    }
586}
587
588#[cfg(feature = "postgres4-types")]
589impl Decode for rust_decimal::Decimal {
590    fn decode(value: &DbValue) -> Result<Self, Error> {
591        match value {
592            DbValue::Decimal(s) => {
593                rust_decimal::Decimal::from_str_exact(s).map_err(|e| Error::Decode(e.to_string()))
594            }
595            _ => Err(Error::Decode(format_decode_err("NUMERIC", value))),
596        }
597    }
598}
599
600#[cfg(feature = "postgres4-types")]
601fn bound_type_from_wit(kind: RangeBoundKind) -> postgres_range::BoundType {
602    match kind {
603        RangeBoundKind::Inclusive => postgres_range::BoundType::Inclusive,
604        RangeBoundKind::Exclusive => postgres_range::BoundType::Exclusive,
605    }
606}
607
608#[cfg(feature = "postgres4-types")]
609impl Decode for postgres_range::Range<i32> {
610    fn decode(value: &DbValue) -> Result<Self, Error> {
611        match value {
612            DbValue::RangeInt32((lbound, ubound)) => {
613                let lower = lbound.map(|(value, kind)| {
614                    postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
615                });
616                let upper = ubound.map(|(value, kind)| {
617                    postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
618                });
619                Ok(postgres_range::Range::new(lower, upper))
620            }
621            _ => Err(Error::Decode(format_decode_err("INT4RANGE", value))),
622        }
623    }
624}
625
626#[cfg(feature = "postgres4-types")]
627impl Decode for postgres_range::Range<i64> {
628    fn decode(value: &DbValue) -> Result<Self, Error> {
629        match value {
630            DbValue::RangeInt64((lbound, ubound)) => {
631                let lower = lbound.map(|(value, kind)| {
632                    postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
633                });
634                let upper = ubound.map(|(value, kind)| {
635                    postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
636                });
637                Ok(postgres_range::Range::new(lower, upper))
638            }
639            _ => Err(Error::Decode(format_decode_err("INT8RANGE", value))),
640        }
641    }
642}
643
644// We can't use postgres_range::Range because rust_decimal::Decimal
645// is not Normalizable
646#[cfg(feature = "postgres4-types")]
647impl Decode
648    for (
649        Option<(rust_decimal::Decimal, RangeBoundKind)>,
650        Option<(rust_decimal::Decimal, RangeBoundKind)>,
651    )
652{
653    fn decode(value: &DbValue) -> Result<Self, Error> {
654        fn parse(
655            value: &str,
656            kind: RangeBoundKind,
657        ) -> Result<(rust_decimal::Decimal, RangeBoundKind), Error> {
658            let dec = rust_decimal::Decimal::from_str_exact(value)
659                .map_err(|e| Error::Decode(e.to_string()))?;
660            Ok((dec, kind))
661        }
662
663        match value {
664            DbValue::RangeDecimal((lbound, ubound)) => {
665                let lower = lbound
666                    .as_ref()
667                    .map(|(value, kind)| parse(value, *kind))
668                    .transpose()?;
669                let upper = ubound
670                    .as_ref()
671                    .map(|(value, kind)| parse(value, *kind))
672                    .transpose()?;
673                Ok((lower, upper))
674            }
675            _ => Err(Error::Decode(format_decode_err("NUMERICRANGE", value))),
676        }
677    }
678}
679
680// TODO: can we return a slice here? It seems like it should be possible but
681// I wasn't able to get the lifetimes to work with the trait
682impl Decode for Vec<Option<i32>> {
683    fn decode(value: &DbValue) -> Result<Self, Error> {
684        match value {
685            DbValue::ArrayInt32(a) => Ok(a.to_vec()),
686            _ => Err(Error::Decode(format_decode_err("INT4[]", value))),
687        }
688    }
689}
690
691impl Decode for Vec<Option<i64>> {
692    fn decode(value: &DbValue) -> Result<Self, Error> {
693        match value {
694            DbValue::ArrayInt64(a) => Ok(a.to_vec()),
695            _ => Err(Error::Decode(format_decode_err("INT8[]", value))),
696        }
697    }
698}
699
700impl Decode for Vec<Option<String>> {
701    fn decode(value: &DbValue) -> Result<Self, Error> {
702        match value {
703            DbValue::ArrayStr(a) => Ok(a.to_vec()),
704            _ => Err(Error::Decode(format_decode_err("TEXT[]", value))),
705        }
706    }
707}
708
709#[cfg(feature = "postgres4-types")]
710fn map_decimal(s: &Option<String>) -> Result<Option<rust_decimal::Decimal>, Error> {
711    s.as_ref()
712        .map(|s| rust_decimal::Decimal::from_str_exact(s))
713        .transpose()
714        .map_err(|e| Error::Decode(e.to_string()))
715}
716
717#[cfg(feature = "postgres4-types")]
718impl Decode for Vec<Option<rust_decimal::Decimal>> {
719    fn decode(value: &DbValue) -> Result<Self, Error> {
720        match value {
721            DbValue::ArrayDecimal(a) => {
722                let decs = a.iter().map(map_decimal).collect::<Result<_, _>>()?;
723                Ok(decs)
724            }
725            _ => Err(Error::Decode(format_decode_err("NUMERIC[]", value))),
726        }
727    }
728}
729
730impl Decode for Interval {
731    fn decode(value: &DbValue) -> Result<Self, Error> {
732        match value {
733            DbValue::Interval(i) => Ok(*i),
734            _ => Err(Error::Decode(format_decode_err("INTERVAL", value))),
735        }
736    }
737}
738
739macro_rules! impl_parameter_value_conversions {
740    ($($ty:ty => $id:ident),*) => {
741        $(
742            impl From<$ty> for ParameterValue {
743                fn from(v: $ty) -> ParameterValue {
744                    ParameterValue::$id(v)
745                }
746            }
747        )*
748    };
749}
750
751impl_parameter_value_conversions! {
752    i8 => Int8,
753    i16 => Int16,
754    i32 => Int32,
755    i64 => Int64,
756    f32 => Floating32,
757    f64 => Floating64,
758    bool => Boolean,
759    String => Str,
760    Vec<u8> => Binary,
761    Vec<Option<i32>> => ArrayInt32,
762    Vec<Option<i64>> => ArrayInt64,
763    Vec<Option<String>> => ArrayStr
764}
765
766impl From<chrono::NaiveDateTime> for ParameterValue {
767    fn from(v: chrono::NaiveDateTime) -> ParameterValue {
768        ParameterValue::Datetime((
769            v.year(),
770            v.month() as u8,
771            v.day() as u8,
772            v.hour() as u8,
773            v.minute() as u8,
774            v.second() as u8,
775            v.nanosecond(),
776        ))
777    }
778}
779
780impl From<chrono::NaiveTime> for ParameterValue {
781    fn from(v: chrono::NaiveTime) -> ParameterValue {
782        ParameterValue::Time((
783            v.hour() as u8,
784            v.minute() as u8,
785            v.second() as u8,
786            v.nanosecond(),
787        ))
788    }
789}
790
791impl From<chrono::NaiveDate> for ParameterValue {
792    fn from(v: chrono::NaiveDate) -> ParameterValue {
793        ParameterValue::Date((v.year(), v.month() as u8, v.day() as u8))
794    }
795}
796
797impl From<chrono::TimeDelta> for ParameterValue {
798    fn from(v: chrono::TimeDelta) -> ParameterValue {
799        ParameterValue::Timestamp(v.num_seconds())
800    }
801}
802
803#[cfg(feature = "postgres4-types")]
804impl From<uuid::Uuid> for ParameterValue {
805    fn from(v: uuid::Uuid) -> ParameterValue {
806        ParameterValue::Uuid(v.to_string())
807    }
808}
809
810#[cfg(feature = "json")]
811impl TryFrom<serde_json::Value> for ParameterValue {
812    type Error = serde_json::Error;
813
814    fn try_from(v: serde_json::Value) -> Result<ParameterValue, Self::Error> {
815        jsonb(&v)
816    }
817}
818
819/// Converts a `Serialize` value to a Postgres JSONB SQL parameter.
820#[cfg(feature = "json")]
821pub fn jsonb<T: serde::Serialize>(value: &T) -> Result<ParameterValue, serde_json::Error> {
822    let json = serde_json::to_vec(value)?;
823    Ok(ParameterValue::Jsonb(json))
824}
825
826#[cfg(feature = "postgres4-types")]
827impl From<rust_decimal::Decimal> for ParameterValue {
828    fn from(v: rust_decimal::Decimal) -> ParameterValue {
829        ParameterValue::Decimal(v.to_string())
830    }
831}
832
833// We cannot impl From<T: RangeBounds<...>> because Rust fears that some future
834// knave or rogue might one day add RangeBounds to NaiveDateTime. The best we can
835// do is therefore a helper function we can call from range Froms.
836#[allow(
837    clippy::type_complexity,
838    reason = "I sure hope 'blame Alex' works here too"
839)]
840fn range_bounds_to_wit<T, U>(
841    range: impl std::ops::RangeBounds<T>,
842    f: impl Fn(&T) -> U,
843) -> (Option<(U, RangeBoundKind)>, Option<(U, RangeBoundKind)>) {
844    (
845        range_bound_to_wit(range.start_bound(), &f),
846        range_bound_to_wit(range.end_bound(), &f),
847    )
848}
849
850fn range_bound_to_wit<T, U>(
851    bound: std::ops::Bound<&T>,
852    f: &dyn Fn(&T) -> U,
853) -> Option<(U, RangeBoundKind)> {
854    match bound {
855        std::ops::Bound::Included(v) => Some((f(v), RangeBoundKind::Inclusive)),
856        std::ops::Bound::Excluded(v) => Some((f(v), RangeBoundKind::Exclusive)),
857        std::ops::Bound::Unbounded => None,
858    }
859}
860
861#[cfg(feature = "postgres4-types")]
862fn pg_range_bound_to_wit<S: postgres_range::BoundSided, T: Copy>(
863    bound: &postgres_range::RangeBound<S, T>,
864) -> (T, RangeBoundKind) {
865    let kind = match &bound.type_ {
866        postgres_range::BoundType::Inclusive => RangeBoundKind::Inclusive,
867        postgres_range::BoundType::Exclusive => RangeBoundKind::Exclusive,
868    };
869    (bound.value, kind)
870}
871
872impl From<std::ops::Range<i32>> for ParameterValue {
873    fn from(v: std::ops::Range<i32>) -> ParameterValue {
874        ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
875    }
876}
877
878impl From<std::ops::RangeInclusive<i32>> for ParameterValue {
879    fn from(v: std::ops::RangeInclusive<i32>) -> ParameterValue {
880        ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
881    }
882}
883
884impl From<std::ops::RangeFrom<i32>> for ParameterValue {
885    fn from(v: std::ops::RangeFrom<i32>) -> ParameterValue {
886        ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
887    }
888}
889
890impl From<std::ops::RangeTo<i32>> for ParameterValue {
891    fn from(v: std::ops::RangeTo<i32>) -> ParameterValue {
892        ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
893    }
894}
895
896impl From<std::ops::RangeToInclusive<i32>> for ParameterValue {
897    fn from(v: std::ops::RangeToInclusive<i32>) -> ParameterValue {
898        ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
899    }
900}
901
902#[cfg(feature = "postgres4-types")]
903impl From<postgres_range::Range<i32>> for ParameterValue {
904    fn from(v: postgres_range::Range<i32>) -> ParameterValue {
905        let lbound = v.lower().map(pg_range_bound_to_wit);
906        let ubound = v.upper().map(pg_range_bound_to_wit);
907        ParameterValue::RangeInt32((lbound, ubound))
908    }
909}
910
911impl From<std::ops::Range<i64>> for ParameterValue {
912    fn from(v: std::ops::Range<i64>) -> ParameterValue {
913        ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
914    }
915}
916
917impl From<std::ops::RangeInclusive<i64>> for ParameterValue {
918    fn from(v: std::ops::RangeInclusive<i64>) -> ParameterValue {
919        ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
920    }
921}
922
923impl From<std::ops::RangeFrom<i64>> for ParameterValue {
924    fn from(v: std::ops::RangeFrom<i64>) -> ParameterValue {
925        ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
926    }
927}
928
929impl From<std::ops::RangeTo<i64>> for ParameterValue {
930    fn from(v: std::ops::RangeTo<i64>) -> ParameterValue {
931        ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
932    }
933}
934
935impl From<std::ops::RangeToInclusive<i64>> for ParameterValue {
936    fn from(v: std::ops::RangeToInclusive<i64>) -> ParameterValue {
937        ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
938    }
939}
940
941#[cfg(feature = "postgres4-types")]
942impl From<postgres_range::Range<i64>> for ParameterValue {
943    fn from(v: postgres_range::Range<i64>) -> ParameterValue {
944        let lbound = v.lower().map(pg_range_bound_to_wit);
945        let ubound = v.upper().map(pg_range_bound_to_wit);
946        ParameterValue::RangeInt64((lbound, ubound))
947    }
948}
949
950#[cfg(feature = "postgres4-types")]
951impl From<std::ops::Range<rust_decimal::Decimal>> for ParameterValue {
952    fn from(v: std::ops::Range<rust_decimal::Decimal>) -> ParameterValue {
953        ParameterValue::RangeDecimal(range_bounds_to_wit(v, |d| d.to_string()))
954    }
955}
956
957impl From<Vec<i32>> for ParameterValue {
958    fn from(v: Vec<i32>) -> ParameterValue {
959        ParameterValue::ArrayInt32(v.into_iter().map(Some).collect())
960    }
961}
962
963impl From<Vec<i64>> for ParameterValue {
964    fn from(v: Vec<i64>) -> ParameterValue {
965        ParameterValue::ArrayInt64(v.into_iter().map(Some).collect())
966    }
967}
968
969impl From<Vec<String>> for ParameterValue {
970    fn from(v: Vec<String>) -> ParameterValue {
971        ParameterValue::ArrayStr(v.into_iter().map(Some).collect())
972    }
973}
974
975#[cfg(feature = "postgres4-types")]
976impl From<Vec<Option<rust_decimal::Decimal>>> for ParameterValue {
977    fn from(v: Vec<Option<rust_decimal::Decimal>>) -> ParameterValue {
978        let strs = v
979            .into_iter()
980            .map(|optd| optd.map(|d| d.to_string()))
981            .collect();
982        ParameterValue::ArrayDecimal(strs)
983    }
984}
985
986#[cfg(feature = "postgres4-types")]
987impl From<Vec<rust_decimal::Decimal>> for ParameterValue {
988    fn from(v: Vec<rust_decimal::Decimal>) -> ParameterValue {
989        let strs = v.into_iter().map(|d| Some(d.to_string())).collect();
990        ParameterValue::ArrayDecimal(strs)
991    }
992}
993
994impl From<Interval> for ParameterValue {
995    fn from(v: Interval) -> ParameterValue {
996        ParameterValue::Interval(v)
997    }
998}
999
1000impl<T: Into<ParameterValue>> From<Option<T>> for ParameterValue {
1001    fn from(o: Option<T>) -> ParameterValue {
1002        match o {
1003            Some(v) => v.into(),
1004            None => ParameterValue::DbNull,
1005        }
1006    }
1007}
1008
1009fn format_decode_err(types: &str, value: &DbValue) -> String {
1010    format!("Expected {} from the DB but got {:?}", types, value)
1011}
1012
1013#[cfg(test)]
1014mod tests {
1015    use chrono::NaiveDateTime;
1016
1017    use super::*;
1018
1019    #[test]
1020    fn boolean() {
1021        assert!(bool::decode(&DbValue::Boolean(true)).unwrap());
1022        assert!(bool::decode(&DbValue::Int32(0)).is_err());
1023        assert!(Option::<bool>::decode(&DbValue::DbNull).unwrap().is_none());
1024    }
1025
1026    #[test]
1027    fn int16() {
1028        assert_eq!(i16::decode(&DbValue::Int16(0)).unwrap(), 0);
1029        assert!(i16::decode(&DbValue::Int32(0)).is_err());
1030        assert!(Option::<i16>::decode(&DbValue::DbNull).unwrap().is_none());
1031    }
1032
1033    #[test]
1034    fn int32() {
1035        assert_eq!(i32::decode(&DbValue::Int32(0)).unwrap(), 0);
1036        assert!(i32::decode(&DbValue::Boolean(false)).is_err());
1037        assert!(Option::<i32>::decode(&DbValue::DbNull).unwrap().is_none());
1038    }
1039
1040    #[test]
1041    fn int64() {
1042        assert_eq!(i64::decode(&DbValue::Int64(0)).unwrap(), 0);
1043        assert!(i64::decode(&DbValue::Boolean(false)).is_err());
1044        assert!(Option::<i64>::decode(&DbValue::DbNull).unwrap().is_none());
1045    }
1046
1047    #[test]
1048    fn floating32() {
1049        assert!(f32::decode(&DbValue::Floating32(0.0)).is_ok());
1050        assert!(f32::decode(&DbValue::Boolean(false)).is_err());
1051        assert!(Option::<f32>::decode(&DbValue::DbNull).unwrap().is_none());
1052    }
1053
1054    #[test]
1055    fn floating64() {
1056        assert!(f64::decode(&DbValue::Floating64(0.0)).is_ok());
1057        assert!(f64::decode(&DbValue::Boolean(false)).is_err());
1058        assert!(Option::<f64>::decode(&DbValue::DbNull).unwrap().is_none());
1059    }
1060
1061    #[test]
1062    fn str() {
1063        assert_eq!(
1064            String::decode(&DbValue::Str(String::from("foo"))).unwrap(),
1065            String::from("foo")
1066        );
1067
1068        assert!(String::decode(&DbValue::Int32(0)).is_err());
1069        assert!(Option::<String>::decode(&DbValue::DbNull)
1070            .unwrap()
1071            .is_none());
1072    }
1073
1074    #[test]
1075    fn binary() {
1076        assert!(Vec::<u8>::decode(&DbValue::Binary(vec![0, 0])).is_ok());
1077        assert!(Vec::<u8>::decode(&DbValue::Boolean(false)).is_err());
1078        assert!(Option::<Vec<u8>>::decode(&DbValue::DbNull)
1079            .unwrap()
1080            .is_none());
1081    }
1082
1083    #[test]
1084    fn date() {
1085        assert_eq!(
1086            chrono::NaiveDate::decode(&DbValue::Date((1, 2, 4))).unwrap(),
1087            chrono::NaiveDate::from_ymd_opt(1, 2, 4).unwrap()
1088        );
1089        assert_ne!(
1090            chrono::NaiveDate::decode(&DbValue::Date((1, 2, 4))).unwrap(),
1091            chrono::NaiveDate::from_ymd_opt(1, 2, 5).unwrap()
1092        );
1093        assert!(Option::<chrono::NaiveDate>::decode(&DbValue::DbNull)
1094            .unwrap()
1095            .is_none());
1096    }
1097
1098    #[test]
1099    fn time() {
1100        assert_eq!(
1101            chrono::NaiveTime::decode(&DbValue::Time((1, 2, 3, 4))).unwrap(),
1102            chrono::NaiveTime::from_hms_nano_opt(1, 2, 3, 4).unwrap()
1103        );
1104        assert_ne!(
1105            chrono::NaiveTime::decode(&DbValue::Time((1, 2, 3, 4))).unwrap(),
1106            chrono::NaiveTime::from_hms_nano_opt(1, 2, 4, 5).unwrap()
1107        );
1108        assert!(Option::<chrono::NaiveTime>::decode(&DbValue::DbNull)
1109            .unwrap()
1110            .is_none());
1111    }
1112
1113    #[test]
1114    fn datetime() {
1115        let date = chrono::NaiveDate::from_ymd_opt(1, 2, 3).unwrap();
1116        let mut time = chrono::NaiveTime::from_hms_nano_opt(4, 5, 6, 7).unwrap();
1117        assert_eq!(
1118            chrono::NaiveDateTime::decode(&DbValue::Datetime((1, 2, 3, 4, 5, 6, 7))).unwrap(),
1119            chrono::NaiveDateTime::new(date, time)
1120        );
1121
1122        time = chrono::NaiveTime::from_hms_nano_opt(4, 5, 6, 8).unwrap();
1123        assert_ne!(
1124            NaiveDateTime::decode(&DbValue::Datetime((1, 2, 3, 4, 5, 6, 7))).unwrap(),
1125            chrono::NaiveDateTime::new(date, time)
1126        );
1127        assert!(Option::<chrono::NaiveDateTime>::decode(&DbValue::DbNull)
1128            .unwrap()
1129            .is_none());
1130    }
1131
1132    #[test]
1133    fn timestamp() {
1134        assert_eq!(
1135            chrono::Duration::decode(&DbValue::Timestamp(1)).unwrap(),
1136            chrono::Duration::seconds(1),
1137        );
1138        assert_ne!(
1139            chrono::Duration::decode(&DbValue::Timestamp(2)).unwrap(),
1140            chrono::Duration::seconds(1)
1141        );
1142        assert!(Option::<chrono::Duration>::decode(&DbValue::DbNull)
1143            .unwrap()
1144            .is_none());
1145    }
1146
1147    #[test]
1148    #[cfg(feature = "postgres4-types")]
1149    fn uuid() {
1150        let uuid_str = "12341234-1234-1234-1234-123412341234";
1151        assert_eq!(
1152            uuid::Uuid::try_parse(uuid_str).unwrap(),
1153            uuid::Uuid::decode(&DbValue::Uuid(uuid_str.to_owned())).unwrap(),
1154        );
1155        assert!(Option::<uuid::Uuid>::decode(&DbValue::DbNull)
1156            .unwrap()
1157            .is_none());
1158    }
1159
1160    #[derive(Debug, serde::Deserialize, PartialEq)]
1161    struct JsonTest {
1162        hello: String,
1163    }
1164
1165    #[test]
1166    #[cfg(feature = "json")]
1167    fn jsonb() {
1168        let json_val = serde_json::json!({
1169            "hello": "world"
1170        });
1171        let dbval = DbValue::Jsonb(r#"{"hello":"world"}"#.into());
1172
1173        assert_eq!(json_val, serde_json::Value::decode(&dbval).unwrap(),);
1174
1175        let json_struct = JsonTest {
1176            hello: "world".to_owned(),
1177        };
1178        assert_eq!(json_struct, from_jsonb(&dbval).unwrap());
1179    }
1180
1181    #[test]
1182    #[cfg(feature = "postgres4-types")]
1183    fn ranges() {
1184        let i32_range = postgres_range::Range::<i32>::decode(&DbValue::RangeInt32((
1185            Some((45, RangeBoundKind::Inclusive)),
1186            Some((89, RangeBoundKind::Exclusive)),
1187        )))
1188        .unwrap();
1189        assert_eq!(45, i32_range.lower().unwrap().value);
1190        assert_eq!(
1191            postgres_range::BoundType::Inclusive,
1192            i32_range.lower().unwrap().type_
1193        );
1194        assert_eq!(89, i32_range.upper().unwrap().value);
1195        assert_eq!(
1196            postgres_range::BoundType::Exclusive,
1197            i32_range.upper().unwrap().type_
1198        );
1199
1200        let i32_range_from = postgres_range::Range::<i32>::decode(&DbValue::RangeInt32((
1201            Some((45, RangeBoundKind::Inclusive)),
1202            None,
1203        )))
1204        .unwrap();
1205        assert!(i32_range_from.upper().is_none());
1206
1207        let i64_range = postgres_range::Range::<i64>::decode(&DbValue::RangeInt64((
1208            Some((4567456745674567, RangeBoundKind::Inclusive)),
1209            Some((890189018901890189, RangeBoundKind::Exclusive)),
1210        )))
1211        .unwrap();
1212        assert_eq!(4567456745674567, i64_range.lower().unwrap().value);
1213        assert_eq!(890189018901890189, i64_range.upper().unwrap().value);
1214
1215        #[allow(clippy::type_complexity)]
1216        let (dec_lbound, dec_ubound): (
1217            Option<(rust_decimal::Decimal, RangeBoundKind)>,
1218            Option<(rust_decimal::Decimal, RangeBoundKind)>,
1219        ) = Decode::decode(&DbValue::RangeDecimal((
1220            Some(("4567.8901".to_owned(), RangeBoundKind::Inclusive)),
1221            Some(("8901.2345678901".to_owned(), RangeBoundKind::Exclusive)),
1222        )))
1223        .unwrap();
1224        assert_eq!(
1225            rust_decimal::Decimal::from_i128_with_scale(45678901, 4),
1226            dec_lbound.unwrap().0
1227        );
1228        assert_eq!(
1229            rust_decimal::Decimal::from_i128_with_scale(89012345678901, 10),
1230            dec_ubound.unwrap().0
1231        );
1232    }
1233
1234    #[test]
1235    #[cfg(feature = "postgres4-types")]
1236    fn arrays() {
1237        let v32 = vec![Some(123), None, Some(456)];
1238        let i32_arr = Vec::<Option<i32>>::decode(&DbValue::ArrayInt32(v32.clone())).unwrap();
1239        assert_eq!(v32, i32_arr);
1240
1241        let v64 = vec![Some(123), None, Some(456)];
1242        let i64_arr = Vec::<Option<i64>>::decode(&DbValue::ArrayInt64(v64.clone())).unwrap();
1243        assert_eq!(v64, i64_arr);
1244
1245        let vdec = vec![Some("1.23".to_owned()), None];
1246        let dec_arr =
1247            Vec::<Option<rust_decimal::Decimal>>::decode(&DbValue::ArrayDecimal(vdec)).unwrap();
1248        assert_eq!(
1249            vec![
1250                Some(rust_decimal::Decimal::from_i128_with_scale(123, 2)),
1251                None
1252            ],
1253            dec_arr
1254        );
1255
1256        let vstr = vec![Some("alice".to_owned()), None, Some("bob".to_owned())];
1257        let str_arr = Vec::<Option<String>>::decode(&DbValue::ArrayStr(vstr.clone())).unwrap();
1258        assert_eq!(vstr, str_arr);
1259    }
1260}