use std::future::Future;
use std::marker::PhantomData;
use crate::{
error::TitoError,
query::IndexQueryBuilder,
types::{
FieldValue, ReverseIndex, TitoCursor, TitoEngine, TitoFindPayload, TitoKvPair,
TitoModelOptions, TitoPaginated, TitoRelationshipConfig, TitoScanPayload, TitoTransaction,
},
utils::{next_string_lexicographically, previous_string_lexicographically},
};
use base64::{engine::general_purpose, Engine};
use chrono::Utc;
use serde::{de::DeserializeOwned, Serialize};
use serde_json::Value;
#[derive(Clone)]
pub struct TitoModel<E: TitoEngine, T> {
pub engine: E,
pub partition_count: u32,
_phantom: PhantomData<T>,
}
pub struct SetBuilder<'a, E: TitoEngine, T: crate::types::TitoModelConstraints> {
model: &'a TitoModel<E, T>,
payload: T,
timestamps: bool,
}
impl<'a, E: TitoEngine, T: crate::types::TitoModelConstraints> SetBuilder<'a, E, T> {
pub fn timestamps(mut self, timestamps: bool) -> Self {
self.timestamps = timestamps;
self
}
pub async fn execute(self, tx: &E::Transaction) -> Result<T, TitoError> {
self.model
.set_internal(self.payload, self.timestamps, tx)
.await
}
}
pub struct GetBuilder<'a, E: TitoEngine, T: crate::types::TitoModelConstraints> {
model: &'a TitoModel<E, T>,
id: String,
rels: Vec<String>,
}
impl<'a, E: TitoEngine, T: crate::types::TitoModelConstraints> GetBuilder<'a, E, T> {
pub fn relationship(mut self, rel: impl Into<String>) -> Self {
self.rels.push(rel.into());
self
}
pub async fn execute(self, tx: Option<&E::Transaction>) -> Result<T, TitoError> {
self.model.get_internal(&self.id, self.rels, tx).await
}
}
pub struct GetManyBuilder<'a, E: TitoEngine, T: crate::types::TitoModelConstraints> {
model: &'a TitoModel<E, T>,
ids: Vec<String>,
rels: Vec<String>,
}
impl<'a, E: TitoEngine, T: crate::types::TitoModelConstraints> GetManyBuilder<'a, E, T> {
pub fn relationship(mut self, rel: impl Into<String>) -> Self {
self.rels.push(rel.into());
self
}
pub async fn execute(self, tx: Option<&E::Transaction>) -> Result<Vec<T>, TitoError> {
self.model.get_many_internal(self.ids, self.rels, tx).await
}
}
impl<E: TitoEngine, T: crate::types::TitoModelConstraints> TitoModel<E, T> {
pub fn new(engine: E, options: TitoModelOptions) -> Self {
Self {
engine,
partition_count: options.partition_count,
_phantom: PhantomData,
}
}
pub fn relationships(&self) -> Vec<TitoRelationshipConfig> {
T::relationships()
}
pub fn get_table(&self) -> String {
T::key_prefix()
}
pub fn get_id_from_table(&self, key: String) -> String {
let parts: Vec<&str> = key.split(':').collect();
parts
.last()
.map(|last| last.to_string())
.unwrap_or_else(|| key)
}
pub fn query_by_index(&self, index: impl Into<String>) -> IndexQueryBuilder<E, T> {
IndexQueryBuilder::new(self.clone(), index.into())
}
fn decode_cursor(&self, cursor: String) -> Result<TitoCursor, TitoError> {
let cursor = general_purpose::STANDARD.decode(cursor).map_err(|_err| {
TitoError::DeserializationFailed("Failed to decode cursor".to_string())
})?;
if let Ok(value) = serde_json::from_slice::<TitoCursor>(&cursor) {
return Ok(value);
}
return Err(TitoError::DeserializationFailed(
"Failed to deserialize cursor".to_string(),
));
}
fn encode_cursors(&self, ids: Vec<Option<String>>) -> Result<String, TitoError> {
let tikv_cursor = TitoCursor { ids };
let json_bytes = serde_json::to_vec(&tikv_cursor).map_err(|_| {
TitoError::SerializationFailed("Failed to serialize cursor".to_string())
})?;
Ok(general_purpose::STANDARD.encode(&json_bytes))
}
pub async fn tx<F, Fut, R, Err>(&self, f: F) -> Result<R, Err>
where
F: FnOnce(E::Transaction) -> Fut + Clone + Send,
Fut: Future<Output = Result<R, Err>> + Send,
Err: From<TitoError> + Send + Sync + std::fmt::Debug,
R: Send,
{
self.engine.transaction(f).await
}
fn to_results(
&self,
items: impl IntoIterator<Item = TitoKvPair>,
) -> Result<Vec<(String, Value)>, TitoError> {
let mut results = vec![];
for (key_bytes, value_bytes) in items {
let key = match String::from_utf8(key_bytes) {
Ok(k) => k,
Err(_) => {
continue;
}
};
match serde_json::from_slice::<Value>(&value_bytes) {
Ok(value) => results.push((key, value)),
Err(_err) => continue,
}
}
Ok(results)
}
async fn get_raw(&self, key: &str, tx: &E::Transaction) -> Result<(String, Value), TitoError> {
let key = key.to_string();
match tx.get(key.clone()).await {
Ok(Some(value)) => match serde_json::from_slice::<Value>(&value) {
Ok(value) => Ok((key, value)),
Err(e) => Err(TitoError::NotFound(format!(
"Failed to deserialize value for key '{}': {}",
key, e
))),
},
Ok(None) => Err(TitoError::NotFound(format!(
"Key '{}' not found in database",
key
))),
Err(e) => Err(TitoError::NotFound(format!(
"Failed to get key '{}': {}",
key, e
))),
}
}
pub async fn get_key(&self, key: &str, tx: &E::Transaction) -> Result<Value, TitoError> {
let result = tx.get(key.to_string()).await?;
let result = result.ok_or(TitoError::NotFound("Not found".to_string()))?;
serde_json::from_slice::<Value>(&result)
.map_err(|_| TitoError::NotFound("Not found".to_string()))
}
async fn put_with_options<P>(
&self,
key: String,
payload: P,
timestamps: bool,
tx: &E::Transaction,
) -> Result<Value, TitoError>
where
P: Serialize + Unpin + std::marker::Send + Sync,
{
let mut value = serde_json::to_value(&payload)
.map_err(|e| TitoError::SerializationFailed(e.to_string()))?;
if timestamps {
if let serde_json::Value::Object(ref mut map) = value {
let now = Utc::now().timestamp();
let is_new = match tx.get(&key).await {
Ok(Some(_)) => false,
_ => true,
};
if is_new {
map.insert("created_at".to_string(), serde_json::json!(now));
}
map.insert("updated_at".to_string(), serde_json::json!(now));
}
}
let bytes = serde_json::to_vec(&value)
.map_err(|e| TitoError::SerializationFailed(e.to_string()))?;
tx.put(key, bytes).await?;
Ok(value)
}
pub async fn delete(&self, key: String, tx: &E::Transaction) -> Result<bool, TitoError> {
tx.delete(key).await?;
Ok(true)
}
pub fn to_paginated_items_with_cursor(
&self,
items: Vec<(String, Value)>,
cursor: String,
) -> Result<TitoPaginated<T>, TitoError> {
let mut results = vec![];
for item in items.into_iter() {
if let Ok(item) = serde_json::from_value::<T>(item.1) {
results.push(item);
}
}
let results = TitoPaginated::new(results, Some(cursor));
Ok(results)
}
pub fn to_paginated_items(
&self,
items: Vec<(String, Value)>,
has_more: bool,
) -> Result<TitoPaginated<T>, TitoError> {
let mut results = vec![];
let mut last_item: Option<String> = None;
for item in items.into_iter() {
last_item = Some(item.0.clone());
if let Ok(item) = serde_json::from_value::<T>(item.1) {
results.push(item);
}
}
let cursor = match (has_more, last_item) {
(true, Some(item)) => Some(self.encode_cursors(vec![Some(item)]).map_err(|e| {
TitoError::SerializationFailed(format!("Failed to encode cursor: {}", e))
})?),
_ => None,
};
let results = TitoPaginated::new(results, cursor);
Ok(results)
}
async fn get_reverse_index(
&self,
key: &str,
tx: &E::Transaction,
) -> Result<ReverseIndex, TitoError> {
let result = tx.get(key.to_string()).await?;
let result = result.ok_or(TitoError::NotFound(format!(
"Reverse index not found for key '{}'",
key
)))?;
serde_json::from_slice::<ReverseIndex>(&result).map_err(|e| {
TitoError::NotFound(format!(
"Failed to deserialize reverse index for key '{}': {}",
key, e
))
})
}
pub fn get_nested_values(&self, json: &Value, field_path: &str) -> Option<Vec<FieldValue>> {
let mut results = Vec::new();
let mut to_process = vec![(json.clone(), 0)];
let parts: Vec<&str> = field_path.split('.').collect();
while let Some((current_value, depth)) = to_process.pop() {
if depth == parts.len() {
if let Some(obj) = current_value.as_object() {
for (key, value) in obj.iter() {
results.push(FieldValue::HashMapEntry {
key: key.clone(),
value: value.clone(),
});
}
} else {
results.push(FieldValue::Simple(current_value));
}
continue;
}
match current_value.get(parts[depth]) {
Some(nested) => {
if nested.is_array() {
if let Some(array) = nested.as_array() {
if array.is_empty() {
return None;
}
for item in array {
to_process.push((item.clone(), depth + 1));
}
}
} else {
to_process.push((nested.clone(), depth + 1));
}
}
None => return None,
}
}
if results.is_empty() {
None
} else {
Some(results)
}
}
pub fn set(&self, payload: T) -> SetBuilder<'_, E, T> {
SetBuilder {
model: self,
payload,
timestamps: true,
}
}
async fn set_internal(
&self,
payload: T,
timestamps: bool,
tx: &E::Transaction,
) -> Result<T, TitoError>
where
T: serde::de::DeserializeOwned,
{
let raw_id = payload.id();
let id = format!("{}:{}", self.get_table(), raw_id);
self.clear_indexes(&raw_id, tx).await?;
let value = serde_json::to_value(&payload).map_err(|e| {
TitoError::SerializationFailed(format!("Failed to serialize payload: {}", e))
})?;
let stored_value = self
.put_with_options(id.clone(), &value, timestamps, tx)
.await?;
let all_index_data = self.get_index_keys(id.clone(), &payload, &stored_value)?;
let mut all_index_keys = vec![];
for data in all_index_data {
all_index_keys.push(data.0.clone());
self.put_with_options(data.0.clone(), &data.1, false, tx)
.await?;
}
let index_json_key = ReverseIndex {
value: all_index_keys,
};
let reverse_key = format!("reverse-index:{}", id);
self.put_with_options(reverse_key, index_json_key, false, tx)
.await?;
serde_json::from_value(stored_value).map_err(|e| {
TitoError::DeserializationFailed(format!("Failed to deserialize stored value: {}", e))
})
}
async fn get_one_with_tx(
&self,
id: &str,
rels: Vec<String>,
tx: &E::Transaction,
) -> Result<T, TitoError>
where
T: serde::de::DeserializeOwned,
{
let id = format!("{}:{}", self.get_table(), id);
let value = match self.get_raw(&id, tx).await {
Ok(value) => value,
Err(e) => {
return Err(TitoError::NotFound(format!(
"Failed to get record with id '{}': {}",
id, e
)));
}
};
let items = match self
.fetch_and_stitch_relationships(vec![value], rels.clone(), tx)
.await
{
Ok(value) => value,
Err(e) => {
return Err(TitoError::NotFound(format!(
"Failed to fetch relationships for id '{}' with rels {:?}: {}",
id, rels, e
)));
}
};
if let Some(value) = items.get(0) {
serde_json::from_value(value.1.clone()).map_err(|err| {
TitoError::NotFound(format!(
"Failed to deserialize record with id '{}': {}",
id, err
))
})
} else {
Err(TitoError::NotFound(format!(
"No record found with id '{}'",
id
)))
}
}
pub fn get(&self, id: &str) -> GetBuilder<'_, E, T> {
GetBuilder {
model: self,
id: id.to_string(),
rels: vec![],
}
}
async fn get_internal(
&self,
id: &str,
rels: Vec<String>,
tx: Option<&E::Transaction>,
) -> Result<T, TitoError>
where
T: serde::de::DeserializeOwned,
{
match tx {
Some(tx) => self.get_one_with_tx(id, rels, tx).await,
None => {
let id = id.to_string();
self.tx(|tx| {
let id = id.clone();
let rels = rels.clone();
async move { self.get_one_with_tx(&id, rels, &tx).await }
})
.await
}
}
}
pub async fn scan(
&self,
payload: TitoScanPayload,
tx: &E::Transaction,
) -> Result<(Vec<(String, Value)>, bool), TitoError>
where
T: DeserializeOwned,
{
let mut start_bound = format!("{}", payload.start);
if let Some(cursor) = payload.cursor.clone() {
let cursor = self.decode_cursor(cursor)?.first_id()?;
let after_cursor = next_string_lexicographically(cursor);
start_bound = after_cursor;
}
let end_bound = if let Some(end) = payload.end.clone() {
end
} else {
next_string_lexicographically(payload.start.clone())
};
let limit = payload.limit.unwrap_or(u32::MAX);
let limit_plus_one = if limit == u32::MAX {
u32::MAX
} else {
limit + 1
};
let scan_stream = tx.scan(start_bound..end_bound, limit_plus_one).await?;
let mut items = self.to_results(scan_stream)?;
let has_more = if limit == u32::MAX {
false
} else {
items.len() > limit as usize
};
if has_more {
items.truncate(limit as usize);
}
Ok((items, has_more))
}
pub async fn get_many_raw(
&self,
ids: Vec<String>,
rels: Vec<String>,
tx: &E::Transaction,
) -> Result<Vec<(String, Value)>, TitoError>
where
T: DeserializeOwned,
{
let ids = ids
.into_iter()
.map(|id| format!("{}:{}", self.get_table(), id))
.collect();
let items = self.batch_get(ids, tx).await?;
let items = self.fetch_and_stitch_relationships(items, rels, tx).await?;
Ok(items)
}
async fn get_many_with_tx(
&self,
ids: Vec<String>,
rels: Vec<String>,
tx: &E::Transaction,
) -> Result<Vec<T>, TitoError>
where
T: DeserializeOwned,
{
let items = self.get_many_raw(ids, rels, tx).await?;
let mut result = vec![];
for value in items.into_iter() {
if let Ok(item) = serde_json::from_value::<T>(value.1) {
result.push(item);
}
}
Ok(result)
}
pub fn get_many(&self, ids: Vec<String>) -> GetManyBuilder<'_, E, T> {
GetManyBuilder {
model: self,
ids,
rels: vec![],
}
}
async fn get_many_internal(
&self,
ids: Vec<String>,
rels: Vec<String>,
tx: Option<&E::Transaction>,
) -> Result<Vec<T>, TitoError>
where
T: DeserializeOwned,
{
match tx {
Some(tx) => self.get_many_with_tx(ids, rels, tx).await,
None => {
self.tx(|tx| {
let ids = ids.clone();
let rels = rels.clone();
async move { self.get_many_with_tx(ids, rels, &tx).await }
})
.await
}
}
}
pub async fn scan_reverse(
&self,
payload: TitoScanPayload,
tx: &E::Transaction,
) -> Result<(Vec<(String, Value)>, bool), TitoError>
where
T: DeserializeOwned,
{
let start_bound = format!("{}", payload.start.clone());
let mut end_bound = if let Some(end) = payload.end {
end
} else {
next_string_lexicographically(payload.start.clone())
};
if let Some(cursor) = payload.cursor {
let cursor = self.decode_cursor(cursor)?.first_id()?;
let after_cursor = previous_string_lexicographically(cursor.clone());
end_bound = after_cursor;
}
let limit = payload.limit.unwrap_or(u32::MAX);
let limit_plus_one = if limit == u32::MAX {
u32::MAX
} else {
limit + 1
};
let scan_stream = tx
.scan_reverse(start_bound..end_bound, limit_plus_one)
.await?;
let mut items = self.to_results(scan_stream)?;
let has_more = if limit == u32::MAX {
false
} else {
items.len() > limit as usize
};
if has_more {
items.truncate(limit as usize);
}
Ok((items, has_more))
}
async fn clear_indexes(&self, raw_id: &str, tx: &E::Transaction) -> Result<(), TitoError> {
let id = format!("{}:{}", self.get_table(), raw_id);
let reverse_index_key = format!("reverse-index:{}", id);
match self.get_reverse_index(&reverse_index_key, tx).await {
Ok(reverse_index) => {
for key in reverse_index.value {
self.delete(key, tx).await?;
}
self.delete(reverse_index_key, tx).await?;
}
Err(TitoError::NotFound(_)) => {}
Err(e) => return Err(e),
}
Ok(())
}
pub fn get_last_id(&self, key: String) -> Option<String> {
let parts: Vec<&str> = key.split(':').collect();
parts.last().map(|last| last.to_string())
}
pub async fn batch_get(
&self,
keys: Vec<String>,
tx: &E::Transaction,
) -> Result<Vec<(String, Value)>, TitoError> {
match tx.batch_get(keys).await {
Ok(res) => self.to_results(res.into_iter()),
Err(e) => Err(e),
}
}
pub async fn remove_by_index(
&self,
index: &str,
value: &str,
batch_size: u32,
tx: &E::Transaction,
) -> Result<Vec<String>, TitoError>
where
T: DeserializeOwned,
{
let mut query = self.query_by_index(index);
query.value(value.to_string());
query.limit(Some(batch_size));
let items = query.execute(Some(tx)).await?;
if items.items.is_empty() {
return Ok(vec![]);
}
let mut ids = vec![];
for item in items.items {
let id = item.id();
self.remove(&id, tx).await?;
ids.push(id);
}
Ok(ids)
}
pub async fn remove(&self, raw_id: &str, tx: &E::Transaction) -> Result<bool, TitoError> {
let id = format!("{}:{}", self.get_table(), raw_id);
let reverse_index_key = format!("reverse-index:{}", id);
let reverse_index = match self.get_reverse_index(&reverse_index_key, tx).await {
Ok(idx) => idx,
Err(_) => return Err(TitoError::NotFound(format!("Entity not found: {}", id))),
};
let mut keys = reverse_index.value;
keys.push(id.clone());
keys.push(reverse_index_key);
for key in keys.into_iter() {
self.delete(key, tx).await?;
}
Ok(true)
}
pub async fn find(&self, payload: TitoFindPayload) -> Result<TitoPaginated<T>, TitoError>
where
T: DeserializeOwned,
{
let start_bound = format!("{}:{}", self.get_table(), payload.start);
self.tx(|tx| {
let start_bound = start_bound.clone();
let payload = payload.clone();
async move {
let (scan_stream, has_more) = self
.scan(
TitoScanPayload {
start: start_bound,
end: None,
limit: payload.limit,
cursor: payload.cursor.clone(),
},
&tx,
)
.await?;
let items = self
.fetch_and_stitch_relationships(scan_stream, payload.rels, &tx)
.await?;
self.to_paginated_items(items, has_more)
}
})
.await
}
}