1use std::collections::HashMap;
2use std::time::Duration;
3
4use parking_lot::RwLock;
5use tokio_postgres::NoTls;
6
7#[derive(Debug, Clone)]
8pub struct RunningDbContainer {
9 pub container_id: String,
10 pub host_port: u16,
11}
12
13pub struct RdsRuntime {
14 cli: String,
15 containers: RwLock<HashMap<String, RunningDbContainer>>,
16 instance_id: String,
17}
18
19#[derive(Debug, thiserror::Error)]
20pub enum RuntimeError {
21 #[error("container runtime is unavailable")]
22 Unavailable,
23 #[error("container failed to start: {0}")]
24 ContainerStartFailed(String),
25}
26
27impl RdsRuntime {
28 pub fn new() -> Option<Self> {
29 let cli = if let Ok(cli) = std::env::var("FAKECLOUD_CONTAINER_CLI") {
30 if cli_available(&cli) {
31 cli
32 } else {
33 return None;
34 }
35 } else if cli_available("docker") {
36 "docker".to_string()
37 } else if cli_available("podman") {
38 "podman".to_string()
39 } else {
40 return None;
41 };
42
43 Some(Self {
44 cli,
45 containers: RwLock::new(HashMap::new()),
46 instance_id: format!("fakecloud-{}", std::process::id()),
47 })
48 }
49
50 pub fn cli_name(&self) -> &str {
51 &self.cli
52 }
53
54 pub async fn ensure_postgres(
55 &self,
56 db_instance_identifier: &str,
57 engine: &str,
58 engine_version: &str,
59 username: &str,
60 password: &str,
61 db_name: &str,
62 ) -> Result<RunningDbContainer, RuntimeError> {
63 self.stop_container(db_instance_identifier).await;
64
65 let (image, port, env_vars) = match engine {
67 "postgres" => {
68 let major_version = engine_version.split('.').next().unwrap_or("16");
69 let image = format!("postgres:{}-alpine", major_version);
70 let env_vars = vec![
71 format!("POSTGRES_USER={username}"),
72 format!("POSTGRES_PASSWORD={password}"),
73 format!("POSTGRES_DB={db_name}"),
74 ];
75 (image, "5432", env_vars)
76 }
77 "mysql" => {
78 let major_version = if engine_version.starts_with("5.7") {
79 "5.7"
80 } else {
81 "8.0"
82 };
83 let image = format!("mysql:{}", major_version);
84 let env_vars = vec![
85 format!("MYSQL_ROOT_PASSWORD={password}"),
86 format!("MYSQL_USER={username}"),
87 format!("MYSQL_PASSWORD={password}"),
88 format!("MYSQL_DATABASE={db_name}"),
89 ];
90 (image, "3306", env_vars)
91 }
92 "mariadb" => {
93 let major_version = if engine_version.starts_with("10.11") {
94 "10.11"
95 } else {
96 "10.6"
97 };
98 let image = format!("mariadb:{}", major_version);
99 let env_vars = vec![
100 format!("MARIADB_ROOT_PASSWORD={password}"),
101 format!("MARIADB_USER={username}"),
102 format!("MARIADB_PASSWORD={password}"),
103 format!("MARIADB_DATABASE={db_name}"),
104 ];
105 (image, "3306", env_vars)
106 }
107 _ => {
108 return Err(RuntimeError::ContainerStartFailed(format!(
109 "Unsupported engine: {}",
110 engine
111 )))
112 }
113 };
114
115 let mut args = vec![
117 "create".to_string(),
118 "-p".to_string(),
119 format!(":{}", port),
120 "--label".to_string(),
121 format!("fakecloud-rds={db_instance_identifier}"),
122 "--label".to_string(),
123 format!("fakecloud-instance={}", self.instance_id),
124 ];
125
126 for env_var in env_vars {
127 args.push("-e".to_string());
128 args.push(env_var);
129 }
130
131 args.push(image);
132
133 let output = tokio::process::Command::new(&self.cli)
134 .args(&args)
135 .output()
136 .await
137 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
138
139 if !output.status.success() {
140 return Err(RuntimeError::ContainerStartFailed(
141 String::from_utf8_lossy(&output.stderr).trim().to_string(),
142 ));
143 }
144
145 let container_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
146 let start_result = tokio::process::Command::new(&self.cli)
147 .args(["start", &container_id])
148 .output()
149 .await
150 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
151
152 if !start_result.status.success() {
153 self.remove_container(&container_id).await;
154 return Err(RuntimeError::ContainerStartFailed(format!(
155 "container start failed: {}",
156 String::from_utf8_lossy(&start_result.stderr).trim()
157 )));
158 }
159
160 let host_port = match self.lookup_port(&container_id, port).await {
161 Ok(host_port) => host_port,
162 Err(error) => {
163 self.remove_container(&container_id).await;
164 return Err(error);
165 }
166 };
167
168 let wait_result = match engine {
170 "postgres" => {
171 self.wait_for_postgres(username, password, db_name, host_port)
172 .await
173 }
174 "mysql" | "mariadb" => {
175 self.wait_for_mysql(username, password, db_name, host_port)
176 .await
177 }
178 _ => unreachable!("engine already validated"),
179 };
180
181 if let Err(error) = wait_result {
182 self.remove_container(&container_id).await;
183 return Err(error);
184 }
185
186 let running = RunningDbContainer {
187 container_id,
188 host_port,
189 };
190 self.containers
191 .write()
192 .insert(db_instance_identifier.to_string(), running.clone());
193 Ok(running)
194 }
195
196 pub async fn stop_container(&self, db_instance_identifier: &str) {
197 let container = self.containers.write().remove(db_instance_identifier);
198 if let Some(container) = container {
199 self.remove_container(&container.container_id).await;
200 }
201 }
202
203 pub async fn restart_container(
204 &self,
205 db_instance_identifier: &str,
206 engine: &str,
207 username: &str,
208 password: &str,
209 db_name: &str,
210 ) -> Result<RunningDbContainer, RuntimeError> {
211 let running = self
212 .containers
213 .read()
214 .get(db_instance_identifier)
215 .cloned()
216 .ok_or(RuntimeError::Unavailable)?;
217
218 let output = tokio::process::Command::new(&self.cli)
219 .args(["restart", &running.container_id])
220 .output()
221 .await
222 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
223
224 if !output.status.success() {
225 return Err(RuntimeError::ContainerStartFailed(format!(
226 "container restart failed: {}",
227 String::from_utf8_lossy(&output.stderr).trim()
228 )));
229 }
230
231 let port = match engine {
232 "postgres" => "5432",
233 "mysql" | "mariadb" => "3306",
234 _ => "5432", };
236
237 let host_port = self.lookup_port(&running.container_id, port).await?;
238
239 match engine {
240 "postgres" => {
241 self.wait_for_postgres(username, password, db_name, host_port)
242 .await?
243 }
244 "mysql" | "mariadb" => {
245 self.wait_for_mysql(username, password, db_name, host_port)
246 .await?
247 }
248 _ => {
249 self.wait_for_postgres(username, password, db_name, host_port)
250 .await?
251 }
252 };
253 let running = RunningDbContainer {
254 container_id: running.container_id,
255 host_port,
256 };
257 self.containers
258 .write()
259 .insert(db_instance_identifier.to_string(), running.clone());
260 Ok(running)
261 }
262
263 pub async fn stop_all(&self) {
264 let containers: Vec<String> = {
265 let mut containers = self.containers.write();
266 containers
267 .drain()
268 .map(|(_, container)| container.container_id)
269 .collect()
270 };
271 for container_id in containers {
272 self.remove_container(&container_id).await;
273 }
274 }
275
276 async fn lookup_port(&self, container_id: &str, port: &str) -> Result<u16, RuntimeError> {
277 let port_output = tokio::process::Command::new(&self.cli)
278 .args(["port", container_id, port])
279 .output()
280 .await
281 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
282
283 let port_str = String::from_utf8_lossy(&port_output.stdout);
284 port_str
285 .trim()
286 .rsplit(':')
287 .next()
288 .and_then(|value| value.parse::<u16>().ok())
289 .ok_or_else(|| {
290 RuntimeError::ContainerStartFailed(format!(
291 "could not determine container port from '{}'",
292 port_str.trim()
293 ))
294 })
295 }
296
297 async fn wait_for_postgres(
298 &self,
299 username: &str,
300 password: &str,
301 db_name: &str,
302 host_port: u16,
303 ) -> Result<(), RuntimeError> {
304 for _ in 0..40 {
305 tokio::time::sleep(Duration::from_millis(500)).await;
306 let connection_string = format!(
307 "host=127.0.0.1 port={host_port} user={username} password={password} dbname={db_name}"
308 );
309 if let Ok((client, connection)) =
310 tokio_postgres::connect(&connection_string, NoTls).await
311 {
312 tokio::spawn(async move {
313 let _ = connection.await;
314 });
315 if client.simple_query("SELECT 1").await.is_ok() {
316 return Ok(());
317 }
318 }
319 }
320
321 Err(RuntimeError::ContainerStartFailed(
322 "postgres container did not become ready within 20 seconds".to_string(),
323 ))
324 }
325
326 async fn wait_for_mysql(
327 &self,
328 username: &str,
329 password: &str,
330 db_name: &str,
331 host_port: u16,
332 ) -> Result<(), RuntimeError> {
333 use mysql_async::prelude::*;
334 use mysql_async::OptsBuilder;
335
336 for attempt in 1..=40 {
337 let opts = OptsBuilder::default()
338 .ip_or_hostname("127.0.0.1")
339 .tcp_port(host_port)
340 .user(Some(username))
341 .pass(Some(password))
342 .db_name(Some(db_name));
343
344 match mysql_async::Conn::new(opts).await {
345 Ok(mut conn) => {
346 if conn.query_drop("SELECT 1").await.is_ok() {
347 let _ = conn.disconnect().await;
348 return Ok(());
349 }
350 }
351 Err(_) => {
352 if attempt < 40 {
353 tokio::time::sleep(Duration::from_millis(500)).await;
354 }
355 continue;
356 }
357 }
358 }
359
360 Err(RuntimeError::ContainerStartFailed(
361 "MySQL/MariaDB container did not become ready within 20 seconds".to_string(),
362 ))
363 }
364
365 async fn remove_container(&self, container_id: &str) {
366 let _ = tokio::process::Command::new(&self.cli)
367 .args(["rm", "-f", container_id])
368 .output()
369 .await;
370 }
371
372 pub async fn dump_database(
373 &self,
374 db_instance_identifier: &str,
375 username: &str,
376 db_name: &str,
377 ) -> Result<Vec<u8>, RuntimeError> {
378 let container = self
379 .containers
380 .read()
381 .get(db_instance_identifier)
382 .cloned()
383 .ok_or(RuntimeError::Unavailable)?;
384
385 let output = tokio::process::Command::new(&self.cli)
386 .args([
387 "exec",
388 &container.container_id,
389 "pg_dump",
390 "-U",
391 username,
392 "-d",
393 db_name,
394 "--no-password",
395 ])
396 .output()
397 .await
398 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
399
400 if !output.status.success() {
401 return Err(RuntimeError::ContainerStartFailed(format!(
402 "pg_dump failed: {}",
403 String::from_utf8_lossy(&output.stderr).trim()
404 )));
405 }
406
407 Ok(output.stdout)
408 }
409
410 pub async fn restore_database(
411 &self,
412 db_instance_identifier: &str,
413 username: &str,
414 db_name: &str,
415 dump_data: &[u8],
416 ) -> Result<(), RuntimeError> {
417 let container = self
418 .containers
419 .read()
420 .get(db_instance_identifier)
421 .cloned()
422 .ok_or(RuntimeError::Unavailable)?;
423
424 let mut child = tokio::process::Command::new(&self.cli)
425 .args([
426 "exec",
427 "-i",
428 &container.container_id,
429 "psql",
430 "-U",
431 username,
432 "-d",
433 db_name,
434 "--no-password",
435 "-v",
436 "ON_ERROR_STOP=1",
437 ])
438 .stdin(std::process::Stdio::piped())
439 .stdout(std::process::Stdio::piped())
440 .stderr(std::process::Stdio::piped())
441 .spawn()
442 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
443
444 if let Some(mut stdin) = child.stdin.take() {
445 use tokio::io::AsyncWriteExt;
446 stdin
447 .write_all(dump_data)
448 .await
449 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
450 drop(stdin);
451 }
452
453 let output = child
454 .wait_with_output()
455 .await
456 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
457
458 if !output.status.success() {
459 return Err(RuntimeError::ContainerStartFailed(format!(
460 "psql restore failed: {}",
461 String::from_utf8_lossy(&output.stderr).trim()
462 )));
463 }
464
465 Ok(())
466 }
467}
468
469fn cli_available(cli: &str) -> bool {
470 std::process::Command::new(cli)
471 .arg("info")
472 .stdout(std::process::Stdio::null())
473 .stderr(std::process::Stdio::null())
474 .status()
475 .map(|status| status.success())
476 .unwrap_or(false)
477}