Skip to main content

alien_bindings/providers/kv/
aws_dynamodb.rs

1use crate::error::{map_cloud_client_error, ErrorData, Result};
2use crate::traits::{Binding, Kv, PutOptions, ScanResult};
3use alien_aws_clients::dynamodb::*;
4use alien_error::AlienError;
5use async_trait::async_trait;
6use base64::{prelude::BASE64_STANDARD, Engine};
7use chrono::Utc;
8use std::collections::HashMap;
9use std::fmt::{Debug, Formatter};
10
11use super::{validate_key, validate_value};
12
13/// AWS DynamoDB implementation of the KV trait.
14///
15/// Credential refresh is handled automatically by the underlying `AwsCredentialProvider`
16/// inside `DynamoDbClient`.
17pub struct AwsDynamodbKv {
18    client: DynamoDbClient,
19    table_name: String,
20}
21
22impl Debug for AwsDynamodbKv {
23    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("AwsDynamodbKv")
25            .field("table_name", &self.table_name)
26            .finish()
27    }
28}
29
30impl AwsDynamodbKv {
31    pub fn new(table_name: String, client: DynamoDbClient) -> Self {
32        Self { client, table_name }
33    }
34
35    /// Creates a hash bucket for load distribution
36    fn hash_bucket(&self, key: &str) -> String {
37        use std::collections::hash_map::DefaultHasher;
38        use std::hash::{Hash, Hasher};
39
40        let mut hasher = DefaultHasher::new();
41        key.hash(&mut hasher);
42        let bucket_id = hasher.finish() % 16; // 16 buckets for load distribution
43        format!("bucket_{}", bucket_id)
44    }
45
46    /// Checks if an item has expired based on TTL
47    fn is_expired(&self, ttl_epoch: Option<i64>) -> bool {
48        if let Some(ttl_timestamp) = ttl_epoch {
49            let now = Utc::now().timestamp();
50            now >= ttl_timestamp
51        } else {
52            false
53        }
54    }
55}
56
57impl Binding for AwsDynamodbKv {}
58
59#[async_trait]
60impl Kv for AwsDynamodbKv {
61    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
62        validate_key(key)?;
63
64        let bucket = self.hash_bucket(key);
65        let mut primary_key = HashMap::new();
66        primary_key.insert("pk".to_string(), AttributeValue::s(bucket));
67        primary_key.insert("sk".to_string(), AttributeValue::s(key.to_string()));
68
69        let request = GetItemRequest::builder()
70            .table_name(self.table_name.clone())
71            .key(primary_key)
72            .build();
73
74        let response = self.client.get_item(request).await.map_err(|e| {
75            map_cloud_client_error(
76                e,
77                format!("Failed to get item with key '{}'", key),
78                Some(key.to_string()),
79            )
80        })?;
81
82        if let Some(item) = response.item {
83            // Check TTL expiry (logical expiry contract)
84            if let Some(ttl_attr) = item.get("ttl") {
85                if let Some(ttl_epoch) = ttl_attr.n.as_ref().and_then(|s| s.parse::<i64>().ok()) {
86                    if self.is_expired(Some(ttl_epoch)) {
87                        return Ok(None); // Logically expired
88                    }
89                }
90            }
91
92            let value = item
93                .get("value")
94                .and_then(|attr| attr.b.as_ref())
95                .and_then(|base64_value| BASE64_STANDARD.decode(base64_value).ok())
96                .ok_or_else(|| {
97                    AlienError::new(ErrorData::CloudPlatformError {
98                        message: format!("Missing or invalid value attribute for key '{}'", key),
99                        resource_id: Some(key.to_string()),
100                    })
101                })?;
102
103            Ok(Some(value))
104        } else {
105            Ok(None)
106        }
107    }
108
109    async fn put(&self, key: &str, value: Vec<u8>, options: Option<PutOptions>) -> Result<bool> {
110        validate_key(key)?;
111        validate_value(&value)?;
112
113        let bucket = self.hash_bucket(key);
114        let options = options.unwrap_or_default();
115
116        let mut item = HashMap::new();
117        item.insert("pk".to_string(), AttributeValue::s(bucket));
118        item.insert("sk".to_string(), AttributeValue::s(key.to_string()));
119        item.insert(
120            "value".to_string(),
121            AttributeValue::b(BASE64_STANDARD.encode(&value)),
122        );
123
124        if let Some(ttl) = options.ttl {
125            let expires_at = (Utc::now() + ttl).timestamp();
126            item.insert("ttl".to_string(), AttributeValue::n(expires_at.to_string()));
127        }
128
129        let request = if options.if_not_exists {
130            PutItemRequest::builder()
131                .table_name(self.table_name.clone())
132                .item(item)
133                .condition_expression(
134                    "attribute_not_exists(pk) AND attribute_not_exists(sk)".to_string(),
135                )
136                .build()
137        } else {
138            PutItemRequest::builder()
139                .table_name(self.table_name.clone())
140                .item(item)
141                .build()
142        };
143
144        match self.client.put_item(request).await {
145            Ok(_) => Ok(true),
146            Err(e) => {
147                // Check if this is a conditional check failure for if_not_exists
148                if options.if_not_exists {
149                    if let Some(alien_client_core::ErrorData::RemoteResourceConflict { .. }) =
150                        &e.error
151                    {
152                        return Ok(false);
153                    }
154                }
155                Err(map_cloud_client_error(
156                    e,
157                    format!("Failed to put item with key '{}'", key),
158                    Some(key.to_string()),
159                ))
160            }
161        }
162    }
163
164    async fn delete(&self, key: &str) -> Result<()> {
165        validate_key(key)?;
166
167        let bucket = self.hash_bucket(key);
168        let mut primary_key = HashMap::new();
169        primary_key.insert("pk".to_string(), AttributeValue::s(bucket));
170        primary_key.insert("sk".to_string(), AttributeValue::s(key.to_string()));
171
172        let request = DeleteItemRequest::builder()
173            .table_name(self.table_name.clone())
174            .key(primary_key)
175            .build();
176
177        self.client.delete_item(request).await.map_err(|e| {
178            map_cloud_client_error(
179                e,
180                format!("Failed to delete item with key '{}'", key),
181                Some(key.to_string()),
182            )
183        })?;
184
185        Ok(())
186    }
187
188    async fn exists(&self, key: &str) -> Result<bool> {
189        validate_key(key)?;
190
191        let bucket = self.hash_bucket(key);
192        let mut primary_key = HashMap::new();
193        primary_key.insert("pk".to_string(), AttributeValue::s(bucket));
194        primary_key.insert("sk".to_string(), AttributeValue::s(key.to_string()));
195
196        // Use expression attribute names to avoid reserved keyword 'ttl'
197        let mut expression_attribute_names = HashMap::new();
198        expression_attribute_names.insert("#ttl".to_string(), "ttl".to_string());
199
200        let request = GetItemRequest::builder()
201            .table_name(self.table_name.clone())
202            .key(primary_key)
203            .projection_expression("pk, #ttl".to_string()) // Get key and TTL for expiry check
204            .expression_attribute_names(expression_attribute_names)
205            .build();
206
207        let response = self.client.get_item(request).await.map_err(|e| {
208            map_cloud_client_error(
209                e,
210                format!("Failed to check existence of item with key '{}'", key),
211                Some(key.to_string()),
212            )
213        })?;
214
215        if let Some(item) = response.item {
216            // Check TTL expiry (logical expiry contract)
217            if let Some(ttl_attr) = item.get("ttl") {
218                if let Some(ttl_epoch) = ttl_attr.n.as_ref().and_then(|s| s.parse::<i64>().ok()) {
219                    if self.is_expired(Some(ttl_epoch)) {
220                        return Ok(false); // Logically expired
221                    }
222                }
223            }
224            Ok(true)
225        } else {
226            Ok(false)
227        }
228    }
229
230    async fn scan_prefix(
231        &self,
232        prefix: &str,
233        limit: Option<usize>,
234        _cursor: Option<String>,
235    ) -> Result<ScanResult> {
236        validate_key(prefix)?; // Prefix follows same key validation rules
237
238        // For prefix scans with hash-based bucketing, we must query ALL buckets
239        // since items with the same prefix can be distributed across different buckets
240        let mut all_items = Vec::new();
241        let mut total_fetched = 0;
242        let limit = limit.unwrap_or(1000);
243
244        // For simplicity, we'll query all 16 buckets sequentially
245        // In production, this could be parallelized for better performance
246        for bucket_id in 0..16 {
247            if total_fetched >= limit {
248                break;
249            }
250
251            let bucket = format!("bucket_{}", bucket_id);
252            let mut expression_attribute_values = HashMap::new();
253            expression_attribute_values.insert(":bucket".to_string(), AttributeValue::s(bucket));
254            expression_attribute_values
255                .insert(":prefix".to_string(), AttributeValue::s(prefix.to_string()));
256
257            // Build request for this bucket
258            let request = QueryRequest::builder()
259                .table_name(self.table_name.clone())
260                .key_condition_expression("pk = :bucket AND begins_with(sk, :prefix)".to_string())
261                .expression_attribute_values(expression_attribute_values)
262                .limit((limit - total_fetched) as i32)
263                .build();
264
265            let response = self.client.query(request).await.map_err(|e| {
266                map_cloud_client_error(
267                    e,
268                    format!("Failed to scan prefix '{}' in bucket {}", prefix, bucket_id),
269                    Some(prefix.to_string()),
270                )
271            })?;
272
273            // Process items from this bucket
274            for item in response.items {
275                if total_fetched >= limit {
276                    break;
277                }
278
279                // Check TTL expiry
280                if let Some(ttl_attr) = item.get("ttl") {
281                    if let Some(ttl_epoch) = ttl_attr.n.as_ref().and_then(|s| s.parse::<i64>().ok())
282                    {
283                        if self.is_expired(Some(ttl_epoch)) {
284                            continue; // Skip expired items
285                        }
286                    }
287                }
288
289                if let (Some(key_attr), Some(value_attr)) = (item.get("sk"), item.get("value")) {
290                    if let (Some(key), Some(base64_value)) =
291                        (key_attr.s.as_ref(), value_attr.b.as_ref())
292                    {
293                        if let Ok(value) = BASE64_STANDARD.decode(base64_value) {
294                            all_items.push((key.clone(), value));
295                            total_fetched += 1;
296                        }
297                    }
298                }
299            }
300        }
301
302        // For simplicity, we're not implementing cursor-based pagination across buckets
303        // In production, this would require more complex cursor state management
304        Ok(ScanResult {
305            items: all_items,
306            next_cursor: None,
307        })
308    }
309}