1use std::cmp::Ordering;
2
3use tracing::{debug, error, info};
4
5use crate::{
6 db::create_db_pool,
7 error::{Error, RestoreError},
8 pos::Pos,
9 replica::ReplicaClient,
10};
11
12pub async fn restore(
13 client: &impl ReplicaClient,
14 db_output: &str,
15 encryption_key: Option<String>,
16) -> Result<(), Error> {
17 info!("Restoring");
18 let generation = find_latest_generation(client)
19 .await?
20 .ok_or(RestoreError::NoGeneration)?;
21 info!("Latest generation to restore '{generation}'");
22 let snapshot_idx = find_max_snapshot_for_generation(client, &generation)
23 .await?
24 .ok_or_else(|| RestoreError::NoSnapshotForGeneration {
25 generation: generation.clone(),
26 })?;
27 let wal_idx = find_latest_wal_index(client, &generation)
28 .await?
29 .ok_or_else(|| RestoreError::LatestWalIndex {
30 generation: generation.clone(),
31 })?;
32 assert!(snapshot_idx <= wal_idx, "TODO");
33
34 let mut snapshot_reader = client
36 .snapshot_reader(&generation, &Pos::format(snapshot_idx))
37 .await?;
38 let mut db = tokio::fs::OpenOptions::new()
39 .create_new(true)
40 .write(true)
41 .open(db_output)
42 .await
43 .map_err(RestoreError::CreateRestoreDb)?;
44 tokio::io::copy(snapshot_reader.as_mut(), &mut db)
45 .await
46 .map_err(RestoreError::CreateRestoreDb)?;
47 db.sync_all().await.map_err(RestoreError::CreateRestoreDb)?;
48 drop(db);
49
50 let wal_segments = client.wal_segments(&generation).await?;
52
53 let mut wal_segment_indexes: Vec<Vec<Pos>> = vec![];
54
55 for wal_segment in wal_segments {
56 if wal_segment.index() < snapshot_idx {
57 continue;
58 }
59 if wal_segment.index() > wal_idx {
60 break;
61 }
62 if let Some(last_wal_segment) = wal_segment_indexes.last_mut() {
63 if last_wal_segment
64 .last()
65 .expect("always a wal segment in index")
66 .index()
67 < wal_segment.index()
68 {
69 wal_segment_indexes.push(vec![wal_segment]);
70 } else {
71 last_wal_segment.push(wal_segment);
72 }
73 } else {
74 wal_segment_indexes.push(vec![wal_segment]);
75 };
76 }
77
78 for index in wal_segment_indexes {
79 let wal_file_name = format!("{db_output}-wal");
80 let wal_idx_file_name = format!("{db_output}-shm");
81 let tmp_wal_file_name = format!("{wal_file_name}-{}-tmp", index[0].index_str());
82 let mut wal_file = tokio::fs::OpenOptions::new()
83 .create_new(true)
84 .write(true)
85 .open(&tmp_wal_file_name)
86 .await
87 .map_err(RestoreError::CreateTemporaryWal)?;
88
89 if tokio::fs::remove_file(&wal_idx_file_name).await.is_ok() {
90 info!("Removed wal index file");
91 } else {
92 info!("Failed to remove wal index file");
93 }
94
95 for wal_segment in index.clone() {
96 let mut wal = client.wal_segment_reader(&wal_segment).await?;
97 tokio::io::copy(wal.as_mut(), &mut wal_file)
98 .await
99 .map_err(RestoreError::CreateTemporaryWal)?;
100 }
101 wal_file
102 .sync_all()
103 .await
104 .map_err(RestoreError::CreateTemporaryWal)?;
105 tokio::fs::rename(tmp_wal_file_name, &wal_file_name)
106 .await
107 .map_err(RestoreError::CreateTemporaryWal)?;
108
109 let pool = create_db_pool(db_output, encryption_key.clone()).await?;
111 let (col1, col2, col3): (u32, u32, u32) =
112 sqlx::query_as("PRAGMA wal_checkpoint('TRUNCATE');")
113 .fetch_one(&pool)
114 .await
115 .map_err(RestoreError::TruncateWal)?;
116 if col1 != 0 {
117 error!("truncation checkpoint failed during restore ({col1},{col2},{col3})");
118 return Err(RestoreError::ApplyWal)?;
119 }
120 debug!("checkpoint result({col1},{col2},{col3})");
121 pool.close().await;
124 }
125
126 Ok(())
127}
128
129async fn find_max_snapshot_for_generation(
130 client: &dyn ReplicaClient,
131 generation: &str,
132) -> Result<Option<usize>, Error> {
133 client
134 .snapshots(generation)
135 .await
136 .map(|snapshots| snapshots.into_iter().last().map(|snapshot| snapshot.index))
137}
138
139async fn find_latest_generation(client: &dyn ReplicaClient) -> Result<Option<String>, Error> {
140 let generations = client.generations().await?;
141 let mut futures = Vec::with_capacity(generations.len());
142
143 for generation in &generations {
144 futures.push(async {
145 let snapshots = client.snapshots(generation).await?;
146 let first_snapshot_time = snapshots.first().and_then(|s| s.updated_at);
147 Ok((generation.clone(), first_snapshot_time))
148 });
149 }
150
151 futures::future::try_join_all(futures).await.map(|results| {
152 results
153 .into_iter()
154 .max_by(|s1, s2| match (s1.1, s2.1) {
155 (Some(s1), Some(s2)) => s1.cmp(&s2),
156 (Some(_), None) => Ordering::Greater,
157 (None, Some(_)) => Ordering::Less,
158 (None, None) => Ordering::Equal,
159 })
160 .map(|(latest_generation, _)| latest_generation)
161 })
162}
163
164async fn find_latest_wal_index(
165 client: &dyn ReplicaClient,
166 generation: &str,
167) -> Result<Option<usize>, Error> {
168 client
169 .wal_segments(generation)
170 .await
171 .map(|segments| segments.last().map(Pos::index))
172}
173
174pub async fn has_backup(client: &impl ReplicaClient) -> Result<bool, Error> {
175 find_latest_generation(client)
176 .await
177 .map(|generation| generation.is_some())
178}
179
180#[cfg(test)]
181mod tests {
182 use uuid::Uuid;
183
184 use crate::replica::s3::S3Replica;
185
186 use super::*;
187
188 #[tokio::test]
189 async fn test_find_latest_generation() {
190 std::env::set_var("AWS_ACCESS_KEY_ID", "minioadmin");
191 std::env::set_var("AWS_SECRET_ACCESS_KEY", "minioadmin");
192
193 let config = crate::replica::s3::Config {
194 bucket: Uuid::new_v4().to_string(),
195 endpoint: Some("http://localhost:9000".to_string()),
196 region: "eu-west-1".to_string(),
197 prefix: "tmp.db".to_string(),
198 };
199 let s3_client = S3Replica::new(config.clone()).await.unwrap();
200 s3_client.create_bucket().await.unwrap();
201
202 let generations = vec!["7", "3", "2", "1", "101241", "8", "5"];
203 for generation in generations {
204 let tmp_dir = tempfile::tempdir().unwrap();
205 let tmp_file_path = tmp_dir.path().join("tmp.db");
206 let tmp_file_path = tmp_file_path.to_str().unwrap_or_default();
207 tokio::fs::write(tmp_file_path, "foobar".as_bytes())
208 .await
209 .unwrap();
210
211 let index = 123;
212 s3_client
213 .write_snapshot(generation, &Pos::format(index), tmp_file_path)
214 .await
215 .unwrap();
216 }
217
218 let latest_generation = find_latest_generation(&s3_client).await.unwrap();
219 assert_eq!(latest_generation, Some("5".to_string()));
220 }
221}