use std::sync::Arc;
use bytes::Bytes;
use crate::codec::{decode_key, decode_value, encode_key, encode_value};
use crate::error::RelError;
use crate::index::{index_range_for_eq, IndexSpec, IndexWriter};
use crate::predicate::Predicate;
use crate::schema::{FieldValue, Record, Schema};
use donadb::types::{BlockHeight, DomainId};
use donadb::DonaDb;
pub struct RelTableConfig {
pub schema: Schema,
pub domain_id: DomainId,
pub indexes: Vec<IndexSpec>,
}
impl RelTableConfig {
pub fn new(schema: Schema, domain_id: DomainId) -> Self {
let indexes = schema
.indexed_fields()
.into_iter()
.filter_map(|(_, field)| IndexSpec::new(&schema, &field.name).ok())
.collect();
Self {
schema,
domain_id,
indexes,
}
}
pub fn with_domain(mut self, domain_id: DomainId) -> Self {
self.domain_id = domain_id;
self
}
}
#[derive(Clone)]
pub struct RelTable {
db: Arc<DonaDb>,
schema: Arc<Schema>,
domain_id: DomainId,
writer: Arc<IndexWriter>,
}
impl RelTable {
pub fn new(db: Arc<DonaDb>, config: RelTableConfig) -> Self {
let writer = IndexWriter::new(config.indexes);
Self {
db,
schema: Arc::new(config.schema),
domain_id: config.domain_id,
writer: Arc::new(writer),
}
}
pub fn schema(&self) -> &Schema {
&self.schema
}
pub fn domain_id(&self) -> DomainId {
self.domain_id
}
pub fn put_record(
&self,
record: &Record,
block_height: BlockHeight,
entropy: &[u8],
) -> Result<(), RelError> {
record.validate(&self.schema)?;
let primary_key = encode_key(record, &self.schema)?;
let primary_value = encode_value(record, &self.schema)?;
let mut batch = self.db.begin_batch(block_height, entropy);
if let Some(old_bytes) = self.db.get(self.domain_id, &primary_key)? {
let old_record = self.decode_full_record(&primary_key, &old_bytes)?;
self.writer
.remove_indexes(&mut batch, &old_record, &self.schema, &primary_key)?;
}
batch.put(self.domain_id, primary_key.clone(), primary_value);
self.writer
.write_indexes(&mut batch, record, &self.schema, &primary_key)?;
batch.commit()?;
Ok(())
}
pub fn put_records(
&self,
records: &[Record],
block_height: BlockHeight,
entropy: &[u8],
) -> Result<(), RelError> {
let mut batch = self.db.begin_batch(block_height, entropy);
for record in records {
record.validate(&self.schema)?;
let primary_key = encode_key(record, &self.schema)?;
let primary_value = encode_value(record, &self.schema)?;
if let Some(old_bytes) = self.db.get(self.domain_id, &primary_key)? {
let old = self.decode_full_record(&primary_key, &old_bytes)?;
self.writer
.remove_indexes(&mut batch, &old, &self.schema, &primary_key)?;
}
batch.put(self.domain_id, primary_key.clone(), primary_value);
self.writer
.write_indexes(&mut batch, record, &self.schema, &primary_key)?;
}
batch.commit()?;
Ok(())
}
pub fn get_record(&self, key_values: &[FieldValue]) -> Result<Option<Record>, RelError> {
let primary_key = self.build_key_from_values(key_values)?;
match self.db.get(self.domain_id, &primary_key)? {
Some(bytes) if !bytes.is_empty() => {
Ok(Some(self.decode_full_record(&primary_key, &bytes)?))
}
_ => Ok(None),
}
}
pub fn get_record_at(
&self,
key_values: &[FieldValue],
block_height: BlockHeight,
) -> Result<Option<Record>, RelError> {
let primary_key = self.build_key_from_values(key_values)?;
match self.db.get_at(self.domain_id, &primary_key, block_height)? {
Some(bytes) if !bytes.is_empty() => {
Ok(Some(self.decode_full_record(&primary_key, &bytes)?))
}
_ => Ok(None),
}
}
pub fn scan_where(&self, predicates: &[Predicate]) -> Result<Vec<Record>, RelError> {
let index_hit = predicates.iter().find_map(|p| {
self.writer
.specs
.iter()
.find(|s| s.field_name == p.field)
.map(|spec| (p, spec))
});
let raw_results: Vec<(Bytes, Bytes)> = if let Some((pred, spec)) = index_hit {
if let crate::predicate::FieldOp::Eq(ref val) = pred.op {
let field = &self.schema.fields[spec.field_idx];
let (start, end) = index_range_for_eq(val, &field.field_type)?;
let index_entries = self.db.scan(spec.domain_id, &start, &end)?;
let mut results = Vec::new();
for (_idx_key, pk_bytes) in index_entries {
if pk_bytes.is_empty() {
continue;
} if let Some(val_bytes) = self.db.get(self.domain_id, &pk_bytes)? {
if !val_bytes.is_empty() {
results.push((pk_bytes, val_bytes));
}
}
}
results
} else {
self.full_scan_raw()?
}
} else {
self.full_scan_raw()?
};
let mut records = Vec::new();
for (pk_bytes, val_bytes) in raw_results {
let record = self.decode_full_record(&pk_bytes, &val_bytes)?;
if self.record_matches(&record, predicates)? {
records.push(record);
}
}
Ok(records)
}
pub fn scan_range(
&self,
start_key: &[FieldValue],
end_key: &[FieldValue],
) -> Result<Vec<Record>, RelError> {
let start = self.build_key_from_values(start_key)?;
let end = self.build_key_from_values(end_key)?;
let raw = self.db.scan(self.domain_id, &start, &end)?;
let mut records = Vec::new();
for (pk, val) in raw {
if !val.is_empty() {
records.push(self.decode_full_record(&pk, &val)?);
}
}
Ok(records)
}
pub fn scan_prefix_raw(&self, prefix: &[u8]) -> Result<Vec<Record>, RelError> {
let raw = self.db.scan_prefix_domain(self.domain_id, prefix)?;
let mut records = Vec::new();
for (pk, val) in raw {
if !val.is_empty() {
records.push(self.decode_full_record(&pk, &val)?);
}
}
Ok(records)
}
pub fn count_where(&self, predicates: &[Predicate]) -> Result<usize, RelError> {
Ok(self.scan_where(predicates)?.len())
}
pub fn scan_all_raw(&self) -> Result<Vec<Record>, RelError> {
let pairs = self.full_scan_raw()?;
pairs
.into_iter()
.map(|(pk, val)| self.decode_full_record(&pk, &val))
.collect()
}
pub fn follow_ref(
&self,
record: &Record,
ref_field: &str,
target: &RelTable,
) -> Result<Option<Record>, RelError> {
let (idx, _) = self
.schema
.field(ref_field)
.ok_or_else(|| RelError::UnknownField(ref_field.to_string()))?;
let fk_value = record
.values
.get(idx)
.ok_or_else(|| RelError::MissingField(ref_field.to_string()))?;
if fk_value == &FieldValue::Null {
return Ok(None);
}
target.get_record(&[fk_value.clone()])
}
pub fn delete_record(
&self,
key_values: &[FieldValue],
block_height: BlockHeight,
entropy: &[u8],
) -> Result<(), RelError> {
let primary_key = self.build_key_from_values(key_values)?;
if let Some(old_bytes) = self.db.get(self.domain_id, &primary_key)? {
let old_record = self.decode_full_record(&primary_key, &old_bytes)?;
let mut batch = self.db.begin_batch(block_height, entropy);
self.writer
.remove_indexes(&mut batch, &old_record, &self.schema, &primary_key)?;
batch.put(self.domain_id, primary_key, Bytes::new());
batch.commit()?;
}
Ok(())
}
fn build_key_from_values(&self, key_values: &[FieldValue]) -> Result<Bytes, RelError> {
use crate::codec::encode_field_key;
use bytes::BytesMut;
let key_fields = self.schema.key_fields();
if key_values.len() != key_fields.len() {
return Err(RelError::Schema(format!(
"schema '{}' has {} key fields, got {}",
self.schema.name,
key_fields.len(),
key_values.len()
)));
}
let mut buf = BytesMut::new();
for ((_, field), value) in key_fields.iter().zip(key_values.iter()) {
encode_field_key(&mut buf, &field.field_type, value)?;
}
Ok(buf.freeze())
}
fn full_scan_raw(&self) -> Result<Vec<(Bytes, Bytes)>, RelError> {
let all = self.db.scan_all(self.domain_id)?;
Ok(all.into_iter().filter(|(_, v)| !v.is_empty()).collect())
}
fn decode_full_record(&self, pk_bytes: &Bytes, val_bytes: &Bytes) -> Result<Record, RelError> {
let mut values: Vec<FieldValue> = vec![FieldValue::Null; self.schema.fields.len()];
for (idx, v) in decode_key(pk_bytes, &self.schema)? {
values[idx] = v;
}
let decoded_vals = decode_value(val_bytes, &self.schema)?;
for (i, v) in decoded_vals.into_iter().enumerate() {
if !self.schema.fields[i].is_key {
values[i] = v;
}
}
Ok(Record::new(values))
}
fn record_matches(&self, record: &Record, predicates: &[Predicate]) -> Result<bool, RelError> {
for pred in predicates {
let (idx, _) = self
.schema
.field(&pred.field)
.ok_or_else(|| RelError::UnknownField(pred.field.clone()))?;
let value = record.values.get(idx).unwrap_or(&FieldValue::Null);
if !pred.evaluate(value) {
return Ok(false);
}
}
Ok(true)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use std::sync::Arc;
use tempfile::TempDir;
use crate::predicate::Predicate;
use crate::schema::{Field, FieldType, FieldValue, Record, Schema};
fn open_db(dir: &TempDir) -> Arc<DonaDb> {
use donadb::DonaDbConfig;
Arc::new(
DonaDb::open(DonaDbConfig {
data_dir: dir.path().to_path_buf(),
shard_count: 16,
compaction_threads: 2,
block_cache_bytes: 8 * 1024 * 1024,
write_buffer_bytes: 16 * 1024 * 1024,
..Default::default()
})
.unwrap(),
)
}
fn account_schema() -> Schema {
Schema::new(
"accounts",
vec![
Field::key("address", FieldType::Address),
Field::value("balance", FieldType::U128),
Field::indexed_value("nonce", FieldType::U64),
],
)
}
fn make_account(addr: [u8; 32], balance: u128, nonce: u64) -> Record {
Record::new(vec![
FieldValue::Bytes(addr.to_vec()),
FieldValue::U128(balance),
FieldValue::U64(nonce),
])
}
#[test]
fn test_put_and_get() {
let dir = TempDir::new().unwrap();
let db = open_db(&dir);
let tbl = RelTable::new(db, RelTableConfig::new(account_schema(), 1));
let addr = [0xABu8; 32];
let rec = make_account(addr, 1_000_000, 0);
tbl.put_record(&rec, 1, b"entropy1").unwrap();
let fetched = tbl.get_record(&[FieldValue::Bytes(addr.to_vec())]).unwrap();
assert!(fetched.is_some());
let r = fetched.unwrap();
assert_eq!(r.values[1], FieldValue::U128(1_000_000));
assert_eq!(r.values[2], FieldValue::U64(0));
}
#[test]
fn test_scan_where_full_scan() {
let dir = TempDir::new().unwrap();
let db = open_db(&dir);
let tbl = RelTable::new(db, RelTableConfig::new(account_schema(), 1));
for i in 0u8..10 {
let mut addr = [0u8; 32];
addr[31] = i;
tbl.put_record(
&make_account(addr, (i as u128) * 100, i as u64),
i as u64 + 1,
b"e",
)
.unwrap();
}
tbl.db.finalize_block(10).unwrap();
let results = tbl
.scan_where(&[Predicate::gte("balance", FieldValue::U128(500))])
.unwrap();
assert_eq!(results.len(), 5); }
#[test]
fn test_scan_where_index_accelerated() {
let dir = TempDir::new().unwrap();
let db = open_db(&dir);
let tbl = RelTable::new(db, RelTableConfig::new(account_schema(), 2));
for i in 0u8..5 {
let mut addr = [0u8; 32];
addr[31] = i;
tbl.put_record(&make_account(addr, 1000, i as u64), i as u64 + 1, b"e")
.unwrap();
}
tbl.db.finalize_block(5).unwrap();
let results = tbl
.scan_where(&[Predicate::eq("nonce", FieldValue::U64(3))])
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].values[2], FieldValue::U64(3));
}
#[test]
fn test_point_in_time_read() {
let dir = TempDir::new().unwrap();
let db = open_db(&dir);
let tbl = RelTable::new(db, RelTableConfig::new(account_schema(), 3));
let addr = [0x01u8; 32];
tbl.put_record(&make_account(addr, 500, 0), 10, b"e")
.unwrap();
tbl.put_record(&make_account(addr, 900, 1), 20, b"e")
.unwrap();
let at_10 = tbl
.get_record_at(&[FieldValue::Bytes(addr.to_vec())], 10)
.unwrap()
.unwrap();
assert_eq!(at_10.values[1], FieldValue::U128(500));
let at_20 = tbl
.get_record_at(&[FieldValue::Bytes(addr.to_vec())], 20)
.unwrap()
.unwrap();
assert_eq!(at_20.values[1], FieldValue::U128(900));
}
#[test]
fn test_follow_ref() {
let dir = TempDir::new().unwrap();
let db = open_db(&dir);
let acc_schema = account_schema();
let acc_tbl = RelTable::new(db.clone(), RelTableConfig::new(acc_schema, 4));
let token_schema = Schema::new(
"token_balances",
vec![
Field::key("cell_id", FieldType::U64),
Field::value("holder", FieldType::Address),
Field::value("amount", FieldType::U128),
],
);
let tok_tbl = RelTable::new(db.clone(), RelTableConfig::new(token_schema, 5));
let addr = [0x11u8; 32];
acc_tbl
.put_record(&make_account(addr, 5000, 2), 1, b"e")
.unwrap();
let tok_record = Record::new(vec![
FieldValue::U64(42),
FieldValue::Bytes(addr.to_vec()),
FieldValue::U128(250),
]);
tok_tbl.put_record(&tok_record, 2, b"e").unwrap();
db.finalize_block(2).unwrap();
let fetched_tok = tok_tbl.get_record(&[FieldValue::U64(42)]).unwrap().unwrap();
let account = tok_tbl
.follow_ref(&fetched_tok, "holder", &acc_tbl)
.unwrap();
assert!(account.is_some());
assert_eq!(account.unwrap().values[1], FieldValue::U128(5000));
}
#[test]
fn test_delete_record() {
let dir = TempDir::new().unwrap();
let db = open_db(&dir);
let tbl = RelTable::new(db, RelTableConfig::new(account_schema(), 6));
let addr = [0x22u8; 32];
tbl.put_record(&make_account(addr, 1000, 0), 1, b"e")
.unwrap();
tbl.delete_record(&[FieldValue::Bytes(addr.to_vec())], 2, b"e")
.unwrap();
let result = tbl.get_record(&[FieldValue::Bytes(addr.to_vec())]).unwrap();
assert!(result.is_none());
}
#[test]
fn test_batch_put_records() {
let dir = TempDir::new().unwrap();
let db = open_db(&dir);
let tbl = RelTable::new(db, RelTableConfig::new(account_schema(), 7));
let records: Vec<Record> = (0u8..20)
.map(|i| {
let mut addr = [0u8; 32];
addr[31] = i;
make_account(addr, i as u128 * 1000, i as u64)
})
.collect();
tbl.put_records(&records, 1, b"batch").unwrap();
tbl.db.finalize_block(1).unwrap();
let count = tbl
.count_where(&[Predicate::gte("balance", FieldValue::U128(10_000))])
.unwrap();
assert_eq!(count, 10);
}
}