use super::error::{MvError, MvState};
use arrow_schema::SchemaRef;
use rustc_hash::{FxHashMap, FxHashSet};
use std::collections::VecDeque;
#[derive(Debug, Clone)]
pub struct MaterializedView {
pub name: String,
pub sql: String,
pub sources: Vec<String>,
pub schema: SchemaRef,
pub operator_id: String,
pub state: MvState,
}
impl MaterializedView {
#[must_use]
pub fn new(
name: impl Into<String>,
sql: impl Into<String>,
sources: Vec<String>,
schema: SchemaRef,
) -> Self {
let name = name.into();
let operator_id = format!("mv_{name}");
Self {
name,
sql: sql.into(),
sources,
schema,
operator_id,
state: MvState::Running,
}
}
#[cfg(test)]
pub fn simple(name: impl Into<String>, sources: Vec<String>) -> Self {
use arrow_schema::{DataType, Field, Schema};
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Int64,
false,
)]));
Self::new(name, "", sources, schema)
}
#[must_use]
pub fn depends_on(&self, source: &str) -> bool {
self.sources.iter().any(|s| s == source)
}
}
#[derive(Debug, Default)]
pub struct MvRegistry {
views: FxHashMap<String, MaterializedView>,
base_tables: FxHashSet<String>,
dependents: FxHashMap<String, FxHashSet<String>>,
dependencies: FxHashMap<String, FxHashSet<String>>,
topo_order: Vec<String>,
}
impl MvRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register_base_table(&mut self, name: impl Into<String>) {
self.base_tables.insert(name.into());
}
#[must_use]
pub fn is_base_table(&self, name: &str) -> bool {
self.base_tables.contains(name)
}
pub fn register(&mut self, view: MaterializedView) -> Result<(), MvError> {
if self.views.contains_key(&view.name) {
return Err(MvError::DuplicateName(view.name.clone()));
}
for source in &view.sources {
if !self.views.contains_key(source) && !self.is_base_table(source) {
return Err(MvError::SourceNotFound(source.clone()));
}
}
if self.would_create_cycle(&view.name, &view.sources) {
return Err(MvError::CycleDetected(view.name.clone()));
}
for source in &view.sources {
self.dependents
.entry(source.clone())
.or_default()
.insert(view.name.clone());
self.dependencies
.entry(view.name.clone())
.or_default()
.insert(source.clone());
}
self.views.insert(view.name.clone(), view);
self.update_topo_order();
Ok(())
}
pub fn unregister(&mut self, name: &str) -> Result<MaterializedView, MvError> {
if !self.views.contains_key(name) {
return Err(MvError::ViewNotFound(name.to_string()));
}
if let Some(deps) = self.dependents.get(name) {
if !deps.is_empty() {
let dep_names: Vec<_> = deps.iter().cloned().collect();
return Err(MvError::HasDependents(name.to_string(), dep_names));
}
}
self.remove_view(name)
}
pub fn unregister_cascade(&mut self, name: &str) -> Result<Vec<MaterializedView>, MvError> {
if !self.views.contains_key(name) {
return Err(MvError::ViewNotFound(name.to_string()));
}
let mut to_remove = Vec::new();
self.collect_dependents_recursive(name, &mut to_remove);
to_remove.push(name.to_string());
let mut removed = Vec::with_capacity(to_remove.len());
for view_name in to_remove {
if let Ok(view) = self.remove_view(&view_name) {
removed.push(view);
}
}
Ok(removed)
}
fn collect_dependents_recursive(&self, name: &str, result: &mut Vec<String>) {
if let Some(deps) = self.dependents.get(name) {
for dep in deps {
if !result.contains(dep) {
self.collect_dependents_recursive(dep, result);
result.push(dep.clone());
}
}
}
}
fn remove_view(&mut self, name: &str) -> Result<MaterializedView, MvError> {
let view = self
.views
.remove(name)
.ok_or_else(|| MvError::ViewNotFound(name.to_string()))?;
if let Some(sources) = self.dependencies.remove(name) {
for source in sources {
if let Some(deps) = self.dependents.get_mut(&source) {
deps.remove(name);
}
}
}
self.dependents.remove(name);
self.update_topo_order();
Ok(view)
}
#[must_use]
pub fn get(&self, name: &str) -> Option<&MaterializedView> {
self.views.get(name)
}
#[must_use]
pub fn get_mut(&mut self, name: &str) -> Option<&mut MaterializedView> {
self.views.get_mut(name)
}
#[must_use]
pub fn topo_order(&self) -> &[String] {
&self.topo_order
}
pub fn get_dependents(&self, source: &str) -> impl Iterator<Item = &str> {
self.dependents
.get(source)
.into_iter()
.flatten()
.map(String::as_str)
}
pub fn get_dependencies(&self, view: &str) -> impl Iterator<Item = &str> {
self.dependencies
.get(view)
.into_iter()
.flatten()
.map(String::as_str)
}
#[must_use]
pub fn len(&self) -> usize {
self.views.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.views.is_empty()
}
pub fn views(&self) -> impl Iterator<Item = &MaterializedView> {
self.views.values()
}
#[must_use]
pub fn base_tables(&self) -> &FxHashSet<String> {
&self.base_tables
}
#[must_use]
pub fn dependency_chain(&self, name: &str) -> Vec<String> {
let mut chain = Vec::new();
let mut visited = FxHashSet::default();
self.collect_dependencies_recursive(name, &mut chain, &mut visited);
chain
}
fn collect_dependencies_recursive(
&self,
name: &str,
result: &mut Vec<String>,
visited: &mut FxHashSet<String>,
) {
if !visited.insert(name.to_string()) {
return;
}
if let Some(deps) = self.dependencies.get(name) {
for dep in deps {
self.collect_dependencies_recursive(dep, result, visited);
}
}
if self.views.contains_key(name) {
result.push(name.to_string());
}
}
fn would_create_cycle(&self, new_name: &str, sources: &[String]) -> bool {
let mut visited = FxHashSet::default();
let mut stack: Vec<_> = sources.to_vec();
while let Some(current) = stack.pop() {
if current == new_name {
return true;
}
if visited.insert(current.clone()) {
if let Some(deps) = self.dependencies.get(¤t) {
stack.extend(deps.iter().cloned());
}
}
}
false
}
fn update_topo_order(&mut self) {
let mut in_degree: FxHashMap<String, usize> = FxHashMap::default();
let mut queue: VecDeque<String> = VecDeque::new();
for name in self.views.keys() {
let deps = self.dependencies.get(name).map_or(0, |d| {
d.iter().filter(|dep| self.views.contains_key(*dep)).count()
});
in_degree.insert(name.clone(), deps);
if deps == 0 {
queue.push_back(name.clone());
}
}
self.topo_order.clear();
while let Some(name) = queue.pop_front() {
self.topo_order.push(name.clone());
if let Some(dependents) = self.dependents.get(&name) {
for dep in dependents {
if let Some(count) = in_degree.get_mut(dep) {
*count = count.saturating_sub(1);
if *count == 0 {
queue.push_back(dep.clone());
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mv(name: &str, sources: Vec<&str>) -> MaterializedView {
MaterializedView::simple(name, sources.into_iter().map(String::from).collect())
}
#[test]
fn test_simple_registration() {
let mut registry = MvRegistry::new();
registry.register_base_table("trades");
let view = mv("ohlc_1s", vec!["trades"]);
registry.register(view).unwrap();
assert_eq!(registry.len(), 1);
assert!(registry.get("ohlc_1s").is_some());
}
#[test]
fn test_cascading_registration() {
let mut registry = MvRegistry::new();
registry.register_base_table("trades");
registry.register(mv("ohlc_1s", vec!["trades"])).unwrap();
registry.register(mv("ohlc_1m", vec!["ohlc_1s"])).unwrap();
registry.register(mv("ohlc_1h", vec!["ohlc_1m"])).unwrap();
assert_eq!(registry.topo_order(), &["ohlc_1s", "ohlc_1m", "ohlc_1h"]);
}
#[test]
fn test_duplicate_name_error() {
let mut registry = MvRegistry::new();
registry.register_base_table("trades");
registry.register(mv("ohlc_1s", vec!["trades"])).unwrap();
let result = registry.register(mv("ohlc_1s", vec!["trades"]));
assert!(matches!(result, Err(MvError::DuplicateName(_))));
}
#[test]
fn test_source_not_found_error() {
let mut registry = MvRegistry::new();
let result = registry.register(mv("view", vec!["nonexistent"]));
assert!(matches!(result, Err(MvError::SourceNotFound(_))));
}
#[test]
fn test_cycle_detection_direct() {
let mut registry = MvRegistry::new();
registry.register_base_table("a");
registry.register(mv("b", vec!["a"])).unwrap();
registry.register(mv("c", vec!["b"])).unwrap();
registry.register(mv("d", vec!["c"])).unwrap();
}
#[test]
fn test_multi_source_view() {
let mut registry = MvRegistry::new();
registry.register_base_table("orders");
registry.register_base_table("payments");
registry
.register(mv("order_payments", vec!["orders", "payments"]))
.unwrap();
assert_eq!(registry.topo_order(), &["order_payments"]);
let deps: Vec<_> = registry.get_dependencies("order_payments").collect();
assert!(deps.contains(&"orders"));
assert!(deps.contains(&"payments"));
}
#[test]
fn test_diamond_dependency() {
let mut registry = MvRegistry::new();
registry.register_base_table("source");
registry.register(mv("a", vec!["source"])).unwrap();
registry.register(mv("b", vec!["source"])).unwrap();
registry.register(mv("c", vec!["a", "b"])).unwrap();
let order = registry.topo_order();
let c_idx = order.iter().position(|x| x == "c").unwrap();
let a_idx = order.iter().position(|x| x == "a").unwrap();
let b_idx = order.iter().position(|x| x == "b").unwrap();
assert!(c_idx > a_idx);
assert!(c_idx > b_idx);
}
#[test]
fn test_unregister_simple() {
let mut registry = MvRegistry::new();
registry.register_base_table("trades");
registry.register(mv("ohlc_1s", vec!["trades"])).unwrap();
let removed = registry.unregister("ohlc_1s").unwrap();
assert_eq!(removed.name, "ohlc_1s");
assert!(registry.is_empty());
}
#[test]
fn test_unregister_with_dependents_error() {
let mut registry = MvRegistry::new();
registry.register_base_table("trades");
registry.register(mv("ohlc_1s", vec!["trades"])).unwrap();
registry.register(mv("ohlc_1m", vec!["ohlc_1s"])).unwrap();
let result = registry.unregister("ohlc_1s");
assert!(matches!(result, Err(MvError::HasDependents(_, _))));
}
#[test]
fn test_unregister_cascade() {
let mut registry = MvRegistry::new();
registry.register_base_table("trades");
registry.register(mv("ohlc_1s", vec!["trades"])).unwrap();
registry.register(mv("ohlc_1m", vec!["ohlc_1s"])).unwrap();
registry.register(mv("ohlc_1h", vec!["ohlc_1m"])).unwrap();
let removed = registry.unregister_cascade("ohlc_1s").unwrap();
assert_eq!(removed.len(), 3);
assert!(registry.is_empty());
assert_eq!(removed[0].name, "ohlc_1h");
assert_eq!(removed[1].name, "ohlc_1m");
assert_eq!(removed[2].name, "ohlc_1s");
}
#[test]
fn test_dependency_chain() {
let mut registry = MvRegistry::new();
registry.register_base_table("trades");
registry.register(mv("ohlc_1s", vec!["trades"])).unwrap();
registry.register(mv("ohlc_1m", vec!["ohlc_1s"])).unwrap();
registry.register(mv("ohlc_1h", vec!["ohlc_1m"])).unwrap();
let chain = registry.dependency_chain("ohlc_1h");
assert_eq!(chain, vec!["ohlc_1s", "ohlc_1m", "ohlc_1h"]);
}
#[test]
fn test_get_dependents() {
let mut registry = MvRegistry::new();
registry.register_base_table("trades");
registry.register(mv("a", vec!["trades"])).unwrap();
registry.register(mv("b", vec!["trades"])).unwrap();
registry.register(mv("c", vec!["a"])).unwrap();
let dependents: Vec<_> = registry.get_dependents("trades").collect();
assert!(dependents.contains(&"a"));
assert!(dependents.contains(&"b"));
assert!(!dependents.contains(&"c"));
let a_dependents: Vec<_> = registry.get_dependents("a").collect();
assert_eq!(a_dependents, vec!["c"]);
}
#[test]
fn test_view_state_update() {
let mut registry = MvRegistry::new();
registry.register_base_table("trades");
registry.register(mv("ohlc_1s", vec!["trades"])).unwrap();
let view = registry.get_mut("ohlc_1s").unwrap();
assert_eq!(view.state, MvState::Running);
view.state = MvState::Dropping;
assert_eq!(view.state, MvState::Dropping);
}
}