libsql_wal/replication/
replicator.rs1use std::sync::Arc;
2
3use roaring::RoaringBitmap;
4use tokio::sync::watch;
5use tokio_stream::{Stream, StreamExt};
6
7use crate::io::Io;
8use crate::replication::Error;
9use crate::segment::Frame;
10use crate::shared_wal::SharedWal;
11
12use super::Result;
13
14pub struct Replicator<IO: Io> {
15 shared: Arc<SharedWal<IO>>,
16 new_frame_notifier: watch::Receiver<u64>,
17 next_frame_no: u64,
18 wait_for_more: bool,
19}
20
21impl<IO: Io> Replicator<IO> {
22 pub fn new(shared: Arc<SharedWal<IO>>, next_frame_no: u64, wait_for_more: bool) -> Self {
23 let new_frame_notifier = shared.new_frame_notifier.subscribe();
24 Self {
25 shared,
26 new_frame_notifier,
27 next_frame_no,
28 wait_for_more,
29 }
30 }
31
32 #[tracing::instrument(skip(self))]
46 pub fn into_frame_stream(mut self) -> impl Stream<Item = Result<Box<Frame>>> + Send {
47 async_stream::try_stream! {
48 loop {
49 tracing::debug!(next_frame_no = self.next_frame_no);
52 let most_recent_frame_no = *self
53 .new_frame_notifier
54 .wait_for(|fno| *fno >= self.next_frame_no)
55 .await
56 .expect("channel cannot be closed because we hold a ref to the sending end");
57
58 tracing::debug!(most_recent_frame_no, "new frame_no available");
59
60 let mut commit_frame_no = 0;
61 if most_recent_frame_no >= self.next_frame_no {
63 let current = self.shared.current.load();
66 let mut seen = RoaringBitmap::new();
67 let (stream, replicated_until, size_after) = current.frame_stream_from(self.next_frame_no, &mut seen);
68 let should_replicate_from_tail = replicated_until != self.next_frame_no;
69
70 {
71 tokio::pin!(stream);
72
73 let mut stream = stream.peekable();
74
75 tracing::debug!(replicated_until, "replicating from current log");
76 loop {
77 let Some(frame) = stream.next().await else { break };
78 let mut frame = frame.map_err(|e| Error::CurrentSegment(e.into()))?;
79 commit_frame_no = frame.header().frame_no().max(commit_frame_no);
80 if stream.peek().await.is_none() && !should_replicate_from_tail {
81 frame.header_mut().set_size_after(size_after);
82 self.next_frame_no = commit_frame_no + 1;
83 }
84
85 yield frame
86 }
87 }
88
89 if should_replicate_from_tail {
92 let replicated_until = {
93 let (stream, replicated_until) = current
94 .tail()
95 .stream_pages_from(replicated_until, self.next_frame_no, &mut seen).await;
96 tokio::pin!(stream);
97
98 tracing::debug!(replicated_until, "replicating from tail");
99 let mut stream = stream.peekable();
100
101 let should_replicate_from_storage = replicated_until != self.next_frame_no;
102
103 loop {
104 let Some(frame) = stream.next().await else { break };
105 let mut frame = frame.map_err(|e| Error::SealedSegment(e.into()))?;
106 commit_frame_no = frame.header().frame_no().max(commit_frame_no);
107 if stream.peek().await.is_none() && !should_replicate_from_storage {
108 frame.header_mut().set_size_after(size_after);
109 self.next_frame_no = commit_frame_no + 1;
110 }
111
112 yield frame
113 }
114
115 should_replicate_from_storage.then_some(replicated_until)
116 };
117
118 if let Some(replicated_until) = replicated_until {
121 tracing::debug!("replicating from durable storage");
122 let stream = self
123 .shared
124 .stored_segments
125 .stream(&mut seen, replicated_until, self.next_frame_no)
126 .peekable();
127
128 tokio::pin!(stream);
129
130 loop {
131 let Some(frame) = stream.next().await else { break };
132 let mut frame = frame?;
133 commit_frame_no = frame.header().frame_no().max(commit_frame_no);
134 if stream.peek().await.is_none() {
135 frame.header_mut().set_size_after(size_after);
136 self.next_frame_no = commit_frame_no + 1;
137 }
138
139 yield frame
140 }
141 }
142 }
143 }
144
145 if !self.wait_for_more {
146 break
147 }
148 }
149 }
150 }
151}
152
153#[cfg(test)]
154mod test {
155 use std::time::Duration;
156
157 use tempfile::NamedTempFile;
158 use tokio_stream::StreamExt;
159
160 use crate::io::FileExt;
161 use crate::test::{seal_current_segment, TestEnv};
162
163 use super::*;
164
165 #[tokio::test]
166 async fn stream_from_current_log() {
167 let env = TestEnv::new();
168 let conn = env.open_conn("test");
169 let shared = env.shared("test");
170
171 conn.execute("create table test (x)", ()).unwrap();
172
173 for _ in 0..50 {
174 conn.execute("insert into test values (randomblob(128))", ())
175 .unwrap();
176 }
177
178 let replicator = Replicator::new(shared.clone(), 1, true);
179
180 let tmp = NamedTempFile::new().unwrap();
181 let stream = replicator.into_frame_stream();
182 tokio::pin!(stream);
183 let mut last_frame_no = 0;
184 let mut size_after;
185 loop {
186 let frame = stream.next().await.unwrap().unwrap();
187 size_after = frame.header().size_after();
189 last_frame_no = last_frame_no.max(frame.header().frame_no());
190 let offset = (frame.header().page_no() - 1) * 4096;
191 tmp.as_file()
192 .write_all_at(frame.data(), offset as _)
193 .unwrap();
194 if size_after != 0 {
195 break;
196 }
197 }
198
199 assert_eq!(size_after, 4);
200 assert_eq!(last_frame_no, 55);
201
202 {
203 let conn = libsql_sys::rusqlite::Connection::open(tmp.path()).unwrap();
204 conn.query_row("select count(0) from test", (), |row| {
205 let count = row.get_unwrap::<_, usize>(0);
206 assert_eq!(count, 50);
207 Ok(())
208 })
209 .unwrap();
210 }
211
212 seal_current_segment(&shared);
213
214 for _ in 0..50 {
215 conn.execute("insert into test values (randomblob(128))", ())
216 .unwrap();
217 }
218
219 let mut size_after;
220 loop {
221 let frame = stream.next().await.unwrap().unwrap();
222 assert!(frame.header().frame_no() > last_frame_no);
223 size_after = frame.header().size_after();
224 let offset = (frame.header().page_no() - 1) * 4096;
226 tmp.as_file()
227 .write_all_at(frame.data(), offset as _)
228 .unwrap();
229 if size_after != 0 {
230 break;
231 }
232 }
233
234 assert_eq!(size_after, 6);
235
236 {
237 let conn = libsql_sys::rusqlite::Connection::open(tmp.path()).unwrap();
238 conn.query_row("select count(0) from test", (), |row| {
239 let count = row.get_unwrap::<_, usize>(0);
240 assert_eq!(count, 100);
241 Ok(())
242 })
243 .unwrap();
244 }
245
246 {
248 let tmp = NamedTempFile::new().unwrap();
249 let replicator = Replicator::new(shared.clone(), 1, true);
250 let stream = replicator.into_frame_stream();
251
252 tokio::pin!(stream);
253
254 loop {
255 let frame = stream.next().await.unwrap().unwrap();
256 let offset = (frame.header().page_no() - 1) * 4096;
258 tmp.as_file()
259 .write_all_at(frame.data(), offset as _)
260 .unwrap();
261 if frame.header().size_after() != 0 {
262 break;
263 }
264 }
265
266 let conn = libsql_sys::rusqlite::Connection::open(tmp.path()).unwrap();
267 conn.query_row("select count(0) from test", (), |row| {
268 let count = row.get_unwrap::<_, usize>(0);
269 assert_eq!(count, 100);
270 Ok(())
271 })
272 .unwrap();
273 }
274 }
275
276 #[tokio::test]
277 async fn stream_from_storage() {
278 let env = TestEnv::new_store(true);
279 let conn = env.open_conn("test");
280 let shared = env.shared("test");
281
282 conn.execute("create table test (x)", ()).unwrap();
283
284 conn.execute("insert into test values (randomblob(128))", ())
285 .unwrap();
286
287 tokio::task::spawn_blocking({
288 let shared = shared.clone();
289 move || seal_current_segment(&shared)
290 })
291 .await
292 .unwrap();
293
294 conn.execute("create table test2 (x)", ()).unwrap();
295 conn.execute("insert into test2 values (randomblob(128))", ())
296 .unwrap();
297
298 tokio::task::spawn_blocking({
299 let shared = shared.clone();
300 move || seal_current_segment(&shared)
301 })
302 .await
303 .unwrap();
304
305 while !shared.current.load().tail().is_empty() {
306 tokio::time::sleep(Duration::from_millis(50)).await;
307 }
308
309 let db_content = std::fs::read(&env.db_path("test").join("data")).unwrap();
310
311 let replicator = Replicator::new(shared, 1, true);
312 let stream = replicator.into_frame_stream().take(3);
313
314 tokio::pin!(stream);
315
316 let tmp = NamedTempFile::new().unwrap();
317 let mut replica_content = vec![0u8; db_content.len()];
318 while let Some(f) = stream.next().await {
319 let frame = f.unwrap();
320 let offset = (frame.header().page_no() as usize - 1) * 4096;
321 tmp.as_file()
322 .write_all_at(frame.data(), offset as u64)
323 .unwrap();
324 replica_content[offset..offset + 4096].copy_from_slice(frame.data());
325 }
326
327 assert_eq!(db_payload(&replica_content), db_payload(&db_content));
328 }
329
330 fn db_payload(db: &[u8]) -> &[u8] {
331 let size = (db.len() / 4096) * 4096;
332 &db[..size]
333 }
334}