use crate::table::{DynamoTable, GSITable};
use serde::Serialize;
use serde_dynamo::{AttributeValue, to_item};
use std::{
collections::{HashMap, HashSet},
fmt,
};
pub(crate) mod retry_config {
use std::time::Duration;
pub(crate) fn retry_delay(attempt: usize, initial: Duration, max: Duration) -> Duration {
let delay_ms = initial.as_millis() as u64 * 2u64.pow(attempt as u32);
let capped_delay = delay_ms.min(max.as_millis() as u64);
Duration::from_millis(capped_delay)
}
}
pub(crate) mod validation {
use super::*;
#[inline]
fn validate_key(key: &str) {
crate::assert_not_reserved_key(key);
}
#[inline]
fn validate_optional_key(key: Option<&str>) {
if let Some(k) = key {
validate_key(k);
}
}
pub(crate) fn validate_table_keys<T>()
where
T: DynamoTable,
T::PK: fmt::Display + Clone + Send + Sync + fmt::Debug,
T::SK: fmt::Display + Clone + Send + Sync + fmt::Debug,
{
validate_key(T::PARTITION_KEY);
validate_optional_key(T::SORT_KEY);
}
pub(crate) fn validate_gsi_keys<T>()
where
T: GSITable,
T::PK: fmt::Display + Clone + Send + Sync + fmt::Debug,
T::SK: fmt::Display + Clone + Send + Sync + fmt::Debug,
{
validate_table_keys::<T>();
validate_key(T::GSI_PARTITION_KEY);
validate_optional_key(T::GSI_SORT_KEY);
}
pub(crate) fn validate_aliased_field_names(field_names: &[&str], aliases: &[String]) {
if cfg!(debug_assertions) {
debug_assert_eq!(
field_names.len(),
aliases.len(),
"Each update field must have an attribute-name alias"
);
let mut seen_aliases = HashSet::with_capacity(aliases.len());
for (field, alias) in field_names.iter().zip(aliases.iter()) {
debug_assert!(!field.is_empty(), "Field name must not be empty");
debug_assert!(
alias.starts_with('#') && alias.len() > 1,
"Alias must use DynamoDB expression attribute name syntax: {alias}"
);
debug_assert_ne!(
*field,
alias.as_str(),
"Field name must not be used directly in the update expression: {field}"
);
debug_assert!(
seen_aliases.insert(alias.as_str()),
"Duplicate attribute-name alias generated: {alias}"
);
}
}
}
pub(crate) fn validate_filter_expression_values<U: Serialize>(filter_expression_values: &U) {
if cfg!(debug_assertions) {
let filter_keys =
to_item::<_, HashMap<String, AttributeValue>>(filter_expression_values)
.expect("valid serialization for validation");
for key in filter_keys.keys() {
validate_key(key);
}
}
}
#[cfg(test)]
mod tests {
use super::validate_aliased_field_names;
#[test]
fn validate_field_names_allows_reserved_words_when_updates_alias_names() {
validate_aliased_field_names(
&["status", "status_reason", "last_updated_at"],
&["#n0".to_string(), "#n1".to_string(), "#n2".to_string()],
);
}
}
}
pub(crate) mod expressions {
use aws_sdk_dynamodb::types::AttributeValue;
use std::collections::HashMap;
pub(crate) struct KeyConditionBuilder {
expression: String,
values: HashMap<String, AttributeValue>,
}
impl KeyConditionBuilder {
pub(crate) fn new() -> Self {
Self {
expression: String::new(),
values: HashMap::new(),
}
}
pub(crate) fn with_partition_key(mut self, field: &str, value: String) -> Self {
self.expression = format!("{field} = :hash_value");
let _ = self
.values
.insert(":hash_value".to_string(), AttributeValue::S(value));
self
}
pub(crate) fn with_sort_key(mut self, field: &str, value: String) -> Self {
if !self.expression.is_empty() {
self.expression.push_str(" and ");
}
self.expression.push_str(&format!("{field} = :range_value"));
let _ = self
.values
.insert(":range_value".to_string(), AttributeValue::S(value));
self
}
pub(crate) fn build(self) -> (String, HashMap<String, AttributeValue>) {
(self.expression, self.values)
}
}
}
pub(crate) mod query_builder {
use super::{DynamoTable, GSITable, expressions};
use aws_sdk_dynamodb::operation::query::builders::QueryFluentBuilder;
use aws_sdk_dynamodb::types::{AttributeValue, ReturnConsumedCapacity, Select};
use std::collections::HashMap;
use std::fmt;
pub(crate) struct QueryBuilder<'a> {
table_name: &'a str,
index_name: Option<String>,
partition_key_field: &'a str,
sort_key_field: Option<&'a str>,
table_pk_field: Option<&'a str>,
}
impl<'a> QueryBuilder<'a> {
pub(crate) fn for_table<T>() -> Self
where
T: DynamoTable,
T::PK: fmt::Display + Clone + Send + Sync + fmt::Debug,
T::SK: fmt::Display + Clone + Send + Sync + fmt::Debug,
{
Self {
table_name: T::TABLE,
index_name: None,
partition_key_field: T::PARTITION_KEY,
sort_key_field: T::SORT_KEY,
table_pk_field: None,
}
}
pub(crate) fn for_gsi<T>() -> Self
where
T: GSITable,
T::PK: fmt::Display + Clone + Send + Sync + fmt::Debug,
T::SK: fmt::Display + Clone + Send + Sync + fmt::Debug,
{
Self {
table_name: T::TABLE,
index_name: Some(T::global_index_name()),
partition_key_field: T::GSI_PARTITION_KEY,
sort_key_field: T::GSI_SORT_KEY,
table_pk_field: Some(T::PARTITION_KEY),
}
}
pub(crate) fn for_index<T>(index_name: String) -> Self
where
T: DynamoTable,
T::PK: fmt::Display + Clone + Send + Sync + fmt::Debug,
T::SK: fmt::Display + Clone + Send + Sync + fmt::Debug,
{
Self {
table_name: T::TABLE,
index_name: Some(index_name),
partition_key_field: T::PARTITION_KEY,
sort_key_field: T::SORT_KEY,
table_pk_field: None,
}
}
pub(crate) fn build_query(
&self,
client: &aws_sdk_dynamodb::Client,
partition_key: String,
sort_key: Option<String>,
exclusive_start_key: Option<String>,
exclusive_start_table_pk: Option<String>,
limit: u16,
scan_index_forward: bool,
) -> QueryFluentBuilder {
let select = if self.index_name.is_some() {
Select::AllProjectedAttributes
} else {
Select::AllAttributes
};
let mut builder = client
.query()
.table_name(self.table_name)
.select(select)
.return_consumed_capacity(if cfg!(feature = "consumed_capacity_stats") {
ReturnConsumedCapacity::Total
} else {
ReturnConsumedCapacity::None
})
.scan_index_forward(scan_index_forward)
.limit(limit as i32);
if let Some(ref index_name) = self.index_name {
builder = builder.index_name(index_name);
}
if let Some(start_key) = exclusive_start_key {
builder = builder.exclusive_start_key(
self.partition_key_field,
AttributeValue::S(partition_key.clone()),
);
if let Some(table_pk_field) = self.table_pk_field {
let table_pk_value =
exclusive_start_table_pk.unwrap_or_else(|| start_key.clone());
builder = builder
.exclusive_start_key(table_pk_field, AttributeValue::S(table_pk_value));
}
if let Some(sort_key_field) = self.sort_key_field {
builder =
builder.exclusive_start_key(sort_key_field, AttributeValue::S(start_key));
}
}
let (condition_expr, condition_values) =
self.build_key_condition(partition_key, sort_key);
builder = builder.key_condition_expression(condition_expr);
for (key, value) in condition_values {
builder = builder.expression_attribute_values(key, value);
}
builder
}
pub(crate) fn build_count_query(
&self,
client: &aws_sdk_dynamodb::Client,
partition_key: String,
) -> QueryFluentBuilder {
let mut builder = client
.query()
.table_name(self.table_name)
.select(Select::Count)
.return_consumed_capacity(if cfg!(feature = "consumed_capacity_stats") {
ReturnConsumedCapacity::Total
} else {
ReturnConsumedCapacity::None
});
if let Some(ref index_name) = self.index_name {
builder = builder.index_name(index_name);
}
let condition_expr = format!("{} = :hash_value", self.partition_key_field);
builder = builder
.key_condition_expression(condition_expr)
.expression_attribute_values(":hash_value", AttributeValue::S(partition_key));
builder
}
fn build_key_condition(
&self,
partition_key: String,
sort_key: Option<String>,
) -> (String, HashMap<String, AttributeValue>) {
let mut builder = expressions::KeyConditionBuilder::new()
.with_partition_key(self.partition_key_field, partition_key);
if let (Some(sort_key_field), Some(sort_value)) = (self.sort_key_field, sort_key) {
builder = builder.with_sort_key(sort_key_field, sort_value);
}
builder.build()
}
}
}
pub(crate) mod batch_processor {
use crate::Error;
use futures_util::{StreamExt, TryStreamExt};
use std::{cmp, future::Future};
use tokio_stream::{self as stream};
#[allow(dead_code)]
pub(crate) struct BatchProcessor {
chunk_size: usize,
concurrency: usize,
}
impl BatchProcessor {
#[allow(dead_code)]
pub(crate) fn new(chunk_size: usize, concurrency: usize) -> Self {
Self {
chunk_size,
concurrency,
}
}
#[allow(dead_code)]
pub(crate) async fn process<T, R, F, Fut, O, M>(
&self,
items: Vec<T>,
operation: F,
output: O,
merge_results: M,
) -> Result<O, Error>
where
F: Fn(Vec<T>) -> Fut + Send + Sync,
Fut: Future<Output = Result<R, Error>> + Send,
T: Send + Clone + 'static,
R: Send,
O: Send,
M: Fn(&mut O, R) -> Result<(), Error> + Send + Sync,
{
if items.is_empty() {
return Ok(output);
}
let batches: Vec<Vec<T>> = items
.chunks(self.chunk_size)
.map(|chunk| chunk.to_vec())
.collect();
let concurrency = cmp::max(1, batches.len().min(self.concurrency));
stream::iter(batches.into_iter().map(operation))
.buffer_unordered(concurrency)
.map_err(Into::<Error>::into)
.try_fold(output, |mut acc, result| {
let merge_results = &merge_results;
async move {
merge_results(&mut acc, result)?;
Ok(acc)
}
})
.await
}
}
pub(crate) const BATCH_WRITE_SIZE: usize = 25;
pub(crate) const BATCH_READ_SIZE: usize = 100;
pub(crate) const DEFAULT_CONCURRENCY: usize = 10;
}