Skip to main content

lutra_runner_postgres/
lib.rs

1//! PostgreSQL Lutra runner
2
3#[cfg(not(any(feature = "postgres", feature = "tokio-postgres")))]
4compile_error!("At least one of 'postgres' or 'tokio-postgres' features has to be enabled.");
5
6mod case;
7mod params;
8mod result;
9
10#[cfg(feature = "tokio-postgres")]
11mod schema;
12
13pub use lutra_runner::Run;
14
15use std::collections::HashMap;
16use thiserror::Error;
17
18#[cfg(feature = "postgres")]
19use postgres::Error as PgError;
20#[cfg(not(feature = "postgres"))]
21use tokio_postgres::Error as PgError;
22
23use lutra_bin::{ir, rr};
24
25#[derive(Error, Debug)]
26pub enum Error {
27    #[error("bad result: {}", .0)]
28    BadDatabaseResponse(&'static str),
29    #[error("postgres: {:?}", .0)]
30    Postgres(#[from] PgError),
31}
32
33#[cfg(feature = "postgres")]
34pub fn execute(
35    client: &mut impl postgres::GenericClient,
36    program: &rr::SqlProgram,
37    input: &[u8],
38) -> Result<Vec<u8>, Error> {
39    // prepare
40    let def = client.prepare(&program.sql)?;
41
42    let ctx = Context::new(&program.defs);
43
44    // pack input into query args
45    let args = params::to_sql(program, input, &ctx);
46
47    // execute
48    let rows = client.query(&def, &args.as_refs())?;
49
50    // convert result from sql
51    result::from_sql(program, &rows, &ctx)
52}
53
54#[cfg(feature = "tokio-postgres")]
55pub struct RunnerAsync<C: tokio_postgres::GenericClient = tokio_postgres::Client> {
56    client: C,
57}
58
59impl<C> RunnerAsync<C>
60where
61    C: tokio_postgres::GenericClient,
62{
63    pub fn new(client: C) -> Self {
64        RunnerAsync { client }
65    }
66
67    pub fn into_inner(self) -> C {
68        self.client
69    }
70}
71
72impl RunnerAsync<tokio_postgres::Client> {
73    /// Helper for [tokio_postgres::connect] and [RunnerAsync::new].
74    pub async fn connect_no_tls(config: &str) -> Result<Self, Error> {
75        let (client, conn) = tokio_postgres::connect(config, tokio_postgres::NoTls).await?;
76        tokio::task::spawn(async {
77            if let Err(e) = conn.await {
78                eprintln!("{e}");
79            }
80        });
81
82        Ok(Self::new(client))
83    }
84}
85
86#[derive(Clone)]
87pub struct PreparedProgram {
88    program: rr::SqlProgram,
89    stmt: tokio_postgres::Statement,
90}
91
92#[cfg(feature = "tokio-postgres")]
93impl<C> lutra_runner::Run for RunnerAsync<C>
94where
95    C: tokio_postgres::GenericClient,
96{
97    type Error = Error;
98    type Prepared = PreparedProgram;
99
100    async fn prepare(&self, program: rr::Program) -> Result<Self::Prepared, Self::Error> {
101        let program = *program.into_sql_postgres().unwrap();
102
103        let stmt = self.client.prepare(&program.sql).await?;
104
105        Ok(PreparedProgram { program, stmt })
106    }
107
108    async fn execute(
109        &self,
110        handle: &Self::Prepared,
111        input: &[u8],
112    ) -> Result<std::vec::Vec<u8>, Self::Error> {
113        let ctx = Context::new(&handle.program.defs);
114
115        // pack input into query args
116        let args = params::to_sql(&handle.program, input, &ctx);
117
118        let rows = self.client.query(&handle.stmt, &args.as_refs()).await?;
119
120        // convert result from sql
121        result::from_sql(&handle.program, &rows, &ctx)
122    }
123
124    async fn get_interface(&self) -> Result<std::string::String, Self::Error> {
125        Ok(crate::schema::pull_interface(self).await?)
126    }
127}
128
129struct Context<'a> {
130    pub types: HashMap<&'a ir::Path, &'a ir::Ty>,
131}
132
133impl<'a> Context<'a> {
134    fn new(ty_defs: &'a [ir::TyDef]) -> Self {
135        Context {
136            types: ty_defs.iter().map(|def| (&def.name, &def.ty)).collect(),
137        }
138    }
139
140    fn get_ty_mat(&self, ty: &'a ir::Ty) -> &'a ir::Ty {
141        match &ty.kind {
142            ir::TyKind::Ident(path) => self.types.get(path).unwrap(),
143            _ => ty,
144        }
145    }
146
147    /// Checks if an enum is an "option" enum. Must match [lutra_compiler::sql::utils::is_option].
148    fn is_option(&self, variants: &[ir::TyEnumVariant]) -> bool {
149        if variants.len() != 2 || !variants[0].ty.is_unit() {
150            return false;
151        }
152        let some_ty = self.get_ty_mat(&variants[1].ty);
153        some_ty.kind.is_primitive() || some_ty.kind.is_array()
154    }
155}