use super::PostgresLocator;
use crate::common::*;
use crate::drivers::postgres_shared::{connect, CheckCatalog, PgSchema};
#[instrument(
level = "trace",
name = "postgres::count",
skip(ctx, shared_args, source_args)
)]
pub(crate) async fn count_helper(
ctx: Context,
locator: PostgresLocator,
shared_args: SharedArguments<Unverified>,
source_args: SourceArguments<Unverified>,
) -> Result<usize> {
let shared_args = shared_args.verify(PostgresLocator::features())?;
let source_args = source_args.verify(PostgresLocator::features())?;
let url = locator.url.clone();
let table_name = locator.table_name.clone();
let schema = shared_args.schema();
let pg_schema = PgSchema::from_pg_catalog_or_default(
&ctx,
CheckCatalog::No,
&url,
&table_name,
schema,
)
.await?;
let mut sql_bytes: Vec<u8> = vec![];
pg_schema.write_count_sql(&mut sql_bytes, &source_args)?;
let sql = String::from_utf8(sql_bytes).expect("should always be UTF-8");
debug!("count SQL: {}", sql);
let conn = connect(&ctx, &url).await?;
let stmt = conn.prepare(&sql).await?;
let rows = conn
.query(&stmt, &[])
.await
.context("error running count query")?;
if rows.len() != 1 {
Err(format_err!(
"expected 1 row of count output, got {}",
rows.len(),
))
} else {
let count: i64 = rows[0].get("count");
Ok(usize::try_from(count).context("count out of range")?)
}
}