lutra_runner_postgres/
lib.rs1#[cfg(not(any(feature = "postgres", feature = "tokio-postgres")))]
4compile_error!("At least one of 'postgres' or 'tokio-postgres' features has to be enabled.");
5
6mod case;
7mod params;
8mod result;
9
10#[cfg(feature = "tokio-postgres")]
11mod schema;
12
13pub use lutra_runner::Run;
14
15use std::collections::HashMap;
16use thiserror::Error;
17
18#[cfg(feature = "postgres")]
19use postgres::Error as PgError;
20#[cfg(not(feature = "postgres"))]
21use tokio_postgres::Error as PgError;
22
23use lutra_bin::{ir, rr};
24
25#[derive(Error, Debug)]
26pub enum Error {
27 #[error("bad result: {}", .0)]
28 BadDatabaseResponse(&'static str),
29 #[error("postgres: {:?}", .0)]
30 Postgres(#[from] PgError),
31}
32
33#[cfg(feature = "postgres")]
34pub fn execute(
35 client: &mut impl postgres::GenericClient,
36 program: &rr::SqlProgram,
37 input: &[u8],
38) -> Result<Vec<u8>, Error> {
39 let def = client.prepare(&program.sql)?;
41
42 let ctx = Context::new(&program.defs);
43
44 let args = params::to_sql(program, input, &ctx);
46
47 let rows = client.query(&def, &args.as_refs())?;
49
50 result::from_sql(program, &rows, &ctx)
52}
53
54#[cfg(feature = "tokio-postgres")]
55pub struct RunnerAsync<C: tokio_postgres::GenericClient = tokio_postgres::Client> {
56 client: C,
57}
58
59impl<C> RunnerAsync<C>
60where
61 C: tokio_postgres::GenericClient,
62{
63 pub fn new(client: C) -> Self {
64 RunnerAsync { client }
65 }
66
67 pub fn into_inner(self) -> C {
68 self.client
69 }
70}
71
72impl RunnerAsync<tokio_postgres::Client> {
73 pub async fn connect_no_tls(config: &str) -> Result<Self, Error> {
75 let (client, conn) = tokio_postgres::connect(config, tokio_postgres::NoTls).await?;
76 tokio::task::spawn(async {
77 if let Err(e) = conn.await {
78 eprintln!("{e}");
79 }
80 });
81
82 Ok(Self::new(client))
83 }
84}
85
86#[derive(Clone)]
87pub struct PreparedProgram {
88 program: rr::SqlProgram,
89 stmt: tokio_postgres::Statement,
90}
91
92#[cfg(feature = "tokio-postgres")]
93impl<C> lutra_runner::Run for RunnerAsync<C>
94where
95 C: tokio_postgres::GenericClient,
96{
97 type Error = Error;
98 type Prepared = PreparedProgram;
99
100 async fn prepare(&self, program: rr::Program) -> Result<Self::Prepared, Self::Error> {
101 let program = *program.into_sql_postgres().unwrap();
102
103 let stmt = self.client.prepare(&program.sql).await?;
104
105 Ok(PreparedProgram { program, stmt })
106 }
107
108 async fn execute(
109 &self,
110 handle: &Self::Prepared,
111 input: &[u8],
112 ) -> Result<std::vec::Vec<u8>, Self::Error> {
113 let ctx = Context::new(&handle.program.defs);
114
115 let args = params::to_sql(&handle.program, input, &ctx);
117
118 let rows = self.client.query(&handle.stmt, &args.as_refs()).await?;
119
120 result::from_sql(&handle.program, &rows, &ctx)
122 }
123
124 async fn get_interface(&self) -> Result<std::string::String, Self::Error> {
125 Ok(crate::schema::pull_interface(self).await?)
126 }
127}
128
129struct Context<'a> {
130 pub types: HashMap<&'a ir::Path, &'a ir::Ty>,
131}
132
133impl<'a> Context<'a> {
134 fn new(ty_defs: &'a [ir::TyDef]) -> Self {
135 Context {
136 types: ty_defs.iter().map(|def| (&def.name, &def.ty)).collect(),
137 }
138 }
139
140 fn get_ty_mat(&self, ty: &'a ir::Ty) -> &'a ir::Ty {
141 match &ty.kind {
142 ir::TyKind::Ident(path) => self.types.get(path).unwrap(),
143 _ => ty,
144 }
145 }
146
147 fn is_option(&self, variants: &[ir::TyEnumVariant]) -> bool {
149 if variants.len() != 2 || !variants[0].ty.is_unit() {
150 return false;
151 }
152 let some_ty = self.get_ty_mat(&variants[1].ty);
153 some_ty.kind.is_primitive() || some_ty.kind.is_array()
154 }
155}