drasi_mssql_common/
keys.rs1use crate::config::{MsSqlSourceConfig, TableKeyConfig};
18use crate::error::{MsSqlError, PrimaryKeyError};
19use crate::types::{extract_column_value, value_to_string};
20use anyhow::{anyhow, Result};
21use drasi_core::models::ElementValue;
22use log::warn;
23use std::collections::HashMap;
24use tiberius::{Client, Row};
25use tokio::net::TcpStream;
26use tokio_util::compat::Compat;
27
28pub struct PrimaryKeyCache {
30 keys: HashMap<String, Vec<String>>,
32}
33
34impl PrimaryKeyCache {
35 pub fn new() -> Self {
37 Self {
38 keys: HashMap::new(),
39 }
40 }
41
42 pub async fn discover_keys(
47 &mut self,
48 client: &mut Client<Compat<TcpStream>>,
49 config: &MsSqlSourceConfig,
50 ) -> Result<()> {
51 let query = "
52 SELECT
53 t.name AS table_name,
54 c.name AS column_name,
55 ic.key_ordinal
56 FROM sys.indexes i
57 INNER JOIN sys.index_columns ic ON i.object_id = ic.object_id
58 AND i.index_id = ic.index_id
59 INNER JOIN sys.columns c ON ic.object_id = c.object_id
60 AND ic.column_id = c.column_id
61 INNER JOIN sys.tables t ON i.object_id = t.object_id
62 WHERE i.is_primary_key = 1
63 ORDER BY t.name, ic.key_ordinal
64 ";
65
66 let stream = client.query(query, &[]).await?;
67 let rows = stream.into_first_result().await?;
68
69 for row in rows {
70 let table_name: &str = row.get(0).ok_or_else(|| anyhow!("Missing table_name"))?;
71 let column_name: &str = row.get(1).ok_or_else(|| anyhow!("Missing column_name"))?;
72
73 self.keys
74 .entry(table_name.to_string())
75 .or_default()
76 .push(column_name.to_string());
77 }
78
79 for tk in &config.table_keys {
81 self.keys.insert(tk.table.clone(), tk.key_columns.clone());
82 }
83
84 log::info!("Discovered primary keys for {} tables", self.keys.len());
85 for (table, keys) in &self.keys {
86 log::debug!("Table '{table}' primary key: {keys:?}");
87 }
88
89 Ok(())
90 }
91
92 pub fn get(&self, table: &str) -> Option<&Vec<String>> {
95 if let Some(keys) = self.keys.get(table) {
97 return Some(keys);
98 }
99
100 if let Some(table_only) = table.split('.').nth(1) {
102 if let Some(keys) = self.keys.get(table_only) {
103 return Some(keys);
104 }
105 }
106
107 None
108 }
109
110 pub fn generate_element_id(&self, table: &str, row: &Row) -> Result<String> {
127 let keys = match self.get(table) {
128 Some(keys) => keys,
129 None => {
130 return Err(MsSqlError::PrimaryKey(PrimaryKeyError::NotConfigured {
131 table: table.to_string(),
132 })
133 .into());
134 }
135 };
136
137 let mut key_values = Vec::new();
138 let mut null_columns = Vec::new();
139
140 for pk_col in keys {
141 if let Some(col_idx) = row.columns().iter().position(|c| c.name() == pk_col) {
143 let value = extract_column_value(row, col_idx)?;
144
145 if !matches!(value, ElementValue::Null) {
146 key_values.push(value_to_string(&value));
147 } else {
148 null_columns.push(pk_col.clone());
149 }
150 } else {
151 return Err(MsSqlError::PrimaryKey(PrimaryKeyError::ColumnNotFound {
152 table: table.to_string(),
153 column: pk_col.clone(),
154 })
155 .into());
156 }
157 }
158
159 if !key_values.is_empty() {
161 if !null_columns.is_empty() {
163 warn!(
164 "NULL value(s) in primary key column(s) {null_columns:?} for table '{table}'. \
165 Using remaining key columns for element ID."
166 );
167 }
168 Ok(format!("{}:{}", table, key_values.join("_")))
169 } else {
170 Err(MsSqlError::PrimaryKey(PrimaryKeyError::AllNull {
172 table: table.to_string(),
173 columns: keys.clone(),
174 })
175 .into())
176 }
177 }
178}
179
180impl Default for PrimaryKeyCache {
181 fn default() -> Self {
182 Self::new()
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn test_empty_cache() {
192 let cache = PrimaryKeyCache::new();
193 assert!(cache.get("orders").is_none());
194 }
195
196 #[test]
197 fn test_insert_and_get() {
198 let mut cache = PrimaryKeyCache::new();
199 cache
200 .keys
201 .insert("orders".to_string(), vec!["order_id".to_string()]);
202
203 assert_eq!(cache.get("orders").unwrap(), &vec!["order_id"]);
204 }
205
206 #[test]
207 fn test_composite_key() {
208 let mut cache = PrimaryKeyCache::new();
209 cache.keys.insert(
210 "order_items".to_string(),
211 vec!["order_id".to_string(), "product_id".to_string()],
212 );
213
214 let keys = cache.get("order_items").unwrap();
215 assert_eq!(keys.len(), 2);
216 assert_eq!(keys[0], "order_id");
217 assert_eq!(keys[1], "product_id");
218 }
219}