athena_rs 3.22.1

Hyper performant polyglot Database driver
Documentation
use once_cell::sync::Lazy;
use serde_json::{Map, Value};
use std::collections::HashMap;
use std::sync::Mutex as StdMutex;
use std::sync::MutexGuard as StdMutexGaurd;
use std::time::{Duration as StdDuration, Instant as StdInstant};

use super::config::recent_unique_conflict_cache_enabled;

#[derive(Clone)]
struct RecentUniqueViolationEntry {
    constraint: Option<String>,
    recorded_at: StdInstant,
}

const RECENT_UNIQUE_VIOLATION_TTL: StdDuration = StdDuration::from_secs(180);
const RECENT_UNIQUE_VIOLATION_MAX_ENTRIES: usize = 20_000;

static RECENT_UNIQUE_VIOLATION_CACHE: Lazy<StdMutex<HashMap<String, RecentUniqueViolationEntry>>> =
    Lazy::new(|| StdMutex::new(HashMap::new()));

pub(crate) fn unique_violation_message(constraint: Option<&str>) -> String {
    match constraint {
        Some(name) => format!(
            "Insert would create a duplicate record (unique constraint '{}'). Use /gateway/update to modify the existing record or change the unique field values.",
            name
        ),
        None => "Insert would create a duplicate record. Use /gateway/update to modify the existing record or change the unique field values.".to_string(),
    }
}

fn normalized_json_for_signature(value: &Value) -> Value {
    match value {
        Value::Object(map) => {
            let mut entries: Vec<(&String, &Value)> = map.iter().collect();
            entries.sort_by(|(left, _), (right, _)| left.cmp(right));
            let mut normalized = Map::new();
            for (key, child) in entries {
                normalized.insert(key.clone(), normalized_json_for_signature(child));
            }
            Value::Object(normalized)
        }
        Value::Array(items) => Value::Array(
            items
                .iter()
                .map(normalized_json_for_signature)
                .collect::<Vec<_>>(),
        ),
        _ => value.clone(),
    }
}

pub(crate) fn build_insert_duplicate_signature(
    client_name: &str,
    table_name: &str,
    insert_body: &Value,
) -> Option<String> {
    if table_name.trim().is_empty() {
        return None;
    }
    let normalized: Value = normalized_json_for_signature(insert_body);
    Some(format!(
        "{}\u{001f}{}\u{001f}{}",
        client_name.trim(),
        table_name.trim(),
        normalized
    ))
}

pub(crate) fn lookup_recent_unique_violation(signature: &str) -> Option<Option<String>> {
    if !recent_unique_conflict_cache_enabled() {
        return None;
    }
    let mut cache: StdMutexGaurd<'_, HashMap<String, RecentUniqueViolationEntry>> =
        RECENT_UNIQUE_VIOLATION_CACHE.lock().ok()?;
    let entry: RecentUniqueViolationEntry = cache.get(signature).cloned()?;
    if entry.recorded_at.elapsed() > RECENT_UNIQUE_VIOLATION_TTL {
        cache.remove(signature);
        return None;
    }
    Some(entry.constraint)
}

pub(crate) fn store_recent_unique_violation(signature: String, constraint: Option<String>) {
    if !recent_unique_conflict_cache_enabled() {
        return;
    }
    let Ok(mut cache) = RECENT_UNIQUE_VIOLATION_CACHE.lock() else {
        return;
    };

    if cache.len() >= RECENT_UNIQUE_VIOLATION_MAX_ENTRIES {
        let oldest_key: Option<String> = cache
            .iter()
            .min_by_key(|(_, entry)| entry.recorded_at)
            .map(|(key, _)| key.clone());
        if let Some(oldest_key) = oldest_key {
            cache.remove(&oldest_key);
        }
    }

    cache.insert(
        signature,
        RecentUniqueViolationEntry {
            constraint,
            recorded_at: StdInstant::now(),
        },
    );
}

pub(crate) fn remember_unique_violation_from_insert(
    client_name: &str,
    table_name: &str,
    insert_body: &Value,
    error_code: &str,
    details: &Value,
) {
    if error_code != "unique_violation" {
        return;
    }

    let Some(signature) = build_insert_duplicate_signature(client_name, table_name, insert_body)
    else {
        return;
    };

    let constraint: Option<String> = details
        .get("constraint")
        .and_then(Value::as_str)
        .map(str::to_string);
    store_recent_unique_violation(signature, constraint);
}