use crate::config::{MsSqlSourceConfig, TableKeyConfig};
use crate::error::{MsSqlError, PrimaryKeyError};
use crate::types::{extract_column_value, value_to_string};
use anyhow::{anyhow, Result};
use drasi_core::models::ElementValue;
use log::warn;
use std::collections::HashMap;
use tiberius::{Client, Row};
use tokio::net::TcpStream;
use tokio_util::compat::Compat;
pub struct PrimaryKeyCache {
keys: HashMap<String, Vec<String>>,
}
impl PrimaryKeyCache {
pub fn new() -> Self {
Self {
keys: HashMap::new(),
}
}
pub async fn discover_keys(
&mut self,
client: &mut Client<Compat<TcpStream>>,
config: &MsSqlSourceConfig,
) -> Result<()> {
let query = "
SELECT
t.name AS table_name,
c.name AS column_name,
ic.key_ordinal
FROM sys.indexes i
INNER JOIN sys.index_columns ic ON i.object_id = ic.object_id
AND i.index_id = ic.index_id
INNER JOIN sys.columns c ON ic.object_id = c.object_id
AND ic.column_id = c.column_id
INNER JOIN sys.tables t ON i.object_id = t.object_id
WHERE i.is_primary_key = 1
ORDER BY t.name, ic.key_ordinal
";
let stream = client.query(query, &[]).await?;
let rows = stream.into_first_result().await?;
for row in rows {
let table_name: &str = row.get(0).ok_or_else(|| anyhow!("Missing table_name"))?;
let column_name: &str = row.get(1).ok_or_else(|| anyhow!("Missing column_name"))?;
self.keys
.entry(table_name.to_string())
.or_default()
.push(column_name.to_string());
}
for tk in &config.table_keys {
self.keys.insert(tk.table.clone(), tk.key_columns.clone());
}
log::info!("Discovered primary keys for {} tables", self.keys.len());
for (table, keys) in &self.keys {
log::debug!("Table '{table}' primary key: {keys:?}");
}
Ok(())
}
pub fn get(&self, table: &str) -> Option<&Vec<String>> {
if let Some(keys) = self.keys.get(table) {
return Some(keys);
}
if let Some(table_only) = table.split('.').nth(1) {
if let Some(keys) = self.keys.get(table_only) {
return Some(keys);
}
}
None
}
pub fn generate_element_id(&self, table: &str, row: &Row) -> Result<String> {
let keys = match self.get(table) {
Some(keys) => keys,
None => {
return Err(MsSqlError::PrimaryKey(PrimaryKeyError::NotConfigured {
table: table.to_string(),
})
.into());
}
};
let mut key_values = Vec::new();
let mut null_columns = Vec::new();
for pk_col in keys {
if let Some(col_idx) = row.columns().iter().position(|c| c.name() == pk_col) {
let value = extract_column_value(row, col_idx)?;
if !matches!(value, ElementValue::Null) {
key_values.push(value_to_string(&value));
} else {
null_columns.push(pk_col.clone());
}
} else {
return Err(MsSqlError::PrimaryKey(PrimaryKeyError::ColumnNotFound {
table: table.to_string(),
column: pk_col.clone(),
})
.into());
}
}
if !key_values.is_empty() {
if !null_columns.is_empty() {
warn!(
"NULL value(s) in primary key column(s) {null_columns:?} for table '{table}'. \
Using remaining key columns for element ID."
);
}
Ok(format!("{}:{}", table, key_values.join("_")))
} else {
Err(MsSqlError::PrimaryKey(PrimaryKeyError::AllNull {
table: table.to_string(),
columns: keys.clone(),
})
.into())
}
}
}
impl Default for PrimaryKeyCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_cache() {
let cache = PrimaryKeyCache::new();
assert!(cache.get("orders").is_none());
}
#[test]
fn test_insert_and_get() {
let mut cache = PrimaryKeyCache::new();
cache
.keys
.insert("orders".to_string(), vec!["order_id".to_string()]);
assert_eq!(cache.get("orders").unwrap(), &vec!["order_id"]);
}
#[test]
fn test_composite_key() {
let mut cache = PrimaryKeyCache::new();
cache.keys.insert(
"order_items".to_string(),
vec!["order_id".to_string(), "product_id".to_string()],
);
let keys = cache.get("order_items").unwrap();
assert_eq!(keys.len(), 2);
assert_eq!(keys[0], "order_id");
assert_eq!(keys[1], "product_id");
}
}