synclite/
restore.rs

1use std::cmp::Ordering;
2
3use tracing::{debug, error, info};
4
5use crate::{
6    db::create_db_pool,
7    error::{Error, RestoreError},
8    pos::Pos,
9    replica::ReplicaClient,
10};
11
12pub async fn restore(
13    client: &impl ReplicaClient,
14    db_output: &str,
15    encryption_key: Option<String>,
16) -> Result<(), Error> {
17    info!("Restoring");
18    let generation = find_latest_generation(client)
19        .await?
20        .ok_or(RestoreError::NoGeneration)?;
21    info!("Latest generation to restore '{generation}'");
22    let snapshot_idx = find_max_snapshot_for_generation(client, &generation)
23        .await?
24        .ok_or_else(|| RestoreError::NoSnapshotForGeneration {
25            generation: generation.clone(),
26        })?;
27    let wal_idx = find_latest_wal_index(client, &generation)
28        .await?
29        .ok_or_else(|| RestoreError::LatestWalIndex {
30            generation: generation.clone(),
31        })?;
32    assert!(snapshot_idx <= wal_idx, "TODO");
33
34    // Download snapshat to db output path
35    let mut snapshot_reader = client
36        .snapshot_reader(&generation, &Pos::format(snapshot_idx))
37        .await?;
38    let mut db = tokio::fs::OpenOptions::new()
39        .create_new(true)
40        .write(true)
41        .open(db_output)
42        .await
43        .map_err(RestoreError::CreateRestoreDb)?;
44    tokio::io::copy(snapshot_reader.as_mut(), &mut db)
45        .await
46        .map_err(RestoreError::CreateRestoreDb)?;
47    db.sync_all().await.map_err(RestoreError::CreateRestoreDb)?;
48    drop(db);
49
50    // Download all wal files from snapshot_idx..=wal_idx and apply the wal segments to the db
51    let wal_segments = client.wal_segments(&generation).await?;
52
53    let mut wal_segment_indexes: Vec<Vec<Pos>> = vec![];
54
55    for wal_segment in wal_segments {
56        if wal_segment.index() < snapshot_idx {
57            continue;
58        }
59        if wal_segment.index() > wal_idx {
60            break;
61        }
62        if let Some(last_wal_segment) = wal_segment_indexes.last_mut() {
63            if last_wal_segment
64                .last()
65                .expect("always a wal segment in index")
66                .index()
67                < wal_segment.index()
68            {
69                wal_segment_indexes.push(vec![wal_segment]);
70            } else {
71                last_wal_segment.push(wal_segment);
72            }
73        } else {
74            wal_segment_indexes.push(vec![wal_segment]);
75        };
76    }
77
78    for index in wal_segment_indexes {
79        let wal_file_name = format!("{db_output}-wal");
80        let wal_idx_file_name = format!("{db_output}-shm");
81        let tmp_wal_file_name = format!("{wal_file_name}-{}-tmp", index[0].index_str());
82        let mut wal_file = tokio::fs::OpenOptions::new()
83            .create_new(true)
84            .write(true)
85            .open(&tmp_wal_file_name)
86            .await
87            .map_err(RestoreError::CreateTemporaryWal)?;
88
89        if tokio::fs::remove_file(&wal_idx_file_name).await.is_ok() {
90            info!("Removed wal index file");
91        } else {
92            info!("Failed to remove wal index file");
93        }
94
95        for wal_segment in index.clone() {
96            let mut wal = client.wal_segment_reader(&wal_segment).await?;
97            tokio::io::copy(wal.as_mut(), &mut wal_file)
98                .await
99                .map_err(RestoreError::CreateTemporaryWal)?;
100        }
101        wal_file
102            .sync_all()
103            .await
104            .map_err(RestoreError::CreateTemporaryWal)?;
105        tokio::fs::rename(tmp_wal_file_name, &wal_file_name)
106            .await
107            .map_err(RestoreError::CreateTemporaryWal)?;
108
109        // Apply wal to db
110        let pool = create_db_pool(db_output, encryption_key.clone()).await?;
111        let (col1, col2, col3): (u32, u32, u32) =
112            sqlx::query_as("PRAGMA wal_checkpoint('TRUNCATE');")
113                .fetch_one(&pool)
114                .await
115                .map_err(RestoreError::TruncateWal)?;
116        if col1 != 0 {
117            error!("truncation checkpoint failed during restore ({col1},{col2},{col3})");
118            return Err(RestoreError::ApplyWal)?;
119        }
120        debug!("checkpoint result({col1},{col2},{col3})");
121        // TODO: this is needed when doing manual restore. Try to
122        // write a test that fails without this.
123        pool.close().await;
124    }
125
126    Ok(())
127}
128
129async fn find_max_snapshot_for_generation(
130    client: &dyn ReplicaClient,
131    generation: &str,
132) -> Result<Option<usize>, Error> {
133    client
134        .snapshots(generation)
135        .await
136        .map(|snapshots| snapshots.into_iter().last().map(|snapshot| snapshot.index))
137}
138
139async fn find_latest_generation(client: &dyn ReplicaClient) -> Result<Option<String>, Error> {
140    let generations = client.generations().await?;
141    let mut futures = Vec::with_capacity(generations.len());
142
143    for generation in &generations {
144        futures.push(async {
145            let snapshots = client.snapshots(generation).await?;
146            let first_snapshot_time = snapshots.first().and_then(|s| s.updated_at);
147            Ok((generation.clone(), first_snapshot_time))
148        });
149    }
150
151    futures::future::try_join_all(futures).await.map(|results| {
152        results
153            .into_iter()
154            .max_by(|s1, s2| match (s1.1, s2.1) {
155                (Some(s1), Some(s2)) => s1.cmp(&s2),
156                (Some(_), None) => Ordering::Greater,
157                (None, Some(_)) => Ordering::Less,
158                (None, None) => Ordering::Equal,
159            })
160            .map(|(latest_generation, _)| latest_generation)
161    })
162}
163
164async fn find_latest_wal_index(
165    client: &dyn ReplicaClient,
166    generation: &str,
167) -> Result<Option<usize>, Error> {
168    client
169        .wal_segments(generation)
170        .await
171        .map(|segments| segments.last().map(Pos::index))
172}
173
174pub async fn has_backup(client: &impl ReplicaClient) -> Result<bool, Error> {
175    find_latest_generation(client)
176        .await
177        .map(|generation| generation.is_some())
178}
179
180#[cfg(test)]
181mod tests {
182    use uuid::Uuid;
183
184    use crate::replica::s3::S3Replica;
185
186    use super::*;
187
188    #[tokio::test]
189    async fn test_find_latest_generation() {
190        std::env::set_var("AWS_ACCESS_KEY_ID", "minioadmin");
191        std::env::set_var("AWS_SECRET_ACCESS_KEY", "minioadmin");
192
193        let config = crate::replica::s3::Config {
194            bucket: Uuid::new_v4().to_string(),
195            endpoint: Some("http://localhost:9000".to_string()),
196            region: "eu-west-1".to_string(),
197            prefix: "tmp.db".to_string(),
198        };
199        let s3_client = S3Replica::new(config.clone()).await.unwrap();
200        s3_client.create_bucket().await.unwrap();
201
202        let generations = vec!["7", "3", "2", "1", "101241", "8", "5"];
203        for generation in generations {
204            let tmp_dir = tempfile::tempdir().unwrap();
205            let tmp_file_path = tmp_dir.path().join("tmp.db");
206            let tmp_file_path = tmp_file_path.to_str().unwrap_or_default();
207            tokio::fs::write(tmp_file_path, "foobar".as_bytes())
208                .await
209                .unwrap();
210
211            let index = 123;
212            s3_client
213                .write_snapshot(generation, &Pos::format(index), tmp_file_path)
214                .await
215                .unwrap();
216        }
217
218        let latest_generation = find_latest_generation(&s3_client).await.unwrap();
219        assert_eq!(latest_generation, Some("5".to_string()));
220    }
221}