libsql_wal/replication/
replicator.rs

1use std::sync::Arc;
2
3use roaring::RoaringBitmap;
4use tokio::sync::watch;
5use tokio_stream::{Stream, StreamExt};
6
7use crate::io::Io;
8use crate::replication::Error;
9use crate::segment::Frame;
10use crate::shared_wal::SharedWal;
11
12use super::Result;
13
14pub struct Replicator<IO: Io> {
15    shared: Arc<SharedWal<IO>>,
16    new_frame_notifier: watch::Receiver<u64>,
17    next_frame_no: u64,
18    wait_for_more: bool,
19}
20
21impl<IO: Io> Replicator<IO> {
22    pub fn new(shared: Arc<SharedWal<IO>>, next_frame_no: u64, wait_for_more: bool) -> Self {
23        let new_frame_notifier = shared.new_frame_notifier.subscribe();
24        Self {
25            shared,
26            new_frame_notifier,
27            next_frame_no,
28            wait_for_more,
29        }
30    }
31
32    /// Stream frames from this replicator. The replicator will wait for new frames to become
33    /// available, and never return.
34    ///
35    /// The replicator keeps track of how much progress has been made by the replica, and will
36    /// attempt to find the next frames to send with following strategy:
37    /// - First, replicate as much as possible from the current log
38    /// - The, if we still haven't caught up with `self.start_frame_no`, we select the next frames
39    /// to replicate from tail of current.
40    /// - Finally, if we still haven't reached `self.start_frame_no`, read from durable storage
41    /// (todo: maybe the replica should read from durable storage directly?)
42    ///
43    /// In a single replication step, the replicator guarantees that a minimal set of frames is
44    /// sent to the replica.
45    #[tracing::instrument(skip(self))]
46    pub fn into_frame_stream(mut self) -> impl Stream<Item = Result<Box<Frame>>> + Send {
47        async_stream::try_stream! {
48            loop {
49                // First we decide up to what frame_no we want to replicate in this step. If we are
50                // already up to date, wait for something to happen
51                tracing::debug!(next_frame_no = self.next_frame_no);
52                let most_recent_frame_no = *self
53                    .new_frame_notifier
54                    .wait_for(|fno| *fno >= self.next_frame_no)
55                    .await
56                    .expect("channel cannot be closed because we hold a ref to the sending end");
57
58                tracing::debug!(most_recent_frame_no, "new frame_no available");
59
60                let mut commit_frame_no = 0;
61                // we have stuff to replicate
62                if most_recent_frame_no >= self.next_frame_no {
63                    // first replicate the most recent version of each page from the current
64                    // segment. We also return how far back we have replicated from the current log
65                    let current = self.shared.current.load();
66                    let mut seen = RoaringBitmap::new();
67                    let (stream, replicated_until, size_after) = current.frame_stream_from(self.next_frame_no, &mut seen);
68                    let should_replicate_from_tail = replicated_until != self.next_frame_no;
69
70                    {
71                        tokio::pin!(stream);
72
73                        let mut stream = stream.peekable();
74
75                        tracing::debug!(replicated_until, "replicating from current log");
76                        loop {
77                            let Some(frame) = stream.next().await else { break };
78                            let mut frame = frame.map_err(|e| Error::CurrentSegment(e.into()))?;
79                            commit_frame_no = frame.header().frame_no().max(commit_frame_no);
80                            if stream.peek().await.is_none() && !should_replicate_from_tail {
81                                frame.header_mut().set_size_after(size_after);
82                                self.next_frame_no = commit_frame_no + 1;
83                            }
84
85                            yield frame
86                        }
87                    }
88
89                    // Replicating from the current segment wasn't enough to bring us up to date,
90                    // wee need to take frames from the sealed segments.
91                    if should_replicate_from_tail {
92                        let replicated_until = {
93                            let (stream, replicated_until) = current
94                                .tail()
95                                .stream_pages_from(replicated_until, self.next_frame_no, &mut seen).await;
96                            tokio::pin!(stream);
97
98                        tracing::debug!(replicated_until, "replicating from tail");
99                            let mut stream = stream.peekable();
100
101                            let should_replicate_from_storage = replicated_until != self.next_frame_no;
102
103                            loop {
104                                let Some(frame) = stream.next().await else { break };
105                                let mut frame = frame.map_err(|e| Error::SealedSegment(e.into()))?;
106                                commit_frame_no = frame.header().frame_no().max(commit_frame_no);
107                                if stream.peek().await.is_none() && !should_replicate_from_storage {
108                                    frame.header_mut().set_size_after(size_after);
109                                    self.next_frame_no = commit_frame_no + 1;
110                                }
111
112                                yield frame
113                            }
114
115                            should_replicate_from_storage.then_some(replicated_until)
116                        };
117
118                        // Replicating from sealed segments was not enough, so we replicate from
119                        // durable storage
120                        if let Some(replicated_until) = replicated_until {
121                            tracing::debug!("replicating from durable storage");
122                            let stream = self
123                                .shared
124                                .stored_segments
125                                .stream(&mut seen, replicated_until, self.next_frame_no)
126                                .peekable();
127
128                            tokio::pin!(stream);
129
130                            loop {
131                                let Some(frame) = stream.next().await else { break };
132                                let mut frame = frame?;
133                                commit_frame_no = frame.header().frame_no().max(commit_frame_no);
134                                if stream.peek().await.is_none() {
135                                    frame.header_mut().set_size_after(size_after);
136                                    self.next_frame_no = commit_frame_no + 1;
137                                }
138
139                                yield frame
140                            }
141                        }
142                    }
143                }
144
145                if !self.wait_for_more {
146                    break
147                }
148            }
149        }
150    }
151}
152
153#[cfg(test)]
154mod test {
155    use std::time::Duration;
156
157    use tempfile::NamedTempFile;
158    use tokio_stream::StreamExt;
159
160    use crate::io::FileExt;
161    use crate::test::{seal_current_segment, TestEnv};
162
163    use super::*;
164
165    #[tokio::test]
166    async fn stream_from_current_log() {
167        let env = TestEnv::new();
168        let conn = env.open_conn("test");
169        let shared = env.shared("test");
170
171        conn.execute("create table test (x)", ()).unwrap();
172
173        for _ in 0..50 {
174            conn.execute("insert into test values (randomblob(128))", ())
175                .unwrap();
176        }
177
178        let replicator = Replicator::new(shared.clone(), 1, true);
179
180        let tmp = NamedTempFile::new().unwrap();
181        let stream = replicator.into_frame_stream();
182        tokio::pin!(stream);
183        let mut last_frame_no = 0;
184        let mut size_after;
185        loop {
186            let frame = stream.next().await.unwrap().unwrap();
187            // the last frame should commit
188            size_after = frame.header().size_after();
189            last_frame_no = last_frame_no.max(frame.header().frame_no());
190            let offset = (frame.header().page_no() - 1) * 4096;
191            tmp.as_file()
192                .write_all_at(frame.data(), offset as _)
193                .unwrap();
194            if size_after != 0 {
195                break;
196            }
197        }
198
199        assert_eq!(size_after, 4);
200        assert_eq!(last_frame_no, 55);
201
202        {
203            let conn = libsql_sys::rusqlite::Connection::open(tmp.path()).unwrap();
204            conn.query_row("select count(0) from test", (), |row| {
205                let count = row.get_unwrap::<_, usize>(0);
206                assert_eq!(count, 50);
207                Ok(())
208            })
209            .unwrap();
210        }
211
212        seal_current_segment(&shared);
213
214        for _ in 0..50 {
215            conn.execute("insert into test values (randomblob(128))", ())
216                .unwrap();
217        }
218
219        let mut size_after;
220        loop {
221            let frame = stream.next().await.unwrap().unwrap();
222            assert!(frame.header().frame_no() > last_frame_no);
223            size_after = frame.header().size_after();
224            // the last frame should commit
225            let offset = (frame.header().page_no() - 1) * 4096;
226            tmp.as_file()
227                .write_all_at(frame.data(), offset as _)
228                .unwrap();
229            if size_after != 0 {
230                break;
231            }
232        }
233
234        assert_eq!(size_after, 6);
235
236        {
237            let conn = libsql_sys::rusqlite::Connection::open(tmp.path()).unwrap();
238            conn.query_row("select count(0) from test", (), |row| {
239                let count = row.get_unwrap::<_, usize>(0);
240                assert_eq!(count, 100);
241                Ok(())
242            })
243            .unwrap();
244        }
245
246        // replicate everything from scratch again
247        {
248            let tmp = NamedTempFile::new().unwrap();
249            let replicator = Replicator::new(shared.clone(), 1, true);
250            let stream = replicator.into_frame_stream();
251
252            tokio::pin!(stream);
253
254            loop {
255                let frame = stream.next().await.unwrap().unwrap();
256                // the last frame should commit
257                let offset = (frame.header().page_no() - 1) * 4096;
258                tmp.as_file()
259                    .write_all_at(frame.data(), offset as _)
260                    .unwrap();
261                if frame.header().size_after() != 0 {
262                    break;
263                }
264            }
265
266            let conn = libsql_sys::rusqlite::Connection::open(tmp.path()).unwrap();
267            conn.query_row("select count(0) from test", (), |row| {
268                let count = row.get_unwrap::<_, usize>(0);
269                assert_eq!(count, 100);
270                Ok(())
271            })
272            .unwrap();
273        }
274    }
275
276    #[tokio::test]
277    async fn stream_from_storage() {
278        let env = TestEnv::new_store(true);
279        let conn = env.open_conn("test");
280        let shared = env.shared("test");
281
282        conn.execute("create table test (x)", ()).unwrap();
283
284        conn.execute("insert into test values (randomblob(128))", ())
285            .unwrap();
286
287        tokio::task::spawn_blocking({
288            let shared = shared.clone();
289            move || seal_current_segment(&shared)
290        })
291        .await
292        .unwrap();
293
294        conn.execute("create table test2 (x)", ()).unwrap();
295        conn.execute("insert into test2 values (randomblob(128))", ())
296            .unwrap();
297
298        tokio::task::spawn_blocking({
299            let shared = shared.clone();
300            move || seal_current_segment(&shared)
301        })
302        .await
303        .unwrap();
304
305        while !shared.current.load().tail().is_empty() {
306            tokio::time::sleep(Duration::from_millis(50)).await;
307        }
308
309        let db_content = std::fs::read(&env.db_path("test").join("data")).unwrap();
310
311        let replicator = Replicator::new(shared, 1, true);
312        let stream = replicator.into_frame_stream().take(3);
313
314        tokio::pin!(stream);
315
316        let tmp = NamedTempFile::new().unwrap();
317        let mut replica_content = vec![0u8; db_content.len()];
318        while let Some(f) = stream.next().await {
319            let frame = f.unwrap();
320            let offset = (frame.header().page_no() as usize - 1) * 4096;
321            tmp.as_file()
322                .write_all_at(frame.data(), offset as u64)
323                .unwrap();
324            replica_content[offset..offset + 4096].copy_from_slice(frame.data());
325        }
326
327        assert_eq!(db_payload(&replica_content), db_payload(&db_content));
328    }
329
330    fn db_payload(db: &[u8]) -> &[u8] {
331        let size = (db.len() / 4096) * 4096;
332        &db[..size]
333    }
334}