Skip to main content

drasi_mssql_common/
keys.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Primary key discovery and element ID generation
16
17use crate::config::{MsSqlSourceConfig, TableKeyConfig};
18use crate::error::{MsSqlError, PrimaryKeyError};
19use crate::types::{extract_column_value, value_to_string};
20use anyhow::{anyhow, Result};
21use drasi_core::models::ElementValue;
22use log::warn;
23use std::collections::HashMap;
24use tiberius::{Client, Row};
25use tokio::net::TcpStream;
26use tokio_util::compat::Compat;
27
28/// Cache of primary keys for tables
29pub struct PrimaryKeyCache {
30    /// Map of table name -> ordered list of primary key column names
31    keys: HashMap<String, Vec<String>>,
32}
33
34impl PrimaryKeyCache {
35    /// Create a new empty cache
36    pub fn new() -> Self {
37        Self {
38            keys: HashMap::new(),
39        }
40    }
41
42    /// Discover primary keys from MS SQL system catalogs
43    ///
44    /// Queries sys.indexes, sys.index_columns, sys.columns, and sys.tables
45    /// to find primary key columns for all tables in the database.
46    pub async fn discover_keys(
47        &mut self,
48        client: &mut Client<Compat<TcpStream>>,
49        config: &MsSqlSourceConfig,
50    ) -> Result<()> {
51        let query = "
52            SELECT 
53                t.name AS table_name,
54                c.name AS column_name,
55                ic.key_ordinal
56            FROM sys.indexes i
57            INNER JOIN sys.index_columns ic ON i.object_id = ic.object_id 
58                AND i.index_id = ic.index_id
59            INNER JOIN sys.columns c ON ic.object_id = c.object_id 
60                AND ic.column_id = c.column_id
61            INNER JOIN sys.tables t ON i.object_id = t.object_id
62            WHERE i.is_primary_key = 1
63            ORDER BY t.name, ic.key_ordinal
64        ";
65
66        let stream = client.query(query, &[]).await?;
67        let rows = stream.into_first_result().await?;
68
69        for row in rows {
70            let table_name: &str = row.get(0).ok_or_else(|| anyhow!("Missing table_name"))?;
71            let column_name: &str = row.get(1).ok_or_else(|| anyhow!("Missing column_name"))?;
72
73            self.keys
74                .entry(table_name.to_string())
75                .or_default()
76                .push(column_name.to_string());
77        }
78
79        // Merge with configured table_keys (which take precedence)
80        for tk in &config.table_keys {
81            self.keys.insert(tk.table.clone(), tk.key_columns.clone());
82        }
83
84        log::info!("Discovered primary keys for {} tables", self.keys.len());
85        for (table, keys) in &self.keys {
86            log::debug!("Table '{table}' primary key: {keys:?}");
87        }
88
89        Ok(())
90    }
91
92    /// Get primary key columns for a table
93    /// Handles both "table" and "schema.table" formats
94    pub fn get(&self, table: &str) -> Option<&Vec<String>> {
95        // Try exact match first
96        if let Some(keys) = self.keys.get(table) {
97            return Some(keys);
98        }
99
100        // Try without schema prefix (e.g., "dbo.Orders" -> "Orders")
101        if let Some(table_only) = table.split('.').nth(1) {
102            if let Some(keys) = self.keys.get(table_only) {
103                return Some(keys);
104            }
105        }
106
107        None
108    }
109
110    /// Generate element ID from a row using primary key values
111    ///
112    /// Format: `{table_name}:{key_values}`
113    ///
114    /// # Arguments
115    /// * `table` - Table name
116    /// * `row` - Tiberius row with data
117    ///
118    /// # Returns
119    /// Element ID string
120    ///
121    /// # Errors
122    /// Returns an error if no primary key is configured for the table or if all
123    /// primary key values are NULL. This is intentional - without a stable primary
124    /// key, UPDATE and DELETE operations cannot be correctly matched to previous
125    /// INSERT operations, breaking change tracking.
126    pub fn generate_element_id(&self, table: &str, row: &Row) -> Result<String> {
127        let keys = match self.get(table) {
128            Some(keys) => keys,
129            None => {
130                return Err(MsSqlError::PrimaryKey(PrimaryKeyError::NotConfigured {
131                    table: table.to_string(),
132                })
133                .into());
134            }
135        };
136
137        let mut key_values = Vec::new();
138        let mut null_columns = Vec::new();
139
140        for pk_col in keys {
141            // Find column index
142            if let Some(col_idx) = row.columns().iter().position(|c| c.name() == pk_col) {
143                let value = extract_column_value(row, col_idx)?;
144
145                if !matches!(value, ElementValue::Null) {
146                    key_values.push(value_to_string(&value));
147                } else {
148                    null_columns.push(pk_col.clone());
149                }
150            } else {
151                return Err(MsSqlError::PrimaryKey(PrimaryKeyError::ColumnNotFound {
152                    table: table.to_string(),
153                    column: pk_col.clone(),
154                })
155                .into());
156            }
157        }
158
159        // Generate element ID
160        if !key_values.is_empty() {
161            // Warn if some (but not all) key columns are NULL
162            if !null_columns.is_empty() {
163                warn!(
164                    "NULL value(s) in primary key column(s) {null_columns:?} for table '{table}'. \
165                     Using remaining key columns for element ID."
166                );
167            }
168            Ok(format!("{}:{}", table, key_values.join("_")))
169        } else {
170            // All primary key values are NULL - this is an error
171            Err(MsSqlError::PrimaryKey(PrimaryKeyError::AllNull {
172                table: table.to_string(),
173                columns: keys.clone(),
174            })
175            .into())
176        }
177    }
178}
179
180impl Default for PrimaryKeyCache {
181    fn default() -> Self {
182        Self::new()
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_empty_cache() {
192        let cache = PrimaryKeyCache::new();
193        assert!(cache.get("orders").is_none());
194    }
195
196    #[test]
197    fn test_insert_and_get() {
198        let mut cache = PrimaryKeyCache::new();
199        cache
200            .keys
201            .insert("orders".to_string(), vec!["order_id".to_string()]);
202
203        assert_eq!(cache.get("orders").unwrap(), &vec!["order_id"]);
204    }
205
206    #[test]
207    fn test_composite_key() {
208        let mut cache = PrimaryKeyCache::new();
209        cache.keys.insert(
210            "order_items".to_string(),
211            vec!["order_id".to_string(), "product_id".to_string()],
212        );
213
214        let keys = cache.get("order_items").unwrap();
215        assert_eq!(keys.len(), 2);
216        assert_eq!(keys[0], "order_id");
217        assert_eq!(keys[1], "product_id");
218    }
219}