use crate::error::{GroundDbError, Result};
use crate::schema::{SchemaDefinition, ViewDefinition, ViewType};
use crate::system_db::SystemDb;
use sqlparser::ast::{
Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor, TableWithJoins,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
use std::collections::{HashMap, HashSet};
use std::path::Path;
use std::sync::Mutex;
#[derive(Debug, Clone)]
pub struct TableRef {
pub collection: String,
pub alias: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ParsedView {
pub name: String,
pub original_sql: String,
pub table_refs: Vec<TableRef>,
pub columns: Vec<ViewColumn>,
pub limit: Option<u64>,
pub buffer_multiplier: f64,
pub materialize: bool,
pub is_query_template: bool,
pub param_names: Vec<String>,
}
impl ParsedView {
pub fn referenced_collections(&self) -> HashSet<String> {
self.table_refs.iter().map(|r| r.collection.clone()).collect()
}
}
#[derive(Debug, Clone)]
pub struct ViewColumn {
pub name: String,
pub source_collection: Option<String>,
pub source_field: Option<String>,
}
pub struct ViewEngine {
views: HashMap<String, ParsedView>,
view_data: Mutex<HashMap<String, Vec<serde_json::Value>>>,
}
impl ViewEngine {
pub fn new(schema: &SchemaDefinition) -> Result<Self> {
let mut views = HashMap::new();
for (name, view_def) in &schema.views {
let parsed = parse_view_query(name, view_def)?;
views.insert(name.clone(), parsed);
}
Ok(ViewEngine {
views,
view_data: Mutex::new(HashMap::new()),
})
}
pub fn get_view(&self, name: &str) -> Option<&ParsedView> {
self.views.get(name)
}
pub fn affected_views(&self, collection: &str) -> Vec<&str> {
self.views
.iter()
.filter(|(_, v)| v.referenced_collections().contains(collection))
.map(|(name, _)| name.as_str())
.collect()
}
pub fn load_from_db(&self, db: &SystemDb) -> Result<()> {
let mut cache = self.view_data.lock().unwrap();
for name in self.views.keys() {
if let Some(json_str) = db.get_view_data(name)? {
let data: Vec<serde_json::Value> = serde_json::from_str(&json_str)?;
cache.insert(name.clone(), data);
}
}
Ok(())
}
pub fn save_to_db(&self, db: &SystemDb) -> Result<()> {
let cache = self.view_data.lock().unwrap();
for (name, data) in cache.iter() {
let json_str = serde_json::to_string(data)?;
db.set_view_data(name, &json_str)?;
}
Ok(())
}
pub fn get_view_data(&self, name: &str) -> Option<Vec<serde_json::Value>> {
let cache = self.view_data.lock().unwrap();
cache.get(name).cloned()
}
pub fn set_view_data(&self, name: &str, data: Vec<serde_json::Value>) {
let mut cache = self.view_data.lock().unwrap();
cache.insert(name.to_string(), data);
}
pub fn materialize_view(&self, root: &Path, view_name: &str) -> Result<()> {
let parsed = match self.views.get(view_name) {
Some(p) if p.materialize => p,
_ => return Ok(()),
};
let cache = self.view_data.lock().unwrap();
if let Some(data) = cache.get(view_name) {
let views_dir = root.join("views");
std::fs::create_dir_all(&views_dir)?;
let output_path = views_dir.join(format!("{view_name}.yaml"));
let limited_data: Vec<&serde_json::Value> = if let Some(limit) = parsed.limit {
data.iter().take(limit as usize).collect()
} else {
data.iter().collect()
};
let yaml = serde_yaml::to_string(&limited_data)?;
std::fs::write(&output_path, &yaml)?;
}
Ok(())
}
pub fn materialize_views(&self, root: &Path) -> Result<()> {
let view_names: Vec<String> = self.views.keys().cloned().collect();
for name in &view_names {
self.materialize_view(root, name)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct RewrittenQuery {
pub sql: String,
pub param_names: Vec<String>,
pub buffer_limit: Option<usize>,
pub original_limit: Option<usize>,
}
pub fn rewrite_view_sql(
parsed: &ParsedView,
schema: &SchemaDefinition,
) -> Result<RewrittenQuery> {
let mut cte_parts = Vec::new();
for table_ref in &parsed.table_refs {
let collection_name = &table_ref.collection;
let col_def = schema.collections.get(collection_name);
if col_def.is_none() {
return Err(GroundDbError::SqlParse(format!(
"View '{}': referenced collection '{}' not found in schema",
parsed.name, collection_name
)));
}
let col_def = col_def.unwrap();
let mut cte_columns = Vec::new();
cte_columns.push("id".to_string());
cte_columns.push("created_at".to_string());
cte_columns.push("modified_at".to_string());
if col_def.content {
cte_columns.push("content_text AS content".to_string());
}
for (field_name, _field_def) in &col_def.fields {
cte_columns.push(format!(
"json_extract(data_json, '$.{field_name}') AS {field_name}"
));
}
let columns_sql = cte_columns.join(",\n ");
let cte = format!(
"{collection_name} AS (\n SELECT\n {columns_sql}\n FROM documents\n WHERE collection = '{collection_name}'\n )"
);
cte_parts.push(cte);
}
let original_sql = parsed.original_sql.trim();
let full_sql = if cte_parts.is_empty() {
original_sql.to_string()
} else {
format!("WITH {}\n{}", cte_parts.join(",\n "), original_sql)
};
let buffer_limit = parsed.limit.map(|l| {
(l as f64 * parsed.buffer_multiplier).ceil() as usize
});
log::debug!(
"View '{}' rewritten SQL:\n{}",
parsed.name,
full_sql
);
Ok(RewrittenQuery {
sql: full_sql,
param_names: parsed.param_names.clone(),
buffer_limit,
original_limit: parsed.limit.map(|l| l as usize),
})
}
fn parse_view_query(name: &str, view_def: &ViewDefinition) -> Result<ParsedView> {
let sql = view_def.query.trim().to_string();
let clean_sql = replace_params(&sql);
let dialect = GenericDialect {};
let statements = Parser::parse_sql(&dialect, &clean_sql)
.map_err(|e| GroundDbError::SqlParse(format!("View '{name}': {e}")))?;
if statements.is_empty() {
return Err(GroundDbError::SqlParse(format!(
"View '{name}': no SQL statements found"
)));
}
let stmt = &statements[0];
let mut table_refs = Vec::new();
let mut columns = Vec::new();
let mut limit = None;
if let Statement::Query(query) = stmt {
extract_from_query(query, &mut table_refs, &mut columns, &mut limit);
}
let buffer_multiplier = view_def
.buffer
.as_ref()
.and_then(|b| {
b.strip_suffix('x')
.and_then(|n| n.parse::<f64>().ok())
})
.unwrap_or(1.0);
let is_query_template = view_def.view_type == Some(ViewType::Query);
let param_names = view_def
.params
.as_ref()
.map(|p| p.keys().cloned().collect())
.unwrap_or_default();
Ok(ParsedView {
name: name.to_string(),
original_sql: sql,
table_refs,
columns,
limit,
buffer_multiplier,
materialize: view_def.materialize,
is_query_template,
param_names,
})
}
fn replace_params(sql: &str) -> String {
let mut result = String::new();
let mut chars = sql.chars().peekable();
while let Some(c) = chars.next() {
if c == ':' {
if chars.peek().map(|ch| ch.is_alphabetic() || *ch == '_').unwrap_or(false) {
while chars
.peek()
.map(|ch| ch.is_alphanumeric() || *ch == '_')
.unwrap_or(false)
{
chars.next();
}
result.push_str("NULL");
} else {
result.push(c);
}
} else {
result.push(c);
}
}
result
}
fn extract_from_query(
query: &Query,
table_refs: &mut Vec<TableRef>,
columns: &mut Vec<ViewColumn>,
limit: &mut Option<u64>,
) {
if let SetExpr::Select(select) = query.body.as_ref() {
extract_from_select(select, table_refs, columns);
}
if let Some(expr) = &query.limit {
if let Expr::Value(sqlparser::ast::Value::Number(n, _)) = expr {
if let Ok(l) = n.parse::<u64>() {
*limit = Some(l);
}
}
}
}
fn extract_from_select(
select: &Select,
table_refs: &mut Vec<TableRef>,
columns: &mut Vec<ViewColumn>,
) {
for table in &select.from {
extract_from_table_with_joins(table, table_refs);
}
for item in &select.projection {
match item {
SelectItem::UnnamedExpr(expr) => {
let (col_name, source_col, source_field) = extract_column_info(expr);
columns.push(ViewColumn {
name: col_name,
source_collection: source_col,
source_field,
});
}
SelectItem::ExprWithAlias { expr, alias } => {
let (_, source_col, source_field) = extract_column_info(expr);
columns.push(ViewColumn {
name: alias.value.clone(),
source_collection: source_col,
source_field,
});
}
SelectItem::Wildcard(_) => {
columns.push(ViewColumn {
name: "*".to_string(),
source_collection: None,
source_field: None,
});
}
_ => {}
}
}
}
fn extract_from_table_with_joins(
table_with_joins: &TableWithJoins,
table_refs: &mut Vec<TableRef>,
) {
extract_table_name(&table_with_joins.relation, table_refs);
for join in &table_with_joins.joins {
extract_table_name(&join.relation, table_refs);
}
}
fn extract_table_name(
factor: &TableFactor,
table_refs: &mut Vec<TableRef>,
) {
if let TableFactor::Table { name, alias, .. } = factor {
let table_name = name.0.last().map(|i| i.value.clone()).unwrap_or_default();
if !table_name.is_empty() {
let alias_name = alias.as_ref().map(|a| a.name.value.clone());
table_refs.push(TableRef {
collection: table_name,
alias: alias_name,
});
}
}
}
fn extract_column_info(expr: &Expr) -> (String, Option<String>, Option<String>) {
match expr {
Expr::Identifier(ident) => (ident.value.clone(), None, Some(ident.value.clone())),
Expr::CompoundIdentifier(parts) => {
if parts.len() == 2 {
(
parts[1].value.clone(),
Some(parts[0].value.clone()),
Some(parts[1].value.clone()),
)
} else {
let name = parts.last().map(|p| p.value.clone()).unwrap_or_default();
(name, None, None)
}
}
_ => (format!("{expr}"), None, None),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::parse_schema_str;
fn test_schema() -> SchemaDefinition {
parse_schema_str(
r#"
collections:
users:
path: "users/{name}.md"
fields:
name: { type: string, required: true }
email: { type: string, required: true }
role: { type: string, enum: [admin, member, guest], default: member }
posts:
path: "posts/{status}/{date:YYYY-MM-DD}-{title}.md"
fields:
title: { type: string, required: true }
author_id: { type: ref, target: users, required: true }
date: { type: date, required: true }
status: { type: string, enum: [draft, published, archived], default: draft }
content: true
views:
post_feed:
query: |
SELECT p.title, p.date, u.name AS author_name
FROM posts p
JOIN users u ON p.author_id = u.id
WHERE p.status = 'published'
ORDER BY p.date DESC
LIMIT 100
materialize: true
buffer: 2x
user_lookup:
query: |
SELECT id, name, email, role
FROM users
ORDER BY name ASC
materialize: true
post_comments:
type: query
query: |
SELECT c.id, c.created_at
FROM posts c
WHERE c.id = :post_id
ORDER BY c.created_at ASC
params:
post_id: { type: string }
"#,
)
.unwrap()
}
#[test]
fn test_view_engine_creation() {
let schema = test_schema();
let engine = ViewEngine::new(&schema).unwrap();
assert_eq!(engine.views.len(), 3);
assert!(engine.views.contains_key("post_feed"));
assert!(engine.views.contains_key("user_lookup"));
assert!(engine.views.contains_key("post_comments"));
}
#[test]
fn test_post_feed_view_parsing() {
let schema = test_schema();
let engine = ViewEngine::new(&schema).unwrap();
let feed = engine.get_view("post_feed").unwrap();
let feed_collections = feed.referenced_collections();
assert!(feed_collections.contains("posts"));
assert!(feed_collections.contains("users"));
assert_eq!(feed.limit, Some(100));
assert_eq!(feed.buffer_multiplier, 2.0);
assert!(feed.materialize);
assert!(!feed.is_query_template);
assert_eq!(feed.columns.len(), 3);
}
#[test]
fn test_user_lookup_view_parsing() {
let schema = test_schema();
let engine = ViewEngine::new(&schema).unwrap();
let lookup = engine.get_view("user_lookup").unwrap();
let lookup_collections = lookup.referenced_collections();
assert!(lookup_collections.contains("users"));
assert_eq!(lookup_collections.len(), 1);
assert!(lookup.materialize);
assert_eq!(lookup.limit, None);
}
#[test]
fn test_query_template_parsing() {
let schema = test_schema();
let engine = ViewEngine::new(&schema).unwrap();
let comments = engine.get_view("post_comments").unwrap();
assert!(comments.is_query_template);
assert!(comments.param_names.contains(&"post_id".to_string()));
}
#[test]
fn test_affected_views() {
let schema = test_schema();
let engine = ViewEngine::new(&schema).unwrap();
let affected = engine.affected_views("posts");
assert!(affected.contains(&"post_feed"));
assert!(affected.contains(&"post_comments"));
let affected_users = engine.affected_views("users");
assert!(affected_users.contains(&"post_feed"));
assert!(affected_users.contains(&"user_lookup"));
}
#[test]
fn test_replace_params() {
let sql = "SELECT * FROM posts WHERE id = :post_id AND status = :status";
let cleaned = replace_params(sql);
assert_eq!(
cleaned,
"SELECT * FROM posts WHERE id = NULL AND status = NULL"
);
}
#[test]
fn test_rewrite_simple_select() {
let schema = test_schema();
let engine = ViewEngine::new(&schema).unwrap();
let view = engine.get_view("user_lookup").unwrap();
let rewritten = rewrite_view_sql(view, &schema).unwrap();
assert!(rewritten.sql.contains("WITH users AS"));
assert!(rewritten.sql.contains("json_extract(data_json, '$.name') AS name"));
assert!(rewritten.sql.contains("json_extract(data_json, '$.email') AS email"));
assert!(rewritten.sql.contains("json_extract(data_json, '$.role') AS role"));
assert!(rewritten.sql.contains("WHERE collection = 'users'"));
assert!(rewritten.sql.contains("id"));
assert!(rewritten.sql.contains("created_at"));
assert!(rewritten.sql.contains("modified_at"));
assert!(rewritten.buffer_limit.is_none());
assert!(rewritten.original_limit.is_none());
}
#[test]
fn test_rewrite_join_query() {
let schema = test_schema();
let engine = ViewEngine::new(&schema).unwrap();
let view = engine.get_view("post_feed").unwrap();
let rewritten = rewrite_view_sql(view, &schema).unwrap();
assert!(rewritten.sql.contains("posts AS"));
assert!(rewritten.sql.contains("users AS"));
assert!(rewritten.sql.contains("JOIN"));
assert!(rewritten.sql.contains("p.author_id = u.id"));
assert!(rewritten.sql.contains("p.status = 'published'"));
assert!(rewritten.sql.contains("ORDER BY p.date DESC"));
assert_eq!(rewritten.buffer_limit, Some(200));
assert_eq!(rewritten.original_limit, Some(100));
}
#[test]
fn test_rewrite_preserves_implicit_fields() {
let schema = test_schema();
let engine = ViewEngine::new(&schema).unwrap();
let view = engine.get_view("user_lookup").unwrap();
let rewritten = rewrite_view_sql(view, &schema).unwrap();
let cte_start = rewritten.sql.find("users AS").unwrap();
let cte_section = &rewritten.sql[cte_start..];
assert!(!cte_section.contains("json_extract(data_json, '$.id')"));
assert!(!cte_section.contains("json_extract(data_json, '$.created_at')"));
}
#[test]
fn test_rewrite_content_collection() {
let schema = test_schema();
let engine = ViewEngine::new(&schema).unwrap();
let view = engine.get_view("post_feed").unwrap();
let rewritten = rewrite_view_sql(view, &schema).unwrap();
let posts_cte_start = rewritten.sql.find("posts AS").unwrap();
let posts_section = &rewritten.sql[posts_cte_start..];
assert!(posts_section.contains("content_text AS content"));
}
#[test]
fn test_rewrite_parameterized_query() {
let schema = test_schema();
let engine = ViewEngine::new(&schema).unwrap();
let view = engine.get_view("post_comments").unwrap();
let rewritten = rewrite_view_sql(view, &schema).unwrap();
assert!(rewritten.sql.contains(":post_id"));
assert!(rewritten.param_names.contains(&"post_id".to_string()));
}
#[test]
fn test_rewrite_unknown_collection_errors() {
let schema = test_schema();
let parsed = ParsedView {
name: "bad_view".to_string(),
original_sql: "SELECT * FROM nonexistent".to_string(),
table_refs: vec![TableRef {
collection: "nonexistent".to_string(),
alias: None,
}],
columns: vec![],
limit: None,
buffer_multiplier: 1.0,
materialize: false,
is_query_template: false,
param_names: vec![],
};
let result = rewrite_view_sql(&parsed, &schema);
assert!(result.is_err());
}
}