libsql_wal/replication/
injector.rs1use std::sync::Arc;
4
5use crate::error::Result;
6use crate::io::Io;
7use crate::segment::Frame;
8use crate::shared_wal::SharedWal;
9use crate::transaction::{Transaction, TxGuardOwned};
10
11pub struct Injector<IO: Io> {
13 wal: Arc<SharedWal<IO>>,
15 buffer: Vec<Box<Frame>>,
16 capacity: usize,
18 tx: Option<TxGuardOwned<IO::File>>,
19 max_tx_frame_no: u64,
20 previous_durable_frame_no: u64,
21}
22
23impl<IO: Io> Injector<IO> {
24 pub fn new(wal: Arc<SharedWal<IO>>, buffer_capacity: usize) -> Result<Self> {
25 Ok(Self {
26 wal,
27 buffer: Vec::with_capacity(buffer_capacity),
28 capacity: buffer_capacity,
29 tx: None,
30 max_tx_frame_no: 0,
31 previous_durable_frame_no: 0,
32 })
33 }
34
35 pub fn set_durable(&mut self, durable_frame_no: u64) {
36 let mut old = self.wal.durable_frame_no.lock();
37 if *old <= durable_frame_no {
38 self.previous_durable_frame_no = *old;
39 *old = durable_frame_no;
40 } else {
41 todo!("primary reported older frame_no than current");
42 }
43 }
44
45 pub fn current_durable(&self) -> u64 {
46 *self.wal.durable_frame_no.lock()
47 }
48
49 pub fn maybe_begin_txn(&mut self) -> Result<()> {
50 if self.tx.is_none() {
51 let mut tx = Transaction::Read(self.wal.begin_read(u64::MAX));
52 self.wal.upgrade(&mut tx)?;
53 let tx = tx
54 .into_write()
55 .unwrap_or_else(|_| unreachable!())
56 .into_lock_owned();
57 assert!(self.tx.replace(tx).is_none());
58 }
59
60 Ok(())
61 }
62
63 pub async fn insert_frame(&mut self, frame: Box<Frame>) -> Result<Option<u64>> {
64 self.maybe_begin_txn()?;
65 let size_after = frame.size_after();
66 self.max_tx_frame_no = self.max_tx_frame_no.max(frame.header().frame_no());
67 self.buffer.push(frame);
68
69 if size_after.is_some() || self.capacity == self.buffer.len() {
70 self.flush(size_after).await?;
71 }
72
73 Ok(size_after.map(|_| self.max_tx_frame_no))
74 }
75
76 pub async fn flush(&mut self, size_after: Option<u32>) -> Result<()> {
77 if !self.buffer.is_empty() && self.tx.is_some() {
78 let last_committed_frame_no = self.max_tx_frame_no;
79 {
80 let tx = self.tx.as_mut().expect("we just checked that tx was there");
81 let buffer = std::mem::take(&mut self.buffer);
82 let current = self.wal.current.load();
83 let commit_data = size_after.map(|size| (size, self.max_tx_frame_no));
84 if commit_data.is_some() {
85 self.max_tx_frame_no = 0;
86 }
87 let buffer = current.inject_frames(buffer, commit_data, tx).await?;
88 self.buffer = buffer;
89 self.buffer.clear();
90 }
91
92 if size_after.is_some() {
93 let mut tx = self.tx.take().unwrap();
94 self.wal
95 .new_frame_notifier
96 .send_replace(last_committed_frame_no);
97 if self.current_durable() != self.previous_durable_frame_no
100 && self.current_durable() >= self.max_tx_frame_no
101 {
102 let wal = self.wal.clone();
103 tokio::task::spawn_blocking(move || {
105 tx.commit();
106 wal.swap_current(&tx)
107 })
108 .await
109 .unwrap()?
110 }
111 }
112 }
113
114 Ok(())
115 }
116
117 pub fn rollback(&mut self) {
118 self.buffer.clear();
119 if let Some(tx) = self.tx.as_mut() {
120 tx.reset(0);
121 }
122 }
123}
124
125#[cfg(test)]
126mod test {
127 use tokio_stream::StreamExt;
128
129 use crate::replication::replicator::Replicator;
130 use crate::test::TestEnv;
131
132 use super::*;
133
134 #[tokio::test]
135 async fn inject_basic() {
136 let primary_env = TestEnv::new();
137 let primary_conn = primary_env.open_conn("test");
138 let primary_shared = primary_env.shared("test");
139
140 let replicator = Replicator::new(primary_shared.clone(), 1, true);
141 let stream = replicator.into_frame_stream();
142
143 tokio::pin!(stream);
144
145 let replica_env = TestEnv::new();
147 let replica_conn = replica_env.open_conn("test");
148 let replica_shared = replica_env.shared("test");
149
150 let mut injector = Injector::new(replica_shared.clone(), 10).unwrap();
151
152 primary_conn.execute("create table test (x)", ()).unwrap();
153
154 primary_shared.last_committed_frame_no();
155 for _ in 0..2 {
156 let frame = stream.next().await.unwrap().unwrap();
157 injector.insert_frame(frame).await.unwrap();
158 }
159
160 replica_conn
161 .query_row("select count(*) from test", (), |r| {
162 assert_eq!(r.get_unwrap::<_, usize>(0), 0);
163 Ok(())
164 })
165 .unwrap();
166
167 primary_conn
168 .execute("insert into test values (123)", ())
169 .unwrap();
170 primary_conn
171 .execute("insert into test values (123)", ())
172 .unwrap();
173 primary_conn
174 .execute("insert into test values (123)", ())
175 .unwrap();
176
177 let frame = stream.next().await.unwrap().unwrap();
178 injector.insert_frame(frame).await.unwrap();
179
180 replica_conn
181 .query_row("select count(*) from test", (), |r| {
182 assert_eq!(r.get_unwrap::<_, usize>(0), 3);
183 Ok(())
184 })
185 .unwrap();
186 }
187}