Skip to main content

alien_bindings/providers/kv/
azure_table_storage.rs

1use crate::error::{ErrorData, Result};
2use crate::traits::{Binding, Kv, PutOptions, ScanResult};
3use alien_azure_clients::tables::{
4    AzureTableStorageClient, EntityQueryOptions, TableEntity, TableStorageApi,
5};
6use alien_error::{AlienError, Context, IntoAlienError};
7use async_trait::async_trait;
8use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::collections::HashMap;
13use std::fmt::{Debug, Formatter};
14
15use super::{validate_key, validate_value};
16
17/// Convert a KV operation to a Table Storage entity
18/// This only base64 encodes the raw bytes when creating the properties map, not in memory
19fn create_table_entity(
20    partition_key: String,
21    row_key: String,
22    value: &[u8],
23    expires_at: Option<DateTime<Utc>>,
24) -> TableEntity {
25    let mut properties = HashMap::new();
26
27    // Base64 encode the raw bytes only when storing in the properties map
28    // This keeps the original 32KB limit valid since we're not storing the encoded version in memory
29    properties.insert("Value".to_string(), Value::String(BASE64.encode(value)));
30
31    // Store creation timestamp
32    properties.insert(
33        "CreatedAt".to_string(),
34        Value::String(Utc::now().to_rfc3339()),
35    );
36
37    // Store expiration timestamp if provided
38    if let Some(expiry) = expires_at {
39        properties.insert("ExpiresAt".to_string(), Value::String(expiry.to_rfc3339()));
40    }
41
42    TableEntity {
43        partition_key,
44        row_key,
45        timestamp: None, // Azure will set this
46        properties,
47    }
48}
49
50/// Extract KV value from Table Storage entity
51fn extract_value_from_entity(entity: &TableEntity) -> Result<Vec<u8>> {
52    let value_str = entity
53        .properties
54        .get("Value")
55        .and_then(|v| v.as_str())
56        .ok_or_else(|| {
57            AlienError::new(ErrorData::InvalidInput {
58                operation_context: "Azure Table Storage KV extract value".to_string(),
59                details: "Entity missing Value property or not a string".to_string(),
60                field_name: Some("Value".to_string()),
61            })
62        })?;
63
64    // Decode base64 value
65    BASE64
66        .decode(value_str)
67        .into_alien_error()
68        .context(ErrorData::InvalidInput {
69            operation_context: "Azure Table Storage KV extract value".to_string(),
70            details: "Failed to decode base64 value".to_string(),
71            field_name: Some("Value".to_string()),
72        })
73}
74
75/// Check if entity has expired based on TTL
76fn is_entity_expired(entity: &TableEntity) -> bool {
77    if let Some(expires_at_value) = entity.properties.get("ExpiresAt") {
78        if let Some(expires_at_str) = expires_at_value.as_str() {
79            if let Ok(expires_at) = DateTime::parse_from_rfc3339(expires_at_str) {
80                return Utc::now() > expires_at.with_timezone(&Utc);
81            }
82        }
83    }
84    false
85}
86
87/// Cursor state for pagination across partitions
88#[derive(Serialize, Deserialize)]
89struct CursorState {
90    current_partition: u32,
91    partition_continuation_token: Option<String>, // Azure's NextPartitionKey + NextRowKey combined
92}
93
94/// Azure Table Storage implementation of the KV trait
95pub struct AzureTableStorageKv {
96    client: AzureTableStorageClient,
97    resource_group_name: String,
98    account_name: String,
99    table_name: String,
100    num_partitions: u32,
101}
102
103impl Debug for AzureTableStorageKv {
104    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
105        f.debug_struct("AzureTableStorageKv")
106            .field("resource_group_name", &self.resource_group_name)
107            .field("account_name", &self.account_name)
108            .field("table_name", &self.table_name)
109            .field("num_partitions", &self.num_partitions)
110            .finish()
111    }
112}
113
114impl AzureTableStorageKv {
115    pub fn new(
116        client: AzureTableStorageClient,
117        resource_group_name: String,
118        account_name: String,
119        table_name: String,
120    ) -> Self {
121        Self {
122            client,
123            resource_group_name,
124            account_name,
125            table_name,
126            num_partitions: 16, // 16 partitions for load distribution
127        }
128    }
129
130    /// Creates a hash bucket for load distribution
131    fn hash_bucket(&self, key: &str) -> u32 {
132        use std::collections::hash_map::DefaultHasher;
133        use std::hash::{Hash, Hasher};
134
135        let mut hasher = DefaultHasher::new();
136        key.hash(&mut hasher);
137        hasher.finish() as u32 % self.num_partitions
138    }
139
140    /// Splits key into partition key and row key
141    fn split_key(&self, key: &str) -> (String, String) {
142        // Use hash-based partitioning for load distribution
143        let partition_key = format!("p{}", self.hash_bucket(key));
144        (partition_key, key.to_string())
145    }
146
147    /// Combines partition key and row key back to original key
148    fn combine_key(&self, _partition_key: &str, row_key: &str) -> String {
149        row_key.to_string() // Row key contains the original key
150    }
151
152    /// Encodes cursor state as base64url JSON for safe HTTP transmission
153    fn encode_cursor(&self, state: &CursorState) -> String {
154        let json = serde_json::to_string(state).unwrap();
155        BASE64.encode(json.as_bytes())
156    }
157
158    /// Decodes cursor state from base64url JSON
159    fn decode_cursor(&self, cursor: &str) -> Result<CursorState> {
160        let decoded =
161            BASE64
162                .decode(cursor)
163                .into_alien_error()
164                .context(ErrorData::InvalidInput {
165                    operation_context: "Azure Table Storage KV cursor decoding".to_string(),
166                    details: "Invalid cursor encoding".to_string(),
167                    field_name: Some("cursor".to_string()),
168                })?;
169        let json =
170            String::from_utf8(decoded)
171                .into_alien_error()
172                .context(ErrorData::InvalidInput {
173                    operation_context: "Azure Table Storage KV cursor decoding".to_string(),
174                    details: "Invalid cursor UTF-8".to_string(),
175                    field_name: Some("cursor".to_string()),
176                })?;
177        serde_json::from_str(&json)
178            .into_alien_error()
179            .context(ErrorData::InvalidInput {
180                operation_context: "Azure Table Storage KV cursor decoding".to_string(),
181                details: "Invalid cursor JSON".to_string(),
182                field_name: Some("cursor".to_string()),
183            })
184    }
185}
186
187impl Binding for AzureTableStorageKv {}
188
189#[async_trait]
190impl Kv for AzureTableStorageKv {
191    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
192        validate_key(key)?;
193
194        let (partition_key, row_key) = self.split_key(key);
195
196        match self
197            .client
198            .get_entity(
199                &self.resource_group_name,
200                &self.account_name,
201                &self.table_name,
202                &partition_key,
203                &row_key,
204                None,
205            )
206            .await
207        {
208            Ok(entity) => {
209                // Check if TTL has expired (client-side filtering)
210                if is_entity_expired(&entity) {
211                    return Ok(None); // Expired
212                }
213
214                let value = extract_value_from_entity(&entity)?;
215                Ok(Some(value))
216            }
217            Err(e) => {
218                use alien_client_core::ErrorData as CloudErrorData;
219                match e.error.as_ref() {
220                    Some(CloudErrorData::RemoteResourceNotFound { .. }) => Ok(None),
221                    _ => Err(crate::error::map_cloud_client_error(
222                        e,
223                        format!("Failed to get entity for key '{}'", key),
224                        Some(key.to_string()),
225                    )),
226                }
227            }
228        }
229    }
230
231    async fn put(&self, key: &str, value: Vec<u8>, options: Option<PutOptions>) -> Result<bool> {
232        validate_key(key)?;
233        validate_value(&value)?;
234
235        let options = options.unwrap_or_default();
236        let (partition_key, row_key) = self.split_key(key);
237
238        let expires_at = options.ttl.map(|d| Utc::now() + d);
239        let entity =
240            create_table_entity(partition_key.clone(), row_key.clone(), &value, expires_at);
241
242        if options.if_not_exists {
243            match self
244                .client
245                .insert_entity(
246                    &self.resource_group_name,
247                    &self.account_name,
248                    &self.table_name,
249                    &entity,
250                )
251                .await
252            {
253                Ok(_) => Ok(true),
254                Err(e) => {
255                    use alien_client_core::ErrorData as CloudErrorData;
256                    match e.error.as_ref() {
257                        Some(CloudErrorData::RemoteResourceConflict { .. }) => Ok(false),
258                        _ => Err(crate::error::map_cloud_client_error(
259                            e,
260                            format!("Failed to insert entity for key '{}'", key),
261                            Some(key.to_string()),
262                        )),
263                    }
264                }
265            }
266        } else {
267            // Insert Or Replace (upsert) - matches Azure REST API terminology
268            self.client
269                .insert_or_replace_entity(
270                    &self.resource_group_name,
271                    &self.account_name,
272                    &self.table_name,
273                    &partition_key,
274                    &row_key,
275                    &entity,
276                )
277                .await
278                .map_err(|e| {
279                    crate::error::map_cloud_client_error(
280                        e,
281                        format!("Failed to upsert entity for key '{}'", key),
282                        Some(key.to_string()),
283                    )
284                })?;
285            Ok(true)
286        }
287    }
288
289    async fn delete(&self, key: &str) -> Result<()> {
290        validate_key(key)?;
291
292        let (partition_key, row_key) = self.split_key(key);
293
294        // Delete entity, ignore if not found
295        match self
296            .client
297            .delete_entity(
298                &self.resource_group_name,
299                &self.account_name,
300                &self.table_name,
301                &partition_key,
302                &row_key,
303                None, // No specific ETag constraint
304            )
305            .await
306        {
307            Ok(_) => Ok(()),
308            Err(e) => {
309                use alien_client_core::ErrorData as CloudErrorData;
310                match e.error.as_ref() {
311                    Some(CloudErrorData::RemoteResourceNotFound { .. }) => Ok(()), // No error if key doesn't exist
312                    _ => Err(crate::error::map_cloud_client_error(
313                        e,
314                        format!("Failed to delete entity for key '{}'", key),
315                        Some(key.to_string()),
316                    )),
317                }
318            }
319        }
320    }
321
322    async fn exists(&self, key: &str) -> Result<bool> {
323        validate_key(key)?;
324
325        let (partition_key, row_key) = self.split_key(key);
326
327        match self
328            .client
329            .get_entity(
330                &self.resource_group_name,
331                &self.account_name,
332                &self.table_name,
333                &partition_key,
334                &row_key,
335                None,
336            )
337            .await
338        {
339            Ok(entity) => {
340                // Check TTL expiry
341                Ok(!is_entity_expired(&entity))
342            }
343            Err(e) => {
344                use alien_client_core::ErrorData as CloudErrorData;
345                match e.error.as_ref() {
346                    Some(CloudErrorData::RemoteResourceNotFound { .. }) => Ok(false),
347                    _ => Err(crate::error::map_cloud_client_error(
348                        e,
349                        format!("Failed to check existence of entity for key '{}'", key),
350                        Some(key.to_string()),
351                    )),
352                }
353            }
354        }
355    }
356
357    async fn scan_prefix(
358        &self,
359        prefix: &str,
360        limit: Option<usize>,
361        cursor: Option<String>,
362    ) -> Result<ScanResult> {
363        validate_key(prefix)?; // Prefix follows same key validation rules
364
365        // For prefix scans with hash-based partitioning, must fan-out across ALL partitions
366        // A RowKey-only filter forces expensive table-wide scans
367
368        // Decode cursor to get partition progress and continuation tokens
369        let cursor_state = cursor.as_ref().map(|c| self.decode_cursor(c)).transpose()?;
370
371        let mut all_items = Vec::new();
372        let mut total_fetched = 0;
373        let limit = limit.unwrap_or(1000);
374
375        // Start from the partition in cursor, or 0 if no cursor
376        let start_partition = cursor_state.as_ref().map_or(0, |cs| cs.current_partition);
377
378        for partition_id in start_partition..self.num_partitions {
379            let partition_key = format!("p{}", partition_id);
380
381            // Build filter with BOTH PartitionKey and RowKey conditions
382            // Use a range query approach that's compatible with Azure Table Storage
383            let prefix_end = format!("{}~", prefix); // Use tilde as it's after most printable chars
384            let filter = format!(
385                "(PartitionKey eq '{}') and (RowKey ge '{}') and (RowKey lt '{}')",
386                partition_key, prefix, prefix_end
387            );
388
389            // Note: We'll do TTL filtering client-side to avoid OData syntax issues
390            let filter_with_ttl = filter;
391
392            let query_options = EntityQueryOptions {
393                filter: Some(filter_with_ttl),
394                select: None,
395                top: Some((limit - total_fetched) as u32),
396            };
397
398            let response = self
399                .client
400                .query_entities(
401                    &self.resource_group_name,
402                    &self.account_name,
403                    &self.table_name,
404                    Some(query_options),
405                )
406                .await
407                .map_err(|e| {
408                    crate::error::map_cloud_client_error(
409                        e,
410                        format!("Failed to query entities with prefix '{}'", prefix),
411                        Some(prefix.to_string()),
412                    )
413                })?;
414
415            // Process entities from this partition
416            for entity in response.entities {
417                if total_fetched >= limit {
418                    break;
419                }
420
421                // Additional client-side TTL check for precision
422                if is_entity_expired(&entity) {
423                    continue; // Skip expired
424                }
425
426                let key = self.combine_key(&entity.partition_key, &entity.row_key);
427                let value = extract_value_from_entity(&entity)?;
428
429                all_items.push((key, value));
430                total_fetched += 1;
431            }
432
433            // If we hit the limit or have more data in this partition, encode cursor and return
434            if total_fetched >= limit || response.next_link.is_some() {
435                let next_cursor = self.encode_cursor(&CursorState {
436                    current_partition: partition_id,
437                    partition_continuation_token: response.next_link,
438                });
439                return Ok(ScanResult {
440                    items: all_items,
441                    next_cursor: Some(next_cursor),
442                });
443            }
444        }
445
446        // Scanned all partitions without hitting limit
447        Ok(ScanResult {
448            items: all_items,
449            next_cursor: None,
450        })
451    }
452}