qrlew/io/
postgresql.rs

1//! An object creating a docker container and releasing it after use
2//!
3
4use super::{Database as DatabaseTrait, Error, Result, DATA_GENERATION_SEED};
5use crate::{
6    data_type::{
7        generator::Generator,
8        value::{self, Value},
9        DataTyped,
10    },
11    dialect_translation::postgresql::PostgreSqlTranslator,
12    namer,
13    relation::{Table, Variant as _},
14};
15use std::{env, fmt, process::Command, str::FromStr, sync::Arc, sync::Mutex, thread, time};
16
17use colored::Colorize;
18use postgres::{
19    self,
20    types::{FromSql, ToSql, Type},
21};
22use r2d2::Pool;
23use r2d2_postgres::{postgres::NoTls, PostgresConnectionManager};
24use rand::{rngs::StdRng, SeedableRng};
25use rust_decimal::{prelude::ToPrimitive, Decimal};
26
27const DB: &str = "qrlew-test";
28const PORT: usize = 5432;
29const USER: &str = "postgres";
30const PASSWORD: &str = "qrlew-test";
31
32/// Converts sqlite errors to io errors
33impl From<postgres::Error> for Error {
34    fn from(err: postgres::Error) -> Self {
35        Error::Other(err.to_string())
36    }
37}
38
39pub struct Database {
40    name: String,
41    tables: Vec<Table>,
42    pool: Pool<PostgresConnectionManager<NoTls>>,
43    drop: bool,
44}
45
46/// Only one pool
47pub static POSTGRES_POOL: Mutex<Option<Pool<PostgresConnectionManager<NoTls>>>> = Mutex::new(None);
48/// Only one thread start a container
49pub static POSTGRES_CONTAINER: Mutex<bool> = Mutex::new(false);
50
51impl Database {
52    // fn db() -> String {
53    //     env::var("POSTGRES_DB").unwrap_or(DB.into())
54    // }
55
56    fn port() -> usize {
57        match env::var("POSTGRES_PORT") {
58            Ok(port) => usize::from_str(&port).unwrap_or(PORT),
59            Err(_) => PORT,
60        }
61    }
62
63    fn user() -> String {
64        env::var("POSTGRES_USER").unwrap_or(USER.into())
65    }
66
67    fn password() -> String {
68        env::var("POSTGRES_PASSWORD").unwrap_or(PASSWORD.into())
69    }
70
71    /// Try to build a pool from an existing DB
72    /// A postgresql instance must exist
73    /// `docker run --name qrlew-test -p 5432:5432 -e POSTGRES_PASSWORD=qrlew-test -d postgres`
74    fn build_pool_from_existing() -> Result<Pool<PostgresConnectionManager<NoTls>>> {
75        let manager = PostgresConnectionManager::new(
76            format!(
77                "host=localhost port={} user={} password={}",
78                Database::port(),
79                Database::user(),
80                Database::password()
81            )
82            .parse()?,
83            NoTls,
84        );
85        Ok(r2d2::Pool::builder().max_size(10).build(manager)?)
86    }
87
88    /// Try to build a pool from a DB in a container
89    fn build_pool_from_container(name: String) -> Result<Pool<PostgresConnectionManager<NoTls>>> {
90        let mut postgres_container = POSTGRES_CONTAINER.lock().unwrap();
91        if *postgres_container == false {
92            // A new container will be started
93            *postgres_container = true;
94            // Other threads will wait for this to be ready
95            let name = namer::new_name(name);
96            let port = PORT + namer::new_id("pg-port");
97            // Test the connexion and launch a test instance if necessary
98            if !Command::new("docker")
99                .arg("start")
100                .arg(&name)
101                .status()?
102                .success()
103            {
104                log::debug!("Starting the DB");
105                // If the container does not exist
106                // Start a new container
107                // Run: `docker run --name test-db -e POSTGRES_PASSWORD=test -d postgres`
108                let output = Command::new("docker")
109                    .arg("run")
110                    .arg("--name")
111                    .arg(&name)
112                    .arg("-d")
113                    .arg("--rm")
114                    .arg("-e")
115                    .arg(format!("POSTGRES_PASSWORD={PASSWORD}"))
116                    .arg("-p")
117                    .arg(format!("{}:5432", port))
118                    .arg("postgres")
119                    .output()?;
120                log::info!("{:?}", output);
121                log::info!("Waiting for the DB to start");
122                while !Command::new("docker")
123                    .arg("exec")
124                    .arg(&name)
125                    .arg("pg_isready")
126                    .status()?
127                    .success()
128                {
129                    thread::sleep(time::Duration::from_millis(200));
130                    log::info!("Waiting...");
131                }
132                log::info!("{}", "DB ready".red());
133            }
134            let manager = PostgresConnectionManager::new(
135                format!("host=localhost port={port} user={USER} password={PASSWORD}").parse()?,
136                NoTls,
137            );
138            Ok(r2d2::Pool::builder().max_size(10).build(manager)?)
139        } else {
140            Database::build_pool_from_existing()
141        }
142    }
143}
144
145impl fmt::Debug for Database {
146    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147        f.debug_struct("Database")
148            .field("name", &self.name)
149            .field("tables", &self.tables)
150            .finish()
151    }
152}
153
154impl DatabaseTrait for Database {
155    fn new(name: String, tables: Vec<Table>) -> Result<Self> {
156        let mut postgres_pool = POSTGRES_POOL.lock().unwrap();
157        if let None = *postgres_pool {
158            *postgres_pool = Some(
159                Database::build_pool_from_existing()
160                    .or_else(|_| Database::build_pool_from_container(name.clone()))?,
161            );
162        }
163        let pool = postgres_pool.as_ref().unwrap().clone();
164        let table_names: Vec<String> = pool
165            .get()?
166            .query(
167                "SELECT * FROM pg_catalog.pg_tables WHERE schemaname='public'",
168                &[],
169            )?
170            .into_iter()
171            .map(|row| row.get("tablename"))
172            .collect();
173        if table_names.is_empty() {
174            Database {
175                name,
176                tables: vec![],
177                pool,
178                drop: false,
179            }
180            .with_tables(tables)
181        } else {
182            Ok(Database {
183                name,
184                tables,
185                pool,
186                drop: false,
187            })
188        }
189    }
190
191    fn name(&self) -> &str {
192        &self.name
193    }
194
195    fn tables(&self) -> &[Table] {
196        &self.tables
197    }
198
199    fn tables_mut(&mut self) -> &mut Vec<Table> {
200        &mut self.tables
201    }
202
203    fn create_table(&mut self, table: &Table) -> Result<usize> {
204        let mut connection = self.pool.get()?;
205        let _qq = table.create(PostgreSqlTranslator).to_string();
206        Ok(connection.execute(&table.create(PostgreSqlTranslator).to_string(), &[])? as usize)
207    }
208
209    fn insert_data(&mut self, table: &Table) -> Result<()> {
210        let mut rng = StdRng::seed_from_u64(DATA_GENERATION_SEED);
211        let size = Database::MAX_SIZE.min(table.size().generate(&mut rng) as usize);
212        let mut connection = self.pool.get()?;
213        let statement = connection.prepare(&table.insert("$", PostgreSqlTranslator).to_string())?;
214        for _ in 0..size {
215            let structured: value::Struct =
216                table.schema().data_type().generate(&mut rng).try_into()?;
217            let values: Result<Vec<SqlValue>> = structured
218                .into_iter()
219                .map(|(_, v)| (**v).clone().try_into())
220                .collect();
221            let values = values?;
222            let params: Vec<&(dyn ToSql + Sync)> =
223                values.iter().map(|v| v as &(dyn ToSql + Sync)).collect();
224            connection.execute(&statement, &params)?;
225        }
226        Ok(())
227    }
228
229    fn query(&mut self, query: &str) -> Result<Vec<value::List>> {
230        let rows: Vec<_>;
231        {
232            let mut connection = self.pool.get()?;
233            let statement = connection.prepare(query)?;
234            rows = connection.query(&statement, &[])?;
235        }
236        Ok(rows
237            .into_iter()
238            .map(|r| {
239                let values: Vec<SqlValue> = (0..r.len()).into_iter().map(|i| r.get(i)).collect();
240                value::List::from_iter(values.into_iter().map(|v| v.try_into().expect("Convert")))
241            })
242            .collect())
243    }
244}
245
246impl Drop for Database {
247    fn drop(&mut self) {
248        if self.drop {
249            Command::new("docker")
250                .arg("rm")
251                .arg("--force")
252                .arg(self.name())
253                .status()
254                .expect("Deleted container");
255        }
256    }
257}
258
259#[derive(Debug, Clone)]
260enum SqlValue {
261    Boolean(value::Boolean),
262    Integer(value::Integer),
263    Float(value::Float),
264    Text(value::Text),
265    Optional(Option<Box<SqlValue>>),
266    Date(value::Date),
267    Time(value::Time),
268    DateTime(value::DateTime),
269    Id(value::Id),
270}
271
272impl TryFrom<Value> for SqlValue {
273    type Error = Error;
274
275    fn try_from(value: Value) -> Result<Self> {
276        match value {
277            Value::Boolean(b) => Ok(SqlValue::Boolean(b)),
278            Value::Integer(i) => Ok(SqlValue::Integer(i)),
279            Value::Float(f) => Ok(SqlValue::Float(f)),
280            Value::Text(t) => Ok(SqlValue::Text(t)),
281            Value::Optional(o) => o
282                .as_deref()
283                .map(|v| SqlValue::try_from(v.clone()))
284                .map_or(Ok(None), |r| r.map(|v| Some(Box::new(v))))
285                .map(|o| SqlValue::Optional(o)),
286            Value::Date(d) => Ok(SqlValue::Date(d)),
287            Value::Time(t) => Ok(SqlValue::Time(t)),
288            Value::DateTime(d) => Ok(SqlValue::DateTime(d)),
289            Value::Id(i) => Ok(SqlValue::Id(i)),
290            _ => Err(Error::other(value)),
291        }
292    }
293}
294
295impl TryFrom<SqlValue> for Value {
296    type Error = Error;
297
298    fn try_from(value: SqlValue) -> Result<Self> {
299        match value {
300            SqlValue::Boolean(b) => Ok(Value::Boolean(b)),
301            SqlValue::Integer(i) => Ok(Value::Integer(i)),
302            SqlValue::Float(f) => Ok(Value::Float(f)),
303            SqlValue::Text(t) => Ok(Value::Text(t)),
304            SqlValue::Optional(o) => o
305                .map(|v| Value::try_from(*v))
306                .map_or(Ok(None), |r| r.map(|v| Some(Arc::new(v))))
307                .map(|o| Value::from(o)),
308            SqlValue::Date(d) => Ok(Value::Date(d)),
309            SqlValue::Time(t) => Ok(Value::Time(t)),
310            SqlValue::DateTime(d) => Ok(Value::DateTime(d)),
311            SqlValue::Id(i) => Ok(Value::Id(i)),
312        }
313    }
314}
315
316impl ToSql for SqlValue {
317    fn to_sql(
318        &self,
319        ty: &Type,
320        out: &mut postgres::types::private::BytesMut,
321    ) -> std::result::Result<postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>>
322    where
323        Self: Sized,
324    {
325        match self {
326            SqlValue::Boolean(b) => b.to_sql(ty, out),
327            SqlValue::Integer(i) => i.to_sql(ty, out),
328            SqlValue::Float(f) => f.to_sql(ty, out),
329            SqlValue::Text(t) => t.to_sql(ty, out),
330            SqlValue::Optional(o) => o.as_deref().to_sql(ty, out),
331            SqlValue::Date(d) => d.to_sql(ty, out),
332            SqlValue::Time(t) => t.to_sql(ty, out),
333            SqlValue::DateTime(d) => d.to_sql(ty, out),
334            SqlValue::Id(i) => i.to_sql(ty, out),
335        }
336    }
337
338    postgres::types::accepts!(
339        BOOL, INT2, INT4, INT8, NUMERIC, FLOAT4, FLOAT8, NUMERIC, VARCHAR, TEXT, DATE, TIME,
340        TIMESTAMP
341    );
342
343    postgres::types::to_sql_checked!();
344}
345
346impl<'a> FromSql<'a> for SqlValue {
347    fn from_sql(
348        ty: &Type,
349        raw: &'a [u8],
350    ) -> std::result::Result<Self, Box<dyn std::error::Error + Sync + Send>> {
351        match ty {
352            &Type::BOOL => bool::from_sql(ty, raw).map(|b| SqlValue::Boolean(b.into())),
353            // &Type::INT4 | &Type::INT8 => {
354            //     i64::from_sql(ty, raw).map(|i| SqlValue::Integer(i.into()))
355            // }
356            &Type::INT4 => i32::from_sql(ty, raw).map(|i| SqlValue::Integer((i as i64).into())),
357            &Type::INT8 => i64::from_sql(ty, raw).map(|i| SqlValue::Integer(i.into())),
358            &Type::FLOAT4 | &Type::FLOAT8 => {
359                f64::from_sql(ty, raw).map(|f| SqlValue::Float(f.into()))
360            }
361            &Type::NUMERIC => Decimal::from_sql(ty, raw)
362                .map(|d| SqlValue::Float(d.to_f64().unwrap_or_default().into())),
363            &Type::VARCHAR | &Type::TEXT => {
364                String::from_sql(ty, raw).map(|s| SqlValue::Text(s.into()))
365            }
366            &Type::DATE => chrono::NaiveDate::from_sql(ty, raw).map(|d| SqlValue::Date(d.into())),
367            &Type::TIME => chrono::NaiveTime::from_sql(ty, raw).map(|t| SqlValue::Time(t.into())),
368            &Type::TIMESTAMP => {
369                chrono::NaiveDateTime::from_sql(ty, raw).map(|d| SqlValue::DateTime(d.into()))
370            }
371            _ => todo!(),
372        }
373    }
374
375    fn from_sql_null(
376        _ty: &Type,
377    ) -> std::result::Result<Self, Box<dyn std::error::Error + Sync + Send>> {
378        Ok(SqlValue::Optional(None))
379    }
380
381    postgres::types::accepts!(
382        BOOL, INT2, INT4, INT8, FLOAT4, FLOAT8, NUMERIC, VARCHAR, TEXT, DATE, TIME, TIMESTAMP
383    );
384}
385
386pub fn test_database() -> Database {
387    // Database::test()
388    Database::new(DB.into(), Database::test_tables()).expect("Database")
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn database_display() -> Result<()> {
397        let mut database = test_database();
398        for query in [
399            "SELECT count(a), 1+sum(a), d FROM table_1 group by d",
400            "SELECT AVG(x) as a FROM table_2",
401            "SELECT 1+count(y) as a, sum(1+x) as b FROM table_2",
402            "WITH cte AS (SELECT * FROM table_1) SELECT * FROM cte",
403            "SELECT * FROM table_2",
404        ] {
405            println!("\n{query}");
406            for row in database.query(query)? {
407                println!("{}", row);
408            }
409        }
410        Ok(())
411    }
412
413    #[test]
414    fn database_test() -> Result<()> {
415        let mut database = test_database();
416        println!("Pool {}", database.pool.max_size());
417        assert!(!database.eq("SELECT * FROM table_1", "SELECT * FROM table_2"));
418        assert!(database.eq(
419            "SELECT * FROM table_1",
420            "WITH cte AS (SELECT * FROM table_1) SELECT * FROM cte"
421        ));
422        Ok(())
423    }
424}