use std::fmt;
use bytes::Bytes;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum ChangeOperation {
Insert,
Update,
Delete,
}
impl ChangeOperation {
#[must_use]
pub fn from_sql(s: &str) -> Option<Self> {
match s.trim().to_uppercase().as_str() {
"I" => Some(Self::Insert),
"U" => Some(Self::Update),
"D" => Some(Self::Delete),
_ => None,
}
}
#[must_use]
pub const fn as_sql(&self) -> &'static str {
match self {
Self::Insert => "I",
Self::Update => "U",
Self::Delete => "D",
}
}
#[must_use]
pub const fn is_insert(&self) -> bool {
matches!(self, Self::Insert)
}
#[must_use]
pub const fn is_update(&self) -> bool {
matches!(self, Self::Update)
}
#[must_use]
pub const fn is_delete(&self) -> bool {
matches!(self, Self::Delete)
}
}
impl fmt::Display for ChangeOperation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Insert => write!(f, "INSERT"),
Self::Update => write!(f, "UPDATE"),
Self::Delete => write!(f, "DELETE"),
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ChangeMetadata {
pub version: i64,
pub creation_version: Option<i64>,
pub operation: ChangeOperation,
pub changed_columns: Option<Bytes>,
pub context: Option<Bytes>,
}
impl ChangeMetadata {
#[must_use]
pub fn new(
version: i64,
creation_version: Option<i64>,
operation: ChangeOperation,
changed_columns: Option<Bytes>,
context: Option<Bytes>,
) -> Self {
Self {
version,
creation_version,
operation,
changed_columns,
context,
}
}
#[must_use]
pub fn insert(version: i64) -> Self {
Self {
version,
creation_version: Some(version),
operation: ChangeOperation::Insert,
changed_columns: None,
context: None,
}
}
#[must_use]
pub fn update(version: i64, creation_version: i64) -> Self {
Self {
version,
creation_version: Some(creation_version),
operation: ChangeOperation::Update,
changed_columns: None,
context: None,
}
}
#[must_use]
pub fn delete(version: i64) -> Self {
Self {
version,
creation_version: None,
operation: ChangeOperation::Delete,
changed_columns: None,
context: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ChangeTrackingQuery {
table_name: String,
last_sync_version: i64,
columns: Option<Vec<String>>,
primary_keys: Option<Vec<String>>,
alias: String,
force_seek: bool,
}
impl ChangeTrackingQuery {
#[must_use]
pub fn changes(table_name: impl Into<String>, last_sync_version: i64) -> Self {
Self {
table_name: table_name.into(),
last_sync_version,
columns: None,
primary_keys: None,
alias: "CT".into(),
force_seek: false,
}
}
#[must_use]
pub fn with_columns(mut self, columns: &[&str]) -> Self {
self.columns = Some(columns.iter().map(|&s| s.to_string()).collect());
self
}
#[must_use]
pub fn with_primary_keys(mut self, keys: &[&str]) -> Self {
self.primary_keys = Some(keys.iter().map(|&s| s.to_string()).collect());
self
}
#[must_use]
pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
self.alias = alias.into();
self
}
#[must_use]
pub fn with_force_seek(mut self) -> Self {
self.force_seek = true;
self
}
#[must_use]
pub fn to_sql(&self) -> String {
let force_seek = if self.force_seek { ", FORCESEEK" } else { "" };
let select_cols = self.build_select_columns();
format!(
"SELECT {} FROM CHANGETABLE(CHANGES {}, {}{}) AS {}",
select_cols,
ChangeTracking::bracket_table_name(&self.table_name),
self.last_sync_version,
force_seek,
bracket_identifier(&self.alias)
)
}
#[must_use]
pub fn to_sql_with_data(&self, data_columns: &[&str]) -> String {
let force_seek = if self.force_seek { ", FORCESEEK" } else { "" };
let alias = bracket_identifier(&self.alias);
let ct_cols = format!(
"{alias}.SYS_CHANGE_VERSION, {alias}.SYS_CHANGE_CREATION_VERSION, \
{alias}.SYS_CHANGE_OPERATION, {alias}.SYS_CHANGE_COLUMNS, {alias}.SYS_CHANGE_CONTEXT"
);
let data_cols: String = data_columns
.iter()
.map(|c| format!("T.{}", bracket_identifier(c)))
.collect::<Vec<_>>()
.join(", ");
let pk_cols: String = self
.primary_keys
.as_ref()
.map(|pks| {
pks.iter()
.map(|pk| format!("{alias}.{}", bracket_identifier(pk)))
.collect::<Vec<_>>()
.join(", ")
})
.unwrap_or_default();
let join_condition: String = self
.primary_keys
.as_ref()
.map(|pks| {
pks.iter()
.map(|pk| {
let pk = bracket_identifier(pk);
format!("{alias}.{pk} = T.{pk}")
})
.collect::<Vec<_>>()
.join(" AND ")
})
.unwrap_or_else(|| "1=1".into());
let select_cols = if pk_cols.is_empty() {
format!("{ct_cols}, {data_cols}")
} else {
format!("{ct_cols}, {pk_cols}, {data_cols}")
};
format!(
"SELECT {select_cols} \
FROM CHANGETABLE(CHANGES {table}, {version}{force_seek}) AS {alias} \
LEFT OUTER JOIN {table} AS T ON {join_condition}",
table = ChangeTracking::bracket_table_name(&self.table_name),
version = self.last_sync_version,
)
}
fn build_select_columns(&self) -> String {
let alias = bracket_identifier(&self.alias);
let mut cols = vec![
format!("{alias}.SYS_CHANGE_VERSION"),
format!("{alias}.SYS_CHANGE_CREATION_VERSION"),
format!("{alias}.SYS_CHANGE_OPERATION"),
format!("{alias}.SYS_CHANGE_COLUMNS"),
format!("{alias}.SYS_CHANGE_CONTEXT"),
];
if let Some(ref pks) = self.primary_keys {
for pk in pks {
cols.push(format!("{alias}.{}", bracket_identifier(pk)));
}
}
if let Some(ref data_cols) = self.columns {
for col in data_cols {
cols.push(format!("{alias}.{}", bracket_identifier(col)));
}
}
cols.join(", ")
}
}
pub struct ChangeTracking;
impl ChangeTracking {
#[must_use]
pub const fn current_version_sql() -> &'static str {
"SELECT CHANGE_TRACKING_CURRENT_VERSION()"
}
#[must_use]
pub fn min_valid_version_sql(table_name: &str) -> String {
format!(
"SELECT CHANGE_TRACKING_MIN_VALID_VERSION(OBJECT_ID(N'{}'))",
escape_nvarchar_literal(table_name)
)
}
pub fn column_in_mask_sql(
table_name: &str,
column_name: &str,
mask_variable: &str,
) -> Result<String, crate::Error> {
crate::validation::validate_identifier(mask_variable)?;
Ok(format!(
"SELECT CHANGE_TRACKING_IS_COLUMN_IN_MASK(\
COLUMNPROPERTY(OBJECT_ID(N'{}'), N'{}', 'ColumnId'), \
{mask_variable})",
escape_nvarchar_literal(table_name),
escape_nvarchar_literal(column_name)
))
}
#[must_use]
pub fn enable_database_sql(
database_name: &str,
retention_days: u32,
auto_cleanup: bool,
) -> String {
let cleanup = if auto_cleanup { "ON" } else { "OFF" };
format!(
"ALTER DATABASE [{}] SET CHANGE_TRACKING = ON \
(CHANGE_RETENTION = {retention_days} DAYS, AUTO_CLEANUP = {cleanup})",
database_name.replace(']', "]]")
)
}
#[must_use]
pub fn enable_table_sql(table_name: &str, track_columns_updated: bool) -> String {
let track_cols = if track_columns_updated { "ON" } else { "OFF" };
let table = Self::bracket_table_name(table_name);
format!(
"ALTER TABLE {table} ENABLE CHANGE_TRACKING \
WITH (TRACK_COLUMNS_UPDATED = {track_cols})"
)
}
#[must_use]
pub fn disable_table_sql(table_name: &str) -> String {
let table = Self::bracket_table_name(table_name);
format!("ALTER TABLE {table} DISABLE CHANGE_TRACKING")
}
fn bracket_table_name(table_name: &str) -> String {
table_name
.split('.')
.map(|part| format!("[{}]", part.replace(']', "]]")))
.collect::<Vec<_>>()
.join(".")
}
#[must_use]
pub fn disable_database_sql(database_name: &str) -> String {
format!(
"ALTER DATABASE [{}] SET CHANGE_TRACKING = OFF",
database_name.replace(']', "]]")
)
}
}
fn escape_nvarchar_literal(s: &str) -> String {
s.replace('\'', "''")
}
fn bracket_identifier(name: &str) -> String {
format!("[{}]", name.replace(']', "]]"))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum SyncVersionStatus {
Valid,
TooOld,
NotEnabled,
}
impl SyncVersionStatus {
#[must_use]
pub fn check(last_sync_version: i64, min_valid_version: Option<i64>) -> Self {
match min_valid_version {
None => Self::NotEnabled,
Some(min) if last_sync_version >= min => Self::Valid,
Some(_) => Self::TooOld,
}
}
#[must_use]
pub const fn can_sync_incrementally(&self) -> bool {
matches!(self, Self::Valid)
}
#[must_use]
pub const fn requires_full_sync(&self) -> bool {
matches!(self, Self::TooOld)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_change_operation_from_sql() {
assert_eq!(
ChangeOperation::from_sql("I"),
Some(ChangeOperation::Insert)
);
assert_eq!(
ChangeOperation::from_sql("U"),
Some(ChangeOperation::Update)
);
assert_eq!(
ChangeOperation::from_sql("D"),
Some(ChangeOperation::Delete)
);
assert_eq!(
ChangeOperation::from_sql("i"),
Some(ChangeOperation::Insert)
);
assert_eq!(
ChangeOperation::from_sql(" U "),
Some(ChangeOperation::Update)
);
assert_eq!(ChangeOperation::from_sql("X"), None);
assert_eq!(ChangeOperation::from_sql(""), None);
}
#[test]
fn test_change_operation_as_sql() {
assert_eq!(ChangeOperation::Insert.as_sql(), "I");
assert_eq!(ChangeOperation::Update.as_sql(), "U");
assert_eq!(ChangeOperation::Delete.as_sql(), "D");
}
#[test]
fn test_change_operation_predicates() {
assert!(ChangeOperation::Insert.is_insert());
assert!(!ChangeOperation::Insert.is_update());
assert!(!ChangeOperation::Insert.is_delete());
assert!(!ChangeOperation::Update.is_insert());
assert!(ChangeOperation::Update.is_update());
assert!(!ChangeOperation::Update.is_delete());
assert!(!ChangeOperation::Delete.is_insert());
assert!(!ChangeOperation::Delete.is_update());
assert!(ChangeOperation::Delete.is_delete());
}
#[test]
fn test_change_metadata_constructors() {
let insert = ChangeMetadata::insert(42);
assert_eq!(insert.version, 42);
assert_eq!(insert.creation_version, Some(42));
assert_eq!(insert.operation, ChangeOperation::Insert);
let update = ChangeMetadata::update(50, 42);
assert_eq!(update.version, 50);
assert_eq!(update.creation_version, Some(42));
assert_eq!(update.operation, ChangeOperation::Update);
let delete = ChangeMetadata::delete(60);
assert_eq!(delete.version, 60);
assert_eq!(delete.creation_version, None);
assert_eq!(delete.operation, ChangeOperation::Delete);
}
#[test]
fn test_change_tracking_query_basic() {
let query = ChangeTrackingQuery::changes("Products", 42);
let sql = query.to_sql();
assert!(sql.contains("CHANGETABLE(CHANGES [Products], 42)"));
assert!(sql.contains("SYS_CHANGE_VERSION"));
assert!(sql.contains("SYS_CHANGE_OPERATION"));
assert!(
sql.contains(") AS [CT]"),
"CHANGETABLE must be aliased or the query is not executable: {sql}"
);
}
#[test]
fn test_change_tracking_query_with_columns() {
let query = ChangeTrackingQuery::changes("Products", 42).with_columns(&["Name", "Price"]);
let sql = query.to_sql();
assert!(sql.contains("[CT].[Name]"));
assert!(sql.contains("[CT].[Price]"));
}
#[test]
fn test_change_tracking_query_with_primary_keys() {
let query = ChangeTrackingQuery::changes("Products", 42).with_primary_keys(&["ProductId"]);
let sql = query.to_sql();
assert!(sql.contains("[CT].[ProductId]"));
}
#[test]
fn test_change_tracking_query_force_seek() {
let query = ChangeTrackingQuery::changes("Products", 42).with_force_seek();
let sql = query.to_sql();
assert!(sql.contains("FORCESEEK"));
}
#[test]
fn test_change_tracking_query_with_data() {
let query = ChangeTrackingQuery::changes("Products", 42).with_primary_keys(&["ProductId"]);
let sql = query.to_sql_with_data(&["Name", "Price"]);
assert!(sql.contains("LEFT OUTER JOIN [Products] AS T"));
assert!(sql.contains("[CT].[ProductId] = T.[ProductId]"));
assert!(sql.contains("T.[Name]"));
assert!(sql.contains("T.[Price]"));
}
#[test]
fn test_change_tracking_query_brackets_hostile_identifiers() {
let hostile_table = "T, 0) AS CT; DROP TABLE x--";
let sql = ChangeTrackingQuery::changes(hostile_table, 42).to_sql();
assert!(
sql.contains("CHANGETABLE(CHANGES [T, 0) AS CT; DROP TABLE x--], 42)"),
"hostile table name must stay one quoted identifier: {sql}"
);
let sql = ChangeTrackingQuery::changes("foo]; DROP TABLE bar--", 1).to_sql();
assert!(sql.contains("CHANGETABLE(CHANGES [foo]]; DROP TABLE bar--], 1)"));
let sql = ChangeTrackingQuery::changes("dbo.Items", 1).to_sql();
assert!(sql.contains("CHANGETABLE(CHANGES [dbo].[Items], 1)"));
let sql = ChangeTrackingQuery::changes("Products", 1)
.with_alias("A]; DROP TABLE x--")
.with_columns(&["C] FROM x--"])
.to_sql();
assert!(sql.contains("AS [A]]; DROP TABLE x--]"));
assert!(sql.contains("[A]]; DROP TABLE x--].[C]] FROM x--]"));
let sql = ChangeTrackingQuery::changes("Products", 1)
.with_primary_keys(&["P]--"])
.to_sql_with_data(&["D]--"]);
assert!(sql.contains("[CT].[P]]--] = T.[P]]--]"));
assert!(sql.contains("T.[D]]--]"));
}
#[test]
fn test_change_tracking_helper_sql() {
assert_eq!(
ChangeTracking::current_version_sql(),
"SELECT CHANGE_TRACKING_CURRENT_VERSION()"
);
let min_sql = ChangeTracking::min_valid_version_sql("Products");
assert!(min_sql.contains("CHANGE_TRACKING_MIN_VALID_VERSION"));
assert!(min_sql.contains("Products"));
let mask_sql = ChangeTracking::column_in_mask_sql("Products", "Price", "@mask").unwrap();
assert!(mask_sql.contains("CHANGE_TRACKING_IS_COLUMN_IN_MASK"));
assert!(mask_sql.contains("Price"));
assert!(mask_sql.contains("@mask"));
}
#[test]
fn test_nvarchar_literal_names_cannot_break_out() {
let hostile = "x'); SELECT 1--";
let sql = ChangeTracking::min_valid_version_sql(hostile);
assert!(
sql.contains("N'x''); SELECT 1--'"),
"single quotes must be doubled, got: {sql}"
);
assert!(!sql.contains("N'x');"), "literal must not end early: {sql}");
let sql = ChangeTracking::column_in_mask_sql(hostile, hostile, "@mask").unwrap();
assert!(sql.contains("N'x''); SELECT 1--'"));
assert!(!sql.contains("N'x');"));
}
#[test]
fn test_mask_variable_is_validated() {
assert!(ChangeTracking::column_in_mask_sql("T", "C", "@mask").is_ok());
assert!(ChangeTracking::column_in_mask_sql("T", "C", "@mask); DROP TABLE x--").is_err());
assert!(ChangeTracking::column_in_mask_sql("T", "C", "1 OR 1=1").is_err());
assert!(ChangeTracking::column_in_mask_sql("T", "C", "").is_err());
}
#[test]
fn test_database_name_brackets_are_escaped() {
let hostile = "x]; DROP DATABASE foo--";
let sql = ChangeTracking::enable_database_sql(hostile, 2, true);
assert!(
sql.contains("ALTER DATABASE [x]]; DROP DATABASE foo--]"),
"interior ] must be doubled, got: {sql}"
);
assert!(!sql.contains("ALTER DATABASE [x];"));
let sql = ChangeTracking::disable_database_sql(hostile);
assert!(sql.contains("ALTER DATABASE [x]]; DROP DATABASE foo--]"));
assert!(!sql.contains("ALTER DATABASE [x];"));
}
#[test]
fn test_change_tracking_enable_sql() {
let db_sql = ChangeTracking::enable_database_sql("MyDB", 7, true);
assert!(db_sql.contains("[MyDB]"));
assert!(db_sql.contains("7 DAYS"));
assert!(db_sql.contains("AUTO_CLEANUP = ON"));
let table_sql = ChangeTracking::enable_table_sql("Products", true);
assert!(table_sql.contains("[Products]"));
assert!(table_sql.contains("TRACK_COLUMNS_UPDATED = ON"));
let qualified = ChangeTracking::enable_table_sql("dbo.Products", true);
assert!(qualified.contains("ALTER TABLE [dbo].[Products]"));
let disable = ChangeTracking::disable_table_sql("dbo.Products");
assert!(disable.contains("ALTER TABLE [dbo].[Products]"));
}
#[test]
fn test_bracket_escapes_closing_brackets() {
let sql = ChangeTracking::enable_table_sql("foo]; DROP TABLE bar--", true);
assert!(
sql.contains("ALTER TABLE [foo]]; DROP TABLE bar--]"),
"interior ] must be doubled, got: {sql}"
);
assert!(
!sql.contains("ALTER TABLE [foo];"),
"the identifier must not be terminated early: {sql}"
);
let sql = ChangeTracking::disable_table_sql("we]ird.na]me");
assert!(sql.contains("ALTER TABLE [we]]ird].[na]]me]"));
}
#[test]
fn test_sync_version_status() {
assert!(SyncVersionStatus::check(100, Some(50)).can_sync_incrementally());
assert!(SyncVersionStatus::check(50, Some(50)).can_sync_incrementally());
assert!(SyncVersionStatus::check(40, Some(50)).requires_full_sync());
let status = SyncVersionStatus::check(100, None);
assert_eq!(status, SyncVersionStatus::NotEnabled);
assert!(!status.can_sync_incrementally());
}
}