use futures::pin_mut;
use itertools::Itertools;
use std::{collections::HashSet, io::prelude::*, str};
use super::{csv_to_binary::copy_csv_to_pg_binary, Client, PostgresLocator};
use crate::drivers::postgres_shared::{
connect, CheckCatalog, Ident, PgCreateTable, PgSchema,
};
use crate::tokio_glue::try_forward;
use crate::transform::spawn_sync_transform;
use crate::{common::*, drivers::postgres_shared::PgCreateType};
#[instrument(level = "trace", skip(client, table), fields(table.name = ?table.name))]
async fn drop_table_if_exists(
client: &mut Client,
table: &PgCreateTable,
) -> Result<()> {
debug!("deleting table {} if exists", table.name.quoted(),);
let drop_sql = format!("DROP TABLE IF EXISTS {}", &table.name.quoted());
let drop_stmt = client.prepare(&drop_sql).await?;
client
.execute(&drop_stmt, &[])
.await
.with_context(|| format!("error deleting existing {}", table.name.quoted()))?;
Ok(())
}
#[instrument(level = "trace", skip(client))]
async fn prepare_types(client: &mut Client, schema: &PgSchema) -> Result<()> {
let needed_types = schema.table()?.named_type_names();
for ty in &schema.types {
if needed_types.contains(&ty.name) {
let existing = PgCreateType::from_database(client, &ty.name).await?;
match existing {
None => {
let create_sql = format!("{}", ty);
debug!("creating type: {}", create_sql);
let create_stmt = client.prepare(&create_sql).await?;
client.execute(&create_stmt, &[]).await?;
}
Some(_) => {
debug!(
"assuming existing {} type in destination is compatible",
ty.name.quoted()
);
}
}
}
}
Ok(())
}
#[instrument(level = "trace", skip(client, schema))]
async fn create_table(client: &mut Client, schema: &PgSchema) -> Result<()> {
prepare_types(client, schema).await?;
let table = schema.table()?;
debug!("create table {}", table.name.quoted());
let create_sql = format!("{}", table);
debug!("CREATE TABLE SQL: {}", create_sql);
let create_stmt = client.prepare(&create_sql).await?;
client
.execute(&create_stmt, &[])
.await
.with_context(|| format!("error creating {}", &table.name.quoted()))?;
Ok(())
}
pub(crate) async fn create_temp_table_for(
client: &mut Client,
schema: &PgSchema,
) -> Result<PgCreateTable> {
let table = schema.table()?;
let mut temp_table = table.to_owned();
let temp_name = table.name.temporary_table_name()?;
temp_table.name = temp_name;
temp_table.if_not_exists = false;
temp_table.temporary = true;
let temp_schema = PgSchema {
tables: vec![temp_table],
..schema.to_owned()
};
create_table(client, &temp_schema).await?;
Ok(temp_schema.table()?.to_owned())
}
#[instrument(level = "trace", skip(client))]
pub(crate) async fn prepare_table(
client: &mut Client,
mut schema: PgSchema,
if_exists: &IfExists,
) -> Result<()> {
let table = schema.table_mut()?;
match if_exists {
IfExists::Overwrite => {
drop_table_if_exists(client, table).await?;
table.if_not_exists = false;
}
IfExists::Append => {
table.if_not_exists = true;
}
IfExists::Error => {
table.if_not_exists = false;
}
IfExists::Upsert(_keys) => {
table.if_not_exists = true;
}
}
create_table(client, &schema).await
}
fn copy_from_sql(table: &PgCreateTable, data_format: &str) -> Result<String> {
let mut copy_sql_buff = vec![];
writeln!(&mut copy_sql_buff, "COPY {} (", table.name.quoted())?;
for (idx, col) in table.columns.iter().enumerate() {
if idx + 1 == table.columns.len() {
writeln!(&mut copy_sql_buff, " {}", Ident(&col.name))?;
} else {
writeln!(&mut copy_sql_buff, " {},", Ident(&col.name))?;
}
}
writeln!(&mut copy_sql_buff, ") FROM STDIN WITH {}", data_format)?;
let copy_sql = str::from_utf8(©_sql_buff)
.expect("generated SQL should always be UTF-8")
.to_owned();
Ok(copy_sql)
}
#[instrument(level = "trace", skip_all, fields(dest.name = ?dest.name))]
async fn copy_from_stream<'a>(
client: &'a mut Client,
dest: &'a PgCreateTable,
stream: BoxStream<BytesMut>,
) -> Result<()> {
debug!("copying data into {:?}", dest.name);
let copy_from_sql = copy_from_sql(dest, "BINARY")?;
let stmt = client.prepare(©_from_sql).await?;
let sink = client
.copy_in::<_, BytesMut>(&stmt)
.await
.with_context(|| format!("error copying data into {}", dest.name.quoted()))?;
pin_mut!(sink);
try_forward(stream, sink).await?;
Ok(())
}
pub(crate) fn columns_to_update_for_upsert<'a>(
dest_table: &'a PgCreateTable,
upsert_keys: &[String],
) -> Result<Vec<&'a str>> {
let upsert_keys_set = upsert_keys
.iter()
.map(|k| &k[..])
.collect::<HashSet<&str>>();
let mut update_cols = vec![];
for c in &dest_table.columns {
if upsert_keys_set.contains(&c.name[..]) {
if c.is_nullable {
return Err(format_err!(
"cannot upsert on column {} because it isn't declared NOT NULL",
Ident(&c.name),
));
}
} else {
update_cols.push(&c.name[..]);
}
}
Ok(update_cols)
}
fn upsert_sql(
src_table: &PgCreateTable,
dest_table: &PgCreateTable,
upsert_keys: &[String],
) -> Result<String> {
let value_keys = columns_to_update_for_upsert(dest_table, upsert_keys)?;
Ok(format!(
r#"
INSERT INTO {dest_table} ({all_columns}) (
SELECT {all_columns} FROM {src_table}
)
ON CONFLICT ({key_columns})
DO UPDATE SET
{value_updates}
"#,
dest_table = dest_table.name.quoted(),
src_table = src_table.name.quoted(),
all_columns = dest_table.columns.iter().map(|c| Ident(&c.name)).join(", "),
key_columns = upsert_keys.iter().map(|k| Ident(k)).join(", "),
value_updates = value_keys
.iter()
.map(|vk| format!("{name} = EXCLUDED.{name}", name = Ident(vk)))
.join(",\n "),
))
}
#[instrument(
level = "trace",
skip(client, src_table, dest_table),
fields(src_table.name = ?src_table.name, dest_table.name = ?dest_table.name),
)]
pub(crate) async fn upsert_from(
client: &mut Client,
src_table: &PgCreateTable,
dest_table: &PgCreateTable,
upsert_keys: &[String],
) -> Result<()> {
let sql = upsert_sql(src_table, dest_table, upsert_keys)?;
debug!(
"upserting from {} to {} with {}",
src_table.name.quoted(),
dest_table.name.quoted(),
sql,
);
let stmt = client.prepare(&sql).await?;
client.execute(&stmt, &[]).await.with_context(|| {
format!(
"error upserting from {} to {}",
src_table.name.quoted(),
dest_table.name.quoted(),
)
})?;
Ok(())
}
#[instrument(
level = "debug",
name = "postgres::write_local_data",
skip_all,
fields(dest = %dest)
)]
pub(crate) async fn write_local_data_helper(
ctx: Context,
dest: PostgresLocator,
mut data: BoxStream<CsvStream>,
shared_args: SharedArguments<Unverified>,
dest_args: DestinationArguments<Unverified>,
) -> Result<BoxStream<BoxFuture<BoxLocator>>> {
let shared_args = shared_args.verify(PostgresLocator::features())?;
let dest_args = dest_args.verify(PostgresLocator::features())?;
let schema = shared_args.schema();
let if_exists = dest_args.if_exists().to_owned();
let url = dest.url.clone();
let table_name = dest.table_name.clone();
debug!(
"writing data streams to {} table {}",
url,
table_name.quoted(),
);
let dest_schema = PgSchema::from_pg_catalog_or_default(
&ctx,
CheckCatalog::from(&if_exists),
dest.url(),
dest.table_name(),
schema,
)
.await?;
let mut client = connect(&ctx, &url).await?;
prepare_table(&mut client, dest_schema.clone(), &if_exists).await?;
let fut = async move {
while let Some(result) = data.next().await {
match result {
Err(err) => {
debug!("error reading stream of streams: {}", err);
return Err(err);
}
Ok(csv_stream) => {
load_stream(
&ctx,
&mut client,
csv_stream,
&dest_schema,
&if_exists,
)
.await?;
}
}
}
Ok(dest.boxed())
};
Ok(box_stream_once(Ok(fut.boxed())))
}
#[instrument(level = "debug", skip_all, fields(stream = %csv_stream.name))]
async fn load_stream(
ctx: &Context,
client: &mut Client,
csv_stream: CsvStream,
dest_schema: &PgSchema,
if_exists: &IfExists,
) -> Result<()> {
let transform_schema = dest_schema.clone();
let binary_stream = spawn_sync_transform(
ctx.clone(),
"copy_csv_to_pg_binary".to_owned(),
csv_stream.data,
move |_ctx, rdr, wtr| copy_csv_to_pg_binary(&transform_schema, rdr, wtr),
)?;
if let IfExists::Upsert(cols) = &if_exists {
let temp_table = create_temp_table_for(client, dest_schema).await?;
copy_from_stream(client, &temp_table, binary_stream).await?;
upsert_from(client, &temp_table, dest_schema.table()?, cols).await?;
drop_table_if_exists(client, &temp_table).await?;
} else {
copy_from_stream(client, dest_schema.table()?, binary_stream).await?;
}
Ok(())
}