macro_rules! table_name {
() => {
function_name!()
.strip_prefix("integration::")
.unwrap()
.replace("::", "_")
.replace(r#"_{{closure}}"#, "")
};
}
use std::iter::repeat;
use odbc_api::{
Connection, ConnectionOptions, Cursor, Error, RowSetBuffer, TruncationInfo, buffers,
environment,
handles::{CDataMut, Statement, StatementRef},
};
use super::connection_strings::{
DUCKDB_CONNECTION, MARIADB_CONNECTION, MSSQL_CONNECTION, POSTGRES_CONNECTION,
SQLITE_3_CONNECTION,
};
pub const MSSQL: &Profile = &Profile {
connection_string: MSSQL_CONNECTION,
index_type: "int IDENTITY(1,1)",
blob_type: "Image",
};
pub const SQLITE_3: &Profile = &Profile {
connection_string: SQLITE_3_CONNECTION,
index_type: "int IDENTITY(1,1)",
blob_type: "BLOB",
};
pub const MARIADB: &Profile = &Profile {
connection_string: MARIADB_CONNECTION,
index_type: "INTEGER AUTO_INCREMENT PRIMARY KEY",
blob_type: "BLOB",
};
pub const POSTGRES: &Profile = &Profile {
connection_string: POSTGRES_CONNECTION,
index_type: "SERIAL PRIMARY KEY",
blob_type: "BYTEA",
};
pub const DUCKDB: &Profile = &Profile {
connection_string: DUCKDB_CONNECTION,
index_type: "", blob_type: "BLOB",
};
pub struct Given<'a> {
table_name: &'a str,
column_types: &'a [&'a str],
column_names: &'a [&'a str],
values: &'a [&'a [Option<&'a str>]],
}
impl<'a> Given<'a> {
pub fn new(table_name: &'a str) -> Self {
Given {
table_name,
column_types: &[],
column_names: &["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"],
values: &[],
}
}
pub fn column_types(&mut self, column_types: &'a [&'a str]) -> &mut Self {
self.column_types = column_types;
self
}
pub fn column_names(&mut self, column_names: &'a [&'a str]) -> &mut Self {
self.column_names = column_names;
self
}
pub fn values_by_column(&mut self, values: &'a [&'a [Option<&'a str>]]) -> &mut Self {
self.values = values;
self
}
pub fn build(
&self,
profile: &Profile,
) -> Result<(Connection<'static>, Table<'a>), odbc_api::Error> {
let (conn, table) =
profile.create_table(self.table_name, self.column_types, self.column_names)?;
if !self.values.is_empty() {
let num_rows = self.values[0].len();
let max_str_len = self.values.iter().map(|vals| {
vals.iter()
.map(|text| text.unwrap_or("").len())
.max()
.expect("Columns may not be empty")
});
let mut inserter = conn
.prepare(&table.sql_insert())?
.into_text_inserter(num_rows, max_str_len)?;
inserter.set_num_rows(num_rows);
for r in 0..num_rows {
for c in 0..self.values.len() {
inserter
.column_mut(c)
.set_cell(r, self.values[c][r].map(|text| text.as_bytes()))
}
}
inserter.execute()?;
}
Ok((conn, table))
}
}
#[derive(Clone, Copy, Debug)]
pub struct Profile {
pub connection_string: &'static str,
pub index_type: &'static str,
pub blob_type: &'static str,
}
impl Profile {
pub fn connection(&self) -> Result<Connection<'static>, Error> {
environment()?
.connect_with_connection_string(self.connection_string, ConnectionOptions::default())
}
pub fn setup_empty_table(
&self,
table_name: &str,
column_types: &[&str],
) -> Result<Connection<'static>, odbc_api::Error> {
let (conn, _table) = self.create_table(
table_name,
column_types,
&["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"],
)?;
Ok(conn)
}
pub fn create_table<'a>(
&self,
table_name: &'a str,
column_types: &'a [&'a str],
column_names: &'a [&'a str],
) -> Result<(Connection<'static>, Table<'a>), odbc_api::Error> {
let conn = self.connection()?;
let table = Table::new(table_name, column_types, column_names);
conn.execute(&table.sql_drop_if_exists(), (), None)?;
conn.execute(&table.sql_create_table(self.index_type), (), None)?;
Ok((conn, table))
}
}
pub struct Table<'a> {
pub name: &'a str,
pub column_types: &'a [&'a str],
pub column_names: &'a [&'a str],
}
impl<'a> Table<'a> {
pub fn new(name: &'a str, column_types: &'a [&'a str], column_names: &'a [&'a str]) -> Self {
Table {
name,
column_types,
column_names: &column_names[..column_types.len()],
}
}
pub fn sql_drop_if_exists(&self) -> String {
format!("DROP TABLE IF EXISTS {};", self.name)
}
pub fn sql_create_table(&self, index_type: &str) -> String {
let cols = self
.column_types
.iter()
.zip(self.column_names)
.map(|(ty, name)| format!("{name} {ty}"))
.collect::<Vec<_>>()
.join(", ");
format!("CREATE TABLE {} (id {index_type},{cols});", self.name)
}
pub fn sql_all_ordered_by_id(&self) -> String {
let cols = self.column_names.join(",");
format!("SELECT {cols} FROM {} ORDER BY Id;", self.name)
}
pub fn sql_insert(&self) -> String {
let cols = self.column_names.join(",");
let placeholders = repeat("?")
.take(self.column_names.len())
.collect::<Vec<_>>()
.join(",");
format!(
"INSERT INTO {} ({cols}) VALUES ({placeholders});",
self.name
)
}
pub fn content_as_string(&self, conn: &Connection<'_>) -> String {
let cursor = conn
.execute(&self.sql_all_ordered_by_id(), (), None)
.unwrap()
.unwrap();
cursor_to_string(cursor)
}
}
pub fn cursor_to_string(mut cursor: impl Cursor) -> String {
let batch_size = 20;
let mut buffer = buffers::TextRowSet::for_cursor(batch_size, &mut cursor, Some(8192)).unwrap();
let mut row_set_cursor = cursor.bind_buffer(&mut buffer).unwrap();
let mut text = String::new();
let mut first_batch = true;
while let Some(row_set) = row_set_cursor.fetch().unwrap() {
if first_batch {
first_batch = false;
} else {
text.push('\n');
}
for row_index in 0..row_set.num_rows() {
if row_index != 0 {
text.push('\n');
}
for col_index in 0..row_set.num_cols() {
if col_index != 0 {
text.push(',');
}
text.push_str(
row_set
.at_as_str(col_index, row_index)
.unwrap()
.unwrap_or("NULL"),
);
}
}
}
text
}
pub struct SingleColumnRowSetBuffer<C> {
num_rows_fetched: Box<usize>,
batch_size: usize,
column: C,
}
impl<T> SingleColumnRowSetBuffer<Vec<T>>
where
T: Clone + Default,
{
pub fn new(batch_size: usize) -> Self {
SingleColumnRowSetBuffer {
num_rows_fetched: Box::new(0),
batch_size,
column: vec![T::default(); batch_size],
}
}
pub fn get(&self) -> &[T] {
&self.column[0..*self.num_rows_fetched]
}
}
unsafe impl<C> RowSetBuffer for SingleColumnRowSetBuffer<C>
where
C: CDataMut,
{
fn bind_type(&self) -> usize {
0 }
fn row_array_size(&self) -> usize {
self.batch_size
}
fn mut_num_fetch_rows(&mut self) -> &mut usize {
self.num_rows_fetched.as_mut()
}
unsafe fn bind_colmuns_to_cursor(
&mut self,
mut cursor: StatementRef<'_>,
) -> Result<(), odbc_api::Error> {
unsafe { cursor.bind_col(1, &mut self.column) }.into_result(&cursor)?;
Ok(())
}
fn find_truncation(&self) -> Option<TruncationInfo> {
unimplemented!()
}
}