use std::collections::HashMap;
use std::time::Duration;
use aws_sdk_dynamodb::types::WriteRequest;
use serde_json::Value as Json;
use crate::attribute::deserialize;
use crate::client::{DynamoClient, Item};
use crate::errors::{GraphDDBError, Result};
pub const BATCH_GET_MAX_KEYS: usize = 100;
pub const BATCH_WRITE_MAX_ITEMS: usize = 25;
pub const BATCH_MAX_RETRY_ATTEMPTS: usize = 10;
pub fn compute_backoff_delay(attempt: u32) -> Duration {
let ms = (50.0f64 * 2f64.powi(attempt as i32 - 1)).min(1000.0);
Duration::from_millis(ms as u64)
}
pub fn serialize_key(key: &HashMap<String, Json>) -> String {
let mut pairs: Vec<(&String, &Json)> = key.iter().collect();
pairs.sort_by(|a, b| a.0.cmp(b.0));
let arr: Vec<Json> = pairs
.into_iter()
.map(|(k, v)| Json::Array(vec![Json::String(k.clone()), v.clone()]))
.collect();
serde_json::to_string(&Json::Array(arr)).unwrap_or_default()
}
fn chunk<T: Clone>(items: &[T], size: usize) -> Vec<Vec<T>> {
if size == 0 {
return vec![items.to_vec()];
}
items.chunks(size).map(|c| c.to_vec()).collect()
}
#[derive(Clone, Copy)]
pub struct BatchOptions {
pub max_batch_get_items: usize,
pub sleep: bool,
}
impl Default for BatchOptions {
fn default() -> Self {
Self {
max_batch_get_items: BATCH_GET_MAX_KEYS,
sleep: true,
}
}
}
pub async fn batch_get(
client: &dyn DynamoClient,
physical_table: &str,
serialized_keys: Vec<Item>,
projection_expression: Option<String>,
names: Option<HashMap<String, String>>,
opts: BatchOptions,
) -> Result<Vec<Item>> {
if serialized_keys.is_empty() {
return Ok(vec![]);
}
let mut items: Vec<Item> = Vec::new();
for batch in chunk(&serialized_keys, opts.max_batch_get_items) {
let got = get_chunk(
client,
physical_table,
batch,
&projection_expression,
&names,
opts,
)
.await?;
items.extend(got);
}
Ok(items)
}
async fn get_chunk(
client: &dyn DynamoClient,
physical_table: &str,
keys: Vec<Item>,
projection_expression: &Option<String>,
names: &Option<HashMap<String, String>>,
opts: BatchOptions,
) -> Result<Vec<Item>> {
let mut pending = keys;
let mut attempt: usize = 0;
let mut out: Vec<Item> = Vec::new();
while !pending.is_empty() {
let (responses, unprocessed) = client
.batch_get_item(
physical_table,
pending.clone(),
projection_expression.clone(),
names.clone(),
)
.await?;
out.extend(responses);
if unprocessed.is_empty() {
break;
}
if attempt >= BATCH_MAX_RETRY_ATTEMPTS {
return Err(GraphDDBError::operation_execution(format!(
"BatchGet exceeded the maximum of {BATCH_MAX_RETRY_ATTEMPTS} retry attempts \
with {} key(s) still unprocessed for table '{physical_table}' (likely \
sustained throttling).",
unprocessed.len()
)));
}
pending = unprocessed;
attempt += 1;
if opts.sleep {
tokio::time::sleep(compute_backoff_delay(attempt as u32)).await;
}
}
Ok(out)
}
pub async fn batch_write(
client: &dyn DynamoClient,
physical_table: &str,
requests: Vec<WriteRequest>,
sleep: bool,
) -> Result<()> {
if requests.is_empty() {
return Ok(());
}
for batch in chunk(&requests, BATCH_WRITE_MAX_ITEMS) {
write_chunk(client, physical_table, batch, sleep).await?;
}
Ok(())
}
async fn write_chunk(
client: &dyn DynamoClient,
physical_table: &str,
requests: Vec<WriteRequest>,
sleep: bool,
) -> Result<()> {
let mut pending = requests;
let mut attempt: usize = 0;
while !pending.is_empty() {
let unprocessed = client
.batch_write_item(physical_table, pending.clone())
.await?;
if unprocessed.is_empty() {
break;
}
if attempt >= BATCH_MAX_RETRY_ATTEMPTS {
return Err(GraphDDBError::operation_execution(format!(
"BatchWrite exceeded the maximum of {BATCH_MAX_RETRY_ATTEMPTS} retry attempts \
with {} item(s) still unprocessed for table '{physical_table}' (likely \
sustained throttling).",
unprocessed.len()
)));
}
pending = unprocessed;
attempt += 1;
if sleep {
tokio::time::sleep(compute_backoff_delay(attempt as u32)).await;
}
}
Ok(())
}
pub fn marker_for(item: &Item, key_attrs: &[String]) -> String {
let mut key: HashMap<String, Json> = HashMap::new();
for attr in key_attrs {
let v = item
.get(attr)
.map(deserialize)
.map(value_to_json_scalar)
.unwrap_or(Json::Null);
key.insert(attr.clone(), v);
}
serialize_key(&key)
}
fn value_to_json_scalar(v: crate::value::Value) -> Json {
use crate::value::Value;
match v {
Value::S(s) => Json::String(s),
Value::Bool(b) => Json::Bool(b),
Value::Null => Json::Null,
Value::N(n) => {
Json::String(n)
}
other => Json::String(format!("{other:?}")),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backoff_schedule() {
assert_eq!(compute_backoff_delay(1), Duration::from_millis(50));
assert_eq!(compute_backoff_delay(2), Duration::from_millis(100));
assert_eq!(compute_backoff_delay(5), Duration::from_millis(800));
assert_eq!(compute_backoff_delay(6), Duration::from_millis(1000));
assert_eq!(compute_backoff_delay(10), Duration::from_millis(1000));
}
#[test]
fn chunking() {
let v: Vec<i32> = (0..250).collect();
let chunks = chunk(&v, 100);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].len(), 100);
assert_eq!(chunks[2].len(), 50);
}
#[test]
fn serialize_key_sorted() {
let mut k = HashMap::new();
k.insert("SK".to_string(), Json::String("s".to_string()));
k.insert("PK".to_string(), Json::String("p".to_string()));
assert_eq!(serialize_key(&k), "[[\"PK\",\"p\"],[\"SK\",\"s\"]]");
}
}