use anyhow::Result;
use async_trait::async_trait;
use deadpool_postgres::tokio_postgres::{self, binary_copy::BinaryCopyInWriter};
use futures::{StreamExt, stream};
use tracing::{error, trace};
use super::traits::{PostgreSQL, SqlTypes};
#[async_trait]
pub trait PgLoadExt {
async fn insert_iter<'a, I, T>(&mut self, stmt: &'a str, collection: I) -> Result<()>
where
I: Iterator<Item = T> + Send + Sync,
T: PostgreSQL + Send + Sync;
async fn copy<'a, I, T>(&mut self, stmt: &'a str, collection: I) -> Result<()>
where
I: Iterator<Item = T> + Send + Sync,
T: SqlTypes + PostgreSQL + Send + Sync;
}
#[async_trait]
impl PgLoadExt for tokio_postgres::Client {
async fn insert_iter<'a, I, T>(&mut self, stmt: &'a str, collection: I) -> Result<()>
where
I: Iterator<Item = T> + Send + Sync,
T: PostgreSQL + Send + Sync,
{
let stmt = self.prepare(stmt).await?;
let tx = self.transaction().await?;
let mut stream = stream::iter(collection.into_iter());
while let Some(item) = stream.next().await {
let stmt = &stmt;
let tx = &tx;
tx.execute(stmt, &item.sql_map()).await?;
}
trace!("{stmt:?} executed successfully");
tx.commit().await?;
Ok(())
}
async fn copy<'a, I, T>(&mut self, stmt: &'a str, collection: I) -> Result<()>
where
I: Iterator<Item = T> + Send + Sync,
T: SqlTypes + PostgreSQL + Send + Sync,
{
let tx = self.transaction().await?;
let sink = tx.copy_in(stmt).await?;
let writer = BinaryCopyInWriter::new(sink, T::sql_types());
futures::pin_mut!(writer);
for item in collection {
match writer.as_mut().write(&item.sql_map()).await {
Ok(_) => {}
Err(e) => error!("Failed to copy {stmt:#?}: {e})"),
}
}
trace!("{stmt:?} executed successfully");
writer.finish().await?;
tx.commit().await?;
Ok(())
}
}
#[async_trait]
impl PgLoadExt for deadpool_postgres::Pool {
async fn insert_iter<'a, I, T>(&mut self, stmt: &'a str, collection: I) -> Result<()>
where
I: Iterator<Item = T> + Send + Sync,
T: PostgreSQL + Send + Sync,
{
let mut pg_client = self.get().await?;
pg_client.insert_iter(stmt, collection).await?;
Ok(())
}
async fn copy<'a, I, T>(&mut self, stmt: &'a str, collection: I) -> Result<()>
where
I: Iterator<Item = T> + Send + Sync,
T: SqlTypes + PostgreSQL + Send + Sync,
{
let mut pg_client = self.get().await?;
pg_client.copy(stmt, collection).await?;
Ok(())
}
}