use std::collections::HashMap;
use async_trait::async_trait;
use aws_sdk_dynamodb::types::{AttributeValue, KeysAndAttributes, TransactWriteItem, WriteRequest};
use crate::errors::{GraphDDBError, Result};
pub type Item = HashMap<String, AttributeValue>;
#[derive(Debug, Clone, Default)]
pub struct GetItemInput {
pub table_name: String,
pub key: Item,
pub consistent_read: bool,
pub projection_expression: Option<String>,
pub expression_attribute_names: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Default)]
pub struct GetItemOutput {
pub item: Option<Item>,
}
#[derive(Debug, Clone, Default)]
pub struct QueryInput {
pub table_name: String,
pub key_condition_expression: String,
pub expression_attribute_names: HashMap<String, String>,
pub expression_attribute_values: HashMap<String, AttributeValue>,
pub index_name: Option<String>,
pub filter_expression: Option<String>,
pub limit: Option<i32>,
pub exclusive_start_key: Option<Item>,
pub consistent_read: bool,
}
#[derive(Debug, Clone, Default)]
pub struct QueryOutput {
pub items: Vec<Item>,
pub last_evaluated_key: Option<Item>,
}
#[derive(Debug, Clone, Default)]
pub struct WriteOutput {
pub attributes: Option<Item>,
}
#[async_trait]
pub trait DynamoClient: Send + Sync {
async fn get_item(&self, input: GetItemInput) -> Result<GetItemOutput>;
async fn query(&self, input: QueryInput) -> Result<QueryOutput>;
async fn put_item(
&self,
table_name: &str,
item: Item,
condition_expression: Option<String>,
names: Option<HashMap<String, String>>,
values: Option<HashMap<String, AttributeValue>>,
return_all_old: bool,
) -> Result<WriteOutput>;
#[allow(clippy::too_many_arguments)]
async fn update_item(
&self,
table_name: &str,
key: Item,
update_expression: Option<String>,
condition_expression: Option<String>,
names: Option<HashMap<String, String>>,
values: Option<HashMap<String, AttributeValue>>,
return_all_old: bool,
) -> Result<WriteOutput>;
async fn delete_item(
&self,
table_name: &str,
key: Item,
condition_expression: Option<String>,
names: Option<HashMap<String, String>>,
values: Option<HashMap<String, AttributeValue>>,
return_all_old: bool,
) -> Result<WriteOutput>;
async fn batch_get_item(
&self,
table_name: &str,
keys: Vec<Item>,
projection_expression: Option<String>,
names: Option<HashMap<String, String>>,
) -> Result<(Vec<Item>, Vec<Item>)>;
async fn batch_write_item(
&self,
table_name: &str,
requests: Vec<WriteRequest>,
) -> Result<Vec<WriteRequest>>;
async fn transact_write_items(&self, items: Vec<TransactWriteItem>) -> Result<()>;
}
pub struct AwsDynamoClient {
client: aws_sdk_dynamodb::Client,
}
impl AwsDynamoClient {
pub fn new(client: aws_sdk_dynamodb::Client) -> Self {
Self { client }
}
}
fn op_err(op: &str, e: impl std::fmt::Display) -> GraphDDBError {
GraphDDBError::operation_execution(format!("DynamoDB {op} failed: {e}"))
}
#[async_trait]
impl DynamoClient for AwsDynamoClient {
async fn get_item(&self, input: GetItemInput) -> Result<GetItemOutput> {
let mut req = self
.client
.get_item()
.table_name(&input.table_name)
.set_key(Some(input.key))
.consistent_read(input.consistent_read);
if let Some(pe) = input.projection_expression {
req = req.projection_expression(pe);
}
if let Some(names) = input.expression_attribute_names {
req = req.set_expression_attribute_names(Some(names));
}
let resp = req.send().await.map_err(|e| op_err("GetItem", e))?;
Ok(GetItemOutput { item: resp.item })
}
async fn query(&self, input: QueryInput) -> Result<QueryOutput> {
let mut req = self
.client
.query()
.table_name(&input.table_name)
.key_condition_expression(&input.key_condition_expression)
.set_expression_attribute_names(Some(input.expression_attribute_names))
.set_expression_attribute_values(Some(input.expression_attribute_values))
.consistent_read(input.consistent_read);
if let Some(idx) = input.index_name {
req = req.index_name(idx);
}
if let Some(fe) = input.filter_expression {
req = req.filter_expression(fe);
}
if let Some(lim) = input.limit {
req = req.limit(lim);
}
if let Some(esk) = input.exclusive_start_key {
req = req.set_exclusive_start_key(Some(esk));
}
let resp = req.send().await.map_err(|e| op_err("Query", e))?;
Ok(QueryOutput {
items: resp.items.unwrap_or_default(),
last_evaluated_key: resp.last_evaluated_key,
})
}
async fn put_item(
&self,
table_name: &str,
item: Item,
condition_expression: Option<String>,
names: Option<HashMap<String, String>>,
values: Option<HashMap<String, AttributeValue>>,
return_all_old: bool,
) -> Result<WriteOutput> {
let mut req = self
.client
.put_item()
.table_name(table_name)
.set_item(Some(item));
if let Some(ce) = condition_expression {
req = req.condition_expression(ce);
}
if let Some(n) = names {
req = req.set_expression_attribute_names(Some(n));
}
if let Some(v) = values {
req = req.set_expression_attribute_values(Some(v));
}
if return_all_old {
req = req.return_values(aws_sdk_dynamodb::types::ReturnValue::AllOld);
}
let resp = req.send().await.map_err(|e| op_err("PutItem", e))?;
Ok(WriteOutput {
attributes: resp.attributes,
})
}
async fn update_item(
&self,
table_name: &str,
key: Item,
update_expression: Option<String>,
condition_expression: Option<String>,
names: Option<HashMap<String, String>>,
values: Option<HashMap<String, AttributeValue>>,
return_all_old: bool,
) -> Result<WriteOutput> {
let mut req = self
.client
.update_item()
.table_name(table_name)
.set_key(Some(key));
if let Some(ue) = update_expression {
req = req.update_expression(ue);
}
if let Some(ce) = condition_expression {
req = req.condition_expression(ce);
}
if let Some(n) = names {
req = req.set_expression_attribute_names(Some(n));
}
if let Some(v) = values {
req = req.set_expression_attribute_values(Some(v));
}
if return_all_old {
req = req.return_values(aws_sdk_dynamodb::types::ReturnValue::AllOld);
}
let resp = req.send().await.map_err(|e| op_err("UpdateItem", e))?;
Ok(WriteOutput {
attributes: resp.attributes,
})
}
async fn delete_item(
&self,
table_name: &str,
key: Item,
condition_expression: Option<String>,
names: Option<HashMap<String, String>>,
values: Option<HashMap<String, AttributeValue>>,
return_all_old: bool,
) -> Result<WriteOutput> {
let mut req = self
.client
.delete_item()
.table_name(table_name)
.set_key(Some(key));
if let Some(ce) = condition_expression {
req = req.condition_expression(ce);
}
if let Some(n) = names {
req = req.set_expression_attribute_names(Some(n));
}
if let Some(v) = values {
req = req.set_expression_attribute_values(Some(v));
}
if return_all_old {
req = req.return_values(aws_sdk_dynamodb::types::ReturnValue::AllOld);
}
let resp = req.send().await.map_err(|e| op_err("DeleteItem", e))?;
Ok(WriteOutput {
attributes: resp.attributes,
})
}
async fn batch_get_item(
&self,
table_name: &str,
keys: Vec<Item>,
projection_expression: Option<String>,
names: Option<HashMap<String, String>>,
) -> Result<(Vec<Item>, Vec<Item>)> {
let mut kaa = KeysAndAttributes::builder().set_keys(Some(keys));
if let Some(pe) = projection_expression {
kaa = kaa.projection_expression(pe);
}
if let Some(n) = names {
kaa = kaa.set_expression_attribute_names(Some(n));
}
let kaa = kaa.build().map_err(|e| op_err("BatchGetItem", e))?;
let resp = self
.client
.batch_get_item()
.request_items(table_name, kaa)
.send()
.await
.map_err(|e| op_err("BatchGetItem", e))?;
let responses = resp
.responses
.and_then(|mut m| m.remove(table_name))
.unwrap_or_default();
let unprocessed = resp
.unprocessed_keys
.and_then(|mut m| m.remove(table_name))
.map(|kaa| kaa.keys)
.unwrap_or_default();
Ok((responses, unprocessed))
}
async fn batch_write_item(
&self,
table_name: &str,
requests: Vec<WriteRequest>,
) -> Result<Vec<WriteRequest>> {
let resp = self
.client
.batch_write_item()
.request_items(table_name, requests)
.send()
.await
.map_err(|e| op_err("BatchWriteItem", e))?;
Ok(resp
.unprocessed_items
.and_then(|mut m| m.remove(table_name))
.unwrap_or_default())
}
async fn transact_write_items(&self, items: Vec<TransactWriteItem>) -> Result<()> {
self.client
.transact_write_items()
.set_transact_items(Some(items))
.send()
.await
.map_err(|e| op_err("TransactWriteItems", e))?;
Ok(())
}
}