use async_trait::async_trait;
use futures_core::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt::Debug;
use std::pin::Pin;
use crate::prelude::*;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum LockMode {
Soft,
Hard,
}
#[derive(Debug, Clone)]
pub struct LockInfo {
pub user_id: Box<str>,
pub mode: LockMode,
pub acquired_at: u64,
pub ttl_secs: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "op", rename_all = "camelCase")]
pub enum AggregateOp {
Sum { field: String },
Avg { field: String },
Min { field: String },
Max { field: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AggregateOptions {
pub group_by: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub ops: Vec<AggregateOp>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct QueryFilter {
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub equals: HashMap<String, Value>,
#[serde(default, skip_serializing_if = "HashMap::is_empty", rename = "notEquals")]
pub not_equals: HashMap<String, Value>,
#[serde(default, skip_serializing_if = "HashMap::is_empty", rename = "greaterThan")]
pub greater_than: HashMap<String, Value>,
#[serde(default, skip_serializing_if = "HashMap::is_empty", rename = "greaterThanOrEqual")]
pub greater_than_or_equal: HashMap<String, Value>,
#[serde(default, skip_serializing_if = "HashMap::is_empty", rename = "lessThan")]
pub less_than: HashMap<String, Value>,
#[serde(default, skip_serializing_if = "HashMap::is_empty", rename = "lessThanOrEqual")]
pub less_than_or_equal: HashMap<String, Value>,
#[serde(default, skip_serializing_if = "HashMap::is_empty", rename = "inArray")]
pub in_array: HashMap<String, Vec<Value>>,
#[serde(default, skip_serializing_if = "HashMap::is_empty", rename = "arrayContains")]
pub array_contains: HashMap<String, Value>,
#[serde(default, skip_serializing_if = "HashMap::is_empty", rename = "notInArray")]
pub not_in_array: HashMap<String, Vec<Value>>,
#[serde(default, skip_serializing_if = "HashMap::is_empty", rename = "arrayContainsAny")]
pub array_contains_any: HashMap<String, Vec<Value>>,
#[serde(default, skip_serializing_if = "HashMap::is_empty", rename = "arrayContainsAll")]
pub array_contains_all: HashMap<String, Vec<Value>>,
}
impl QueryFilter {
pub fn new() -> Self {
Self::default()
}
pub fn equals_one(field: impl Into<String>, value: Value) -> Self {
let mut equals = HashMap::new();
equals.insert(field.into(), value);
Self { equals, ..Default::default() }
}
pub fn with_equals(mut self, field: impl Into<String>, value: Value) -> Self {
self.equals.insert(field.into(), value);
self
}
pub fn with_not_equals(mut self, field: impl Into<String>, value: Value) -> Self {
self.not_equals.insert(field.into(), value);
self
}
pub fn with_greater_than(mut self, field: impl Into<String>, value: Value) -> Self {
self.greater_than.insert(field.into(), value);
self
}
pub fn with_greater_than_or_equal(mut self, field: impl Into<String>, value: Value) -> Self {
self.greater_than_or_equal.insert(field.into(), value);
self
}
pub fn with_less_than(mut self, field: impl Into<String>, value: Value) -> Self {
self.less_than.insert(field.into(), value);
self
}
pub fn with_less_than_or_equal(mut self, field: impl Into<String>, value: Value) -> Self {
self.less_than_or_equal.insert(field.into(), value);
self
}
pub fn with_in_array(mut self, field: impl Into<String>, values: Vec<Value>) -> Self {
self.in_array.insert(field.into(), values);
self
}
pub fn with_array_contains(mut self, field: impl Into<String>, value: Value) -> Self {
self.array_contains.insert(field.into(), value);
self
}
pub fn with_not_in_array(mut self, field: impl Into<String>, values: Vec<Value>) -> Self {
self.not_in_array.insert(field.into(), values);
self
}
pub fn with_array_contains_any(mut self, field: impl Into<String>, values: Vec<Value>) -> Self {
self.array_contains_any.insert(field.into(), values);
self
}
pub fn with_array_contains_all(mut self, field: impl Into<String>, values: Vec<Value>) -> Self {
self.array_contains_all.insert(field.into(), values);
self
}
pub fn matches(&self, doc: &Value) -> bool {
for (field, expected) in &self.equals {
if doc.get(field) != Some(expected) {
return false;
}
}
for (field, expected) in &self.not_equals {
if doc.get(field) == Some(expected) {
return false;
}
}
for (field, threshold) in &self.greater_than {
match doc.get(field) {
Some(actual)
if compare_json_values(Some(actual), Some(threshold))
== std::cmp::Ordering::Greater => {}
_ => return false,
}
}
for (field, threshold) in &self.greater_than_or_equal {
match doc.get(field) {
Some(actual) => {
let ord = compare_json_values(Some(actual), Some(threshold));
if ord != std::cmp::Ordering::Greater && ord != std::cmp::Ordering::Equal {
return false;
}
}
_ => return false,
}
}
for (field, threshold) in &self.less_than {
match doc.get(field) {
Some(actual)
if compare_json_values(Some(actual), Some(threshold))
== std::cmp::Ordering::Less => {}
_ => return false,
}
}
for (field, threshold) in &self.less_than_or_equal {
match doc.get(field) {
Some(actual) => {
let ord = compare_json_values(Some(actual), Some(threshold));
if ord != std::cmp::Ordering::Less && ord != std::cmp::Ordering::Equal {
return false;
}
}
_ => return false,
}
}
for (field, allowed_values) in &self.in_array {
match doc.get(field) {
Some(actual) if allowed_values.contains(actual) => {}
_ => return false,
}
}
for (field, required_value) in &self.array_contains {
match doc.get(field) {
Some(Value::Array(arr)) if arr.contains(required_value) => {}
_ => return false,
}
}
for (field, excluded_values) in &self.not_in_array {
if let Some(actual) = doc.get(field)
&& excluded_values.contains(actual)
{
return false;
}
}
for (field, candidate_values) in &self.array_contains_any {
match doc.get(field) {
Some(Value::Array(arr)) if candidate_values.iter().any(|v| arr.contains(v)) => {}
_ => return false,
}
}
for (field, required_values) in &self.array_contains_all {
match doc.get(field) {
Some(Value::Array(arr)) if required_values.iter().all(|v| arr.contains(v)) => {}
_ => return false,
}
}
true
}
pub fn is_empty(&self) -> bool {
self.equals.is_empty()
&& self.not_equals.is_empty()
&& self.greater_than.is_empty()
&& self.greater_than_or_equal.is_empty()
&& self.less_than.is_empty()
&& self.less_than_or_equal.is_empty()
&& self.in_array.is_empty()
&& self.array_contains.is_empty()
&& self.not_in_array.is_empty()
&& self.array_contains_any.is_empty()
&& self.array_contains_all.is_empty()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SortField {
pub field: String,
pub ascending: bool,
}
impl SortField {
pub fn asc(field: impl Into<String>) -> Self {
Self { field: field.into(), ascending: true }
}
pub fn desc(field: impl Into<String>) -> Self {
Self { field: field.into(), ascending: false }
}
}
#[derive(Debug, Clone, Default)]
pub struct QueryOptions {
pub filter: Option<QueryFilter>,
pub sort: Option<Vec<SortField>>,
pub limit: Option<u32>,
pub offset: Option<u32>,
pub aggregate: Option<AggregateOptions>,
}
impl QueryOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_filter(mut self, filter: QueryFilter) -> Self {
self.filter = Some(filter);
self
}
pub fn with_sort(mut self, sort: Vec<SortField>) -> Self {
self.sort = Some(sort);
self
}
pub fn with_limit(mut self, limit: u32) -> Self {
self.limit = Some(limit);
self
}
pub fn with_offset(mut self, offset: u32) -> Self {
self.offset = Some(offset);
self
}
pub fn with_aggregate(mut self, aggregate: AggregateOptions) -> Self {
self.aggregate = Some(aggregate);
self
}
}
#[derive(Debug, Clone)]
pub struct SubscriptionOptions {
pub path: Box<str>,
pub filter: Option<QueryFilter>,
}
impl SubscriptionOptions {
pub fn all(path: impl Into<Box<str>>) -> Self {
Self { path: path.into(), filter: None }
}
pub fn filtered(path: impl Into<Box<str>>, filter: QueryFilter) -> Self {
Self { path: path.into(), filter: Some(filter) }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "action", rename_all = "camelCase")]
pub enum ChangeEvent {
Create {
path: Box<str>,
data: Value,
},
Update {
path: Box<str>,
data: Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
old_data: Option<Value>,
},
Delete {
path: Box<str>,
#[serde(default, skip_serializing_if = "Option::is_none")]
old_data: Option<Value>,
},
Lock {
path: Box<str>,
data: Value,
},
Unlock {
path: Box<str>,
data: Value,
},
Ready {
path: Box<str>,
#[serde(default, skip_serializing_if = "Option::is_none")]
data: Option<Value>,
},
}
impl ChangeEvent {
pub fn path(&self) -> &str {
match self {
ChangeEvent::Create { path, .. }
| ChangeEvent::Update { path, .. }
| ChangeEvent::Delete { path, .. }
| ChangeEvent::Lock { path, .. }
| ChangeEvent::Unlock { path, .. }
| ChangeEvent::Ready { path, .. } => path,
}
}
pub fn id(&self) -> Option<&str> {
self.path().split('/').next_back()
}
pub fn parent_path(&self) -> Option<&str> {
let path = self.path();
path.rfind('/').map(|pos| &path[..pos])
}
pub fn data(&self) -> Option<&Value> {
match self {
ChangeEvent::Create { data, .. }
| ChangeEvent::Update { data, .. }
| ChangeEvent::Lock { data, .. }
| ChangeEvent::Unlock { data, .. } => Some(data),
ChangeEvent::Delete { .. } => None,
ChangeEvent::Ready { data, .. } => data.as_ref(),
}
}
pub fn is_create(&self) -> bool {
matches!(self, ChangeEvent::Create { .. })
}
pub fn is_update(&self) -> bool {
matches!(self, ChangeEvent::Update { .. })
}
pub fn is_delete(&self) -> bool {
matches!(self, ChangeEvent::Delete { .. })
}
}
fn compare_json_values(a: Option<&Value>, b: Option<&Value>) -> std::cmp::Ordering {
match (a, b) {
(None, None) => std::cmp::Ordering::Equal,
(None, Some(_)) => std::cmp::Ordering::Less,
(Some(_), None) => std::cmp::Ordering::Greater,
(Some(Value::Number(a)), Some(Value::Number(b))) => {
a.as_f64().partial_cmp(&b.as_f64()).unwrap_or(std::cmp::Ordering::Equal)
}
(Some(Value::String(a)), Some(Value::String(b))) => a.cmp(b),
(Some(Value::Bool(a)), Some(Value::Bool(b))) => a.cmp(b),
(Some(a), Some(b)) => a.to_string().cmp(&b.to_string()),
}
}
pub fn value_to_group_string(value: &Value) -> String {
match value {
Value::String(s) => s.clone(),
Value::Number(n) => n.to_string(),
Value::Bool(b) => b.to_string(),
Value::Null => "null".to_string(),
_ => serde_json::to_string(value).unwrap_or_default(),
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DbStats {
pub size_bytes: u64,
pub record_count: u64,
pub table_count: u32,
}
#[async_trait]
pub trait Transaction: Send + Sync {
async fn create(&mut self, path: &str, data: Value) -> ClResult<Box<str>>;
async fn update(&mut self, path: &str, data: Value) -> ClResult<()>;
async fn delete(&mut self, path: &str) -> ClResult<()>;
async fn get(&self, path: &str) -> ClResult<Option<Value>>;
async fn commit(&mut self) -> ClResult<()>;
async fn rollback(&mut self) -> ClResult<()>;
}
#[async_trait]
pub trait RtdbAdapter: Debug + Send + Sync {
async fn transaction(&self, tn_id: TnId, db_id: &str) -> ClResult<Box<dyn Transaction>>;
async fn close_db(&self, tn_id: TnId, db_id: &str) -> ClResult<()>;
async fn query(
&self,
tn_id: TnId,
db_id: &str,
path: &str,
opts: QueryOptions,
) -> ClResult<Vec<Value>>;
async fn get(&self, tn_id: TnId, db_id: &str, path: &str) -> ClResult<Option<Value>>;
async fn subscribe(
&self,
tn_id: TnId,
db_id: &str,
opts: SubscriptionOptions,
) -> ClResult<Pin<Box<dyn Stream<Item = ChangeEvent> + Send>>>;
async fn create_index(&self, tn_id: TnId, db_id: &str, path: &str, field: &str)
-> ClResult<()>;
async fn stats(&self, tn_id: TnId, db_id: &str) -> ClResult<DbStats>;
async fn export_all(&self, tn_id: TnId, db_id: &str) -> ClResult<Vec<(Box<str>, Value)>>;
async fn acquire_lock(
&self,
tn_id: TnId,
db_id: &str,
path: &str,
user_id: &str,
mode: LockMode,
conn_id: &str,
) -> ClResult<Option<LockInfo>>;
async fn release_lock(
&self,
tn_id: TnId,
db_id: &str,
path: &str,
user_id: &str,
conn_id: &str,
) -> ClResult<()>;
async fn check_lock(&self, tn_id: TnId, db_id: &str, path: &str) -> ClResult<Option<LockInfo>>;
async fn release_all_locks(
&self,
tn_id: TnId,
db_id: &str,
user_id: &str,
conn_id: &str,
) -> ClResult<()>;
async fn delete_tenant_databases(&self, tn_id: TnId) -> ClResult<()>;
}