use std::collections::{HashMap, HashSet};
use std::str::FromStr;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TrackingMode {
None,
Table,
Row,
#[default]
Adaptive,
}
impl TrackingMode {
pub fn as_str(&self) -> &'static str {
match self {
Self::None => "none",
Self::Table => "table",
Self::Row => "row",
Self::Adaptive => "adaptive",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParseTrackingModeError(pub String);
impl std::fmt::Display for ParseTrackingModeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "invalid tracking mode: {}", self.0)
}
}
impl std::error::Error for ParseTrackingModeError {}
impl FromStr for TrackingMode {
type Err = ParseTrackingModeError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"none" => Ok(Self::None),
"table" => Ok(Self::Table),
"row" => Ok(Self::Row),
"adaptive" => Ok(Self::Adaptive),
_ => Err(ParseTrackingModeError(s.to_string())),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ReadSet {
pub tables: HashSet<String>,
pub rows: HashMap<String, HashSet<Uuid>>,
pub filter_columns: HashMap<String, HashSet<String>>,
pub mode: TrackingMode,
}
impl ReadSet {
pub fn new() -> Self {
Self::default()
}
pub fn table_level() -> Self {
Self {
mode: TrackingMode::Table,
..Default::default()
}
}
pub fn row_level() -> Self {
Self {
mode: TrackingMode::Row,
..Default::default()
}
}
pub fn add_table(&mut self, table: impl Into<String>) {
self.tables.insert(table.into());
}
pub fn add_row(&mut self, table: impl Into<String>, row_id: Uuid) {
let table = table.into();
self.tables.insert(table.clone());
self.rows.entry(table).or_default().insert(row_id);
}
pub fn add_filter_column(&mut self, table: impl Into<String>, column: impl Into<String>) {
self.filter_columns
.entry(table.into())
.or_default()
.insert(column.into());
}
pub fn includes_table(&self, table: &str) -> bool {
self.tables.contains(table)
}
pub fn includes_row(&self, table: &str, row_id: Uuid) -> bool {
if !self.tables.contains(table) {
return false;
}
if self.mode == TrackingMode::Table {
return true;
}
if let Some(rows) = self.rows.get(table) {
rows.contains(&row_id)
} else {
true
}
}
pub fn memory_bytes(&self) -> usize {
let table_bytes = self.tables.iter().map(|s| s.len() + 24).sum::<usize>();
let row_bytes = self
.rows
.values()
.map(|set| set.len() * 16 + 24)
.sum::<usize>();
let filter_bytes = self
.filter_columns
.values()
.map(|set| set.iter().map(|s| s.len() + 24).sum::<usize>())
.sum::<usize>();
table_bytes + row_bytes + filter_bytes + 64 }
pub fn row_count(&self) -> usize {
self.rows.values().map(|set| set.len()).sum()
}
pub fn merge(&mut self, other: &ReadSet) {
self.tables.extend(other.tables.iter().cloned());
for (table, rows) in &other.rows {
self.rows
.entry(table.clone())
.or_default()
.extend(rows.iter().cloned());
}
for (table, columns) in &other.filter_columns {
self.filter_columns
.entry(table.clone())
.or_default()
.extend(columns.iter().cloned());
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChangeOperation {
Insert,
Update,
Delete,
}
impl ChangeOperation {
pub fn as_str(&self) -> &'static str {
match self {
Self::Insert => "INSERT",
Self::Update => "UPDATE",
Self::Delete => "DELETE",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParseChangeOperationError(pub String);
impl std::fmt::Display for ParseChangeOperationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "invalid change operation: {}", self.0)
}
}
impl std::error::Error for ParseChangeOperationError {}
impl FromStr for ChangeOperation {
type Err = ParseChangeOperationError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_uppercase().as_str() {
"INSERT" | "I" => Ok(Self::Insert),
"UPDATE" | "U" => Ok(Self::Update),
"DELETE" | "D" => Ok(Self::Delete),
_ => Err(ParseChangeOperationError(s.to_string())),
}
}
}
#[derive(Debug, Clone)]
pub struct Change {
pub table: String,
pub operation: ChangeOperation,
pub row_id: Option<Uuid>,
pub changed_columns: Vec<String>,
}
impl Change {
pub fn new(table: impl Into<String>, operation: ChangeOperation) -> Self {
Self {
table: table.into(),
operation,
row_id: None,
changed_columns: Vec::new(),
}
}
pub fn with_row_id(mut self, row_id: Uuid) -> Self {
self.row_id = Some(row_id);
self
}
pub fn with_columns(mut self, columns: Vec<String>) -> Self {
self.changed_columns = columns;
self
}
pub fn invalidates(&self, read_set: &ReadSet) -> bool {
if !read_set.includes_table(&self.table) {
return false;
}
if read_set.mode == TrackingMode::Row {
if let Some(row_id) = self.row_id {
match self.operation {
ChangeOperation::Update | ChangeOperation::Delete => {
return read_set.includes_row(&self.table, row_id);
}
ChangeOperation::Insert => {}
}
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tracking_mode_conversion() {
assert_eq!("table".parse::<TrackingMode>(), Ok(TrackingMode::Table));
assert_eq!("row".parse::<TrackingMode>(), Ok(TrackingMode::Row));
assert_eq!(
"adaptive".parse::<TrackingMode>(),
Ok(TrackingMode::Adaptive)
);
assert!("invalid".parse::<TrackingMode>().is_err());
}
#[test]
fn test_read_set_add_table() {
let mut read_set = ReadSet::new();
read_set.add_table("projects");
assert!(read_set.includes_table("projects"));
assert!(!read_set.includes_table("users"));
}
#[test]
fn test_read_set_add_row() {
let mut read_set = ReadSet::row_level();
let row_id = Uuid::new_v4();
read_set.add_row("projects", row_id);
assert!(read_set.includes_table("projects"));
assert!(read_set.includes_row("projects", row_id));
assert!(!read_set.includes_row("projects", Uuid::new_v4()));
}
#[test]
fn test_change_invalidates_table_level() {
let mut read_set = ReadSet::table_level();
read_set.add_table("projects");
let change = Change::new("projects", ChangeOperation::Insert);
assert!(change.invalidates(&read_set));
let change = Change::new("users", ChangeOperation::Insert);
assert!(!change.invalidates(&read_set));
}
#[test]
fn test_change_invalidates_row_level() {
let mut read_set = ReadSet::row_level();
let tracked_id = Uuid::new_v4();
let other_id = Uuid::new_v4();
read_set.add_row("projects", tracked_id);
let change = Change::new("projects", ChangeOperation::Update).with_row_id(tracked_id);
assert!(change.invalidates(&read_set));
let change = Change::new("projects", ChangeOperation::Update).with_row_id(other_id);
assert!(!change.invalidates(&read_set));
let change = Change::new("projects", ChangeOperation::Insert).with_row_id(other_id);
assert!(change.invalidates(&read_set));
}
#[test]
fn test_read_set_merge() {
let mut read_set1 = ReadSet::new();
read_set1.add_table("projects");
let mut read_set2 = ReadSet::new();
read_set2.add_table("users");
read_set1.merge(&read_set2);
assert!(read_set1.includes_table("projects"));
assert!(read_set1.includes_table("users"));
}
}