use anyhow::{anyhow, Result};
use async_trait::async_trait;
use drasi_core::models::{
Element, ElementMetadata, ElementPropertyMap, ElementReference, ElementValue, SourceChange,
};
use drasi_lib::bootstrap::BootstrapProvider;
use drasi_lib::bootstrap::{BootstrapContext, BootstrapRequest, BootstrapResult};
use drasi_lib::channels::{BootstrapEvent, SourceChangeEvent};
use drasi_mssql_common::{
validate_sql_identifier, MsSqlConnection, MsSqlSourceConfig, PrimaryKeyCache,
};
use log::{debug, info, warn};
use ordered_float::OrderedFloat;
use std::sync::Arc;
use tiberius::Row;
pub struct MsSqlBootstrapProvider {
config: MsSqlSourceConfig,
source_id: String,
}
impl MsSqlBootstrapProvider {
pub fn new(source_id: impl Into<String>, config: MsSqlSourceConfig) -> Self {
Self {
config,
source_id: source_id.into(),
}
}
pub fn builder() -> MsSqlBootstrapProviderBuilder {
MsSqlBootstrapProviderBuilder::new()
}
}
#[async_trait]
impl BootstrapProvider for MsSqlBootstrapProvider {
async fn bootstrap(
&self,
request: BootstrapRequest,
context: &BootstrapContext,
event_tx: tokio::sync::mpsc::Sender<drasi_lib::channels::BootstrapEvent>,
_settings: Option<&drasi_lib::config::SourceSubscriptionSettings>,
) -> Result<BootstrapResult> {
info!(
"Starting MS SQL bootstrap for query '{}' with {} node labels",
request.query_id,
request.node_labels.len()
);
let mut handler = MsSqlBootstrapHandler::new(self.config.clone(), self.source_id.clone());
let query_id = request.query_id.clone();
let count = handler.execute(request, context, event_tx).await?;
info!("Completed MS SQL bootstrap for query {query_id}: sent {count} records");
Ok(BootstrapResult {
event_count: count,
last_sequence: None,
sequences_aligned: false,
})
}
}
pub struct MsSqlBootstrapProviderBuilder {
config: MsSqlSourceConfig,
source_id: String,
}
impl MsSqlBootstrapProviderBuilder {
pub fn new() -> Self {
Self {
config: MsSqlSourceConfig::default(),
source_id: "mssql-bootstrap".to_string(),
}
}
pub fn with_source_id(mut self, id: impl Into<String>) -> Self {
self.source_id = id.into();
self
}
pub fn with_host(mut self, host: impl Into<String>) -> Self {
self.config.host = host.into();
self
}
pub fn with_port(mut self, port: u16) -> Self {
self.config.port = port;
self
}
pub fn with_database(mut self, database: impl Into<String>) -> Self {
self.config.database = database.into();
self
}
pub fn with_user(mut self, user: impl Into<String>) -> Self {
self.config.user = user.into();
self
}
pub fn with_password(mut self, password: impl Into<String>) -> Self {
self.config.password = password.into();
self
}
pub fn with_tables(mut self, tables: Vec<String>) -> Self {
self.config.tables = tables;
self
}
pub fn build(self) -> Result<MsSqlBootstrapProvider> {
if self.config.database.is_empty() {
return Err(anyhow::anyhow!("Database name is required"));
}
if self.config.user.is_empty() {
return Err(anyhow::anyhow!("Database user is required"));
}
Ok(MsSqlBootstrapProvider::new(self.source_id, self.config))
}
}
impl Default for MsSqlBootstrapProviderBuilder {
fn default() -> Self {
Self::new()
}
}
struct MsSqlBootstrapHandler {
config: MsSqlSourceConfig,
source_id: String,
pk_cache: PrimaryKeyCache,
}
impl MsSqlBootstrapHandler {
fn new(config: MsSqlSourceConfig, source_id: String) -> Self {
Self {
config,
source_id,
pk_cache: PrimaryKeyCache::new(),
}
}
async fn execute(
&mut self,
request: BootstrapRequest,
context: &BootstrapContext,
event_tx: tokio::sync::mpsc::Sender<BootstrapEvent>,
) -> Result<usize> {
info!(
"Bootstrap: Connecting to MS SQL at {}:{}",
self.config.host, self.config.port
);
let mut connection = MsSqlConnection::connect(&self.config).await?;
let client = connection.client_mut();
info!("Discovering primary keys");
self.pk_cache.discover_keys(client, &self.config).await?;
let tables = self.map_labels_to_tables(&request)?;
if tables.is_empty() {
warn!("No tables to bootstrap");
return Ok(0);
}
info!("Starting bootstrap transaction with snapshot isolation");
client
.simple_query("SET TRANSACTION ISOLATION LEVEL SNAPSHOT")
.await?
.into_results()
.await?;
client
.simple_query("BEGIN TRANSACTION")
.await?
.into_results()
.await?;
let mut total_count = 0;
for (label, table_name) in &tables {
info!("Bootstrapping table '{table_name}' with label '{label}'");
let count = self
.bootstrap_table(client, label, table_name, context, &event_tx)
.await?;
total_count += count;
info!("Bootstrapped {count} records from table '{table_name}'");
}
client
.simple_query("COMMIT TRANSACTION")
.await?
.into_results()
.await?;
info!("Bootstrap transaction committed");
Ok(total_count)
}
fn map_labels_to_tables(&self, request: &BootstrapRequest) -> Result<Vec<(String, String)>> {
let mut tables = Vec::new();
for label in &request.node_labels {
let mut matched = false;
if self.config.tables.contains(label) {
tables.push((label.clone(), label.clone()));
matched = true;
} else {
let with_schema = format!("dbo.{label}");
if self.config.tables.contains(&with_schema) {
tables.push((label.clone(), with_schema));
matched = true;
} else {
for table in &self.config.tables {
if table == label || table.ends_with(&format!(".{label}")) {
tables.push((label.clone(), table.clone()));
matched = true;
break;
}
}
}
}
if !matched {
warn!(
"Table for label '{}' not found in configured tables {:?}, skipping",
label, self.config.tables
);
}
}
Ok(tables)
}
async fn bootstrap_table(
&self,
client: &mut tiberius::Client<tokio_util::compat::Compat<tokio::net::TcpStream>>,
label: &str,
table_name: &str,
context: &BootstrapContext,
event_tx: &tokio::sync::mpsc::Sender<BootstrapEvent>,
) -> Result<usize> {
debug!("Starting bootstrap of table '{table_name}' with label '{label}'");
validate_sql_identifier(table_name)?;
let query = format!("SELECT * FROM {table_name}");
let stream = client.query(&query, &[]).await?;
let rows = stream.into_results().await?;
let mut count = 0;
let mut batch = Vec::new();
let batch_size = 1000;
for row_set in rows {
for row in row_set {
let source_change = self.row_to_source_change(&row, label, table_name).await?;
batch.push(SourceChangeEvent {
source_id: self.source_id.clone(),
change: source_change,
timestamp: chrono::Utc::now(),
});
if batch.len() >= batch_size {
self.send_batch(&mut batch, context, event_tx).await?;
count += batch_size;
}
}
}
if !batch.is_empty() {
count += batch.len();
self.send_batch(&mut batch, context, event_tx).await?;
}
Ok(count)
}
async fn row_to_source_change(
&self,
row: &Row,
label: &str,
table_name: &str,
) -> Result<SourceChange> {
let mut properties = ElementPropertyMap::new();
for (idx, column) in row.columns().iter().enumerate() {
let column_name = column.name();
let element_value = self.convert_column_value(row, idx, column)?;
properties.insert(column_name, element_value);
}
let element_id = self.pk_cache.generate_element_id(table_name, row)?;
Ok(SourceChange::Insert {
element: Element::Node {
metadata: ElementMetadata {
reference: ElementReference::new(&self.source_id, &element_id),
labels: Arc::from([Arc::from(label)]),
effective_from: 0,
},
properties,
},
})
}
fn convert_column_value(
&self,
row: &Row,
idx: usize,
column: &tiberius::Column,
) -> Result<ElementValue> {
use tiberius::ColumnType;
match column.column_type() {
ColumnType::Bit | ColumnType::Bitn => {
if let Ok(Some(val)) = row.try_get::<bool, _>(idx) {
Ok(ElementValue::Bool(val))
} else {
Ok(ElementValue::Null)
}
}
ColumnType::Int1
| ColumnType::Int2
| ColumnType::Int4
| ColumnType::Int8
| ColumnType::Intn => {
if let Ok(Some(val)) = row.try_get::<i32, _>(idx) {
Ok(ElementValue::Integer(val as i64))
} else if let Ok(Some(val)) = row.try_get::<i64, _>(idx) {
Ok(ElementValue::Integer(val))
} else if let Ok(Some(val)) = row.try_get::<i16, _>(idx) {
Ok(ElementValue::Integer(val as i64))
} else if let Ok(Some(val)) = row.try_get::<u8, _>(idx) {
Ok(ElementValue::Integer(val as i64))
} else {
Ok(ElementValue::Null)
}
}
ColumnType::Float4
| ColumnType::Float8
| ColumnType::Floatn
| ColumnType::Numericn
| ColumnType::Decimaln => {
if let Ok(Some(val)) = row.try_get::<f32, _>(idx) {
Ok(ElementValue::Float(OrderedFloat(val as f64)))
} else if let Ok(Some(val)) = row.try_get::<f64, _>(idx) {
Ok(ElementValue::Float(OrderedFloat(val)))
} else {
Ok(ElementValue::Null)
}
}
ColumnType::BigVarChar
| ColumnType::BigChar
| ColumnType::NVarchar
| ColumnType::NChar
| ColumnType::BigVarBin
| ColumnType::BigBinary
| ColumnType::Text
| ColumnType::NText => {
if let Ok(Some(val)) = row.try_get::<&str, _>(idx) {
Ok(ElementValue::String(Arc::from(val)))
} else {
Ok(ElementValue::Null)
}
}
ColumnType::Datetime
| ColumnType::Datetime2
| ColumnType::Datetime4
| ColumnType::Datetimen => {
if let Ok(Some(val)) = row.try_get::<chrono::NaiveDateTime, _>(idx) {
Ok(ElementValue::String(Arc::from(val.to_string().as_str())))
} else if let Ok(Some(val)) = row.try_get::<chrono::DateTime<chrono::Utc>, _>(idx) {
Ok(ElementValue::String(Arc::from(val.to_rfc3339().as_str())))
} else {
Ok(ElementValue::Null)
}
}
_ => {
if let Ok(Some(val)) = row.try_get::<&str, _>(idx) {
Ok(ElementValue::String(Arc::from(val)))
} else {
warn!(
"Unsupported column type {:?} for column {}, treating as NULL",
column.column_type(),
column.name()
);
Ok(ElementValue::Null)
}
}
}
}
async fn send_batch(
&self,
batch: &mut Vec<SourceChangeEvent>,
context: &BootstrapContext,
event_tx: &tokio::sync::mpsc::Sender<BootstrapEvent>,
) -> Result<()> {
for event in batch.drain(..) {
let sequence = context.next_sequence();
let bootstrap_event = BootstrapEvent {
source_id: event.source_id,
change: event.change,
timestamp: event.timestamp,
sequence,
};
event_tx
.send(bootstrap_event)
.await
.map_err(|e| anyhow!("Failed to send bootstrap event: {e}"))?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder() {
let provider = MsSqlBootstrapProvider::builder()
.with_host("localhost")
.with_database("testdb")
.with_user("testuser")
.with_password("testpass")
.build()
.unwrap();
assert_eq!(provider.source_id, "mssql-bootstrap");
assert_eq!(provider.config.host, "localhost");
}
#[test]
fn test_builder_missing_required() {
let result = MsSqlBootstrapProvider::builder()
.with_host("localhost")
.build();
assert!(result.is_err());
}
}