use std::fmt;
use std::hash::{
Hash,
Hasher
};
use conciliator::{
Buffer,
Inline,
Paint
};
use fnv::FnvHasher;
use liter::{
Bind,
database,
Entry,
HasKey,
Table,
Value,
Ref
};
use rusqlite::OptionalExtension;
use rusqlite::Result as SqlResult;
use serde::{Serialize, Deserialize};
use crate::peer::{
NodeID,
Address,
Clock,
Status
};
#[database]
#[derive(Debug)]
pub struct Store(
Config,
Conspirator,
LogEntry,
Op,
Pair,
PeerAddress
);
#[derive(Debug, Table, PartialEq, Eq)]
pub struct Config {
#[key]
key: String,
val: String
}
#[derive(Debug, Table, PartialEq, Eq)]
pub struct Conspirator {
#[key]
pub id: NodeID,
pub name: String,
pub active: bool,
pub clock: Clock,
pub state: Status,
pub reach_out_idx: usize,
}
#[derive(PartialEq, Eq, Hash, Clone, Serialize, Deserialize, Table)]
pub struct Pair {
#[key]
key: String,
value: String
}
#[derive(Debug, Table)]
#[unique(conspirator, address)]
pub struct PeerAddress {
#[key]
id: usize,
#[key]
conspirator: Ref<Conspirator>,
address: Address
}
#[derive(Debug, Table)]
pub struct LogEntry {
#[unique]
sort: i64,
#[key]
op: Ref<Op>
}
#[derive(PartialEq, Eq, Clone, Hash, Table, Serialize, Deserialize)]
pub struct Op {
#[key]
pub origin: Version,
pub previous: Option<Version>,
pub target: NodeID,
pub action: Action
}
#[derive(PartialEq, Eq, Copy, Hash, Clone, Value, Serialize, Deserialize)]
pub struct Version {
pub node: NodeID,
pub counter: u64
}
#[derive(Hash, PartialEq, Eq, Debug, Clone, Value, Serialize, Deserialize)]
pub enum Action {
Name(String),
AddAddress(Address),
Active(bool),
Write {
key: String,
value: String
}
}
impl Conspirator {
const UPDATE_NAME_SQL: &'static str =
"UPDATE conspirator \
SET name = ?2 \
WHERE id = ?1";
const UPDATE_ACTIVE_SQL: &'static str =
"UPDATE conspirator \
SET active = ?2 \
WHERE id = ?1";
const UPDATE_CLOCK_SQL: &'static str =
"UPDATE conspirator \
SET clock = max(?2, clock) \
WHERE id = ?1";
const GET_ID_BY_NAME_SQL: &'static str =
"SELECT id \
FROM conspirator \
WHERE name = ?1";
const GET_BY_ID_WITH_ADDR_COUNT_SQL: &'static str =
"SELECT conspirator.*, ( \
SELECT COUNT(*) \
FROM peeraddress \
WHERE conspirator.id == peeraddress.conspirator \
) \
FROM conspirator \
WHERE id = ?1";
const GET_WITH_ADDR_COUNT_SQL: &'static str =
"SELECT conspirator.*, ( \
SELECT COUNT(*) \
FROM peeraddress \
WHERE conspirator.id == peeraddress.conspirator \
) \
FROM conspirator";
}
impl PeerAddress {
const ADD_OR_IGNORE_SQL: &'static str =
"INSERT OR IGNORE INTO peeraddress \
VALUES ( \
IFNULL( \
(SELECT max(id) + 1 FROM peeraddress WHERE conspirator = ?1), \
0 \
), \
?1, \
?2 \
)";
const SELECT_REACH_OUT_ADDR_SQL: &'static str =
"SELECT address \
FROM peeraddress \
WHERE conspirator = ?1 \
AND id = mod(?2, 1 + IFNULL( \
(SELECT max(id) FROM peeraddress WHERE conspirator = ?1), \
0 \
))";
const NOT_EXISTS_SQL: &'static str =
"SELECT NOT EXISTS ( \
SELECT 1 \
FROM peeraddress \
WHERE conspirator = ?1 \
AND address = ?2 \
)";
}
impl Op {
pub const INSERT_OR_IGNORE: &'static str =
"INSERT OR IGNORE INTO op \
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
const SORT_REV_SQL: &'static str =
"SELECT logentry.sort, op.* FROM op \
JOIN logentry ON op.origin_node == logentry.op_node \
AND op.origin_counter == logentry.op_counter \
WHERE sort IS NOT NULL \
ORDER BY sort DESC";
pub const SORT_SQL: &'static str =
"SELECT * FROM op \
JOIN logentry ON op.origin_node == logentry.op_node \
AND op.origin_counter == logentry.op_counter \
WHERE sort IS NOT NULL \
ORDER BY sort";
const SORT_SINCE_SQL: &'static str =
"SELECT * FROM op \
JOIN logentry ON op.origin_node == logentry.op_node \
AND op.origin_counter == logentry.op_counter \
WHERE sort >= ?1 \
ORDER BY sort";
const RELEASE_HELD_SQL: &'static str =
"SELECT * FROM op \
WHERE NOT EXISTS ( \
SELECT 1 FROM logentry \
WHERE op.origin_node == logentry.op_node \
AND op.origin_counter == logentry.op_counter \
)";
const SELECT_BY_SINCE_SQL: &'static str =
"SELECT * FROM op \
WHERE origin_node = ?1 \
AND (origin_counter > ?2 OR ?2 IS NULL)";
}
impl fmt::Display for Op {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { origin, previous, target, action: a } = self;
match previous {
Some(p) => write!(f, "Op{{ {origin} &({p}) → {target}: {a:?} }}"),
None => write!(f, "Op{{ {origin} &(none) → {target}: {a:?} }}")
}
}
}
impl fmt::Debug for Op {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl Version {
pub fn new(node: NodeID, counter: u64) -> Self {
Self { node, counter }
}
}
impl fmt::Display for Version {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", self.node, self.counter)
}
}
impl fmt::Debug for Version {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl Config {
pub const ID: &'static str = "id";
pub const KEY: &'static str = "key";
pub const PRIVATE_KEY: &'static str = "private_key";
pub const EXISTS_SQL: &'static str =
"SELECT EXISTS ( \
SELECT 1 \
FROM config \
WHERE key = ?1 \
)";
}
impl LogEntry {
const PUSH_BACK_SQL_1: &'static str =
"UPDATE logentry \
SET sort = -(sort + 1) \
WHERE sort >= ?1";
const PUSH_BACK_SQL_2: &'static str =
"UPDATE logentry \
SET sort = abs(sort) \
WHERE sort < 0";
}
impl Store {
const SELECT_ALL_COUNTERS_SQL: &'static str =
"SELECT origin_node, max(origin_counter) FROM op \
JOIN logentry ON op.origin_node == logentry.op_node \
AND op.origin_counter == logentry.op_counter \
WHERE sort IS NOT NULL \
GROUP BY origin_node";
const SELECT_COUNTER_SQL: &'static str =
"SELECT max(origin_counter) FROM op \
JOIN logentry ON op.origin_node == logentry.op_node \
AND op.origin_counter == logentry.op_counter \
WHERE sort IS NOT NULL \
AND origin_node = ?1";
const SELECT_PREVIOUS_SQL: &'static str =
"SELECT origin_node, origin_counter FROM op \
JOIN logentry ON op.origin_node == logentry.op_node \
AND op.origin_counter == logentry.op_counter \
WHERE sort IS NOT NULL \
ORDER BY sort DESC \
LIMIT 1";
pub fn get_conspirator(&self, id: NodeID)
-> SqlResult<Option<(Conspirator, usize)>>
{
self.query_one_with(Conspirator::GET_BY_ID_WITH_ADDR_COUNT_SQL, &id)
.optional()
}
pub fn get_conspirators(&self) -> SqlResult<Vec<(Conspirator, usize)>> {
self.query_all(Conspirator::GET_WITH_ADDR_COUNT_SQL)
}
pub fn add_conspirator(&self, id: NodeID, name: &str) -> SqlResult<bool> {
self.connection.execute(
Conspirator::INSERT,
&(id, name, true, 0, Status::Unreachable, 0)
).map(|rows_changed| rows_changed == 1)
}
pub fn add_address(&self, id: NodeID, addr: Address) -> SqlResult<bool> {
self.connection.execute(PeerAddress::ADD_OR_IGNORE_SQL, &(id, addr))
.map(|rows_changed| rows_changed == 1)
}
pub fn new_address(&self, id: NodeID, addr: Address) -> SqlResult<bool> {
self.connection.query_one_with(PeerAddress::NOT_EXISTS_SQL, &(id, addr))
}
pub fn reach_out_addr(&self, id: NodeID, idx: usize)
-> SqlResult<Option<Address>>
{
self.query_one_with(PeerAddress::SELECT_REACH_OUT_ADDR_SQL, &(id, idx))
.optional()
}
pub fn find_conspirator_id(&self, name: &str) -> SqlResult<Option<NodeID>> {
self.query_one_with(Conspirator::GET_ID_BY_NAME_SQL, &name)
.optional()
}
pub fn persist_clocks<I>(&self, clocks: I) -> SqlResult<()>
where I: IntoIterator<Item = (NodeID, Clock)>
{
let mut stmt = self.prepare(Conspirator::UPDATE_CLOCK_SQL)?;
for (id, clock) in clocks {
(id, clock).bind_to(&mut stmt)?;
stmt.raw_execute()?;
}
stmt.finalize()?;
Ok(())
}
pub fn vector(&self) -> SqlResult<Vec<Version>> {
self.connection.query_all(Self::SELECT_ALL_COUNTERS_SQL)
}
pub fn sync(&self, with: &[Version]) -> SqlResult<Vec<Op>> {
let mut ops = Vec::new();
for Version {node, counter} in self.vector()? {
let theirs = with.iter()
.find_map(|&v| (v.node == node).then_some(v.counter));
if theirs < Some(counter) {
let needed = self.connection.query_all_with(
Op::SELECT_BY_SINCE_SQL,
&(node, theirs)
)?;
ops.extend(needed);
}
}
Ok(ops)
}
pub fn node_counter(&self, id: NodeID) -> SqlResult<Option<u64>> {
self.connection.query_one_with(Self::SELECT_COUNTER_SQL, &id)
}
pub fn previous(&self) -> SqlResult<Option<Version>> {
self.connection.query_one(Self::SELECT_PREVIOUS_SQL).optional()
}
pub fn create_op(&self, from: NodeID, target: NodeID, action: Action)
-> SqlResult<Op>
{
let previous = self.previous()?;
let counter = self.node_counter(from)?
.map(|c| c + 1)
.unwrap_or(0);
let origin = Version { node: from, counter };
Ok(Op {
origin,
previous,
target,
action
})
}
pub fn absorb_op(&self, op: Op) -> SqlResult<bool> {
if !self.store_op(&op)? {
match self.get(op.origin)? {
Some(existing) if op != existing => error!(
?op,
?existing,
"received op does not match existing"
),
Some(_) => {},
None => error!(
?op,
"failed to store op but also failed to retrieve existing",
)
}
return Ok(false);
}
let mut queue = vec![op];
let mut dirty = None;
while let Some(op) = queue.pop() {
if let Some(idx) = self.try_sort_op(&op)? {
self.insert_log_entry(idx, &op)?;
dirty = Some(std::cmp::min(dirty.unwrap_or(idx), idx));
queue = self.query_all(Op::RELEASE_HELD_SQL)?;
}
}
if let Some(dirty) = dirty {
self.apply_ops_since(dirty)?;
Ok(true)
}
else {Ok(false)}
}
fn store_op(&self, op: &Op) -> SqlResult<bool> {
self.execute(Op::INSERT_OR_IGNORE, op).map(|i| i == 1)
}
fn try_sort_op(&self, new: &Op) -> SqlResult<Option<i64>> {
if let Some(Version {node, counter}) = new.previous {
if self.node_counter(node)?.is_none_or(|c| c < counter) {
return Ok(None)
}
}
let current_counter = self.node_counter(new.origin.node)?;
let next_counter = current_counter.map(|c| c + 1).unwrap_or(0);
(new.origin.counter == next_counter)
.then(|| self.sort_op(new))
.transpose()
}
fn sort_op(&self, new: &Op) -> SqlResult<i64> {
let ops: Vec<(i64, Op)> = self.connection
.query_all(Op::SORT_REV_SQL)?;
let mut upper_bound = ops.len() as i64;
for (idx, op) in ops {
if new.previous == Some(op.origin) {
return Ok(idx + 1);
}
if op.previous == new.previous {
if new.origin.node > op.origin.node {
return Ok(upper_bound);
}
else {
upper_bound = idx;
}
}
}
Ok(0)
}
fn insert_log_entry(&self, idx: i64, op: &Op) -> SqlResult<()> {
let tx = self.unchecked_transaction()?;
let mut stmt = tx.prepare(LogEntry::PUSH_BACK_SQL_1)?;
stmt.execute([idx])?;
stmt.finalize()?;
let mut stmt = tx.prepare(LogEntry::PUSH_BACK_SQL_2)?;
stmt.execute([])?;
stmt.finalize()?;
let mut stmt = tx.prepare(LogEntry::UPSERT)?;
(Some(idx), op.origin).bind_to(&mut stmt)?;
let inserted = stmt.raw_execute()?;
if inserted != 1 {
println!("raw_execute for entry insert returned {inserted}!");
}
stmt.finalize()?;
tx.commit()
}
fn apply_ops_since(&self, index: i64) -> SqlResult<()> {
for op in self.query_all_with(Op::SORT_SINCE_SQL, &index)? {
self.apply_op(&op)?;
}
Ok(())
}
fn apply_op(&self, op: &Op) -> SqlResult<()> {
if self.get::<Conspirator>(op.target)?.is_none() {
self.add_conspirator(op.target, "")?;
}
match &op.action {
Action::Name(name) => {
self.execute(Conspirator::UPDATE_NAME_SQL, &(op.target, name))?;
}
Action::Active(a) => {
self.execute(Conspirator::UPDATE_ACTIVE_SQL, &(op.target, a))?;
}
Action::AddAddress(address) => {
self.add_address(op.target, *address)?;
}
Action::Write { key, value } if value.is_empty() => {
self.execute(Pair::DELETE, &key)?;
}
Action::Write { key, value } => {
self.execute(Pair::UPSERT, &(key, value))?;
}
}
Ok(())
}
pub fn hash(&self) -> SqlResult<u64> {
let ops: Vec<Op> = self.connection.query_all(Op::SORT_SQL)?;
let mut hasher = FnvHasher::default();
ops.hash(&mut hasher);
Ok(hasher.finish())
}
pub fn read_config(&self, key: &str) -> SqlResult<Option<String>> {
self.query_one_with("SELECT val FROM config WHERE key = ?1", &key)
.optional()
}
pub fn write_config(&self, key: &str, val: &str) -> SqlResult<bool> {
self.execute(Config::UPSERT, &(key, val)).map(|i| i == 1)
}
pub fn read_value(&self, key: &str) -> SqlResult<Option<String>> {
self.query_one_with("SELECT value FROM pair WHERE key = ?1", &key)
.optional()
}
}
impl Inline for Pair {
fn inline(&self, buffer: &mut Buffer) {
buffer.push_bold(&self.key)
.push("=")
.push(&self.value);
}
}
#[test]
fn consistency() -> SqlResult<()> {
let store = Store::create_in_memory()?;
let id = NodeID::random();
let addr = Address::Dummy(123);
store.add_address(id, addr).unwrap_err();
assert!(store.add_conspirator(id, "Test123")?);
assert!(store.add_conspirator(id, "Test123").is_err());
assert!(store.add_address(id, addr)?);
assert!(!store.add_address(id, addr)?);
store.hash()?;
Ok(())
}