use crate::helpers::IMPORT_PREFIX;
use crate::quoting::{AttemptedKeywordUsage, Quotable};
use crate::schema_reader::SchemaReader;
use crate::storage::postgres::connection_pool::ConnectionPool;
use crate::storage::postgres::postgres_instance_storage::PostgresInstanceStorage;
use crate::{
AsyncCleanup, CopyDestination, IdentifierQuoter, PostgresClientWrapper, PostgresDatabase,
PostgresSchema, PostgresTable, TableData,
};
use bytes::Bytes;
use futures::{pin_mut, SinkExt, Stream, StreamExt};
use itertools::Itertools;
use std::collections::HashSet;
use std::sync::Arc;
use tracing::{error, info, instrument};
#[derive(Clone)]
pub struct ParallelSafePostgresInstanceCopyDestinationStorage<'a> {
connection_pool: ConnectionPool,
main_connection: &'a PostgresClientWrapper,
identifier_quoter: Arc<IdentifierQuoter>,
in_flight_statements: Arc<tokio::sync::Mutex<HashSet<String>>>,
}
impl<'a> ParallelSafePostgresInstanceCopyDestinationStorage<'a> {
pub async fn new(storage: &PostgresInstanceStorage<'a>) -> crate::Result<Self> {
let main_connection = storage.connection;
main_connection.execute_non_query(IMPORT_PREFIX).await?;
Ok(ParallelSafePostgresInstanceCopyDestinationStorage {
connection_pool: ConnectionPool::new(),
main_connection,
identifier_quoter: storage.identifier_quoter.clone(),
in_flight_statements: Arc::new(tokio::sync::Mutex::new(HashSet::new())),
})
}
async fn get_connection(&self) -> crate::Result<PostgresClientWrapper> {
if let Some(existing) = self.connection_pool.get_connection().await {
Ok(existing)
} else {
let new_conn = self.main_connection.create_another_connection().await?;
new_conn.execute_non_query(IMPORT_PREFIX).await?;
Ok(new_conn)
}
}
async fn release_connection(&self, connection: PostgresClientWrapper) {
self.connection_pool.release_connection(connection).await;
}
}
impl<'a> CopyDestination for ParallelSafePostgresInstanceCopyDestinationStorage<'a> {
async fn apply_data<S: Stream<Item = crate::Result<Bytes>> + Send, C: AsyncCleanup>(
&mut self,
schema: &PostgresSchema,
table: &PostgresTable,
data: TableData<S, C>,
) -> crate::Result<()> {
let data_format = data.data_format;
let copy_statement =
table.get_copy_in_command(schema, &data_format, &self.identifier_quoter);
let connection = self.get_connection().await?;
let sink = connection.copy_in::<Bytes>(©_statement).await?;
pin_mut!(sink);
let stream = data.data;
pin_mut!(stream);
while let Some(item) = stream.next().await {
let item = item?;
sink.feed(item).await?;
}
sink.close().await?;
data.cleanup.cleanup().await?;
self.release_connection(connection).await;
Ok(())
}
#[instrument(skip(self))]
async fn apply_transactional_statement(&mut self, statement: &str) -> crate::Result<()> {
info!("Executing transactional statement");
self.main_connection.execute_non_query(statement).await?;
info!("Executed transactional statement");
Ok(())
}
#[instrument(skip(self))]
async fn apply_non_transactional_statement(&mut self, statement: &str) -> crate::Result<()> {
let in_flight_when_started = {
let mut in_flight_statements = self.in_flight_statements.lock().await;
in_flight_statements.insert(statement.to_string());
in_flight_statements.iter().cloned().collect_vec()
};
info!("Executing non-transactional statement");
let connection = self.get_connection().await?;
let result = connection.execute_non_query(statement).await;
{
let mut in_flight_statements = self.in_flight_statements.lock().await;
if let Err(e) = result {
error!(
"Error occurred. In flight statements: {:?}. In flight when started: {:?}",
in_flight_statements, in_flight_when_started
);
return Err(e);
}
in_flight_statements.remove(statement);
}
self.release_connection(connection).await;
info!("Executed non-transactional statement");
Ok(())
}
#[instrument(skip(self))]
async fn begin_transaction(&mut self) -> crate::Result<()> {
self.main_connection
.execute_non_query("begin transaction isolation level serializable read write;")
.await?;
Ok(())
}
#[instrument(skip(self))]
async fn commit_transaction(&mut self) -> crate::Result<()> {
self.main_connection.execute_non_query("commit;").await?;
Ok(())
}
fn get_identifier_quoter(&self) -> Arc<IdentifierQuoter> {
self.identifier_quoter.clone()
}
async fn try_introspect(&self) -> crate::Result<Option<PostgresDatabase>> {
let reader = SchemaReader::new(self.main_connection);
reader.introspect_database().await.map(Some)
}
async fn has_data_in_table(
&self,
schema: &PostgresSchema,
table: &PostgresTable,
) -> crate::Result<bool> {
let schema_name = schema.name.quote(
&self.identifier_quoter,
AttemptedKeywordUsage::TypeOrFunctionName,
);
let table_name = table.name.quote(
&self.identifier_quoter,
AttemptedKeywordUsage::TypeOrFunctionName,
);
let query = format!(
"select exists(select 1 from {}.{} limit 1);",
schema_name, table_name
);
let result = self
.main_connection
.get_single_result::<bool>(&query)
.await?;
Ok(result)
}
}