libsql_wal/replication/
injector.rs

1//! The injector is the module in charge of injecting frames into a replica database.
2
3use std::sync::Arc;
4
5use crate::error::Result;
6use crate::io::Io;
7use crate::segment::Frame;
8use crate::shared_wal::SharedWal;
9use crate::transaction::{Transaction, TxGuardOwned};
10
11/// The injector takes frames and injects them in the wal.
12pub struct Injector<IO: Io> {
13    // The wal to which we are injecting
14    wal: Arc<SharedWal<IO>>,
15    buffer: Vec<Box<Frame>>,
16    /// capacity of the frame buffer
17    capacity: usize,
18    tx: Option<TxGuardOwned<IO::File>>,
19    max_tx_frame_no: u64,
20    previous_durable_frame_no: u64,
21}
22
23impl<IO: Io> Injector<IO> {
24    pub fn new(wal: Arc<SharedWal<IO>>, buffer_capacity: usize) -> Result<Self> {
25        Ok(Self {
26            wal,
27            buffer: Vec::with_capacity(buffer_capacity),
28            capacity: buffer_capacity,
29            tx: None,
30            max_tx_frame_no: 0,
31            previous_durable_frame_no: 0,
32        })
33    }
34
35    pub fn set_durable(&mut self, durable_frame_no: u64) {
36        let mut old = self.wal.durable_frame_no.lock();
37        if *old <= durable_frame_no {
38            self.previous_durable_frame_no = *old;
39            *old = durable_frame_no;
40        } else {
41            todo!("primary reported older frame_no than current");
42        }
43    }
44
45    pub fn current_durable(&self) -> u64 {
46        *self.wal.durable_frame_no.lock()
47    }
48
49    pub fn maybe_begin_txn(&mut self) -> Result<()> {
50        if self.tx.is_none() {
51            let mut tx = Transaction::Read(self.wal.begin_read(u64::MAX));
52            self.wal.upgrade(&mut tx)?;
53            let tx = tx
54                .into_write()
55                .unwrap_or_else(|_| unreachable!())
56                .into_lock_owned();
57            assert!(self.tx.replace(tx).is_none());
58        }
59
60        Ok(())
61    }
62
63    pub async fn insert_frame(&mut self, frame: Box<Frame>) -> Result<Option<u64>> {
64        self.maybe_begin_txn()?;
65        let size_after = frame.size_after();
66        self.max_tx_frame_no = self.max_tx_frame_no.max(frame.header().frame_no());
67        self.buffer.push(frame);
68
69        if size_after.is_some() || self.capacity == self.buffer.len() {
70            self.flush(size_after).await?;
71        }
72
73        Ok(size_after.map(|_| self.max_tx_frame_no))
74    }
75
76    pub async fn flush(&mut self, size_after: Option<u32>) -> Result<()> {
77        if !self.buffer.is_empty() && self.tx.is_some() {
78            let last_committed_frame_no = self.max_tx_frame_no;
79            {
80                let tx = self.tx.as_mut().expect("we just checked that tx was there");
81                let buffer = std::mem::take(&mut self.buffer);
82                let current = self.wal.current.load();
83                let commit_data = size_after.map(|size| (size, self.max_tx_frame_no));
84                if commit_data.is_some() {
85                    self.max_tx_frame_no = 0;
86                }
87                let buffer = current.inject_frames(buffer, commit_data, tx).await?;
88                self.buffer = buffer;
89                self.buffer.clear();
90            }
91
92            if size_after.is_some() {
93                let mut tx = self.tx.take().unwrap();
94                self.wal
95                    .new_frame_notifier
96                    .send_replace(last_committed_frame_no);
97                // the strategy to swap the current log is to do it on change of durable boundary,
98                // when we have caught up with the current durable frame_no
99                if self.current_durable() != self.previous_durable_frame_no
100                    && self.current_durable() >= self.max_tx_frame_no
101                {
102                    let wal = self.wal.clone();
103                    // FIXME: tokio dependency here is annoying, we need an async version of swap_current.
104                    tokio::task::spawn_blocking(move || {
105                        tx.commit();
106                        wal.swap_current(&tx)
107                    })
108                    .await
109                    .unwrap()?
110                }
111            }
112        }
113
114        Ok(())
115    }
116
117    pub fn rollback(&mut self) {
118        self.buffer.clear();
119        if let Some(tx) = self.tx.as_mut() {
120            tx.reset(0);
121        }
122    }
123}
124
125#[cfg(test)]
126mod test {
127    use tokio_stream::StreamExt;
128
129    use crate::replication::replicator::Replicator;
130    use crate::test::TestEnv;
131
132    use super::*;
133
134    #[tokio::test]
135    async fn inject_basic() {
136        let primary_env = TestEnv::new();
137        let primary_conn = primary_env.open_conn("test");
138        let primary_shared = primary_env.shared("test");
139
140        let replicator = Replicator::new(primary_shared.clone(), 1, true);
141        let stream = replicator.into_frame_stream();
142
143        tokio::pin!(stream);
144
145        // setup replica
146        let replica_env = TestEnv::new();
147        let replica_conn = replica_env.open_conn("test");
148        let replica_shared = replica_env.shared("test");
149
150        let mut injector = Injector::new(replica_shared.clone(), 10).unwrap();
151
152        primary_conn.execute("create table test (x)", ()).unwrap();
153
154        primary_shared.last_committed_frame_no();
155        for _ in 0..2 {
156            let frame = stream.next().await.unwrap().unwrap();
157            injector.insert_frame(frame).await.unwrap();
158        }
159
160        replica_conn
161            .query_row("select count(*) from test", (), |r| {
162                assert_eq!(r.get_unwrap::<_, usize>(0), 0);
163                Ok(())
164            })
165            .unwrap();
166
167        primary_conn
168            .execute("insert into test values (123)", ())
169            .unwrap();
170        primary_conn
171            .execute("insert into test values (123)", ())
172            .unwrap();
173        primary_conn
174            .execute("insert into test values (123)", ())
175            .unwrap();
176
177        let frame = stream.next().await.unwrap().unwrap();
178        injector.insert_frame(frame).await.unwrap();
179
180        replica_conn
181            .query_row("select count(*) from test", (), |r| {
182                assert_eq!(r.get_unwrap::<_, usize>(0), 3);
183                Ok(())
184            })
185            .unwrap();
186    }
187}