use std::{collections::HashMap, ffi::OsString, fs, io::BufWriter, sync::{Arc, RwLock}};
use serde_json::{Map, Value};
use crate::{
database::{DbConfig, DbCollection, SchemaProvider},
executor::plan_executor::{Executor, PlanExecutor},
parser::{
aggregators_helper::AggregateRegistry,
analyzer::{AnalysisContext, AnalyzerError},
ast::Query
},
planner::plan_builder::PlanBuilder
};
pub(crate) type ProtectedDb = Arc<RwLock<InternalDb>>;
#[derive(Default)]
pub(crate) struct InternalDb {
config: DbConfig,
collections: HashMap<String, Arc<DbCollection>>,
}
impl InternalDb {
pub fn into_protected(self) -> ProtectedDb {
Arc::new(RwLock::new(self))
}
fn new_db() -> Self {
Self::new_db_with_config(DbConfig::default())
}
fn new_db_with_config(config: DbConfig) -> Self {
Self {
config,
collections: HashMap::new(),
}
}
pub fn create(&mut self, coll_name: &str) -> Arc<DbCollection> {
self.create_with_config(coll_name, self.config.clone())
}
pub fn create_with_config(&mut self, coll_name: &str, config: DbConfig) -> Arc<DbCollection> {
let collection = Arc::new(DbCollection::new_coll(coll_name, config));
self.collections.insert(coll_name.to_ascii_lowercase(), Arc::clone(&collection));
collection
}
pub fn get(&self, col_name: &str) -> Option<Arc<DbCollection>> {
self.collections.get(&col_name.to_ascii_lowercase()).map(Arc::clone)
}
pub fn list_collections(&self) -> Vec<String> {
self.collections.keys().cloned().collect::<Vec<_>>()
}
pub fn drop_collection(&mut self, col_name: &str) -> bool {
self.collections.remove(&col_name.to_ascii_lowercase()).is_some()
}
pub fn clear(&mut self) {
self.collections.clear()
}
pub fn load_from_json(&mut self, json_value: Value, keep: bool) -> Result<usize, String> {
let Value::Object(object) = json_value else {
return Err("Informed JSON does not contain a JSON object in the root".to_string());
};
let mut total = 0;
for (name, items) in object {
let collection = self.create(&name);
collection.load_from_json(items, keep)?;
total += 1;
}
Ok(total)
}
pub fn load_from_file(&mut self, file_path: &OsString) -> Result<String, String> {
let file_path_lossy = file_path.to_string_lossy();
let file_content = fs::read_to_string(file_path)
.map_err(|_| format!("Could not read file {}", file_path_lossy))?;
let json_value = serde_json::from_str::<Value>(&file_content)
.map_err(|_| format!("File {} does not contain valid JSON", file_path_lossy))?;
match self.load_from_json(json_value, false) {
Ok(loaded_collections) => Ok(format!("✔️ Loaded {} initial collections from {}", loaded_collections, file_path_lossy)),
Err(error) => Err(format!("Error to process the file {}. Details: {}", file_path_lossy, error)),
}
}
pub fn write_to_json(&self) -> Value {
let mut collections: Map<String, Value> = Map::new();
for (name, collection) in &self.collections {
let values = collection.get_all();
collections.insert(name.clone(), Value::Array(values));
}
Value::Object(collections)
}
pub fn write_to_file(&self, file_path: &OsString) -> Result<(), String> {
let file = std::fs::File::create(file_path).expect("Failed to create json file");
let mut w = BufWriter::new(file);
let data = self.write_to_json();
serde_json::to_writer_pretty(&mut w, &data).expect("Failed to write to a json file");
Ok(())
}
}
pub struct Db {
pub(crate) internal_db: ProtectedDb,
}
impl Db {
pub fn new_db() -> Self {
Self{
internal_db: InternalDb::new_db().into_protected(),
}
}
pub fn new_db_with_config(config: DbConfig) -> Self {
Self{
internal_db: InternalDb::new_db_with_config(config).into_protected(),
}
}
pub fn create(&self, coll_name: &str) -> Arc<DbCollection> {
self.internal_db.write().unwrap().create(coll_name)
}
pub fn create_with_config(&self, coll_name: &str, config: DbConfig) -> Arc<DbCollection> {
self.internal_db.write().unwrap().create_with_config(coll_name, config)
}
pub fn get(&self, col_name: &str) -> Option<Arc<DbCollection>> {
self.internal_db.read().unwrap().get(col_name)
}
pub fn list_collections(&self) -> Vec<String> {
self.internal_db.read().unwrap().list_collections()
}
pub fn drop_collection(&self, col_name: &str) -> bool {
self.internal_db.write().unwrap().drop_collection(col_name)
}
pub fn clear(&self) {
self.internal_db.write().unwrap().clear();
}
pub fn get_config(&self) -> DbConfig {
self.internal_db.read().unwrap().config.clone()
}
pub fn load_from_json(&self, json_value: Value, keep: bool) -> Result<usize, String> {
self.internal_db.write().unwrap().load_from_json(json_value, keep)
}
pub fn load_from_file(&self, file_path: &OsString) -> Result<String, String> {
self.internal_db.write().unwrap().load_from_file(file_path)
}
pub fn write_to_json(&self) -> Value {
self.internal_db.read().unwrap().write_to_json()
}
pub fn write_to_file(&self, file_path: &OsString) -> Result<(), String> {
self.internal_db.read().unwrap().write_to_file(file_path)
}
pub fn query(&self, sql: &str) -> Result<Vec<serde_json::Value>, AnalyzerError> {
let q = Query::try_from(sql)
.map_err(|e| AnalyzerError::Other(format!("parse error: {e}")))?;
let aggregates = AggregateRegistry::default_aggregate_registry();
let analyzed = AnalysisContext::analyze_query(&q, self, &aggregates, Value::Null)?;
let plan = PlanBuilder::from_analyzed(&analyzed)?;
let exec = PlanExecutor::new(plan);
exec.execute(self)
}
pub fn query_with_args(&self, sql: &str, args: Value) -> Result<Vec<Value>, AnalyzerError> {
let q = Query::try_from(sql)
.map_err(|e| AnalyzerError::Other(format!("parse error: {e}")))?;
let aggregates = AggregateRegistry::default_aggregate_registry();
let analyzed = AnalysisContext::analyze_query(&q, self, &aggregates, args)?;
let plan = PlanBuilder::from_analyzed(&analyzed)?;
let exec = PlanExecutor::new(plan);
exec.execute(self)
}
}
impl SchemaProvider for Db {
fn schema_of(&self, collection_ref: &str) -> Option<super::SchemaDict> {
let guard = self.internal_db.read().ok()?;
let coll = guard.get(collection_ref)?;
coll.schema()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use crate::database::{DbConfig, IdType};
fn mk_db() -> Db {
let db = Db::new_db_with_config(DbConfig { id_type: IdType::None, id_key: "id".into() });
let t = db.create("t");
t.add_batch(json!([
{ "id": 1, "cat": "a", "amt": 10.0 },
{ "id": 2, "cat": "a", "amt": 15.0 },
{ "id": 3, "cat": "b", "amt": 7.5 },
{ "id": 4, "cat": "b", "amt": null },
{ "id": 5, "cat": "a", "amt": 22.5 }
]));
db
}
#[test]
fn db_runner_full_pipeline_group_by_having() {
let db = mk_db();
let sql = r#"
SELECT t.cat AS cat, SUM(t.amt) AS total
FROM t
WHERE t.id > 1
GROUP BY t.cat
HAVING SUM(t.amt) > 20
ORDER BY t.cat
LIMIT 10
"#;
let rows = db.query(sql).expect("query should succeed");
assert_eq!(rows.len(), 1);
let obj = rows[0].as_object().unwrap();
assert_eq!(obj.get("cat").unwrap(), "a");
let total = obj.get("total").unwrap().as_f64().unwrap();
assert!((total - 37.5).abs() < 1e-9);
}
#[test]
fn db_runner_supports_from_list_cross_join() {
let db = mk_db();
let sql = r#"
SELECT COUNT(*) AS n
FROM t a, t b
"#;
let rows = db.query(sql).expect("query should succeed");
assert_eq!(rows.len(), 1);
assert_eq!(rows[0]["n"].as_i64().unwrap(), 25);
}
#[test]
fn db_runner_with_arg() {
let db = mk_db();
let sql = r#"
SELECT id, cat, amt
FROM t
WHERE id = ?
"#;
let rows = db.query_with_args(sql, json!(3)).expect("query should succeed");
assert_eq!(rows.len(), 1);
let obj = rows[0].as_object().unwrap();
assert_eq!(obj.get("id").unwrap(), 3);
assert_eq!(obj.get("cat").unwrap(), "b");
assert_eq!(obj.get("amt").unwrap(), 7.5);
}
#[test]
fn db_runner_with_args() {
let db = mk_db();
let sql = r#"
SELECT id, cat, amt
FROM t
WHERE id IN (?)
ORDER BY id
"#;
let rows = db.query_with_args(sql, json!([[2, 3]])).expect("query should succeed");
assert_eq!(rows.len(), 2);
let obj = rows[0].as_object().unwrap();
assert_eq!(obj.get("id").unwrap(), 2);
assert_eq!(obj.get("cat").unwrap(), "a");
assert_eq!(obj.get("amt").unwrap(), 15.0);
let obj = rows[1].as_object().unwrap();
assert_eq!(obj.get("id").unwrap(), 3);
assert_eq!(obj.get("cat").unwrap(), "b");
assert_eq!(obj.get("amt").unwrap(), 7.5);
}
#[test]
fn db_runner_in_with_empty_array_param_returns_no_rows() {
let db = mk_db();
let rows = db.query_with_args(
r#"
SELECT id FROM t
WHERE id IN (?)
"#,
serde_json::json!([[]]),
).expect("query should succeed");
assert!(rows.is_empty());
}
#[test]
fn db_runner_multiple_positional_params() {
let db = mk_db();
let rows = db.query_with_args(
r#"
SELECT id, cat
FROM t
WHERE id >= ? AND cat = ?
ORDER BY id
"#,
serde_json::json!([2, "a"]),
).expect("query should succeed");
let ids: Vec<i64> = rows.iter()
.map(|r| r["id"].as_i64().unwrap())
.collect();
assert_eq!(ids, vec![2, 5]);
}
#[test]
fn db_runner_param_in_function_and_order_by() {
let db = mk_db();
let sql = r#"
SELECT UPPER(cat) AS c
FROM t
WHERE cat = ?
ORDER BY c DESC
"#;
let rows = db.query_with_args(sql, serde_json::json!("a"))
.expect("query should succeed");
assert!(!rows.is_empty());
for r in rows {
assert_eq!(r["c"], serde_json::json!("A"));
}
}
#[test]
fn db_runner_in_with_mixed_literals_and_param_array() {
let db = mk_db();
let sql = r#"
SELECT id
FROM t
WHERE id IN (1, ?)
ORDER BY id
"#;
let rows = db.query_with_args(sql, serde_json::json!([[2, 3]]))
.expect("query should succeed");
let ids: Vec<i64> = rows.iter().map(|r| r["id"].as_i64().unwrap()).collect();
assert_eq!(ids, vec![1, 2, 3]);
}
#[test]
fn db_runner_insensitive_case() {
let db = mk_db();
let sql = r#"
SELECT COUNT(*) AS n
FROM t a, T b
"#;
let rows = db.query(sql).expect("query should succeed");
assert_eq!(rows.len(), 1);
assert_eq!(rows[0]["n"].as_i64().unwrap(), 25);
}
#[test]
fn test_db_load_from_json() {
use serde_json::json;
let input = json!({ "a": [{ "id": 1, "x": "foo" }] });
let db = Db::new_db_with_config(DbConfig::int("id"));
let count = db.load_from_json(input.clone(), false).unwrap();
assert_eq!(count, 1);
let out = db.write_to_json();
let arr = out.get("a").unwrap().as_array().unwrap();
assert_eq!(arr.len(), 1);
assert_eq!(arr[0].get("x").unwrap(), "foo");
}
#[test]
fn test_db_load_from_file() {
use tempfile::TempDir;
use std::{fs::File, io::Write, ffi::OsString};
use serde_json::json;
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("db.json");
let mut f = File::create(&path).unwrap();
let data = json!({ "b": [{ "id": 2, "y": 42 }] });
f.write_all(data.to_string().as_bytes()).unwrap();
let os_path = OsString::from(path.to_string_lossy().into_owned());
let db = Db::new_db_with_config(DbConfig::int("id"));
let msg = db.load_from_file(&os_path).unwrap();
assert!(msg.contains("Loaded 1 initial collections"));
let out = db.write_to_json();
let arr = out.get("b").unwrap().as_array().unwrap();
assert_eq!(arr[0].get("y").unwrap(), 42);
}
#[test]
fn test_db_write_to_json() {
use serde_json::json;
let db = Db::new_db_with_config(DbConfig::int("id"));
let coll = db.create("z");
coll.add(json!({ "key": "value" }));
let out = db.write_to_json();
let arr = out.get("z").unwrap().as_array().unwrap();
assert_eq!(arr.len(), 1);
assert_eq!(arr[0].get("key").unwrap(), "value");
}
#[test]
fn test_db_write_to_file() {
use tempfile::TempDir;
use std::{ffi::OsString, fs};
use serde_json::json;
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("out.json");
let os_path = OsString::from(path.to_string_lossy().into_owned());
let db = Db::new_db_with_config(DbConfig::int("id"));
let coll = db.create("c");
coll.add(json!({ "n": 3 }));
assert!(db.write_to_file(&os_path).is_ok());
let content = fs::read_to_string(path).unwrap();
let v: serde_json::Value = serde_json::from_str(&content).unwrap();
let arr = v.get("c").unwrap().as_array().unwrap();
assert_eq!(arr[0].get("n").unwrap(), 3);
}
}