use anyhow::Result;
use async_trait::async_trait;
use deadpool_postgres::Pool;
use crate::diff::table::query::input::{
QueryHashDataInput, QueryPrimaryKeysInput, QueryTableCountInput, QueryTableNamesInput,
};
use crate::diff::table::query::table_query::TableQuery;
use crate::diff::table::query::table_types::{IncludedExcludedTables, TableName};
#[cfg(test)]
use mockall::automock;
#[cfg_attr(test, automock)]
#[async_trait]
pub trait TableSingleSourceQueryExecutor {
async fn query_table_names(&self, input: QueryTableNamesInput) -> Vec<String>;
async fn query_primary_keys(&self, input: QueryPrimaryKeysInput) -> Vec<String>;
}
pub struct TableSingleSourceQueryExecutorImpl {
db_pool: Pool,
}
impl TableSingleSourceQueryExecutorImpl {
pub fn new(db_pool: Pool) -> Self {
Self { db_pool }
}
}
#[async_trait]
impl TableSingleSourceQueryExecutor for TableSingleSourceQueryExecutorImpl {
async fn query_table_names(&self, input: QueryTableNamesInput) -> Vec<String> {
let client = self.db_pool.get().await.unwrap();
let all_tables_query = TableQuery::AllTablesForSchema(
input.schema_name().to_owned(),
IncludedExcludedTables::new(input.included_tables(), input.excluded_tables()),
);
let query_result = client
.query(&all_tables_query.to_string(), &[])
.await
.unwrap();
query_result
.iter()
.map(|row| row.get("table_name"))
.collect::<Vec<String>>()
}
async fn query_primary_keys(&self, input: QueryPrimaryKeysInput) -> Vec<String> {
let client = self.db_pool.get().await.unwrap();
let find_primary_key_query =
TableQuery::FindPrimaryKeyForTable(TableName::new(input.table_name()));
let query_result = client
.query(&find_primary_key_query.to_string(), &[])
.await
.unwrap();
query_result
.iter()
.map(|row| row.get("attname"))
.collect::<Vec<String>>()
}
}
#[cfg_attr(test, automock)]
#[async_trait]
pub trait TableDualSourceQueryExecutor {
async fn query_table_count(&self, input: QueryTableCountInput) -> (Result<i64>, Result<i64>);
async fn query_hash_data(&self, input: QueryHashDataInput) -> (String, String);
}
pub struct TableDualSourceQueryExecutorImpl {
first_db_client: Pool,
second_db_client: Pool,
}
impl TableDualSourceQueryExecutorImpl {
pub fn new(first_db_client: Pool, second_db_client: Pool) -> Self {
Self {
first_db_client,
second_db_client,
}
}
}
#[async_trait]
impl TableDualSourceQueryExecutor for TableDualSourceQueryExecutorImpl {
async fn query_table_count(&self, input: QueryTableCountInput) -> (Result<i64>, Result<i64>) {
let first_client = self.first_db_client.get().await.unwrap();
let second_client = self.second_db_client.get().await.unwrap();
let count_rows_query = TableQuery::CountRowsForTable(
input.schema_name().to_owned(),
input.table_name().to_owned(),
);
let count_query_binding = count_rows_query.to_string();
let first_count = first_client.query_one(&count_query_binding, &[]);
let second_count = second_client.query_one(&count_query_binding, &[]);
let count_fetch_futures = futures::future::join_all(vec![first_count, second_count]).await;
let first_count = count_fetch_futures.first().unwrap();
let second_count = count_fetch_futures.get(1).unwrap();
let first_count: Result<i64> = match first_count {
Ok(pg_row) => Ok(pg_row.get("count")),
Err(_e) => Err(anyhow::anyhow!("Failed to fetch count for first table")),
};
let second_count: Result<i64> = match second_count {
Ok(pg_row) => Ok(pg_row.get("count")),
Err(_e) => Err(anyhow::anyhow!("Failed to fetch count for second table")),
};
(first_count, second_count)
}
async fn query_hash_data(&self, input: QueryHashDataInput) -> (String, String) {
let first_client = self.first_db_client.get().await.unwrap();
let second_client = self.second_db_client.get().await.unwrap();
let hash_query = TableQuery::HashQuery(
input.schema_name(),
input.table_name(),
input.primary_keys(),
input.position(),
input.offset(),
);
let hash_query_binding = hash_query.to_string();
let first_hash = first_client.query_one(&hash_query_binding, &[]);
let second_hash = second_client.query_one(&hash_query_binding, &[]);
let hash_fetch_futures = futures::future::join_all(vec![first_hash, second_hash]).await;
let first_hash = hash_fetch_futures.first().unwrap();
let second_hash = hash_fetch_futures.get(1).unwrap();
let first_hash = match first_hash {
Ok(pg_row) => pg_row.try_get("md5").unwrap_or("not_available".to_string()),
Err(e) => e.to_string(),
};
let second_hash = match second_hash {
Ok(pg_row) => pg_row.try_get("md5").unwrap_or("not_available".to_string()),
Err(e) => e.to_string(),
};
(first_hash, second_hash)
}
}