use crate::ObjectKey;
use crate::change_tracker::ChangeTracker;
use crate::flush::{FlushOrderer, FlushPlan, PendingOp};
use serde::Serialize;
use sqlmodel_core::{Error, Model, Value};
use std::collections::{HashMap, HashSet};
#[derive(Default)]
pub struct UnitOfWork {
new_objects: Vec<TrackedInsert>,
dirty_objects: Vec<TrackedUpdate>,
deleted_objects: Vec<TrackedDelete>,
change_tracker: ChangeTracker,
orderer: FlushOrderer,
tables: HashSet<&'static str>,
table_dependencies: HashMap<&'static str, Vec<&'static str>>,
}
struct TrackedInsert {
key: ObjectKey,
table: &'static str,
columns: Vec<&'static str>,
values: Vec<Value>,
}
struct TrackedUpdate {
key: ObjectKey,
table: &'static str,
pk_columns: Vec<&'static str>,
pk_values: Vec<Value>,
set_columns: Vec<&'static str>,
set_values: Vec<Value>,
}
struct TrackedDelete {
key: ObjectKey,
table: &'static str,
pk_columns: Vec<&'static str>,
pk_values: Vec<Value>,
}
#[derive(Debug, Clone)]
pub enum UowError {
CycleDetected {
tables: Vec<&'static str>,
},
AlreadyTracked {
key: ObjectKey,
state: &'static str,
},
}
impl std::fmt::Display for UowError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
UowError::CycleDetected { tables } => {
write!(f, "Dependency cycle detected: {}", tables.join(" -> "))
}
UowError::AlreadyTracked { key, state } => {
write!(f, "Object {:?} already tracked as {}", key, state)
}
}
}
}
impl std::error::Error for UowError {}
impl From<UowError> for Error {
fn from(e: UowError) -> Self {
Error::Custom(e.to_string())
}
}
impl UnitOfWork {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register_model<T: Model>(&mut self) {
self.orderer.register_model::<T>();
let table = T::TABLE_NAME;
self.tables.insert(table);
let deps: Vec<&'static str> = T::fields()
.iter()
.filter_map(|f| f.foreign_key)
.filter_map(|fk| fk.split('.').next())
.collect();
self.table_dependencies.insert(table, deps);
}
pub fn track_new<T: Model + Serialize>(&mut self, model: &T, key: ObjectKey) {
let row = model.to_row();
let columns: Vec<&'static str> = row.iter().map(|(col, _)| *col).collect();
let values: Vec<Value> = row.into_iter().map(|(_, val)| val).collect();
self.new_objects.push(TrackedInsert {
key,
table: T::TABLE_NAME,
columns,
values,
});
}
pub fn track_dirty<T: Model + Serialize>(
&mut self,
model: &T,
key: ObjectKey,
changed_columns: Vec<&'static str>,
) {
if changed_columns.is_empty() {
return;
}
let row = model.to_row();
let row_map: HashMap<&str, Value> = row.into_iter().collect();
let pk_columns: Vec<&'static str> = T::PRIMARY_KEY.to_vec();
let pk_values = model.primary_key_value();
let set_columns = changed_columns;
let set_values: Vec<Value> = set_columns
.iter()
.filter_map(|col| row_map.get(*col).cloned())
.collect();
self.dirty_objects.push(TrackedUpdate {
key,
table: T::TABLE_NAME,
pk_columns,
pk_values,
set_columns,
set_values,
});
}
pub fn track_dirty_auto<T: Model + Serialize>(&mut self, model: &T, key: ObjectKey) {
let changed = self.change_tracker.changed_fields(&key, model);
if !changed.is_empty() {
self.track_dirty(model, key, changed);
}
}
pub fn track_deleted<T: Model>(&mut self, model: &T, key: ObjectKey) {
let pk_columns: Vec<&'static str> = T::PRIMARY_KEY.to_vec();
let pk_values = model.primary_key_value();
self.deleted_objects.push(TrackedDelete {
key,
table: T::TABLE_NAME,
pk_columns,
pk_values,
});
}
pub fn snapshot<T: Model + Serialize>(&mut self, key: ObjectKey, model: &T) {
self.change_tracker.snapshot(key, model);
}
pub fn is_dirty<T: Model + Serialize>(&self, key: &ObjectKey, model: &T) -> bool {
self.change_tracker.is_dirty(key, model)
}
pub fn changed_fields<T: Model + Serialize>(
&self,
key: &ObjectKey,
model: &T,
) -> Vec<&'static str> {
self.change_tracker.changed_fields(key, model)
}
pub fn check_cycles(&self) -> Result<(), UowError> {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
let mut cycle_path = Vec::new();
for table in &self.tables {
if !visited.contains(table)
&& self.detect_cycle_dfs(table, &mut visited, &mut rec_stack, &mut cycle_path)
{
return Err(UowError::CycleDetected { tables: cycle_path });
}
}
Ok(())
}
fn detect_cycle_dfs(
&self,
table: &'static str,
visited: &mut HashSet<&'static str>,
rec_stack: &mut HashSet<&'static str>,
path: &mut Vec<&'static str>,
) -> bool {
visited.insert(table);
rec_stack.insert(table);
path.push(table);
if let Some(deps) = self.table_dependencies.get(table) {
for dep in deps {
if !self.tables.contains(dep) {
continue;
}
if !visited.contains(dep) {
if self.detect_cycle_dfs(dep, visited, rec_stack, path) {
return true;
}
} else if rec_stack.contains(dep) {
path.push(dep);
return true;
}
}
}
rec_stack.remove(table);
path.pop();
false
}
pub fn compute_flush_plan(&self) -> Result<FlushPlan, UowError> {
self.check_cycles()?;
let mut ops = Vec::new();
for insert in &self.new_objects {
ops.push(PendingOp::Insert {
key: insert.key,
table: insert.table,
columns: insert.columns.clone(),
values: insert.values.clone(),
});
}
for update in &self.dirty_objects {
ops.push(PendingOp::Update {
key: update.key,
table: update.table,
pk_columns: update.pk_columns.clone(),
pk_values: update.pk_values.clone(),
set_columns: update.set_columns.clone(),
set_values: update.set_values.clone(),
});
}
for delete in &self.deleted_objects {
ops.push(PendingOp::Delete {
key: delete.key,
table: delete.table,
pk_columns: delete.pk_columns.clone(),
pk_values: delete.pk_values.clone(),
});
}
Ok(self.orderer.order(ops))
}
pub fn clear(&mut self) {
self.new_objects.clear();
self.dirty_objects.clear();
self.deleted_objects.clear();
self.change_tracker.clear_all();
}
#[must_use]
pub fn has_changes(&self) -> bool {
!self.new_objects.is_empty()
|| !self.dirty_objects.is_empty()
|| !self.deleted_objects.is_empty()
}
#[must_use]
pub fn pending_count(&self) -> PendingCounts {
PendingCounts {
new: self.new_objects.len(),
dirty: self.dirty_objects.len(),
deleted: self.deleted_objects.len(),
}
}
#[must_use]
pub fn change_tracker(&self) -> &ChangeTracker {
&self.change_tracker
}
pub fn change_tracker_mut(&mut self) -> &mut ChangeTracker {
&mut self.change_tracker
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct PendingCounts {
pub new: usize,
pub dirty: usize,
pub deleted: usize,
}
impl PendingCounts {
#[must_use]
pub fn total(&self) -> usize {
self.new + self.dirty + self.deleted
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.new == 0 && self.dirty == 0 && self.deleted == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use sqlmodel_core::{FieldInfo, Row, SqlType};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Team {
id: Option<i64>,
name: String,
}
impl Model for Team {
const TABLE_NAME: &'static str = "teams";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
fn fields() -> &'static [FieldInfo] {
static FIELDS: &[FieldInfo] = &[
FieldInfo::new("id", "id", SqlType::BigInt).primary_key(true),
FieldInfo::new("name", "name", SqlType::Text),
];
FIELDS
}
fn to_row(&self) -> Vec<(&'static str, Value)> {
vec![
("id", self.id.map_or(Value::Null, Value::BigInt)),
("name", Value::Text(self.name.clone())),
]
}
fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
Ok(Self {
id: None,
name: String::new(),
})
}
fn primary_key_value(&self) -> Vec<Value> {
vec![self.id.map_or(Value::Null, Value::BigInt)]
}
fn is_new(&self) -> bool {
self.id.is_none()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Hero {
id: Option<i64>,
name: String,
team_id: Option<i64>,
}
impl Model for Hero {
const TABLE_NAME: &'static str = "heroes";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
fn fields() -> &'static [FieldInfo] {
static FIELDS: &[FieldInfo] = &[
FieldInfo::new("id", "id", SqlType::BigInt).primary_key(true),
FieldInfo::new("name", "name", SqlType::Text),
FieldInfo::new("team_id", "team_id", SqlType::BigInt)
.nullable(true)
.foreign_key("teams.id"),
];
FIELDS
}
fn to_row(&self) -> Vec<(&'static str, Value)> {
vec![
("id", self.id.map_or(Value::Null, Value::BigInt)),
("name", Value::Text(self.name.clone())),
("team_id", self.team_id.map_or(Value::Null, Value::BigInt)),
]
}
fn from_row(_row: &Row) -> sqlmodel_core::Result<Self> {
Ok(Self {
id: None,
name: String::new(),
team_id: None,
})
}
fn primary_key_value(&self) -> Vec<Value> {
vec![self.id.map_or(Value::Null, Value::BigInt)]
}
fn is_new(&self) -> bool {
self.id.is_none()
}
}
fn make_key<T: Model + 'static>(pk: i64) -> ObjectKey {
ObjectKey::from_pk::<T>(&[Value::BigInt(pk)])
}
#[test]
fn test_track_new_object() {
let mut uow = UnitOfWork::new();
let team = Team {
id: Some(1),
name: "Avengers".to_string(),
};
let key = make_key::<Team>(1);
uow.track_new(&team, key);
assert!(uow.has_changes());
assert_eq!(uow.pending_count().new, 1);
assert_eq!(uow.pending_count().dirty, 0);
assert_eq!(uow.pending_count().deleted, 0);
}
#[test]
fn test_track_dirty_object() {
let mut uow = UnitOfWork::new();
let hero = Hero {
id: Some(1),
name: "Spider-Man".to_string(),
team_id: Some(1),
};
let key = make_key::<Hero>(1);
uow.track_dirty(&hero, key, vec!["name"]);
assert!(uow.has_changes());
assert_eq!(uow.pending_count().dirty, 1);
}
#[test]
fn test_track_deleted_object() {
let mut uow = UnitOfWork::new();
let team = Team {
id: Some(1),
name: "Avengers".to_string(),
};
let key = make_key::<Team>(1);
uow.track_deleted(&team, key);
assert!(uow.has_changes());
assert_eq!(uow.pending_count().deleted, 1);
}
#[test]
fn test_compute_flush_plan_orders_correctly() {
let mut uow = UnitOfWork::new();
uow.register_model::<Team>();
uow.register_model::<Hero>();
let hero = Hero {
id: Some(1),
name: "Spider-Man".to_string(),
team_id: Some(1),
};
let team = Team {
id: Some(1),
name: "Avengers".to_string(),
};
uow.track_new(&hero, make_key::<Hero>(1));
uow.track_new(&team, make_key::<Team>(1));
let plan = uow.compute_flush_plan().unwrap();
assert_eq!(plan.inserts[0].table(), "teams");
assert_eq!(plan.inserts[1].table(), "heroes");
}
#[test]
fn test_clear_removes_all_tracked() {
let mut uow = UnitOfWork::new();
let team = Team {
id: Some(1),
name: "Avengers".to_string(),
};
uow.track_new(&team, make_key::<Team>(1));
uow.track_deleted(&team, make_key::<Team>(2));
assert!(uow.has_changes());
uow.clear();
assert!(!uow.has_changes());
assert!(uow.pending_count().is_empty());
}
#[test]
fn test_snapshot_and_dirty_detection() {
let mut uow = UnitOfWork::new();
let hero = Hero {
id: Some(1),
name: "Spider-Man".to_string(),
team_id: Some(1),
};
let key = make_key::<Hero>(1);
uow.snapshot(key, &hero);
assert!(!uow.is_dirty(&key, &hero));
let modified = Hero {
id: Some(1),
name: "Peter Parker".to_string(),
team_id: Some(1),
};
assert!(uow.is_dirty(&key, &modified));
let changed = uow.changed_fields(&key, &modified);
assert_eq!(changed, vec!["name"]);
}
#[test]
fn test_track_dirty_auto() {
let mut uow = UnitOfWork::new();
let hero = Hero {
id: Some(1),
name: "Spider-Man".to_string(),
team_id: Some(1),
};
let key = make_key::<Hero>(1);
uow.snapshot(key, &hero);
let modified = Hero {
id: Some(1),
name: "Peter Parker".to_string(),
team_id: Some(2),
};
uow.track_dirty_auto(&modified, key);
assert_eq!(uow.pending_count().dirty, 1);
}
#[test]
fn test_no_cycle_in_normal_hierarchy() {
let mut uow = UnitOfWork::new();
uow.register_model::<Team>();
uow.register_model::<Hero>();
assert!(uow.check_cycles().is_ok());
}
#[test]
fn test_pending_counts() {
let counts = PendingCounts {
new: 3,
dirty: 2,
deleted: 1,
};
assert_eq!(counts.total(), 6);
assert!(!counts.is_empty());
let empty = PendingCounts::default();
assert!(empty.is_empty());
assert_eq!(empty.total(), 0);
}
#[test]
fn test_empty_dirty_not_tracked() {
let mut uow = UnitOfWork::new();
let hero = Hero {
id: Some(1),
name: "Spider-Man".to_string(),
team_id: Some(1),
};
let key = make_key::<Hero>(1);
uow.track_dirty(&hero, key, vec![]);
assert!(!uow.has_changes());
}
}