1use super::{Database as DatabaseTrait, Error, Result, DATA_GENERATION_SEED};
5use crate::{
6 data_type::{
7 generator::Generator,
8 value::{self, Value},
9 DataTyped,
10 },
11 dialect_translation::postgresql::PostgreSqlTranslator,
12 namer,
13 relation::{Table, Variant as _},
14};
15use std::{env, fmt, process::Command, str::FromStr, sync::Arc, sync::Mutex, thread, time};
16
17use colored::Colorize;
18use postgres::{
19 self,
20 types::{FromSql, ToSql, Type},
21};
22use r2d2::Pool;
23use r2d2_postgres::{postgres::NoTls, PostgresConnectionManager};
24use rand::{rngs::StdRng, SeedableRng};
25use rust_decimal::{prelude::ToPrimitive, Decimal};
26
27const DB: &str = "qrlew-test";
28const PORT: usize = 5432;
29const USER: &str = "postgres";
30const PASSWORD: &str = "qrlew-test";
31
32impl From<postgres::Error> for Error {
34 fn from(err: postgres::Error) -> Self {
35 Error::Other(err.to_string())
36 }
37}
38
39pub struct Database {
40 name: String,
41 tables: Vec<Table>,
42 pool: Pool<PostgresConnectionManager<NoTls>>,
43 drop: bool,
44}
45
46pub static POSTGRES_POOL: Mutex<Option<Pool<PostgresConnectionManager<NoTls>>>> = Mutex::new(None);
48pub static POSTGRES_CONTAINER: Mutex<bool> = Mutex::new(false);
50
51impl Database {
52 fn port() -> usize {
57 match env::var("POSTGRES_PORT") {
58 Ok(port) => usize::from_str(&port).unwrap_or(PORT),
59 Err(_) => PORT,
60 }
61 }
62
63 fn user() -> String {
64 env::var("POSTGRES_USER").unwrap_or(USER.into())
65 }
66
67 fn password() -> String {
68 env::var("POSTGRES_PASSWORD").unwrap_or(PASSWORD.into())
69 }
70
71 fn build_pool_from_existing() -> Result<Pool<PostgresConnectionManager<NoTls>>> {
75 let manager = PostgresConnectionManager::new(
76 format!(
77 "host=localhost port={} user={} password={}",
78 Database::port(),
79 Database::user(),
80 Database::password()
81 )
82 .parse()?,
83 NoTls,
84 );
85 Ok(r2d2::Pool::builder().max_size(10).build(manager)?)
86 }
87
88 fn build_pool_from_container(name: String) -> Result<Pool<PostgresConnectionManager<NoTls>>> {
90 let mut postgres_container = POSTGRES_CONTAINER.lock().unwrap();
91 if *postgres_container == false {
92 *postgres_container = true;
94 let name = namer::new_name(name);
96 let port = PORT + namer::new_id("pg-port");
97 if !Command::new("docker")
99 .arg("start")
100 .arg(&name)
101 .status()?
102 .success()
103 {
104 log::debug!("Starting the DB");
105 let output = Command::new("docker")
109 .arg("run")
110 .arg("--name")
111 .arg(&name)
112 .arg("-d")
113 .arg("--rm")
114 .arg("-e")
115 .arg(format!("POSTGRES_PASSWORD={PASSWORD}"))
116 .arg("-p")
117 .arg(format!("{}:5432", port))
118 .arg("postgres")
119 .output()?;
120 log::info!("{:?}", output);
121 log::info!("Waiting for the DB to start");
122 while !Command::new("docker")
123 .arg("exec")
124 .arg(&name)
125 .arg("pg_isready")
126 .status()?
127 .success()
128 {
129 thread::sleep(time::Duration::from_millis(200));
130 log::info!("Waiting...");
131 }
132 log::info!("{}", "DB ready".red());
133 }
134 let manager = PostgresConnectionManager::new(
135 format!("host=localhost port={port} user={USER} password={PASSWORD}").parse()?,
136 NoTls,
137 );
138 Ok(r2d2::Pool::builder().max_size(10).build(manager)?)
139 } else {
140 Database::build_pool_from_existing()
141 }
142 }
143}
144
145impl fmt::Debug for Database {
146 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
147 f.debug_struct("Database")
148 .field("name", &self.name)
149 .field("tables", &self.tables)
150 .finish()
151 }
152}
153
154impl DatabaseTrait for Database {
155 fn new(name: String, tables: Vec<Table>) -> Result<Self> {
156 let mut postgres_pool = POSTGRES_POOL.lock().unwrap();
157 if let None = *postgres_pool {
158 *postgres_pool = Some(
159 Database::build_pool_from_existing()
160 .or_else(|_| Database::build_pool_from_container(name.clone()))?,
161 );
162 }
163 let pool = postgres_pool.as_ref().unwrap().clone();
164 let table_names: Vec<String> = pool
165 .get()?
166 .query(
167 "SELECT * FROM pg_catalog.pg_tables WHERE schemaname='public'",
168 &[],
169 )?
170 .into_iter()
171 .map(|row| row.get("tablename"))
172 .collect();
173 if table_names.is_empty() {
174 Database {
175 name,
176 tables: vec![],
177 pool,
178 drop: false,
179 }
180 .with_tables(tables)
181 } else {
182 Ok(Database {
183 name,
184 tables,
185 pool,
186 drop: false,
187 })
188 }
189 }
190
191 fn name(&self) -> &str {
192 &self.name
193 }
194
195 fn tables(&self) -> &[Table] {
196 &self.tables
197 }
198
199 fn tables_mut(&mut self) -> &mut Vec<Table> {
200 &mut self.tables
201 }
202
203 fn create_table(&mut self, table: &Table) -> Result<usize> {
204 let mut connection = self.pool.get()?;
205 let _qq = table.create(PostgreSqlTranslator).to_string();
206 Ok(connection.execute(&table.create(PostgreSqlTranslator).to_string(), &[])? as usize)
207 }
208
209 fn insert_data(&mut self, table: &Table) -> Result<()> {
210 let mut rng = StdRng::seed_from_u64(DATA_GENERATION_SEED);
211 let size = Database::MAX_SIZE.min(table.size().generate(&mut rng) as usize);
212 let mut connection = self.pool.get()?;
213 let statement = connection.prepare(&table.insert("$", PostgreSqlTranslator).to_string())?;
214 for _ in 0..size {
215 let structured: value::Struct =
216 table.schema().data_type().generate(&mut rng).try_into()?;
217 let values: Result<Vec<SqlValue>> = structured
218 .into_iter()
219 .map(|(_, v)| (**v).clone().try_into())
220 .collect();
221 let values = values?;
222 let params: Vec<&(dyn ToSql + Sync)> =
223 values.iter().map(|v| v as &(dyn ToSql + Sync)).collect();
224 connection.execute(&statement, ¶ms)?;
225 }
226 Ok(())
227 }
228
229 fn query(&mut self, query: &str) -> Result<Vec<value::List>> {
230 let rows: Vec<_>;
231 {
232 let mut connection = self.pool.get()?;
233 let statement = connection.prepare(query)?;
234 rows = connection.query(&statement, &[])?;
235 }
236 Ok(rows
237 .into_iter()
238 .map(|r| {
239 let values: Vec<SqlValue> = (0..r.len()).into_iter().map(|i| r.get(i)).collect();
240 value::List::from_iter(values.into_iter().map(|v| v.try_into().expect("Convert")))
241 })
242 .collect())
243 }
244}
245
246impl Drop for Database {
247 fn drop(&mut self) {
248 if self.drop {
249 Command::new("docker")
250 .arg("rm")
251 .arg("--force")
252 .arg(self.name())
253 .status()
254 .expect("Deleted container");
255 }
256 }
257}
258
259#[derive(Debug, Clone)]
260enum SqlValue {
261 Boolean(value::Boolean),
262 Integer(value::Integer),
263 Float(value::Float),
264 Text(value::Text),
265 Optional(Option<Box<SqlValue>>),
266 Date(value::Date),
267 Time(value::Time),
268 DateTime(value::DateTime),
269 Id(value::Id),
270}
271
272impl TryFrom<Value> for SqlValue {
273 type Error = Error;
274
275 fn try_from(value: Value) -> Result<Self> {
276 match value {
277 Value::Boolean(b) => Ok(SqlValue::Boolean(b)),
278 Value::Integer(i) => Ok(SqlValue::Integer(i)),
279 Value::Float(f) => Ok(SqlValue::Float(f)),
280 Value::Text(t) => Ok(SqlValue::Text(t)),
281 Value::Optional(o) => o
282 .as_deref()
283 .map(|v| SqlValue::try_from(v.clone()))
284 .map_or(Ok(None), |r| r.map(|v| Some(Box::new(v))))
285 .map(|o| SqlValue::Optional(o)),
286 Value::Date(d) => Ok(SqlValue::Date(d)),
287 Value::Time(t) => Ok(SqlValue::Time(t)),
288 Value::DateTime(d) => Ok(SqlValue::DateTime(d)),
289 Value::Id(i) => Ok(SqlValue::Id(i)),
290 _ => Err(Error::other(value)),
291 }
292 }
293}
294
295impl TryFrom<SqlValue> for Value {
296 type Error = Error;
297
298 fn try_from(value: SqlValue) -> Result<Self> {
299 match value {
300 SqlValue::Boolean(b) => Ok(Value::Boolean(b)),
301 SqlValue::Integer(i) => Ok(Value::Integer(i)),
302 SqlValue::Float(f) => Ok(Value::Float(f)),
303 SqlValue::Text(t) => Ok(Value::Text(t)),
304 SqlValue::Optional(o) => o
305 .map(|v| Value::try_from(*v))
306 .map_or(Ok(None), |r| r.map(|v| Some(Arc::new(v))))
307 .map(|o| Value::from(o)),
308 SqlValue::Date(d) => Ok(Value::Date(d)),
309 SqlValue::Time(t) => Ok(Value::Time(t)),
310 SqlValue::DateTime(d) => Ok(Value::DateTime(d)),
311 SqlValue::Id(i) => Ok(Value::Id(i)),
312 }
313 }
314}
315
316impl ToSql for SqlValue {
317 fn to_sql(
318 &self,
319 ty: &Type,
320 out: &mut postgres::types::private::BytesMut,
321 ) -> std::result::Result<postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>>
322 where
323 Self: Sized,
324 {
325 match self {
326 SqlValue::Boolean(b) => b.to_sql(ty, out),
327 SqlValue::Integer(i) => i.to_sql(ty, out),
328 SqlValue::Float(f) => f.to_sql(ty, out),
329 SqlValue::Text(t) => t.to_sql(ty, out),
330 SqlValue::Optional(o) => o.as_deref().to_sql(ty, out),
331 SqlValue::Date(d) => d.to_sql(ty, out),
332 SqlValue::Time(t) => t.to_sql(ty, out),
333 SqlValue::DateTime(d) => d.to_sql(ty, out),
334 SqlValue::Id(i) => i.to_sql(ty, out),
335 }
336 }
337
338 postgres::types::accepts!(
339 BOOL, INT2, INT4, INT8, NUMERIC, FLOAT4, FLOAT8, NUMERIC, VARCHAR, TEXT, DATE, TIME,
340 TIMESTAMP
341 );
342
343 postgres::types::to_sql_checked!();
344}
345
346impl<'a> FromSql<'a> for SqlValue {
347 fn from_sql(
348 ty: &Type,
349 raw: &'a [u8],
350 ) -> std::result::Result<Self, Box<dyn std::error::Error + Sync + Send>> {
351 match ty {
352 &Type::BOOL => bool::from_sql(ty, raw).map(|b| SqlValue::Boolean(b.into())),
353 &Type::INT4 => i32::from_sql(ty, raw).map(|i| SqlValue::Integer((i as i64).into())),
357 &Type::INT8 => i64::from_sql(ty, raw).map(|i| SqlValue::Integer(i.into())),
358 &Type::FLOAT4 | &Type::FLOAT8 => {
359 f64::from_sql(ty, raw).map(|f| SqlValue::Float(f.into()))
360 }
361 &Type::NUMERIC => Decimal::from_sql(ty, raw)
362 .map(|d| SqlValue::Float(d.to_f64().unwrap_or_default().into())),
363 &Type::VARCHAR | &Type::TEXT => {
364 String::from_sql(ty, raw).map(|s| SqlValue::Text(s.into()))
365 }
366 &Type::DATE => chrono::NaiveDate::from_sql(ty, raw).map(|d| SqlValue::Date(d.into())),
367 &Type::TIME => chrono::NaiveTime::from_sql(ty, raw).map(|t| SqlValue::Time(t.into())),
368 &Type::TIMESTAMP => {
369 chrono::NaiveDateTime::from_sql(ty, raw).map(|d| SqlValue::DateTime(d.into()))
370 }
371 _ => todo!(),
372 }
373 }
374
375 fn from_sql_null(
376 _ty: &Type,
377 ) -> std::result::Result<Self, Box<dyn std::error::Error + Sync + Send>> {
378 Ok(SqlValue::Optional(None))
379 }
380
381 postgres::types::accepts!(
382 BOOL, INT2, INT4, INT8, FLOAT4, FLOAT8, NUMERIC, VARCHAR, TEXT, DATE, TIME, TIMESTAMP
383 );
384}
385
386pub fn test_database() -> Database {
387 Database::new(DB.into(), Database::test_tables()).expect("Database")
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn database_display() -> Result<()> {
397 let mut database = test_database();
398 for query in [
399 "SELECT count(a), 1+sum(a), d FROM table_1 group by d",
400 "SELECT AVG(x) as a FROM table_2",
401 "SELECT 1+count(y) as a, sum(1+x) as b FROM table_2",
402 "WITH cte AS (SELECT * FROM table_1) SELECT * FROM cte",
403 "SELECT * FROM table_2",
404 ] {
405 println!("\n{query}");
406 for row in database.query(query)? {
407 println!("{}", row);
408 }
409 }
410 Ok(())
411 }
412
413 #[test]
414 fn database_test() -> Result<()> {
415 let mut database = test_database();
416 println!("Pool {}", database.pool.max_size());
417 assert!(!database.eq("SELECT * FROM table_1", "SELECT * FROM table_2"));
418 assert!(database.eq(
419 "SELECT * FROM table_1",
420 "WITH cte AS (SELECT * FROM table_1) SELECT * FROM cte"
421 ));
422 Ok(())
423 }
424}