use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use bytes::Buf;
use compact_str::{CompactString, ToCompactString};
use enquote::unquote;
use fallible_iterator::FallibleIterator;
use indexmap::IndexMap;
use itertools::Itertools;
use rusqlite::{params_from_iter, Connection};
use serde::{Deserialize, Serialize};
use sqlite3_parser::{
ast::{
As, Cmd, Expr, JoinConstraint, Name, OneSelect, Operator, QualifiedName, ResultColumn,
Select, SelectTable, Stmt,
},
lexer::sql::Parser,
};
use tokio::{
sync::{broadcast, mpsc},
task::block_in_place,
};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
use crate::{
api::QueryEvent,
change::{Change, SqliteValue, SqliteValueRef},
schema::{NormalizedSchema, NormalizedTable},
};
#[derive(Debug, thiserror::Error)]
pub enum NormalizeStatementError {
#[error(transparent)]
Parse(#[from] sqlite3_parser::lexer::sql::Error),
#[error("unexpected statement: {0}")]
UnexpectedStatement(Cmd),
#[error("only 1 statement is supported")]
Multiple,
#[error("at least 1 statement is required")]
NoStatement,
}
pub fn normalize_sql(sql: &str) -> Result<String, NormalizeStatementError> {
let mut parser = Parser::new(sql.as_bytes());
let stmt = match parser.next()? {
Some(Cmd::Stmt(stmt)) => stmt,
Some(cmd) => {
return Err(NormalizeStatementError::UnexpectedStatement(cmd));
}
None => {
return Err(NormalizeStatementError::NoStatement);
}
};
if parser.next()?.is_some() {
return Err(NormalizeStatementError::Multiple);
}
Ok(Cmd::Stmt(stmt).to_string())
}
#[derive(Debug, thiserror::Error)]
pub enum UnpackError {
#[error("abort")]
Abort,
#[error("misuse")]
Misuse,
}
#[derive(PartialEq, Debug)]
pub enum ColumnType {
Integer = 1,
Float = 2,
Text = 3,
Blob = 4,
Null = 5,
}
impl ColumnType {
fn from_u8(u: u8) -> Option<Self> {
Some(match u {
1 => Self::Integer,
2 => Self::Float,
3 => Self::Text,
4 => Self::Blob,
5 => Self::Null,
_ => return None,
})
}
}
pub fn unpack_columns(mut buf: &[u8]) -> Result<Vec<SqliteValueRef>, UnpackError> {
let mut ret = vec![];
let num_columns = buf.get_u8();
for _i in 0..num_columns {
if !buf.has_remaining() {
return Err(UnpackError::Abort);
}
let column_type_and_maybe_intlen = buf.get_u8();
let column_type = ColumnType::from_u8(column_type_and_maybe_intlen & 0x07);
let intlen = (column_type_and_maybe_intlen >> 3 & 0xFF) as usize;
match column_type {
Some(ColumnType::Blob) => {
if buf.remaining() < intlen {
return Err(UnpackError::Abort);
}
let len = buf.get_int(intlen) as usize;
if buf.remaining() < len {
return Err(UnpackError::Abort);
}
ret.push(SqliteValueRef::Blob(&buf[0..len]));
buf.advance(len);
}
Some(ColumnType::Float) => {
if buf.remaining() < 8 {
return Err(UnpackError::Abort);
}
ret.push(SqliteValueRef::Real(buf.get_f64()));
}
Some(ColumnType::Integer) => {
if buf.remaining() < intlen {
return Err(UnpackError::Abort);
}
ret.push(SqliteValueRef::Integer(buf.get_int(intlen)));
}
Some(ColumnType::Null) => {
ret.push(SqliteValueRef::Null);
}
Some(ColumnType::Text) => {
if buf.remaining() < intlen {
return Err(UnpackError::Abort);
}
let len = buf.get_int(intlen) as usize;
if buf.remaining() < len {
return Err(UnpackError::Abort);
}
ret.push(SqliteValueRef::Text(unsafe {
std::str::from_utf8_unchecked(&buf[0..len])
}));
buf.advance(len);
}
None => return Err(UnpackError::Misuse),
}
}
Ok(ret)
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ChangeType {
Upsert,
Delete,
}
pub enum MatcherCmd {
ProcessChange(HashMap<CompactString, Vec<Vec<SqliteValue>>>),
Unsubscribe,
}
#[derive(Debug, Clone)]
pub struct Matcher(pub Arc<InnerMatcher>);
#[derive(Debug, Clone)]
pub struct MatcherStmt {
new_query: String,
temp_query: String,
}
#[derive(Debug)]
pub struct InnerMatcher {
pub id: Uuid,
pub query: Stmt,
pub statements: HashMap<String, MatcherStmt>,
pub pks: IndexMap<String, Vec<String>>,
pub parsed: ParsedSelect,
pub query_table: String,
pub qualified_table_name: String,
pub change_tx: broadcast::Sender<QueryEvent>,
pub cmd_tx: mpsc::Sender<MatcherCmd>,
pub col_names: Vec<CompactString>,
pub cancel: CancellationToken,
}
impl Matcher {
pub fn new(
id: Uuid,
schema: &NormalizedSchema,
mut conn: Connection,
init_tx: mpsc::Sender<QueryEvent>,
change_tx: broadcast::Sender<QueryEvent>,
sql: &str,
cancel: CancellationToken,
) -> Result<Self, MatcherError> {
let col_names: Vec<CompactString> = {
conn.prepare(sql)?
.column_names()
.into_iter()
.map(|s| s.to_compact_string())
.collect()
};
let mut parser = Parser::new(sql.as_bytes());
let (stmt, parsed) = match parser.next()?.ok_or(MatcherError::StatementRequired)? {
Cmd::Stmt(stmt) => {
let parsed = match stmt {
Stmt::Select(ref select) => extract_select_columns(select, schema)?,
_ => return Err(MatcherError::UnsupportedStatement),
};
(stmt, parsed)
}
_ => return Err(MatcherError::StatementRequired),
};
if parsed.table_columns.is_empty() {
return Err(MatcherError::TableRequired);
}
let mut statements = HashMap::new();
let mut pks = IndexMap::default();
let mut stmt = stmt.clone();
match &mut stmt {
Stmt::Select(select) => match &mut select.body.select {
OneSelect::Select { columns, .. } => {
let mut new_cols = parsed
.table_columns
.iter()
.filter_map(|(tbl_name, _cols)| {
schema.tables.get(tbl_name).map(|table| {
let tbl_name = parsed
.aliases
.iter()
.find_map(|(alias, actual)| {
(actual == tbl_name).then_some(alias)
})
.unwrap_or(tbl_name);
table
.pk
.iter()
.map(|pk| {
let alias = format!("__corro_pk_{tbl_name}_{pk}");
let entry: &mut Vec<String> =
pks.entry(table.name.clone()).or_default();
entry.push(alias.clone());
ResultColumn::Expr(
Expr::Qualified(
Name(tbl_name.clone()),
Name(pk.clone()),
),
Some(As::As(Name(alias))),
)
})
.collect::<Vec<_>>()
})
})
.flatten()
.collect::<Vec<_>>();
new_cols.append(&mut parsed.columns.clone());
*columns = new_cols;
}
_ => unreachable!(),
},
_ => unreachable!(),
}
let query_table = format!("query_{}", id.as_simple());
for (tbl_name, _cols) in parsed.table_columns.iter() {
let expr = table_to_expr(
&parsed.aliases,
schema
.tables
.get(tbl_name)
.expect("this should not happen, missing table in schema"),
&tbl_name,
id,
)?;
let mut stmt = stmt.clone();
match &mut stmt {
Stmt::Select(select) => match &mut select.body.select {
OneSelect::Select { where_clause, .. } => {
*where_clause = if let Some(prev) = where_clause.take() {
Some(Expr::Binary(Box::new(expr), Operator::And, Box::new(prev)))
} else {
Some(expr)
};
}
_ => {}
},
_ => {}
}
let mut new_query = Cmd::Stmt(stmt).to_string();
new_query.pop();
let mut tmp_cols = pks.values().cloned().flatten().collect::<Vec<String>>();
for i in 0..(parsed.columns.len()) {
tmp_cols.push(format!("col_{i}"));
}
let pk_cols = pks
.get(tbl_name)
.cloned()
.ok_or(MatcherError::MissingPrimaryKeys)?
.iter()
.cloned()
.collect::<Vec<_>>()
.join(",");
statements.insert(
tbl_name.clone(),
MatcherStmt {
new_query,
temp_query: format!(
"SELECT {} FROM {} WHERE ({}) IN watch_{}_{}",
tmp_cols.join(","),
query_table,
pk_cols,
id.as_simple(),
tbl_name,
),
},
);
}
let (cmd_tx, mut cmd_rx) = mpsc::channel(512);
let matcher = Self(Arc::new(InnerMatcher {
id,
query: stmt,
statements: statements,
pks,
parsed,
qualified_table_name: format!("watches.{query_table}"),
query_table,
change_tx,
cmd_tx,
col_names: col_names.clone(),
cancel: cancel.clone(),
}));
let mut tmp_cols = matcher
.0
.pks
.values()
.flatten()
.cloned()
.collect::<Vec<String>>();
for i in 0..(matcher.0.parsed.columns.len()) {
tmp_cols.push(format!("col_{i}"));
}
let create_temp_table = format!(
"CREATE TABLE {} (__corro_rowid INTEGER PRIMARY KEY AUTOINCREMENT, {});
CREATE UNIQUE INDEX watches.index_{}_pk ON {} ({});",
matcher.0.qualified_table_name,
tmp_cols.join(","),
matcher.0.id.as_simple(),
matcher.0.query_table,
matcher
.0
.pks
.values()
.flatten()
.cloned()
.collect::<Vec<_>>()
.join(","),
);
conn.execute_batch(&create_temp_table)?;
tokio::spawn({
let matcher = matcher.clone();
async move {
let _drop_guard = cancel.clone().drop_guard();
if let Err(e) = init_tx.send(QueryEvent::Columns(col_names)).await {
error!("could not send back columns, probably means no receivers! {e}");
return;
}
let mut query_cols = vec![];
for i in 0..(matcher.0.parsed.columns.len()) {
query_cols.push(format!("col_{i}"));
}
let res = block_in_place(|| {
let tx = conn.transaction()?;
let mut stmt_str = Cmd::Stmt(matcher.0.query.clone()).to_string();
stmt_str.pop();
let insert_into = format!(
"INSERT INTO {} ({}) {} RETURNING __corro_rowid,{}",
matcher.0.qualified_table_name,
tmp_cols.join(","),
stmt_str,
query_cols.join(","),
);
{
let mut prepped = tx.prepare(&insert_into)?;
let mut rows = prepped.query(())?;
loop {
match rows.next() {
Ok(Some(row)) => {
let rowid: i64 = row.get(0)?;
let cells = (1..=query_cols.len())
.map(|i| row.get::<_, SqliteValue>(i))
.collect::<rusqlite::Result<Vec<_>>>()?;
if let Err(e) = init_tx.blocking_send(QueryEvent::Row {
change_type: ChangeType::Upsert,
rowid,
cells,
}) {
error!("could not send back row: {e}");
return Err(MatcherError::ChangeReceiverClosed);
}
}
Ok(None) => {
break;
}
Err(e) => {
return Err(e.into());
}
}
}
}
tx.commit()?;
Ok::<_, MatcherError>(())
});
if let Err(e) = res {
_ = init_tx.send(QueryEvent::Error(e.to_compact_string())).await;
return;
}
if let Err(e) = init_tx.send(QueryEvent::EndOfQuery).await {
error!("could not send back end-of-query message: {e}");
return;
}
drop(init_tx);
loop {
let req = tokio::select! {
Some(req) = cmd_rx.recv() => req,
_ = cancel.cancelled() => return,
else => return,
};
match req {
MatcherCmd::ProcessChange(candidates) => {
if let Err(e) =
block_in_place(|| matcher.handle_change(&mut conn, candidates))
{
if matches!(e, MatcherError::ChangeReceiverClosed) {
continue;
}
error!("could not handle change: {e}");
}
}
MatcherCmd::Unsubscribe => {
if matcher.0.change_tx.receiver_count() == 0 {
info!(
"matcher {} has no more subscribers, we're done!",
matcher.0.id
);
break;
}
}
}
}
debug!(id = %id, "matcher loop is done");
if let Err(e) =
conn.execute_batch(&format!("DROP TABLE {}", matcher.0.qualified_table_name))
{
warn!(
"could not clean up temporary table {} => {e}",
matcher.0.qualified_table_name
);
}
}
});
Ok(matcher)
}
pub fn cmd_tx(&self) -> &mpsc::Sender<MatcherCmd> {
&self.0.cmd_tx
}
pub fn process_change<'a>(&self, changes: &[Change]) -> Result<(), MatcherError> {
let mut candidates: HashMap<CompactString, Vec<Vec<SqliteValue>>> = HashMap::new();
let grouped = changes
.iter()
.filter(|change| {
self.0
.parsed
.table_columns
.contains_key(change.table.as_str())
})
.group_by(|change| (change.table.as_str(), change.pk.as_slice()));
for ((table, pk), _) in grouped.into_iter() {
let pks = unpack_columns(pk)?
.into_iter()
.map(|v| v.to_owned())
.collect();
if let Some(v) = candidates.get_mut(table) {
v.push(pks);
} else {
candidates.insert(table.to_compact_string(), vec![pks]);
}
}
self.0
.cmd_tx
.try_send(MatcherCmd::ProcessChange(candidates))
.map_err(|_| MatcherError::ChangeQueueClosedOrFull)?;
Ok(())
}
pub fn table_name(&self) -> &str {
&self.0.qualified_table_name
}
pub fn handle_change(
&self,
conn: &mut Connection,
candidates: HashMap<CompactString, Vec<Vec<SqliteValue>>>,
) -> Result<(), MatcherError> {
let tx = conn.transaction()?;
let tables = candidates.keys().cloned().collect::<Vec<_>>();
for (table, pks) in candidates {
tx.prepare_cached(&format!(
"CREATE TEMP TABLE watch_{}_{} ({})",
self.0.id.as_simple(),
table,
self.0
.pks
.get(table.as_str())
.ok_or_else(|| MatcherError::MissingPrimaryKeys)?
.iter()
.map(|s| s.clone())
.collect::<Vec<_>>()
.join(",")
))?
.execute(())?;
for pks in pks {
tx.prepare_cached(&format!(
"INSERT INTO watch_{}_{} VALUES ({})",
self.0.id.as_simple(),
table,
(0..pks.len()).map(|_i| "?").collect::<Vec<_>>().join(",")
))?
.execute(params_from_iter(pks))?;
}
}
for table in tables.iter() {
let stmt = match self.0.statements.get(table.as_str()) {
Some(stmt) => stmt,
None => {
continue;
}
};
let mut actual_cols = vec![];
let mut tmp_cols = self
.0
.pks
.values()
.cloned()
.flatten()
.collect::<Vec<String>>();
for i in 0..(self.0.parsed.columns.len()) {
let col_name = format!("col_{i}");
tmp_cols.push(col_name.clone());
actual_cols.push(col_name);
}
let sql = format!(
"INSERT INTO {} ({})
SELECT * FROM (
{}
EXCEPT
{}
) WHERE 1
ON CONFLICT({})
DO UPDATE SET
{}
RETURNING __corro_rowid,{}",
self.0.qualified_table_name,
tmp_cols.join(","),
stmt.new_query,
stmt.temp_query,
self.0
.pks
.values()
.cloned()
.flatten()
.collect::<Vec<String>>()
.join(","),
(0..(self.0.parsed.columns.len()))
.map(|i| format!("col_{i} = excluded.col_{i}"))
.collect::<Vec<_>>()
.join(","),
actual_cols.join(",")
);
let insert_prepped = tx.prepare_cached(&sql)?;
let sql = format!(
"
DELETE FROM {} WHERE ({}) in (SELECT {} FROM (
{}
EXCEPT
{}
)) RETURNING __corro_rowid,{}",
self.0.qualified_table_name,
self.0
.pks
.values()
.cloned()
.flatten()
.collect::<Vec<String>>()
.join(","),
self.0
.pks
.values()
.cloned()
.flatten()
.collect::<Vec<String>>()
.join(","),
stmt.temp_query,
stmt.new_query,
actual_cols.join(",")
);
let delete_prepped = tx.prepare_cached(&sql)?;
for (change_type, mut prepped) in [
(ChangeType::Upsert, insert_prepped),
(ChangeType::Delete, delete_prepped),
] {
let col_count = prepped.column_count();
let mut rows = prepped.raw_query();
while let Ok(Some(row)) = rows.next() {
let rowid: i64 = row.get(0)?;
match (1..col_count)
.map(|i| row.get::<_, SqliteValue>(i))
.collect::<rusqlite::Result<Vec<_>>>()
{
Ok(cells) => {
if let Err(e) = self.0.change_tx.send(QueryEvent::Row {
rowid,
change_type,
cells,
}) {
error!("could not send back row to matcher sub sender: {e}");
return Err(MatcherError::ChangeReceiverClosed);
}
}
Err(e) => {
error!("could not deserialize row's cells: {e}");
return Ok(());
}
}
}
}
}
for table in tables {
tx.prepare_cached(&format!(
"DROP TABLE watch_{}_{}",
self.0.id.as_simple(),
table
))?
.execute(())?;
}
tx.commit()?;
Ok(())
}
pub fn subscribe(&self) -> broadcast::Receiver<QueryEvent> {
self.0.change_tx.subscribe()
}
pub fn receiver_count(&self) -> usize {
self.0.change_tx.receiver_count()
}
pub fn cancel(&self) -> CancellationToken {
self.0.cancel.clone()
}
}
#[derive(Debug, Default)]
pub struct ParsedSelect {
table_columns: IndexMap<String, HashSet<String>>,
aliases: HashMap<String, String>,
pub columns: Vec<ResultColumn>,
children: Vec<Box<ParsedSelect>>,
}
fn extract_select_columns(
select: &Select,
schema: &NormalizedSchema,
) -> Result<ParsedSelect, MatcherError> {
let mut parsed = ParsedSelect::default();
match select.body.select {
OneSelect::Select {
ref from,
ref columns,
ref where_clause,
..
} => {
let from_table = match from {
Some(from) => {
let from_table = match &from.select {
Some(table) => match table.as_ref() {
SelectTable::Table(name, alias, _) => {
if schema.tables.contains_key(name.name.0.as_str()) {
if let Some(As::As(alias) | As::Elided(alias)) = alias {
parsed.aliases.insert(alias.0.clone(), name.name.0.clone());
} else if let Some(ref alias) = name.alias {
parsed.aliases.insert(alias.0.clone(), name.name.0.clone());
}
parsed.table_columns.entry(name.name.0.clone()).or_default();
Some(&name.name)
} else {
return Err(MatcherError::TableNotFound(name.name.0.clone()));
}
}
t => {
warn!("ignoring {t:?}");
None
}
},
_ => {
unreachable!()
}
};
if let Some(ref joins) = from.joins {
for join in joins.iter() {
let tbl_name = match &join.table {
SelectTable::Table(name, alias, _) => {
if let Some(As::As(alias) | As::Elided(alias)) = alias {
parsed.aliases.insert(alias.0.clone(), name.name.0.clone());
} else if let Some(ref alias) = name.alias {
parsed.aliases.insert(alias.0.clone(), name.name.0.clone());
}
parsed.table_columns.entry(name.name.0.clone()).or_default();
&name.name
}
t => {
warn!("ignoring JOIN's non-SelectTable::Table: {t:?}");
continue;
}
};
if let Some(constraint) = &join.constraint {
match constraint {
JoinConstraint::On(expr) => {
extract_expr_columns(expr, schema, &mut parsed)?;
}
JoinConstraint::Using(names) => {
let entry = parsed
.table_columns
.entry(tbl_name.0.clone())
.or_default();
for name in names.iter() {
entry.insert(name.0.clone());
}
}
}
}
}
}
if let Some(expr) = where_clause {
extract_expr_columns(expr, schema, &mut parsed)?;
}
from_table
}
_ => None,
};
extract_columns(columns.as_slice(), from_table, schema, &mut parsed)?;
}
_ => {}
}
Ok(parsed)
}
fn extract_expr_columns(
expr: &Expr,
schema: &NormalizedSchema,
parsed: &mut ParsedSelect,
) -> Result<(), MatcherError> {
match expr {
Expr::Qualified(tblname, colname) => {
let resolved_name = parsed.aliases.get(&tblname.0).unwrap_or(&tblname.0);
parsed
.table_columns
.entry(resolved_name.clone())
.or_default()
.insert(colname.0.clone());
}
Expr::DoublyQualified(schema_name, tblname, colname) if schema_name.0 == "main" => {
let resolved_name = parsed.aliases.get(&tblname.0).unwrap_or(&tblname.0);
parsed
.table_columns
.entry(resolved_name.clone())
.or_default()
.insert(colname.0.clone());
}
Expr::Name(colname) => {
let check_col_name = unquote(&colname.0).ok().unwrap_or(colname.0.clone());
let mut found = None;
for tbl in parsed.table_columns.keys() {
if let Some(tbl) = schema.tables.get(tbl) {
if tbl.columns.contains_key(&check_col_name) {
if found.is_some() {
return Err(MatcherError::QualificationRequired {
col_name: check_col_name.clone(),
});
}
found = Some(tbl.name.as_str());
}
}
}
if let Some(found) = found {
parsed
.table_columns
.entry(found.to_owned())
.or_default()
.insert(check_col_name.clone());
} else {
return Err(MatcherError::TableForColumnNotFound {
col_name: check_col_name.clone(),
});
}
}
Expr::Id(colname) => {
let check_col_name = unquote(&colname.0).ok().unwrap_or(colname.0.clone());
let mut found = None;
for tbl in parsed.table_columns.keys() {
if let Some(tbl) = schema.tables.get(tbl) {
if tbl.columns.contains_key(&check_col_name) {
if found.is_some() {
return Err(MatcherError::QualificationRequired {
col_name: check_col_name.clone(),
});
}
found = Some(tbl.name.as_str());
}
}
}
if let Some(found) = found {
parsed
.table_columns
.entry(found.to_owned())
.or_default()
.insert(colname.0.clone());
} else {
if colname.0.starts_with('"') {
return Ok(());
}
return Err(MatcherError::TableForColumnNotFound {
col_name: colname.0.clone(),
});
}
}
Expr::Between { lhs, .. } => extract_expr_columns(lhs, schema, parsed)?,
Expr::Binary(lhs, _, rhs) => {
extract_expr_columns(lhs, schema, parsed)?;
extract_expr_columns(rhs, schema, parsed)?;
}
Expr::Case {
base,
when_then_pairs,
else_expr,
} => {
if let Some(expr) = base {
extract_expr_columns(expr, schema, parsed)?;
}
for (when_expr, _then_expr) in when_then_pairs.iter() {
extract_expr_columns(when_expr, schema, parsed)?;
}
if let Some(expr) = else_expr {
extract_expr_columns(expr, schema, parsed)?;
}
}
Expr::Cast { expr, .. } => extract_expr_columns(expr, schema, parsed)?,
Expr::Collate(expr, _) => extract_expr_columns(expr, schema, parsed)?,
Expr::Exists(select) => {
parsed
.children
.push(Box::new(extract_select_columns(select, schema)?));
}
Expr::FunctionCall { args, .. } => {
if let Some(args) = args {
for expr in args.iter() {
extract_expr_columns(expr, schema, parsed)?;
}
}
}
Expr::InList { lhs, rhs, .. } => {
extract_expr_columns(lhs, schema, parsed)?;
if let Some(rhs) = rhs {
for expr in rhs.iter() {
extract_expr_columns(expr, schema, parsed)?;
}
}
}
Expr::InSelect { lhs, rhs, .. } => {
extract_expr_columns(lhs, schema, parsed)?;
parsed
.children
.push(Box::new(extract_select_columns(rhs, schema)?));
}
expr @ Expr::InTable { .. } => {
return Err(MatcherError::UnsupportedExpr { expr: expr.clone() })
}
Expr::IsNull(expr) => {
extract_expr_columns(expr, schema, parsed)?;
}
Expr::Like { lhs, rhs, .. } => {
extract_expr_columns(lhs, schema, parsed)?;
extract_expr_columns(rhs, schema, parsed)?;
}
Expr::NotNull(expr) => {
extract_expr_columns(expr, schema, parsed)?;
}
Expr::Parenthesized(parens) => {
for expr in parens.iter() {
extract_expr_columns(expr, schema, parsed)?;
}
}
Expr::Subquery(select) => {
parsed
.children
.push(Box::new(extract_select_columns(select, schema)?));
}
Expr::Unary(_, expr) => {
extract_expr_columns(expr, schema, parsed)?;
}
_ => {}
}
Ok(())
}
fn extract_columns(
columns: &[ResultColumn],
from: Option<&Name>,
schema: &NormalizedSchema,
parsed: &mut ParsedSelect,
) -> Result<(), MatcherError> {
let mut i = 0;
for col in columns.iter() {
match col {
ResultColumn::Expr(expr, _) => {
extract_expr_columns(expr, schema, parsed)?;
parsed.columns.push(ResultColumn::Expr(
expr.clone(),
Some(As::As(Name(format!("col_{i}")))),
));
i += 1;
}
ResultColumn::Star => {
if let Some(tbl_name) = from {
if let Some(table) = schema.tables.get(&tbl_name.0) {
let entry = parsed.table_columns.entry(table.name.clone()).or_default();
for col in table.columns.keys() {
entry.insert(col.clone());
parsed.columns.push(ResultColumn::Expr(
Expr::Name(Name(col.clone())),
Some(As::As(Name(format!("col_{i}")))),
));
i += 1;
}
} else {
return Err(MatcherError::TableStarNotFound {
tbl_name: tbl_name.0.clone(),
});
}
} else {
unreachable!()
}
}
ResultColumn::TableStar(tbl_name) => {
let name = parsed
.aliases
.get(tbl_name.0.as_str())
.unwrap_or(&tbl_name.0);
if let Some(table) = schema.tables.get(name) {
let entry = parsed.table_columns.entry(table.name.clone()).or_default();
for col in table.columns.keys() {
entry.insert(col.clone());
parsed.columns.push(ResultColumn::Expr(
Expr::Qualified(tbl_name.clone(), Name(col.clone())),
Some(As::As(Name(format!("col_{i}")))),
));
i += 1;
}
} else {
return Err(MatcherError::TableStarNotFound {
tbl_name: name.clone(),
});
}
}
}
}
Ok(())
}
fn table_to_expr(
aliases: &HashMap<String, String>,
tbl: &NormalizedTable,
table: &str,
id: Uuid,
) -> Result<Expr, MatcherError> {
let tbl_name = aliases
.iter()
.find_map(|(alias, actual)| (actual == table).then_some(alias))
.cloned()
.unwrap_or_else(|| table.to_owned());
let expr = Expr::in_table(
Expr::Parenthesized(
tbl.pk
.iter()
.map(|pk| Expr::Qualified(Name(tbl_name.clone()), Name(pk.to_owned())))
.collect(),
),
false,
QualifiedName::single(Name(format!("watch_{}_{table}", id.as_simple()))),
None,
);
Ok(expr)
}
#[derive(Debug, thiserror::Error)]
pub enum MatcherError {
#[error(transparent)]
Lexer(#[from] sqlite3_parser::lexer::sql::Error),
#[error("one statement is required for matching")]
StatementRequired,
#[error("unsupported statement")]
UnsupportedStatement,
#[error("at least 1 table is required in FROM / JOIN clause")]
TableRequired,
#[error(transparent)]
Sqlite(#[from] rusqlite::Error),
#[error("table not found in schema: {0}")]
TableNotFound(String),
#[error("no primary key for table: {0}")]
NoPrimaryKey(String),
#[error("aggregate missing primary key {0}.{1}")]
AggPrimaryKeyMissing(String, String),
#[error("JOIN .. ON expression is not supported for join on table '{table}': {expr:?}")]
JoinOnExprUnsupported { table: String, expr: Expr },
#[error("expression is not supported: {expr:?}")]
UnsupportedExpr { expr: Expr },
#[error("could not find table for {tbl_name}.* in corrosion's schema")]
TableStarNotFound { tbl_name: String },
#[error("<tbl>.{col_name} qualification required for ambiguous column name")]
QualificationRequired { col_name: String },
#[error("could not find table for column {col_name}")]
TableForColumnNotFound { col_name: String },
#[error("missing primary keys, this shouldn't happen")]
MissingPrimaryKeys,
#[error("change queue has been closed or is full")]
ChangeQueueClosedOrFull,
#[error("change receiver is closed")]
ChangeReceiverClosed,
#[error(transparent)]
Unpack(#[from] UnpackError),
}
#[cfg(test)]
mod tests {
use rusqlite::params;
use crate::{
change::{row_to_change, SqliteValue},
schema::{make_schema_inner, parse_sql},
sqlite::{setup_conn, CrConn},
};
use super::*;
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_matcher() -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
let schema_sql = "CREATE TABLE sw (pk TEXT primary key, sandwich TEXT);";
let schema = parse_sql(schema_sql)?;
let sql = "SELECT sandwich FROM sw WHERE pk=\"mad\"";
let cancel = CancellationToken::new();
let id = Uuid::new_v4();
let tmpdir = tempfile::tempdir()?;
let db_path = tmpdir.path().join("test.db");
let mut conn = CrConn::init(rusqlite::Connection::open(&db_path)?)?;
setup_conn(
&mut conn,
&[(
tmpdir
.path()
.join("watches.db")
.display()
.to_string()
.into(),
"watches".into(),
)]
.into(),
)?;
{
let tx = conn.transaction()?;
make_schema_inner(&tx, &NormalizedSchema::default(), &schema)?;
tx.commit()?;
}
let mut matcher_conn = rusqlite::Connection::open(&db_path).expect("could not open conn");
setup_conn(
&mut matcher_conn,
&[(
tmpdir
.path()
.join("watches.db")
.display()
.to_string()
.into(),
"watches".into(),
)]
.into(),
)?;
let (tx, _rx) = mpsc::channel(1);
let (change_tx, _change_rx) = broadcast::channel(1);
let _matcher = Matcher::new(id, &schema, matcher_conn, tx, change_tx, sql, cancel)?;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_diff() {
let sql = "SELECT json_object(
'targets', json_array(cs.address||':'||cs.port),
'labels', json_object(
'__metrics_path__', JSON_EXTRACT(cs.meta, '$.path'),
'app', cs.app_name,
'vm_account_id', cs.organization_id,
'instance', cs.instance_id
)
)
FROM consul_services cs
LEFT JOIN machines m ON m.id = cs.instance_id
LEFT JOIN machine_versions mv ON m.id = mv.machine_id AND m.machine_version_id = mv.id
LEFT JOIN machine_version_statuses mvs ON m.id = mvs.machine_id AND m.machine_version_id = mvs.id
WHERE cs.node = 'test-hostname'
AND (mvs.status IS NULL OR mvs.status = 'started')
AND cs.name == 'app-prometheus'";
let schema_sql = "
CREATE TABLE consul_services (
node TEXT NOT NULL,
id TEXT NOT NULL,
name TEXT NOT NULL DEFAULT '',
tags TEXT NOT NULL DEFAULT '[]',
meta TEXT NOT NULL DEFAULT '{}',
port INTEGER NOT NULL DEFAULT 0,
address TEXT NOT NULL DEFAULT '',
updated_at INTEGER NOT NULL DEFAULT 0,
app_id INTEGER AS (CAST(JSON_EXTRACT(meta, '$.app_id') AS INTEGER)), network_id INTEGER AS (
CAST(JSON_EXTRACT(meta, '$.network_id') AS INTEGER)
), app_name TEXT AS (JSON_EXTRACT(meta, '$.app_name')), instance_id TEXT AS (
COALESCE(
JSON_EXTRACT(meta, '$.machine_id'),
SUBSTR(JSON_EXTRACT(meta, '$.alloc_id'), 1, 8),
CASE
WHEN INSTR(id, '_nomad-task-') = 1 THEN SUBSTR(id, 13, 8)
ELSE NULL
END
)
), organization_id INTEGER AS (
CAST(
JSON_EXTRACT(meta, '$.organization_id') AS INTEGER
)
), protocol TEXT
AS (JSON_EXTRACT(meta, '$.protocol')),
PRIMARY KEY (node, id)
);
CREATE TABLE machines (
id TEXT NOT NULL PRIMARY KEY,
node TEXT NOT NULL DEFAULT '',
name TEXT NOT NULL DEFAULT '',
machine_version_id TEXT NOT NULL DEFAULT '',
app_id INTEGER NOT NULL DEFAULT 0,
organization_id INTEGER NOT NULL DEFAULT 0,
network_id INTEGER NOT NULL DEFAULT 0,
updated_at INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE machine_versions (
machine_id TEXT NOT NULL,
id TEXT NOT NULL DEFAULT '',
config TEXT NOT NULL DEFAULT '{}',
updated_at INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (machine_id, id)
);
CREATE TABLE machine_version_statuses (
machine_id TEXT NOT NULL,
id TEXT NOT NULL,
status TEXT NOT NULL DEFAULT '',
updated_at INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (machine_id, id)
);
";
let schema = parse_sql(schema_sql).unwrap();
let tmpdir = tempfile::tempdir().unwrap();
let db_path = tmpdir.path().join("test.db");
let mut conn =
CrConn::init(rusqlite::Connection::open(&db_path).expect("could not open conn"))
.expect("could not init crsql");
setup_conn(
&mut conn,
&[(
tmpdir
.path()
.join("watches.db")
.display()
.to_string()
.into(),
"watches".into(),
)]
.into(),
)
.unwrap();
{
let tx = conn.transaction().unwrap();
make_schema_inner(&tx, &NormalizedSchema::default(), &schema).unwrap();
tx.commit().unwrap();
}
{
let tx = conn.transaction().unwrap();
tx.execute_batch(r#"
INSERT INTO consul_services (node, id, name, address, port, meta) VALUES ('test-hostname', 'service-1', 'app-prometheus', '127.0.0.1', 1, '{"path": "/1", "machine_id": "m-1"}');
INSERT INTO machines (id, machine_version_id) VALUES ('m-1', 'mv-1');
INSERT INTO machine_versions (machine_id, id) VALUES ('m-1', 'mv-1');
INSERT INTO machine_version_statuses (machine_id, id, status) VALUES ('m-1', 'mv-1', 'started');
INSERT INTO consul_services (node, id, name, address, port, meta) VALUES ('test-hostname', 'service-2', 'not-app-prometheus', '127.0.0.1', 1, '{"path": "/1", "machine_id": "m-2"}');
INSERT INTO machines (id, machine_version_id) VALUES ('m-2', 'mv-2');
INSERT INTO machine_versions (machine_id, id) VALUES ('m-2', 'mv-2');
INSERT INTO machine_version_statuses (machine_id, id, status) VALUES ('m-2', 'mv-2', 'started');
"#).unwrap();
tx.commit().unwrap();
}
{
let mut conn2 = CrConn::init(
rusqlite::Connection::open(tmpdir.path().join("test2.db"))
.expect("could not open conn"),
)
.expect("could not init crsql");
setup_conn(
&mut conn2,
&[(
tmpdir
.path()
.join("watches.db")
.display()
.to_string()
.into(),
"watches".into(),
)]
.into(),
)
.unwrap();
{
let tx = conn2.transaction().unwrap();
make_schema_inner(&tx, &NormalizedSchema::default(), &schema).unwrap();
tx.commit().unwrap();
}
let changes = {
let mut prepped = conn.prepare_cached(r#"SELECT "table", pk, cid, val, col_version, db_version, seq, COALESCE(site_id, crsql_site_id()), cl FROM crsql_changes WHERE site_id IS NULL AND db_version = ? ORDER BY seq ASC"#).unwrap();
let rows = prepped.query_map([1], row_to_change).unwrap();
let mut changes = vec![];
for row in rows {
changes.push(row.unwrap());
}
changes
};
let tx = conn2.transaction().unwrap();
for change in changes {
tx.prepare_cached(
r#"
INSERT INTO crsql_changes
("table", pk, cid, val, col_version, db_version, seq, site_id, cl)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?, ?)
"#,
)
.unwrap()
.execute(params![
change.table.as_str(),
change.pk,
change.cid.as_str(),
&change.val,
change.col_version,
change.db_version,
change.seq,
&change.site_id,
change.cl
])
.unwrap();
}
}
let cancel = CancellationToken::new();
let id = Uuid::new_v4();
let mut matcher_conn = rusqlite::Connection::open(&db_path).expect("could not open conn");
setup_conn(
&mut matcher_conn,
&[(
tmpdir
.path()
.join("watches.db")
.display()
.to_string()
.into(),
"watches".into(),
)]
.into(),
)
.unwrap();
{
let (tx, mut rx) = mpsc::channel(1);
let (change_tx, mut change_rx) = broadcast::channel(1);
let matcher =
Matcher::new(id, &schema, matcher_conn, tx, change_tx, sql, cancel).unwrap();
assert!(matches!(rx.recv().await.unwrap(), QueryEvent::Columns(_)));
let cells = vec![SqliteValue::Text("{\"targets\":[\"127.0.0.1:1\"],\"labels\":{\"__metrics_path__\":\"/1\",\"app\":null,\"vm_account_id\":null,\"instance\":\"m-1\"}}".into())];
assert_eq!(
rx.recv().await.unwrap(),
QueryEvent::Row {
rowid: 1,
change_type: ChangeType::Upsert,
cells
}
);
assert!(matches!(rx.recv().await.unwrap(), QueryEvent::EndOfQuery));
{
let tx = conn.transaction().unwrap();
tx.execute_batch(r#"
INSERT INTO consul_services (node, id, name, address, port, meta) VALUES ('test-hostname', 'service-3', 'app-prometheus', '127.0.0.1', 1, '{"path": "/1", "machine_id": "m-3"}');
INSERT INTO machines (id, machine_version_id) VALUES ('m-3', 'mv-3');
INSERT INTO machine_versions (machine_id, id) VALUES ('m-3', 'mv-3');
INSERT INTO machine_version_statuses (machine_id, id, status) VALUES ('m-3', 'mv-3', 'started');
"#).unwrap();
tx.commit().unwrap();
}
let changes = {
let mut prepped = conn.prepare_cached(r#"SELECT "table", pk, cid, val, col_version, db_version, seq, COALESCE(site_id, crsql_site_id()), cl FROM crsql_changes WHERE site_id IS NULL AND db_version = ? ORDER BY seq ASC"#).unwrap();
let rows = prepped.query_map([2], row_to_change).unwrap();
let mut changes = vec![];
for row in rows {
changes.push(row.unwrap());
}
changes
};
matcher.process_change(&changes.as_slice()).unwrap();
let cells = vec![SqliteValue::Text("{\"targets\":[\"127.0.0.1:1\"],\"labels\":{\"__metrics_path__\":\"/1\",\"app\":null,\"vm_account_id\":null,\"instance\":\"m-3\"}}".into())];
assert_eq!(
change_rx.recv().await.unwrap(),
QueryEvent::Row {
rowid: 2,
change_type: ChangeType::Upsert,
cells
}
);
{
let tx = conn.transaction().unwrap();
tx.execute_batch(r#"
DELETE FROM consul_services where node = 'test-hostname' AND id = 'service-1';
"#).unwrap();
tx.commit().unwrap();
}
let changes = {
let mut prepped = conn.prepare_cached(r#"SELECT "table", pk, cid, val, col_version, db_version, seq, COALESCE(site_id, crsql_site_id()), cl FROM crsql_changes WHERE site_id IS NULL AND db_version = ? ORDER BY seq ASC"#).unwrap();
let rows = prepped.query_map([3], row_to_change).unwrap();
let mut changes = vec![];
for row in rows {
println!("change: {row:?}");
changes.push(row.unwrap());
}
changes
};
matcher.process_change(&changes.as_slice()).unwrap();
let cells = vec![SqliteValue::Text("{\"targets\":[\"127.0.0.1:1\"],\"labels\":{\"__metrics_path__\":\"/1\",\"app\":null,\"vm_account_id\":null,\"instance\":\"m-1\"}}".into())];
assert_eq!(
change_rx.recv().await.unwrap(),
QueryEvent::Row {
rowid: 1,
change_type: ChangeType::Delete,
cells
}
);
}
}
}