Skip to main content

atomr_persistence_azure/
journal.rs

1//! Azure Table Storage `Journal`.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use atomr_persistence::{Journal, JournalError, PersistentRepr};
7
8use crate::config::AzureConfig;
9use crate::entities::EventEntity;
10use crate::rest::TableClient;
11
12pub struct AzureJournal {
13    client: TableClient,
14    cfg: AzureConfig,
15}
16
17impl AzureJournal {
18    pub async fn connect(cfg: AzureConfig) -> Result<Arc<Self>, JournalError> {
19        let client = TableClient::new(&cfg.endpoint, &cfg.account, &cfg.key)?;
20        if cfg.auto_create_tables {
21            client.create_table_if_absent(&cfg.journal_table).await?;
22        }
23        Ok(Arc::new(Self { client, cfg }))
24    }
25
26    pub fn config(&self) -> &AzureConfig {
27        &self.cfg
28    }
29
30    async fn current_max(&self, pid: &str) -> Result<u64, JournalError> {
31        let filter = format!("PartitionKey eq '{pid}'");
32        let entities: Vec<EventEntity> =
33            self.client.query_entities(&self.cfg.journal_table, &filter, None).await?;
34        Ok(entities.into_iter().map(|e| e.sequence_nr as u64).max().unwrap_or(0))
35    }
36}
37
38fn escape_pk(pid: &str) -> String {
39    pid.replace('\'', "''")
40}
41
42#[async_trait]
43impl Journal for AzureJournal {
44    async fn write_messages(&self, messages: Vec<PersistentRepr>) -> Result<(), JournalError> {
45        if messages.is_empty() {
46            return Ok(());
47        }
48        let mut by_pid: std::collections::BTreeMap<String, Vec<PersistentRepr>> =
49            std::collections::BTreeMap::new();
50        for m in messages {
51            by_pid.entry(m.persistence_id.clone()).or_default().push(m);
52        }
53        for (pid, batch) in by_pid {
54            let mut expected = self.current_max(&pid).await? + 1;
55            for msg in batch {
56                if msg.sequence_nr != expected {
57                    return Err(JournalError::SequenceOutOfOrder { expected, got: msg.sequence_nr });
58                }
59                expected += 1;
60                let entity = EventEntity::from_repr(&msg);
61                self.client.insert_entity(&self.cfg.journal_table, &entity).await?;
62            }
63        }
64        Ok(())
65    }
66
67    async fn delete_messages_to(
68        &self,
69        persistence_id: &str,
70        to_sequence_nr: u64,
71    ) -> Result<(), JournalError> {
72        let pk = escape_pk(persistence_id);
73        let filter = format!("PartitionKey eq '{pk}' and SequenceNr le {to}", to = to_sequence_nr as i64);
74        let entities: Vec<EventEntity> =
75            self.client.query_entities(&self.cfg.journal_table, &filter, None).await?;
76        for mut entity in entities {
77            entity.deleted = true;
78            self.client
79                .upsert_entity(
80                    &self.cfg.journal_table,
81                    &entity.partition_key.clone(),
82                    &entity.row_key.clone(),
83                    &entity,
84                )
85                .await?;
86        }
87        Ok(())
88    }
89
90    async fn replay_messages(
91        &self,
92        persistence_id: &str,
93        from: u64,
94        to: u64,
95        max: u64,
96    ) -> Result<Vec<PersistentRepr>, JournalError> {
97        let pk = escape_pk(persistence_id);
98        let to_bound = to.min(i64::MAX as u64) as i64;
99        let filter = format!(
100            "PartitionKey eq '{pk}' and SequenceNr ge {from} and SequenceNr le {to_bound} and Deleted eq false",
101            from = from as i64,
102        );
103        let top = if max > u32::MAX as u64 { None } else { Some(max as u32) };
104        let mut entities: Vec<EventEntity> =
105            self.client.query_entities(&self.cfg.journal_table, &filter, top).await?;
106        entities.sort_by_key(|e| e.sequence_nr);
107        let limit = if max > usize::MAX as u64 { usize::MAX } else { max as usize };
108        Ok(entities.into_iter().take(limit).map(EventEntity::into_repr).collect())
109    }
110
111    async fn highest_sequence_nr(&self, persistence_id: &str, _from: u64) -> Result<u64, JournalError> {
112        self.current_max(persistence_id).await
113    }
114}