use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use aws_sdk_dynamodb::types::{AttributeValue, TransactWriteItem, WriteRequest};
use graphddb_runtime::client::{
DynamoClient, GetItemInput, GetItemOutput, Item, QueryInput, QueryOutput, WriteOutput,
};
use graphddb_runtime::{GraphDDBError, GraphDDBRuntime, Middleware, Recovery};
use serde_json::{json, Value as Json};
#[derive(Default)]
struct Recorded {
get_items: Vec<GetItemInput>,
queries: Vec<QueryInput>,
puts: Vec<(String, Item, Option<String>)>, updates: Vec<(String, Item, Option<String>)>,
deletes: Vec<(String, Item, Option<String>)>,
transacts: Vec<Vec<TransactWriteItem>>,
}
#[derive(Clone, Default)]
struct FakeClient {
rec: Arc<Mutex<Recorded>>,
get_item_response: Arc<Mutex<Option<Item>>>,
query_response: Arc<Mutex<Vec<Item>>>,
fail_next_write: Arc<Mutex<Option<String>>>,
}
fn s(v: &str) -> AttributeValue {
AttributeValue::S(v.to_string())
}
#[async_trait]
impl DynamoClient for FakeClient {
async fn get_item(&self, input: GetItemInput) -> graphddb_runtime::Result<GetItemOutput> {
let item = self.get_item_response.lock().unwrap().clone();
self.rec.lock().unwrap().get_items.push(input);
Ok(GetItemOutput { item })
}
async fn query(&self, input: QueryInput) -> graphddb_runtime::Result<QueryOutput> {
let items = self.query_response.lock().unwrap().clone();
self.rec.lock().unwrap().queries.push(input);
Ok(QueryOutput {
items,
last_evaluated_key: None,
})
}
async fn put_item(
&self,
table: &str,
item: Item,
cond: Option<String>,
_n: Option<HashMap<String, String>>,
_v: Option<HashMap<String, AttributeValue>>,
_old: bool,
) -> graphddb_runtime::Result<WriteOutput> {
if let Some(msg) = self.fail_next_write.lock().unwrap().take() {
return Err(GraphDDBError::operation_execution(msg));
}
self.rec
.lock()
.unwrap()
.puts
.push((table.to_string(), item, cond));
Ok(WriteOutput::default())
}
async fn update_item(
&self,
table: &str,
key: Item,
_ue: Option<String>,
cond: Option<String>,
_n: Option<HashMap<String, String>>,
_v: Option<HashMap<String, AttributeValue>>,
_old: bool,
) -> graphddb_runtime::Result<WriteOutput> {
if let Some(msg) = self.fail_next_write.lock().unwrap().take() {
return Err(GraphDDBError::operation_execution(msg));
}
self.rec
.lock()
.unwrap()
.updates
.push((table.to_string(), key, cond));
Ok(WriteOutput::default())
}
async fn delete_item(
&self,
table: &str,
key: Item,
cond: Option<String>,
_n: Option<HashMap<String, String>>,
_v: Option<HashMap<String, AttributeValue>>,
_old: bool,
) -> graphddb_runtime::Result<WriteOutput> {
if let Some(msg) = self.fail_next_write.lock().unwrap().take() {
return Err(GraphDDBError::operation_execution(msg));
}
self.rec
.lock()
.unwrap()
.deletes
.push((table.to_string(), key, cond));
Ok(WriteOutput::default())
}
async fn batch_get_item(
&self,
_t: &str,
_k: Vec<Item>,
_pe: Option<String>,
_n: Option<HashMap<String, String>>,
) -> graphddb_runtime::Result<(Vec<Item>, Vec<Item>)> {
Ok((vec![], vec![]))
}
async fn batch_write_item(
&self,
_t: &str,
_r: Vec<WriteRequest>,
) -> graphddb_runtime::Result<Vec<WriteRequest>> {
Ok(vec![])
}
async fn transact_write_items(
&self,
items: Vec<TransactWriteItem>,
) -> graphddb_runtime::Result<()> {
if let Some(msg) = self.fail_next_write.lock().unwrap().take() {
return Err(GraphDDBError::operation_execution(msg));
}
self.rec.lock().unwrap().transacts.push(items);
Ok(())
}
}
fn manifest() -> Json {
json!({
"version": "1.1",
"entities": {
"UserModel": {
"table": "T",
"key": { "pkTemplate": "USER#{userId}", "skTemplate": "PROFILE" },
"fields": { "userId": {}, "name": {}, "status": {} }
}
}
})
}
fn operations() -> Json {
json!({
"version": "1.1",
"queries": {
"getUser": {
"cardinality": "one",
"params": { "userId": { "type": "string", "required": true } },
"operations": [ {
"type": "GetItem",
"tableName": "T",
"keyCondition": { "PK": "USER#{userId}", "SK": "PROFILE" },
"projection": ["userId", "name", "status"],
"resultPath": "$",
"entity": "UserModel"
} ]
},
"listUsers": {
"cardinality": "many",
"params": { "userId": { "type": "string", "required": true } },
"operations": [ {
"type": "Query",
"tableName": "T",
"keyCondition": { "PK": "USER#{userId}" },
"projection": ["userId", "name"],
"resultPath": "$",
"entity": "UserModel"
} ]
}
},
"commands": {
"putUser": {
"type": "PutItem",
"tableName": "T",
"entity": "UserModel",
"item": { "userId": "{userId}", "name": "{name}", "status": "{status}" },
"params": {
"userId": {"type":"string","required":true},
"name": {"type":"string","required":true},
"status": {"type":"string","required":true}
}
},
"deleteUser": {
"type": "DeleteItem",
"tableName": "T",
"entity": "UserModel",
"keyCondition": { "PK": "USER#{userId}", "SK": "PROFILE" },
"params": { "userId": {"type":"string","required":true} }
}
},
"transactions": {
"putManyUsers": {
"params": { "users": { "type": "array", "element": {
"userId": {"type":"string","required":true},
"name": {"type":"string","required":true}
} } },
"items": [ {
"type": "Put",
"tableName": "T",
"entity": "UserModel",
"forEach": { "source": "users" },
"item": { "userId": "{item.userId}", "name": "{item.name}", "status": "active" }
} ]
}
},
"contracts": {}
})
}
fn runtime(fake: FakeClient) -> GraphDDBRuntime {
GraphDDBRuntime::new(Arc::new(fake), manifest(), operations(), None, None).unwrap()
}
fn params(v: Json) -> serde_json::Map<String, Json> {
v.as_object().unwrap().clone()
}
#[tokio::test]
async fn r1_read_before_fires_and_mutates_params() {
let fake = FakeClient::default();
*fake.get_item_response.lock().unwrap() = Some(HashMap::from([
("userId".into(), s("bob")),
("name".into(), s("Bob")),
]));
let rec = fake.rec.clone();
let mut rt = runtime(fake);
let fired = Arc::new(Mutex::new(false));
let f = fired.clone();
rt.use_middleware(Middleware::new().read_before(move |ctx| {
*f.lock().unwrap() = true;
let p = ctx
.params
.get_mut("params")
.unwrap()
.as_object_mut()
.unwrap();
p.insert("userId".into(), json!("bob"));
Ok(())
}));
rt.execute_query("getUser", ¶ms(json!({"userId": "alice"})))
.await
.unwrap();
assert!(*fired.lock().unwrap(), "R1 did not fire");
let key = &rec.lock().unwrap().get_items[0].key;
assert_eq!(
key.get("PK"),
Some(&s("USER#bob")),
"R1's param mutation was not applied to the physical key"
);
}
#[tokio::test]
async fn r2_read_op_before_sees_operation() {
let fake = FakeClient::default();
let seen = Arc::new(Mutex::new(None));
let sk = seen.clone();
let mut rt = runtime(fake);
rt.use_middleware(Middleware::new().read_op_before(move |ctx| {
*sk.lock().unwrap() = Some((ctx.op_type.clone(), ctx.key.get("PK").cloned()));
Ok(())
}));
rt.execute_query("getUser", ¶ms(json!({"userId": "alice"})))
.await
.unwrap();
let (op_type, pk) = seen.lock().unwrap().clone().expect("R2 did not fire");
assert_eq!(op_type, "GetItem");
assert_eq!(pk, Some(s("USER#alice")));
}
#[tokio::test]
async fn r3_read_op_after_transforms_items() {
let fake = FakeClient::default();
*fake.query_response.lock().unwrap() = vec![HashMap::from([
("userId".into(), s("a")),
("name".into(), s("A")),
])];
let mut rt = runtime(fake);
rt.use_middleware(
Middleware::new().read_op_after(|_ctx, _items| vec![json!({"userId": "z", "name": "Zed"})]),
);
let result = rt
.execute_query("listUsers", ¶ms(json!({"userId": "a"})))
.await
.unwrap();
let items = result.get("items").unwrap().as_array().unwrap();
assert_eq!(items.len(), 1);
assert_eq!(
items[0].get("name").unwrap(),
"Zed",
"R3 transform not applied"
);
}
#[tokio::test]
async fn r4_read_after_replaces_result_lifo_onion() {
let fake = FakeClient::default();
*fake.get_item_response.lock().unwrap() = Some(HashMap::from([("name".into(), s("Bob"))]));
let mut rt = runtime(fake);
rt.use_middleware(Middleware::new().read_after(|_ctx, r| json!({"first": r})));
rt.use_middleware(Middleware::new().read_after(|_ctx, r| json!({"second": r})));
let result = rt
.execute_query("getUser", ¶ms(json!({"userId": "bob"})))
.await
.unwrap();
assert!(
result.get("first").is_some(),
"outer R4 (first-registered) did not wrap last"
);
assert!(
result["first"].get("second").is_some(),
"R4 onion order wrong (expected LIFO)"
);
}
#[tokio::test]
async fn r5_read_on_error_recovers() {
let fake = FakeClient::default();
let mut rt = runtime(fake);
let fired = Arc::new(Mutex::new(false));
let f = fired.clone();
rt.use_middleware(
Middleware::new().read_before(|_ctx| Err(GraphDDBError::new("boom from R1"))),
);
rt.use_middleware(Middleware::new().read_on_error(move |_ctx, err| {
*f.lock().unwrap() = true;
assert!(
err.message.contains("boom"),
"R5 saw the wrong error: {}",
err.message
);
Recovery::Recover(json!({"recovered": true}))
}));
let result = rt
.execute_query("getUser", ¶ms(json!({"userId": "x"})))
.await
.unwrap();
assert!(*fired.lock().unwrap(), "R5 did not fire");
assert_eq!(
result,
json!({"recovered": true}),
"R5 recovery value not returned"
);
}
#[tokio::test]
async fn w1_write_before_injects_field() {
let fake = FakeClient::default();
let rec = fake.rec.clone();
let mut rt = runtime(fake);
let fired = Arc::new(Mutex::new(false));
let f = fired.clone();
rt.use_middleware(Middleware::new().write_before(move |ctx| {
*f.lock().unwrap() = true;
assert_eq!(ctx.kind, "put");
ctx.input
.get_mut("item")
.unwrap()
.as_object_mut()
.unwrap()
.insert("status".into(), json!("banned"));
Ok(())
}));
rt.execute_command(
"putUser",
¶ms(json!({"userId": "u1", "name": "N", "status": "active"})),
)
.await
.unwrap();
assert!(*fired.lock().unwrap(), "W1 did not fire");
let (_t, item, _c) = rec.lock().unwrap().puts[0].clone();
assert_eq!(
item.get("status"),
Some(&s("banned")),
"W1 field injection not applied to the persisted item"
);
}
#[tokio::test]
async fn w1_kind_rewrite_delete_to_update() {
let fake = FakeClient::default();
let rec = fake.rec.clone();
let mut rt = runtime(fake);
rt.use_middleware(Middleware::new().write_before(|ctx| {
assert_eq!(ctx.kind, "delete");
ctx.kind = "update".into();
ctx.input
.insert("changes".into(), json!({"status": "deleted"}));
Ok(())
}));
rt.execute_command("deleteUser", ¶ms(json!({"userId": "u1"})))
.await
.unwrap();
let r = rec.lock().unwrap();
assert!(
r.deletes.is_empty(),
"delete should have been rewritten to an update"
);
assert_eq!(r.updates.len(), 1, "soft-delete update was not issued");
}
#[tokio::test]
async fn w2_write_after_observes_change() {
let fake = FakeClient::default();
let mut rt = runtime(fake);
let seen = Arc::new(Mutex::new(None));
let sk = seen.clone();
rt.use_middleware(Middleware::new().write_after(move |ctx, change| {
*sk.lock().unwrap() = Some((ctx.kind.clone(), change.clone()));
}));
rt.execute_command(
"putUser",
¶ms(json!({"userId": "u1", "name": "N", "status": "active"})),
)
.await
.unwrap();
let (kind, change) = seen.lock().unwrap().clone().expect("W2 did not fire");
assert_eq!(kind, "put");
assert!(change.is_object(), "W2 change should be an object");
}
#[tokio::test]
async fn w3_persist_before_mutates_items() {
let fake = FakeClient::default();
let rec = fake.rec.clone();
let mut rt = runtime(fake);
let fired = Arc::new(Mutex::new(false));
let f = fired.clone();
rt.use_middleware(Middleware::new().persist_before(move |ctx| {
*f.lock().unwrap() = true;
assert_eq!(
ctx.items.len(),
1,
"single-op persist batch should have one item"
);
assert_eq!(ctx.items[0].op_kind, "Put");
Ok(())
}));
rt.execute_command(
"putUser",
¶ms(json!({"userId": "u1", "name": "N", "status": "active"})),
)
.await
.unwrap();
assert!(*fired.lock().unwrap(), "W3 did not fire");
assert_eq!(rec.lock().unwrap().puts.len(), 1);
}
#[tokio::test]
async fn w4_persist_after_fires() {
let fake = FakeClient::default();
let mut rt = runtime(fake);
let fired = Arc::new(Mutex::new(false));
let f = fired.clone();
rt.use_middleware(Middleware::new().persist_after(move |ctx, _results| {
*f.lock().unwrap() = true;
assert_eq!(ctx.items.len(), 1);
}));
rt.execute_command(
"putUser",
¶ms(json!({"userId": "u1", "name": "N", "status": "active"})),
)
.await
.unwrap();
assert!(*fired.lock().unwrap(), "W4 did not fire");
}
#[tokio::test]
async fn w5_write_on_error_recovers() {
let fake = FakeClient::default();
*fake.fail_next_write.lock().unwrap() = Some("ConditionalCheckFailed".into());
let mut rt = runtime(fake);
let fired = Arc::new(Mutex::new(false));
let f = fired.clone();
rt.use_middleware(Middleware::new().write_on_error(move |_ctx, err| {
*f.lock().unwrap() = true;
assert!(err.message.contains("ConditionalCheckFailed"));
Recovery::Recover(Json::Null) }));
rt.execute_command(
"putUser",
¶ms(json!({"userId": "u1", "name": "N", "status": "active"})),
)
.await
.unwrap();
assert!(*fired.lock().unwrap(), "W5 did not fire");
}
#[tokio::test]
async fn ordering_before_fifo_after_lifo() {
let fake = FakeClient::default();
*fake.get_item_response.lock().unwrap() = Some(HashMap::from([("name".into(), s("B"))]));
let mut rt = runtime(fake);
let order = Arc::new(Mutex::new(Vec::<&'static str>::new()));
let o1b = order.clone();
let o1a = order.clone();
rt.use_middleware(
Middleware::new()
.read_before(move |_c| {
o1b.lock().unwrap().push("before1");
Ok(())
})
.read_after(move |_c, r| {
o1a.lock().unwrap().push("after1");
r
}),
);
let o2b = order.clone();
let o2a = order.clone();
rt.use_middleware(
Middleware::new()
.read_before(move |_c| {
o2b.lock().unwrap().push("before2");
Ok(())
})
.read_after(move |_c, r| {
o2a.lock().unwrap().push("after2");
r
}),
);
rt.execute_query("getUser", ¶ms(json!({"userId": "b"})))
.await
.unwrap();
let seq = order.lock().unwrap().clone();
assert_eq!(
seq,
vec!["before1", "before2", "after2", "after1"],
"hook ordering wrong"
);
}
#[tokio::test]
async fn transaction_w1_fires_per_logical_op() {
let fake = FakeClient::default();
let rec = fake.rec.clone();
let mut rt = runtime(fake);
let kinds = Arc::new(Mutex::new(Vec::<String>::new()));
let txns = Arc::new(Mutex::new(Vec::<Option<u64>>::new()));
let k = kinds.clone();
let t = txns.clone();
rt.use_middleware(Middleware::new().write_before(move |ctx| {
k.lock().unwrap().push(ctx.kind.clone());
t.lock().unwrap().push(ctx.transaction);
Ok(())
}));
rt.execute_transaction(
"putManyUsers",
¶ms(json!({"users": [
{"userId": "a", "name": "A"},
{"userId": "b", "name": "B"},
{"userId": "c", "name": "C"}
]})),
)
.await
.unwrap();
assert_eq!(
kinds.lock().unwrap().clone(),
vec!["put", "put", "put"],
"W1 should fire once per logical op"
);
let txns = txns.lock().unwrap().clone();
assert_eq!(txns.len(), 3);
assert!(
txns.iter().all(|t| t.is_some() && *t == txns[0]),
"all logical ops must share one transaction id"
);
assert_eq!(
rec.lock().unwrap().transacts.len(),
1,
"one atomic batch sent"
);
assert_eq!(
rec.lock().unwrap().transacts[0].len(),
3,
"batch has 3 items"
);
}
#[tokio::test]
async fn transaction_w1_throw_aborts_whole_batch() {
let fake = FakeClient::default();
let rec = fake.rec.clone();
let mut rt = runtime(fake);
rt.use_middleware(Middleware::new().write_before(|ctx| {
if ctx
.input
.get("item")
.and_then(|i| i.get("userId"))
.and_then(Json::as_str)
== Some("b")
{
return Err(GraphDDBError::new("veto b"));
}
Ok(())
}));
let r = rt
.execute_transaction(
"putManyUsers",
¶ms(json!({"users": [{"userId":"a","name":"A"},{"userId":"b","name":"B"}]})),
)
.await;
assert!(r.is_err(), "a W1 throw must abort the transaction");
assert!(
rec.lock().unwrap().transacts.is_empty(),
"nothing should have been sent after a W1 abort"
);
}
#[tokio::test]
async fn transaction_w3_fires_once_for_batch() {
let fake = FakeClient::default();
let mut rt = runtime(fake);
let calls = Arc::new(Mutex::new(0usize));
let item_counts = Arc::new(Mutex::new(Vec::<usize>::new()));
let c = calls.clone();
let ic = item_counts.clone();
rt.use_middleware(Middleware::new().persist_before(move |ctx| {
*c.lock().unwrap() += 1;
ic.lock().unwrap().push(ctx.items.len());
Ok(())
}));
rt.execute_transaction(
"putManyUsers",
¶ms(json!({"users": [{"userId":"a","name":"A"},{"userId":"b","name":"B"}]})),
)
.await
.unwrap();
assert_eq!(
*calls.lock().unwrap(),
1,
"W3 must fire exactly once per atomic batch"
);
assert_eq!(
item_counts.lock().unwrap().clone(),
vec![2],
"W3 saw the full 2-item batch"
);
}
#[tokio::test]
async fn getitem_sends_projection_expression_fast_path() {
let fake = FakeClient::default();
*fake.get_item_response.lock().unwrap() = Some(HashMap::from([
("userId".into(), s("alice")),
("name".into(), s("Alice")),
("status".into(), s("active")),
]));
let rec = fake.rec.clone();
let rt = runtime(fake);
rt.execute_query("getUser", ¶ms(json!({"userId": "alice"})))
.await
.unwrap();
let got = &rec.lock().unwrap().get_items[0];
let pe = got
.projection_expression
.clone()
.expect("GetItem must send a ProjectionExpression (#237)");
let names = got.expression_attribute_names.clone().unwrap_or_default();
let projected: std::collections::HashSet<String> = pe
.split(',')
.map(|t| names.get(t.trim()).cloned().unwrap_or_default())
.collect();
for f in ["userId", "name", "status", "PK", "SK"] {
assert!(
projected.contains(f),
"projection missing {f}: {projected:?}"
);
}
}
#[tokio::test]
async fn query_sends_projection_expression_fast_path() {
let fake = FakeClient::default();
*fake.query_response.lock().unwrap() = vec![HashMap::from([
("userId".into(), s("u1")),
("name".into(), s("U1")),
])];
let rec = fake.rec.clone();
let rt = runtime(fake);
rt.execute_query("listUsers", ¶ms(json!({"userId": "u1"})))
.await
.unwrap();
let got = &rec.lock().unwrap().queries[0];
let pe = got
.projection_expression
.clone()
.expect("Query must send a ProjectionExpression (#237)");
let projected: std::collections::HashSet<String> = pe
.split(',')
.map(|t| {
got.expression_attribute_names
.get(t.trim())
.cloned()
.unwrap_or_default()
})
.collect();
for f in ["userId", "name", "PK"] {
assert!(
projected.contains(f),
"projection missing {f}: {projected:?}"
);
}
assert_eq!(
got.expression_attribute_names.get("#k0"),
Some(&"PK".to_string())
);
}
#[tokio::test]
async fn getitem_sends_projection_expression_slow_path() {
let fake = FakeClient::default();
*fake.get_item_response.lock().unwrap() = Some(HashMap::from([
("userId".into(), s("alice")),
("name".into(), s("Alice")),
]));
let rec = fake.rec.clone();
let mut rt = runtime(fake);
rt.use_middleware(Middleware::new().read_before(|_ctx| Ok(())));
rt.execute_query("getUser", ¶ms(json!({"userId": "alice"})))
.await
.unwrap();
let got = &rec.lock().unwrap().get_items[0];
assert!(
got.projection_expression.is_some(),
"GetItem must send a ProjectionExpression on the middleware path too (#237)"
);
}