use ents::{
DraftError, EdgeDraft, EdgeQuery, EdgeValue, Ent, EntExt as _,
EntMutationError, Id, IncomingEdgeProvider, NullEdgeProvider, QueryEdge,
ReadEnt, Transactional,
};
use ents_sqlite::Txn;
use r2d2::Pool;
use r2d2_sqlite::SqliteConnectionManager;
use serde::{Deserialize, Serialize};
#[derive(Clone, Serialize, Deserialize)]
struct TestEntity {
name: String,
value: i32,
id: Id,
last_updated: u64,
}
#[typetag::serde]
impl Ent for TestEntity {
type EdgeProvider = NullEdgeProvider;
fn id(&self) -> Id {
self.id
}
fn set_id(&mut self, id: Id) {
self.id = id;
}
fn last_updated(&self) -> u64 {
self.last_updated
}
fn mark_updated(&mut self) -> Result<(), EntMutationError> {
self.last_updated = 12345; Ok(())
}
}
impl TestEntity {
pub fn build() -> TestEntityBuilder {
TestEntityBuilder::default()
}
}
#[derive(Default)]
struct TestEntityBuilder {
name: String,
value: i32,
id: Id,
last_updated: u64,
}
impl TestEntityBuilder {
pub fn name(mut self, name: String) -> Self {
self.name = name;
self
}
pub fn value(mut self, value: i32) -> Self {
self.value = value;
self
}
pub fn finish(self) -> anyhow::Result<TestEntity> {
Ok(TestEntity {
name: self.name,
value: self.value,
id: self.id,
last_updated: self.last_updated,
})
}
}
#[derive(Clone, Serialize, Deserialize)]
struct TestPerson {
name: String,
age: i32,
lives_in_link: Id,
id: Id,
last_updated: u64,
}
#[typetag::serde]
impl Ent for TestPerson {
type EdgeProvider = TestPersonEdgeProvider;
fn id(&self) -> Id {
self.id
}
fn set_id(&mut self, id: Id) {
self.id = id;
}
fn last_updated(&self) -> u64 {
self.last_updated
}
fn mark_updated(&mut self) -> Result<(), EntMutationError> {
self.last_updated = 12345;
Ok(())
}
}
impl TestPerson {
pub fn lives_in_link(&self) -> &Id {
&self.lives_in_link
}
}
impl TestPerson {
pub fn build() -> TestPersonBuilder {
TestPersonBuilder::default()
}
}
#[derive(Default)]
struct TestPersonBuilder {
name: String,
age: i32,
lives_in_link: Id,
id: Id,
last_updated: u64,
}
impl TestPersonBuilder {
pub fn name(mut self, name: String) -> Self {
self.name = name;
self
}
pub fn age(mut self, age: i32) -> Self {
self.age = age;
self
}
pub fn lives_in_link(mut self, lives_in_link: Id) -> Self {
self.lives_in_link = lives_in_link;
self
}
pub fn last_updated(mut self, last_updated: u64) -> Self {
self.last_updated = last_updated;
self
}
pub fn finish(self) -> anyhow::Result<TestPerson> {
Ok(TestPerson {
name: self.name,
age: self.age,
lives_in_link: self.lives_in_link,
id: self.id,
last_updated: self.last_updated,
})
}
}
#[derive(PartialEq)]
struct TestPersonEdgeDraft {
person_id: Id,
city_id: Id,
}
impl EdgeDraft for TestPersonEdgeDraft {
fn check<T: ReadEnt>(self, _txn: &T) -> Result<Vec<EdgeValue>, DraftError> {
Ok(vec![EdgeValue::new(
self.person_id,
b"lives_in".to_vec(),
self.city_id,
)])
}
}
struct TestPersonEdgeProvider;
impl IncomingEdgeProvider<TestPerson> for TestPersonEdgeProvider {
type Draft = TestPersonEdgeDraft;
fn draft(ent: &TestPerson) -> Self::Draft {
TestPersonEdgeDraft {
person_id: ent.id(),
city_id: *ent.lives_in_link(),
}
}
}
impl TestPerson {
pub fn set_lives_in_link(&mut self, lives_in_link: Id) {
self.lives_in_link = lives_in_link;
}
}
#[derive(Clone, Serialize, Deserialize)]
struct TestCity {
name: String,
population: i64,
id: Id,
last_updated: u64,
}
#[typetag::serde]
impl Ent for TestCity {
type EdgeProvider = NullEdgeProvider;
fn id(&self) -> Id {
self.id
}
fn set_id(&mut self, id: Id) {
self.id = id;
}
fn last_updated(&self) -> u64 {
self.last_updated
}
fn mark_updated(&mut self) -> Result<(), EntMutationError> {
self.last_updated = 12345;
Ok(())
}
}
impl TestCity {
pub fn build() -> TestCityBuilder {
TestCityBuilder::default()
}
}
#[derive(Default)]
struct TestCityBuilder {
name: String,
population: i64,
id: Id,
last_updated: u64,
}
impl TestCityBuilder {
pub fn name(mut self, name: String) -> Self {
self.name = name;
self
}
pub fn population(mut self, population: i64) -> Self {
self.population = population;
self
}
pub fn finish(self) -> anyhow::Result<TestCity> {
Ok(TestCity {
name: self.name,
population: self.population,
id: self.id,
last_updated: self.last_updated,
})
}
}
fn setup_test_db() -> Pool<SqliteConnectionManager> {
let pool = Pool::new(SqliteConnectionManager::memory()).unwrap();
let conn = pool.get().unwrap();
conn.execute_batch(
r#"
CREATE TABLE IF NOT EXISTS entities (
id INTEGER PRIMARY KEY,
type TEXT NOT NULL,
data TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS edges (
source INTEGER NOT NULL,
type TEXT NOT NULL,
dest INTEGER NOT NULL,
PRIMARY KEY (source, type, dest)
);
"#,
)
.unwrap();
pool
}
#[test]
fn test_insert_and_get() {
let pool = setup_test_db();
let mut conn = pool.get().unwrap();
let tx = conn.transaction().unwrap();
let txn = Txn::new(tx);
let ent = TestEntity::build()
.name("test".to_string())
.value(42)
.finish()
.unwrap();
let id = txn.create(ent).unwrap();
let retrieved = txn.get(id).unwrap();
assert!(retrieved.is_some());
let retrieved_ent = retrieved.unwrap();
assert_eq!(retrieved_ent.id(), id);
assert!(retrieved_ent.is::<TestEntity>());
assert_eq!(retrieved_ent.typetag_name(), "TestEntity");
}
#[test]
fn test_get_nonexistent() {
let pool = setup_test_db();
let mut conn = pool.get().unwrap();
let tx = conn.transaction().unwrap();
let txn = Txn::new(tx);
let result = txn.get(999).unwrap();
assert!(result.is_none());
}
#[test]
fn test_transaction_commit() {
let pool = setup_test_db();
let id = {
let mut conn = pool.get().unwrap();
let tx = conn.transaction().unwrap();
let txn = Txn::new(tx);
let ent = TestEntity::build()
.name("committed".to_string())
.value(999)
.finish()
.unwrap();
let id = txn.create(ent).unwrap();
txn.commit().unwrap();
id
};
let mut conn = pool.get().unwrap();
let tx = conn.transaction().unwrap();
let txn = Txn::new(tx);
let retrieved = txn.get(id).unwrap();
assert!(retrieved.is_some());
}
#[test]
fn test_transaction_rollback() {
let pool = setup_test_db();
let id = {
let mut conn = pool.get().unwrap();
let tx = conn.transaction().unwrap();
let txn = Txn::new(tx);
let ent = TestEntity::build()
.name("rolled_back".to_string())
.value(888)
.finish()
.unwrap();
let id = txn.create(ent).unwrap();
id
};
let mut conn = pool.get().unwrap();
let tx = conn.transaction().unwrap();
let txn = Txn::new(tx);
let retrieved = txn.get(id).unwrap();
assert!(retrieved.is_none());
}
#[test]
fn test_update_without_cas() {
let pool = setup_test_db();
let mut conn = pool.get().unwrap();
let tx = conn.transaction().unwrap();
let txn = Txn::new(tx);
let mut ent = TestEntity::build()
.name("original".to_string())
.value(100)
.finish()
.unwrap();
let id = txn.create(ent.clone()).unwrap();
ent.set_id(id);
let success = txn
.update(&mut ent, |e: &mut TestEntity| {
e.name = "updated".to_string();
e.value = 200;
})
.unwrap();
assert!(success);
let retrieved = txn.get(id).unwrap().unwrap();
let retrieved_json = serde_json::to_value(&retrieved).unwrap();
assert_eq!(retrieved_json["name"], "updated");
assert_eq!(retrieved_json["value"], 200);
}
#[derive(Clone, Serialize, Deserialize)]
struct TestEntityWithTimestamp {
name: String,
value: i32,
id: Id,
last_updated: u64,
}
#[typetag::serde]
impl Ent for TestEntityWithTimestamp {
type EdgeProvider = NullEdgeProvider;
fn id(&self) -> Id {
self.id
}
fn set_id(&mut self, id: Id) {
self.id = id;
}
fn last_updated(&self) -> u64 {
self.last_updated
}
fn mark_updated(&mut self) -> Result<(), EntMutationError> {
self.last_updated = 12345;
Ok(())
}
}
impl TestEntityWithTimestamp {
pub fn build() -> TestEntityWithTimestampBuilder {
TestEntityWithTimestampBuilder::default()
}
}
#[derive(Default)]
struct TestEntityWithTimestampBuilder {
name: String,
value: i32,
id: Id,
last_updated: u64,
}
impl TestEntityWithTimestampBuilder {
pub fn name(mut self, name: String) -> Self {
self.name = name;
self
}
pub fn value(mut self, value: i32) -> Self {
self.value = value;
self
}
pub fn last_updated(mut self, last_updated: u64) -> Self {
self.last_updated = last_updated;
self
}
pub fn finish(self) -> anyhow::Result<TestEntityWithTimestamp> {
Ok(TestEntityWithTimestamp {
name: self.name,
value: self.value,
id: self.id,
last_updated: self.last_updated,
})
}
}
#[test]
fn test_update_with_timestamp() {
let pool = setup_test_db();
let mut conn = pool.get().unwrap();
let tx = conn.transaction().unwrap();
let txn = Txn::new(tx);
let timestamp1 = 1000u64;
let mut ent = TestEntityWithTimestamp::build()
.name("original".to_string())
.value(100)
.last_updated(timestamp1)
.finish()
.unwrap();
let id = txn.create(ent.clone()).unwrap();
ent.set_id(id);
let success = txn
.update(&mut ent, |e: &mut TestEntityWithTimestamp| {
e.name = "updated".to_string();
e.value = 200;
})
.unwrap();
assert!(success);
let retrieved = txn.get(id).unwrap().unwrap();
let retrieved_json = serde_json::to_value(&retrieved).unwrap();
assert_eq!(retrieved_json["name"], "updated");
assert_eq!(retrieved_json["value"], 200);
assert_eq!(retrieved_json["last_updated"], 12345);
}
#[test]
fn test_update4_edge_change() {
let pool = setup_test_db();
let mut conn = pool.get().unwrap();
let tx = conn.transaction().unwrap();
let txn = Txn::new(tx);
let city1 = TestCity::build()
.name("City1".to_string())
.population(100)
.finish()
.unwrap();
let city1_id = txn.create(city1).unwrap();
let city2 = TestCity::build()
.name("City2".to_string())
.population(200)
.finish()
.unwrap();
let city2_id = txn.create(city2).unwrap();
let mut person = TestPerson::build()
.name("Alice".to_string())
.age(30)
.lives_in_link(city1_id)
.last_updated(0)
.finish()
.unwrap();
let person_id = txn.create(person.clone()).unwrap();
person.set_id(person_id);
let edges = txn.find_edges(person_id, EdgeQuery::asc(&[])).unwrap();
assert_eq!(edges.len(), 1);
assert_eq!(edges[0].dest, city1_id);
let success = txn
.update(&mut person, |p: &mut TestPerson| {
p.set_lives_in_link(city2_id);
})
.unwrap();
assert!(success);
let edges = txn.find_edges(person_id, EdgeQuery::asc(&[])).unwrap();
assert_eq!(edges.len(), 1);
assert_eq!(edges[0].dest, city2_id);
person.set_lives_in_link(city2_id);
let success_no_change = txn
.update(&mut person, |p: &mut TestPerson| {
p.set_lives_in_link(city2_id);
})
.unwrap();
assert!(success_no_change);
let edges = txn.find_edges(person_id, EdgeQuery::asc(&[])).unwrap();
assert_eq!(edges.len(), 1);
assert_eq!(edges[0].dest, city2_id);
}