feophantlib/engine/transactions/
transaction_manager.rs

1//! This is the interface to transaction visability (clog in postgres).
2use super::{TransactionId, TransactionIdError, TransactionStatus};
3use std::sync::Arc;
4use thiserror::Error;
5use tokio::sync::RwLock;
6
7#[derive(Clone, Debug)]
8pub struct TransactionManager {
9    tran_min: TransactionId, //Used to index the known transactions array
10    known_trans: Arc<RwLock<Vec<TransactionStatus>>>,
11}
12
13impl TransactionManager {
14    pub fn new() -> TransactionManager {
15        let tran_min = TransactionId::new(1); //Must start at 1 since 0 is used for active rows
16        let known_trans = Arc::new(RwLock::new(vec![TransactionStatus::Aborted])); //First transaction will be cancelled
17        TransactionManager {
18            tran_min,
19            known_trans,
20        }
21    }
22
23    pub async fn start_trans(&mut self) -> Result<TransactionId, TransactionManagerError> {
24        let mut known_trans = self.known_trans.write().await;
25
26        known_trans.push(TransactionStatus::InProgress);
27
28        Ok(self.tran_min.checked_add(known_trans.len() - 1)?)
29    }
30
31    pub async fn get_status(
32        &mut self,
33        tran_id: TransactionId,
34    ) -> Result<TransactionStatus, TransactionManagerError> {
35        if tran_id < self.tran_min {
36            return Err(TransactionManagerError::TooOld(tran_id, self.tran_min));
37        }
38
39        let known_trans = self.known_trans.read().await;
40
41        if tran_id > self.tran_min.checked_add(known_trans.len())? {
42            return Err(TransactionManagerError::InTheFuture(
43                tran_id,
44                self.tran_min,
45                known_trans.len(),
46            ));
47        }
48
49        let index = tran_id.checked_sub(self.tran_min)?;
50
51        Ok(known_trans[index])
52    }
53
54    async fn update_trans(
55        &mut self,
56        tran_id: TransactionId,
57        new_status: TransactionStatus,
58    ) -> Result<(), TransactionManagerError> {
59        if tran_id < self.tran_min {
60            return Err(TransactionManagerError::TooOld(tran_id, self.tran_min));
61        }
62
63        let mut known_trans = self.known_trans.write().await;
64
65        if tran_id > self.tran_min.checked_add(known_trans.len())? {
66            return Err(TransactionManagerError::InTheFuture(
67                tran_id,
68                self.tran_min,
69                known_trans.len(),
70            ));
71        }
72
73        let index = tran_id.checked_sub(self.tran_min)?;
74
75        if known_trans[index] != TransactionStatus::InProgress {
76            return Err(TransactionManagerError::NotInProgress(
77                tran_id,
78                known_trans[index],
79            ));
80        }
81
82        known_trans[index] = new_status;
83
84        Ok(())
85    }
86
87    pub async fn commit_trans(
88        &mut self,
89        tran_id: TransactionId,
90    ) -> Result<(), TransactionManagerError> {
91        self.update_trans(tran_id, TransactionStatus::Commited)
92            .await
93    }
94
95    pub async fn abort_trans(
96        &mut self,
97        tran_id: TransactionId,
98    ) -> Result<(), TransactionManagerError> {
99        self.update_trans(tran_id, TransactionStatus::Aborted).await
100    }
101
102    //TODO work on figuring out how to save / load this
103    pub fn serialize() {}
104
105    pub fn parse() {}
106}
107
108impl Default for TransactionManager {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114#[derive(Error, Debug)]
115pub enum TransactionManagerError {
116    #[error(transparent)]
117    TransactionIdError(#[from] TransactionIdError),
118    #[error("Transaction Id {0} too low compared to {1}")]
119    TooOld(TransactionId, TransactionId),
120    #[error("Transaction Id {0} exceeds the min {1} and size {2}")]
121    InTheFuture(TransactionId, TransactionId, usize),
122    #[error("Transaction Id {0} not in progress, found {1}")]
123    NotInProgress(TransactionId, TransactionStatus),
124}
125
126#[cfg(test)]
127mod tests {
128    #![allow(unused_must_use)]
129    use super::*;
130
131    #[tokio::test]
132    async fn tran_man_statuses() -> Result<(), Box<dyn std::error::Error>> {
133        let mut tm = TransactionManager::new();
134        let tran1 = tm.start_trans().await?;
135        let tran2 = tm.start_trans().await?;
136
137        assert_ne!(tran1, tran2);
138        assert!(tran1 < tran2);
139
140        assert_eq!(tm.get_status(tran1).await?, TransactionStatus::InProgress);
141        assert_eq!(tm.get_status(tran2).await?, TransactionStatus::InProgress);
142
143        assert!(tm.commit_trans(tran1).await.is_ok());
144        assert!(tm.commit_trans(tran1).await.is_err());
145
146        assert_eq!(tm.get_status(tran1).await?, TransactionStatus::Commited);
147        assert_eq!(tm.get_status(tran2).await?, TransactionStatus::InProgress);
148
149        assert!(tm.abort_trans(tran2).await.is_ok());
150        assert!(tm.abort_trans(tran2).await.is_err());
151
152        assert_eq!(tm.get_status(tran1).await?, TransactionStatus::Commited);
153        assert_eq!(tm.get_status(tran2).await?, TransactionStatus::Aborted);
154
155        Ok(())
156    }
157}