use diesel::prelude::*;
use diesel::result::{DatabaseErrorKind, Error};
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::rc::Rc;
pub trait Model: Sized + Clone {
type Id: ?Sized + PartialEq + Debug;
fn all(conn: &mut SqliteConnection) -> QueryResult<Vec<Self>>;
fn find(conn: &mut SqliteConnection, id: &Self::Id) -> QueryResult<Option<Self>>;
fn id_exists(conn: &mut SqliteConnection, id: &Self::Id) -> QueryResult<bool>;
}
pub trait ModelUpdate: Model {
fn update(self, conn: &mut SqliteConnection) -> QueryResult<Self>;
fn update_all(conn: &mut SqliteConnection, items: Vec<Self>) -> QueryResult<Vec<Self>> {
let mut results = Vec::with_capacity(items.len());
for item in items {
results.push(item.update(conn)?);
}
Ok(results)
}
}
pub trait ModelDelete: Model {
fn delete(conn: &mut SqliteConnection, id: &Self::Id) -> QueryResult<usize>;
fn delete_all(conn: &mut SqliteConnection, ids: Vec<&Self::Id>) -> QueryResult<usize> {
let mut result: usize = 0;
for id in ids {
result += Self::delete(conn, id)?;
}
Ok(result)
}
}
pub trait ModelInsert: Sized {
type Model: Model;
fn insert(self, conn: &mut SqliteConnection) -> QueryResult<Self::Model>;
fn insert_all(conn: &mut SqliteConnection, items: Vec<Self>) -> QueryResult<Vec<Self::Model>> {
let mut results = Vec::with_capacity(items.len());
for item in items {
results.push(item.insert(conn)?);
}
Ok(results)
}
}
pub trait ModelTree: Model + TreeSeq {
fn move_to(
conn: &mut SqliteConnection,
id: &Self::Id,
parent_id: Option<&Self::Id>,
seq: Option<i32>,
) -> QueryResult<usize>;
fn move_all_to(
conn: &mut SqliteConnection,
ids: &[&Self::Id],
parent_id: &Self::Id,
start_seq: Option<i32>,
) -> QueryResult<usize> {
let mut result: usize = 0;
let mut seq: Option<i32> = start_seq;
for id in ids {
result += Self::move_to(conn, id, Some(parent_id), seq)?;
if seq.is_some() {
seq = Some(seq.unwrap() + 1);
}
}
Ok(result)
}
}
pub trait ModelList: Model + TreeSeq {
fn move_to(conn: &mut SqliteConnection, id: &Self::Id, seq: Option<i32>) -> QueryResult<usize>;
fn move_all_to(
conn: &mut SqliteConnection,
ids: &[&Self::Id],
start_seq: Option<i32>,
) -> QueryResult<usize> {
let mut result: usize = 0;
let mut seq: Option<i32> = start_seq;
for id in ids {
result += Self::move_to(conn, id, seq)?;
if seq.is_some() {
seq = Some(seq.unwrap() + 1);
}
}
Ok(result)
}
}
pub fn validate_seq_numbers(seqs: &[i32], start: i32) -> bool {
for (i, seq) in seqs.iter().enumerate() {
let expected = i as i32 + start;
if expected != *seq {
return false;
}
}
true
}
pub trait TreeSeq {
type ParentId: ?Sized + PartialEq + Debug;
const START_SEQ: i32 = 1;
fn is_valid_parent(
conn: &mut SqliteConnection,
parent_id: &Self::ParentId,
) -> QueryResult<bool>;
fn count_children(conn: &mut SqliteConnection, parent_id: &Self::ParentId) -> QueryResult<i32>;
fn get_seq_numbers(
conn: &mut SqliteConnection,
parent_id: &Self::ParentId,
) -> QueryResult<Vec<i32>>;
fn validate_seq_numbers(
conn: &mut SqliteConnection,
parent_id: &Self::ParentId,
) -> QueryResult<bool> {
let seq_numbers = Self::get_seq_numbers(conn, parent_id)?;
Ok(validate_seq_numbers(&seq_numbers, Self::START_SEQ))
}
fn reset_seq(conn: &mut SqliteConnection, parent_id: &Self::ParentId) -> QueryResult<usize>;
fn increment_seq_gte(
conn: &mut SqliteConnection,
parent_id: &Self::ParentId,
seq: i32,
) -> QueryResult<usize>;
fn decrement_seq_gt(
conn: &mut SqliteConnection,
parent_id: &Self::ParentId,
seq: i32,
) -> QueryResult<usize>;
fn increment_seq_gte_lt(
conn: &mut SqliteConnection,
parent_id: &Self::ParentId,
start: i32,
end: i32,
) -> QueryResult<usize>;
fn decrement_seq_gt_lte(
conn: &mut SqliteConnection,
parent_id: &Self::ParentId,
start: i32,
end: i32,
) -> QueryResult<usize>;
fn update_seq_before_insert(
conn: &mut SqliteConnection,
parent_id: &Self::ParentId,
seq: Option<i32>,
) -> QueryResult<(i32, usize)> {
let mut n: usize = 0;
if !Self::is_valid_parent(conn, &parent_id)? {
return Err(Error::DatabaseError(
DatabaseErrorKind::RestrictViolation,
Box::new(format!("Invalid parent ID: {:?}", parent_id)),
));
}
let count = Self::count_children(conn, parent_id)?;
let seq = if let Some(seq) = seq {
if seq < Self::START_SEQ || seq > Self::START_SEQ + count {
return Err(Error::DatabaseError(
DatabaseErrorKind::RestrictViolation,
Box::new("Invalid sequence number".to_string()),
));
} else if seq < Self::START_SEQ + count {
n = Self::increment_seq_gte(conn, parent_id, seq)?;
} seq
} else {
count + Self::START_SEQ
};
Ok((seq, n))
}
fn update_seq_before_move_to(
conn: &mut SqliteConnection,
old_parent_id: &Self::ParentId,
parent_id: &Self::ParentId,
old_seq: i32,
seq: Option<i32>,
) -> QueryResult<Option<(i32, usize)>> {
if parent_id == old_parent_id {
return Ok(None);
}
if !Self::is_valid_parent(conn, parent_id)? {
return Err(Error::DatabaseError(
DatabaseErrorKind::RestrictViolation,
Box::new(format!("Invalid parent ID: {:?}", parent_id)),
));
}
let count = Self::count_children(conn, parent_id)?;
let max_seq = Self::START_SEQ + count;
let seq = if let Some(seq) = seq {
if seq < Self::START_SEQ || seq > max_seq {
return Err(Error::DatabaseError(
DatabaseErrorKind::RestrictViolation,
Box::new(format!(
"Invalid sequence number: {} (count: {})",
seq, count
)),
));
}
seq
} else {
max_seq };
let mut n = Self::decrement_seq_gt(conn, &old_parent_id, old_seq)?;
if seq < max_seq {
n += Self::increment_seq_gte(conn, &parent_id, seq)?;
};
Ok(Some((seq, n)))
}
fn update_seq_before_move_in(
conn: &mut SqliteConnection,
parent_id: &Self::ParentId,
old_seq: i32,
seq: Option<i32>,
) -> QueryResult<Option<(i32, usize)>> {
let count = Self::count_children(conn, parent_id)?;
let max_seq = Self::START_SEQ + count - 1;
let seq = if let Some(seq) = seq {
if seq < Self::START_SEQ || seq > max_seq {
return Err(Error::DatabaseError(
DatabaseErrorKind::RestrictViolation,
Box::new(format!(
"Invalid sequence number: {} (count: {})",
seq, count
)),
));
}
seq
} else {
max_seq };
if seq == old_seq {
return Ok(None);
}
let mut n: usize = 0;
if seq > old_seq {
n += Self::decrement_seq_gt_lte(conn, &parent_id, old_seq, seq)?;
} else if seq < old_seq {
n += Self::increment_seq_gte_lt(conn, &parent_id, seq, old_seq)?;
};
Ok(Some((seq, n)))
}
fn update_seq_before_move(
conn: &mut SqliteConnection,
old_parent_id: &Self::ParentId,
parent_id: &Self::ParentId,
old_seq: i32,
seq: Option<i32>,
) -> QueryResult<Option<(i32, usize)>> {
if old_parent_id == parent_id {
Self::update_seq_before_move_in(conn, parent_id, old_seq, seq)
} else {
Self::update_seq_before_move_to(conn, old_parent_id, parent_id, old_seq, seq)
}
}
}
pub trait SequenceVisible {
const START_SEQ: i32 = 1;
fn count_visible(conn: &mut SqliteConnection) -> QueryResult<i32>;
fn get_seq_numbers(conn: &mut SqliteConnection) -> QueryResult<Vec<i32>>;
fn reset_seq(conn: &mut SqliteConnection) -> QueryResult<usize>;
fn increment_seq_before_insert(conn: &mut SqliteConnection, seq: i32) -> QueryResult<usize>;
fn validate_seq_numbers(conn: &mut SqliteConnection) -> QueryResult<bool> {
let seq_numbers = Self::get_seq_numbers(conn)?;
for (i, seq) in seq_numbers.iter().enumerate() {
if (i as i32) + Self::START_SEQ != *seq {
return Ok(false);
}
}
Ok(true)
}
fn update_seq_before_insert(conn: &mut SqliteConnection, seq: i32) -> QueryResult<i32> {
let mut seq = seq;
let count = Self::count_visible(conn)?;
if seq == 0 {
seq = count + 1;
} else if 0 < seq && seq <= count + 1 {
if seq < count + 1 {
Self::increment_seq_before_insert(conn, seq)?;
}
} else {
return Err(Error::DatabaseError(
DatabaseErrorKind::RestrictViolation,
Box::new("Invalid sequence number".to_string()),
));
}
Ok(seq)
}
fn decrement_seq_after_delete(conn: &mut SqliteConnection, seq: i32) -> QueryResult<usize>;
}
pub type NodeRef<T> = Rc<RefCell<T>>;
pub trait TreeNode: From<Self::Model> {
type Model: Model;
type Id: ?Sized + Eq + PartialEq + Debug + Clone + Hash;
fn id(&self) -> Self::Id;
fn sequence(&self) -> i32;
fn parent_id(&self) -> Self::Id;
fn children(&self) -> &Vec<NodeRef<Self>>;
fn children_mut(&mut self) -> &mut Vec<NodeRef<Self>>;
fn sort_children(&mut self) {
self.children_mut().sort_by(|a, b| {
let a_seq = a.borrow().sequence();
let b_seq = b.borrow().sequence();
a_seq.cmp(&b_seq)
});
}
}
fn sort_tree_ref<N: TreeNode>(item: &mut NodeRef<N>) {
let mut item_borrow = item.borrow_mut();
item_borrow.sort_children();
for child in item_borrow.children_mut().iter_mut() {
sort_tree_ref(child);
}
}
fn sort_tree_list<N: TreeNode>(items: &mut Vec<NodeRef<N>>) {
items.sort_by(|a, b| {
let a_seq = a.borrow().sequence();
let b_seq = b.borrow().sequence();
a_seq.cmp(&b_seq)
});
for item in items.iter_mut() {
sort_tree_ref(item);
}
}
pub(crate) fn build_tree_list<M, N>(models: Vec<M>, root_id: N::Id) -> Vec<NodeRef<N>>
where
N: TreeNode<Model = M>,
M: Model,
{
let nodes: Vec<N> = models.into_iter().map(N::from).collect();
let mut map: HashMap<N::Id, NodeRef<N>> = HashMap::new();
let mut roots = Vec::new();
for node in nodes.into_iter() {
let id = node.id();
let parent_id = node.parent_id();
let item: NodeRef<N> = Rc::new(RefCell::new(node));
if parent_id == root_id {
roots.push(item.clone());
}
map.insert(id.clone(), item);
}
for node in map.values() {
let parent_id = node.borrow().parent_id();
if let Some(parent_node) = map.get(&parent_id) {
parent_node.borrow_mut().children_mut().push(node.clone());
}
}
sort_tree_list(&mut roots);
roots
}