use dashmap::DashMap;
use std::sync::Arc;
use crate::error::{DbxError, DbxResult};
#[derive(Debug, Default)]
pub struct ViewRegistry {
views: DashMap<String, String>,
}
impl ViewRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn create(&self, name: &str, sql: &str) -> DbxResult<()> {
self.views.insert(name.to_lowercase(), sql.to_string());
Ok(())
}
pub fn drop(&self, name: &str) -> DbxResult<()> {
self.views
.remove(&name.to_lowercase())
.map(|_| ())
.ok_or_else(|| DbxError::InvalidArguments(format!("뷰 '{}' 를 찾을 수 없음", name)))
}
pub fn exists(&self, name: &str) -> bool {
self.views.contains_key(&name.to_lowercase())
}
pub fn expand(&self, sql: &str) -> String {
let mut result = sql.to_string();
for entry in self.views.iter() {
let name = entry.key();
let view_sql = entry.value();
let pattern = format!("from {}", name);
let replacement = format!("FROM ({}) AS {}", view_sql, name);
let lower = result.to_lowercase();
if let Some(pos) = lower.find(&pattern) {
result = format!(
"{}{}{}",
&result[..pos],
replacement,
&result[pos + pattern.len()..]
);
}
}
result
}
pub fn list_views(&self) -> Vec<String> {
self.views.iter().map(|e| e.key().clone()).collect()
}
}
pub type SharedViewRegistry = Arc<ViewRegistry>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_and_exists() {
let reg = ViewRegistry::new();
assert!(!reg.exists("active_users"));
reg.create(
"active_users",
"SELECT id, name FROM users WHERE active = true",
)
.unwrap();
assert!(reg.exists("active_users"));
}
#[test]
fn test_create_case_insensitive() {
let reg = ViewRegistry::new();
reg.create("MyView", "SELECT 1").unwrap();
assert!(reg.exists("myview"));
assert!(reg.exists("MyView")); assert!(reg.exists("MYVIEW"));
}
#[test]
fn test_drop_view() {
let reg = ViewRegistry::new();
reg.create("v", "SELECT 1 AS x").unwrap();
assert!(reg.exists("v"));
reg.drop("v").unwrap();
assert!(!reg.exists("v"));
}
#[test]
fn test_drop_nonexistent_fails() {
let reg = ViewRegistry::new();
assert!(reg.drop("nonexistent").is_err());
}
#[test]
fn test_expand_replaces_from_clause() {
let reg = ViewRegistry::new();
reg.create(
"active_users",
"SELECT id, name FROM users WHERE active = true",
)
.unwrap();
let sql = "SELECT * FROM active_users";
let expanded = reg.expand(sql);
assert!(
expanded.contains("(SELECT id, name FROM users WHERE active = true)"),
"서브쿼리가 삽입되어야 함: {}",
expanded
);
assert!(
expanded.contains("AS active_users"),
"별칭이 지정되어야 함: {}",
expanded
);
}
#[test]
fn test_expand_no_match() {
let reg = ViewRegistry::new();
reg.create("v", "SELECT 1").unwrap();
let sql = "SELECT * FROM users"; let expanded = reg.expand(sql);
assert_eq!(expanded, sql, "치환 없이 원래 SQL 유지");
}
#[test]
fn test_list_views() {
let reg = ViewRegistry::new();
reg.create("v1", "SELECT 1").unwrap();
reg.create("v2", "SELECT 2").unwrap();
let mut views = reg.list_views();
views.sort();
assert_eq!(views, vec!["v1", "v2"]);
}
}
use arrow::record_batch::RecordBatch;
use std::collections::HashSet;
use std::sync::{Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
fn extract_source_tables(sql: &str) -> Vec<String> {
let upper = sql.to_uppercase();
let tokens: Vec<&str> = upper.split_whitespace().collect();
let original_tokens: Vec<&str> = sql.split_whitespace().collect();
let mut tables = Vec::new();
for (i, token) in tokens.iter().enumerate() {
if (*token == "FROM" || *token == "JOIN") && i + 1 < tokens.len() {
let table_name = original_tokens[i + 1]
.trim_matches(|c: char| c == '(' || c == ')' || c == ',' || c == ';')
.to_lowercase();
if !table_name.is_empty() && table_name != "select" && table_name != "(" {
tables.push(table_name);
}
}
}
tables.sort();
tables.dedup();
tables
}
struct MatViewEntry {
sql: String,
cache: Option<Vec<RecordBatch>>,
refreshed_at: Option<Instant>,
refresh_interval_secs: Option<u64>,
source_tables: Vec<String>,
}
pub struct MatViewNotifier {
dirty: Mutex<HashSet<String>>,
condvar: Condvar,
}
impl MatViewNotifier {
fn new() -> Self {
Self {
dirty: Mutex::new(HashSet::new()),
condvar: Condvar::new(),
}
}
fn mark_dirty(&self, mv_name: &str) {
let mut dirty = self.dirty.lock().unwrap();
dirty.insert(mv_name.to_string());
self.condvar.notify_one();
}
pub fn wait_and_take(&self) -> HashSet<String> {
let mut dirty = self.dirty.lock().unwrap();
while dirty.is_empty() {
dirty = self.condvar.wait(dirty).unwrap();
}
dirty.drain().collect()
}
pub fn take(&self) -> HashSet<String> {
let mut dirty = self.dirty.lock().unwrap();
dirty.drain().collect()
}
}
pub struct MaterializedViewRegistry {
views: DashMap<String, RwLock<MatViewEntry>>,
notifier: MatViewNotifier,
min_refresh_interval_ms: std::sync::atomic::AtomicU64,
}
impl Default for MaterializedViewRegistry {
fn default() -> Self {
Self {
views: DashMap::new(),
notifier: MatViewNotifier::new(),
min_refresh_interval_ms: std::sync::atomic::AtomicU64::new(1000),
}
}
}
impl MaterializedViewRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn set_min_refresh_interval_ms(&self, ms: u64) {
self.min_refresh_interval_ms
.store(ms, std::sync::atomic::Ordering::Relaxed);
}
pub fn min_refresh_interval_ms(&self) -> u64 {
self.min_refresh_interval_ms
.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn min_refresh_interval(&self) -> Duration {
Duration::from_millis(self.min_refresh_interval_ms())
}
pub fn create(
&self,
name: &str,
sql: &str,
refresh_interval_secs: Option<u64>,
) -> DbxResult<()> {
let source_tables = extract_source_tables(sql);
self.views.insert(
name.to_lowercase(),
RwLock::new(MatViewEntry {
sql: sql.to_string(),
cache: None,
refreshed_at: None,
refresh_interval_secs,
source_tables,
}),
);
Ok(())
}
pub fn set_cache(&self, name: &str, batches: Vec<RecordBatch>) -> DbxResult<()> {
let entry = self.views.get(&name.to_lowercase()).ok_or_else(|| {
DbxError::InvalidArguments(format!("'{}' 구체화된 뷰를 찾을 수 없음", name))
})?;
let mut e = entry.write().unwrap();
e.cache = Some(batches);
e.refreshed_at = Some(Instant::now());
Ok(())
}
pub fn is_fresh(&self, name: &str) -> bool {
let entry = match self.views.get(&name.to_lowercase()) {
Some(e) => e,
None => return false,
};
let e = entry.read().unwrap();
match (e.refreshed_at, e.refresh_interval_secs) {
(None, _) => false,
(Some(_), None) => true,
(Some(t), Some(secs)) => t.elapsed().as_secs() < secs,
}
}
pub fn get_cache(&self, name: &str) -> Option<Vec<RecordBatch>> {
let entry = self.views.get(&name.to_lowercase())?;
entry.read().unwrap().cache.clone()
}
pub fn get_sql(&self, name: &str) -> Option<String> {
Some(
self.views
.get(&name.to_lowercase())?
.read()
.unwrap()
.sql
.clone(),
)
}
pub fn list(&self) -> Vec<String> {
self.views.iter().map(|e| e.key().clone()).collect()
}
pub fn remove(&self, name: &str) -> DbxResult<()> {
self.views
.remove(&name.to_lowercase())
.map(|_| ())
.ok_or_else(|| {
DbxError::InvalidArguments(format!("'{}' 구체화된 뷰를 찾을 수 없음", name))
})
}
pub fn notify_change(&self, table: &str) {
if self.views.is_empty() {
return; }
let table_lower = table.to_lowercase();
let base_table = if let Some(idx) = table_lower.find("__shard_") {
&table_lower[..idx]
} else {
&table_lower
};
for entry in self.views.iter() {
let mv_name = entry.key();
let e = entry.value().read().unwrap();
if e.source_tables.iter().any(|t| t == base_table) {
drop(e); self.notifier.mark_dirty(mv_name);
}
}
}
pub fn wait_and_take_dirty(&self) -> HashSet<String> {
self.notifier.wait_and_take()
}
pub fn take_dirty(&self) -> HashSet<String> {
self.notifier.take()
}
pub fn stale_views(&self) -> Vec<String> {
self.views
.iter()
.filter(|e| {
let entry = e.value().read().unwrap();
match (entry.refreshed_at, entry.refresh_interval_secs) {
(None, _) => true,
(Some(_), None) => false,
(Some(t), Some(secs)) => t.elapsed().as_secs() >= secs,
}
})
.map(|e| e.key().clone())
.collect()
}
}
pub type SharedMaterializedViewRegistry = Arc<MaterializedViewRegistry>;
#[cfg(test)]
mod matview_tests {
use super::*;
use arrow::array::Int64Array;
use arrow::datatypes::{DataType, Field, Schema};
fn make_batch(n: i64) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, false)]));
RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![n]))]).unwrap()
}
#[test]
fn test_materialized_view_cache() {
let reg = MaterializedViewRegistry::new();
reg.create("mv_users", "SELECT id FROM users", None)
.unwrap();
assert!(!reg.is_fresh("mv_users"));
reg.set_cache("mv_users", vec![make_batch(1), make_batch(2)])
.unwrap();
assert!(reg.is_fresh("mv_users"));
let cached = reg.get_cache("mv_users").unwrap();
assert_eq!(cached.len(), 2);
}
#[test]
fn test_matview_drop() {
let reg = MaterializedViewRegistry::new();
reg.create("mv_test", "SELECT 1", None).unwrap();
assert!(reg.get_sql("mv_test").is_some());
reg.remove("mv_test").unwrap();
assert!(reg.get_sql("mv_test").is_none());
}
#[test]
fn test_matview_drop_nonexistent_fails() {
let reg = MaterializedViewRegistry::new();
assert!(reg.remove("nonexistent").is_err());
}
#[test]
fn test_matview_with_interval() {
let reg = MaterializedViewRegistry::new();
reg.create("mv_sales", "SELECT * FROM sales", Some(300))
.unwrap();
assert!(!reg.is_fresh("mv_sales"));
reg.set_cache("mv_sales", vec![make_batch(42)]).unwrap();
assert!(reg.is_fresh("mv_sales"));
let cached = reg.get_cache("mv_sales").unwrap();
assert_eq!(cached[0].num_rows(), 1);
}
#[test]
fn test_matview_list() {
let reg = MaterializedViewRegistry::new();
reg.create("mv_a", "SELECT 1", None).unwrap();
reg.create("mv_b", "SELECT 2", None).unwrap();
let mut names = reg.list();
names.sort();
assert_eq!(names, vec!["mv_a", "mv_b"]);
}
#[test]
fn test_extract_source_tables() {
assert_eq!(
extract_source_tables("SELECT id, name FROM users WHERE active = true"),
vec!["users"]
);
assert_eq!(
extract_source_tables("SELECT * FROM orders JOIN users ON orders.uid = users.id"),
vec!["orders", "users"]
);
assert_eq!(
extract_source_tables("SELECT AVG(price) FROM products"),
vec!["products"]
);
assert_eq!(
extract_source_tables("SELECT * FROM t1 JOIN t1 ON t1.a = t1.b"),
vec!["t1"]
);
}
#[test]
fn test_notify_change_marks_dirty() {
let reg = MaterializedViewRegistry::new();
reg.create("mv_users", "SELECT id FROM users", None)
.unwrap();
reg.create("mv_orders", "SELECT * FROM orders", None)
.unwrap();
reg.notify_change("users");
let dirty = reg.take_dirty();
assert!(dirty.contains("mv_users"));
assert!(!dirty.contains("mv_orders"));
reg.notify_change("orders");
let dirty = reg.take_dirty();
assert!(dirty.contains("mv_orders"));
assert!(!dirty.contains("mv_users"));
}
#[test]
fn test_notify_change_shard_table() {
let reg = MaterializedViewRegistry::new();
reg.create("mv_users", "SELECT id FROM users", None)
.unwrap();
reg.notify_change("users__shard_0");
let dirty = reg.take_dirty();
assert!(dirty.contains("mv_users"));
}
#[test]
fn test_configurable_min_refresh_interval() {
let reg = MaterializedViewRegistry::new();
assert_eq!(reg.min_refresh_interval_ms(), 1000);
reg.set_min_refresh_interval_ms(500);
assert_eq!(reg.min_refresh_interval_ms(), 500);
assert_eq!(reg.min_refresh_interval(), Duration::from_millis(500));
}
#[test]
fn test_notify_no_views_is_noop() {
let reg = MaterializedViewRegistry::new();
reg.notify_change("some_table");
let dirty = reg.take_dirty();
assert!(dirty.is_empty());
}
}