1use std::collections::HashMap;
2use std::time::Duration;
3
4use parking_lot::RwLock;
5use tokio_postgres::NoTls;
6
7#[derive(Debug, Clone)]
8pub struct RunningDbContainer {
9 pub container_id: String,
10 pub host_port: u16,
11}
12
13pub struct RdsRuntime {
14 cli: String,
15 containers: RwLock<HashMap<String, RunningDbContainer>>,
16 instance_id: String,
17}
18
19#[derive(Debug, thiserror::Error)]
20pub enum RuntimeError {
21 #[error("container runtime is unavailable")]
22 Unavailable,
23 #[error("container failed to start: {0}")]
24 ContainerStartFailed(String),
25}
26
27impl RdsRuntime {
28 pub fn new() -> Option<Self> {
29 let cli = if let Ok(cli) = std::env::var("FAKECLOUD_CONTAINER_CLI") {
30 if cli_available(&cli) {
31 cli
32 } else {
33 return None;
34 }
35 } else if cli_available("docker") {
36 "docker".to_string()
37 } else if cli_available("podman") {
38 "podman".to_string()
39 } else {
40 return None;
41 };
42
43 Some(Self {
44 cli,
45 containers: RwLock::new(HashMap::new()),
46 instance_id: format!("fakecloud-{}", std::process::id()),
47 })
48 }
49
50 pub fn cli_name(&self) -> &str {
51 &self.cli
52 }
53
54 pub async fn ensure_postgres(
55 &self,
56 db_instance_identifier: &str,
57 engine: &str,
58 engine_version: &str,
59 username: &str,
60 password: &str,
61 db_name: &str,
62 ) -> Result<RunningDbContainer, RuntimeError> {
63 self.stop_container(db_instance_identifier).await;
64
65 let (image, port, env_vars) = match engine {
67 "postgres" => {
68 let major_version = engine_version.split('.').next().unwrap_or("16");
69 let image = format!("postgres:{}-alpine", major_version);
70 let env_vars = vec![
71 format!("POSTGRES_USER={username}"),
72 format!("POSTGRES_PASSWORD={password}"),
73 format!("POSTGRES_DB={db_name}"),
74 ];
75 (image, "5432", env_vars)
76 }
77 "mysql" => {
78 let major_version = if engine_version.starts_with("5.7") {
79 "5.7"
80 } else {
81 "8.0"
82 };
83 let image = format!("mysql:{}", major_version);
84 let env_vars = vec![
85 format!("MYSQL_ROOT_PASSWORD={password}"),
86 format!("MYSQL_USER={username}"),
87 format!("MYSQL_PASSWORD={password}"),
88 format!("MYSQL_DATABASE={db_name}"),
89 ];
90 (image, "3306", env_vars)
91 }
92 "mariadb" => {
93 let major_version = if engine_version.starts_with("10.11") {
94 "10.11"
95 } else {
96 "10.6"
97 };
98 let image = format!("mariadb:{}", major_version);
99 let env_vars = vec![
100 format!("MARIADB_ROOT_PASSWORD={password}"),
101 format!("MARIADB_USER={username}"),
102 format!("MARIADB_PASSWORD={password}"),
103 format!("MARIADB_DATABASE={db_name}"),
104 ];
105 (image, "3306", env_vars)
106 }
107 _ => {
108 return Err(RuntimeError::ContainerStartFailed(format!(
109 "Unsupported engine: {}",
110 engine
111 )))
112 }
113 };
114
115 let mut args = vec![
117 "create".to_string(),
118 "-p".to_string(),
119 format!(":{}", port),
120 "--label".to_string(),
121 format!("fakecloud-rds={db_instance_identifier}"),
122 "--label".to_string(),
123 format!("fakecloud-instance={}", self.instance_id),
124 ];
125
126 for env_var in env_vars {
127 args.push("-e".to_string());
128 args.push(env_var);
129 }
130
131 args.push(image);
132
133 let output = tokio::process::Command::new(&self.cli)
134 .args(&args)
135 .output()
136 .await
137 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
138
139 if !output.status.success() {
140 return Err(RuntimeError::ContainerStartFailed(
141 String::from_utf8_lossy(&output.stderr).trim().to_string(),
142 ));
143 }
144
145 let container_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
146 let start_result = tokio::process::Command::new(&self.cli)
147 .args(["start", &container_id])
148 .output()
149 .await
150 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
151
152 if !start_result.status.success() {
153 self.remove_container(&container_id).await;
154 return Err(RuntimeError::ContainerStartFailed(format!(
155 "container start failed: {}",
156 String::from_utf8_lossy(&start_result.stderr).trim()
157 )));
158 }
159
160 let host_port = match self.lookup_port(&container_id, port).await {
161 Ok(host_port) => host_port,
162 Err(error) => {
163 self.remove_container(&container_id).await;
164 return Err(error);
165 }
166 };
167
168 let wait_result = match engine {
170 "postgres" => {
171 self.wait_for_postgres(username, password, db_name, host_port)
172 .await
173 }
174 "mysql" | "mariadb" => {
175 self.wait_for_mysql(username, password, db_name, host_port)
176 .await
177 }
178 _ => unreachable!("engine already validated"),
179 };
180
181 if let Err(error) = wait_result {
182 self.remove_container(&container_id).await;
183 return Err(error);
184 }
185
186 let running = RunningDbContainer {
187 container_id,
188 host_port,
189 };
190 self.containers
191 .write()
192 .insert(db_instance_identifier.to_string(), running.clone());
193 Ok(running)
194 }
195
196 pub async fn stop_container(&self, db_instance_identifier: &str) {
197 let container = self.containers.write().remove(db_instance_identifier);
198 if let Some(container) = container {
199 self.remove_container(&container.container_id).await;
200 }
201 }
202
203 pub async fn restart_container(
204 &self,
205 db_instance_identifier: &str,
206 engine: &str,
207 username: &str,
208 password: &str,
209 db_name: &str,
210 ) -> Result<RunningDbContainer, RuntimeError> {
211 let running = self
212 .containers
213 .read()
214 .get(db_instance_identifier)
215 .cloned()
216 .ok_or(RuntimeError::Unavailable)?;
217
218 let output = tokio::process::Command::new(&self.cli)
219 .args(["restart", &running.container_id])
220 .output()
221 .await
222 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
223
224 if !output.status.success() {
225 return Err(RuntimeError::ContainerStartFailed(format!(
226 "container restart failed: {}",
227 String::from_utf8_lossy(&output.stderr).trim()
228 )));
229 }
230
231 let port = match engine {
232 "postgres" => "5432",
233 "mysql" | "mariadb" => "3306",
234 _ => "5432", };
236
237 let host_port = self.lookup_port(&running.container_id, port).await?;
238
239 match engine {
240 "postgres" => {
241 self.wait_for_postgres(username, password, db_name, host_port)
242 .await?
243 }
244 "mysql" | "mariadb" => {
245 self.wait_for_mysql(username, password, db_name, host_port)
246 .await?
247 }
248 _ => {
249 self.wait_for_postgres(username, password, db_name, host_port)
250 .await?
251 }
252 };
253 let running = RunningDbContainer {
254 container_id: running.container_id,
255 host_port,
256 };
257 self.containers
258 .write()
259 .insert(db_instance_identifier.to_string(), running.clone());
260 Ok(running)
261 }
262
263 pub async fn stop_all(&self) {
264 let containers: Vec<String> = {
265 let mut containers = self.containers.write();
266 containers
267 .drain()
268 .map(|(_, container)| container.container_id)
269 .collect()
270 };
271 for container_id in containers {
272 self.remove_container(&container_id).await;
273 }
274 }
275
276 async fn lookup_port(&self, container_id: &str, port: &str) -> Result<u16, RuntimeError> {
277 let port_output = tokio::process::Command::new(&self.cli)
278 .args(["port", container_id, port])
279 .output()
280 .await
281 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
282
283 let port_str = String::from_utf8_lossy(&port_output.stdout);
284 port_str
285 .trim()
286 .rsplit(':')
287 .next()
288 .and_then(|value| value.parse::<u16>().ok())
289 .ok_or_else(|| {
290 RuntimeError::ContainerStartFailed(format!(
291 "could not determine container port from '{}'",
292 port_str.trim()
293 ))
294 })
295 }
296
297 async fn wait_for_postgres(
298 &self,
299 username: &str,
300 password: &str,
301 db_name: &str,
302 host_port: u16,
303 ) -> Result<(), RuntimeError> {
304 for _ in 0..40 {
305 tokio::time::sleep(Duration::from_millis(500)).await;
306 let connection_string = format!(
307 "host=127.0.0.1 port={host_port} user={username} password={password} dbname={db_name}"
308 );
309 if let Ok((client, connection)) =
310 tokio_postgres::connect(&connection_string, NoTls).await
311 {
312 tokio::spawn(async move {
313 let _ = connection.await;
314 });
315 if client.simple_query("SELECT 1").await.is_ok() {
316 return Ok(());
317 }
318 }
319 }
320
321 Err(RuntimeError::ContainerStartFailed(
322 "postgres container did not become ready within 20 seconds".to_string(),
323 ))
324 }
325
326 async fn wait_for_mysql(
327 &self,
328 username: &str,
329 password: &str,
330 db_name: &str,
331 host_port: u16,
332 ) -> Result<(), RuntimeError> {
333 use mysql_async::prelude::*;
334 use mysql_async::OptsBuilder;
335
336 for attempt in 1..=40 {
337 let opts = OptsBuilder::default()
338 .ip_or_hostname("127.0.0.1")
339 .tcp_port(host_port)
340 .user(Some(username))
341 .pass(Some(password))
342 .db_name(Some(db_name));
343
344 match mysql_async::Conn::new(opts).await {
345 Ok(mut conn) => {
346 if conn.query_drop("SELECT 1").await.is_ok() {
347 let _ = conn.disconnect().await;
348 return Ok(());
349 }
350 }
351 Err(_) => {
352 if attempt < 40 {
353 tokio::time::sleep(Duration::from_millis(500)).await;
354 }
355 continue;
356 }
357 }
358 }
359
360 Err(RuntimeError::ContainerStartFailed(
361 "MySQL/MariaDB container did not become ready within 20 seconds".to_string(),
362 ))
363 }
364
365 async fn remove_container(&self, container_id: &str) {
366 let _ = tokio::process::Command::new(&self.cli)
367 .args(["rm", "-f", container_id])
368 .output()
369 .await;
370 }
371
372 pub async fn dump_database(
373 &self,
374 db_instance_identifier: &str,
375 engine: &str,
376 username: &str,
377 password: &str,
378 db_name: &str,
379 ) -> Result<Vec<u8>, RuntimeError> {
380 let container = self
381 .containers
382 .read()
383 .get(db_instance_identifier)
384 .cloned()
385 .ok_or(RuntimeError::Unavailable)?;
386
387 let args: Vec<String> = match engine {
388 "mysql" | "mariadb" => vec![
389 "exec".into(),
390 container.container_id.clone(),
391 "mysqldump".into(),
392 "-u".into(),
393 username.into(),
394 format!("-p{password}"),
395 db_name.into(),
396 ],
397 _ => vec![
398 "exec".into(),
399 container.container_id.clone(),
400 "pg_dump".into(),
401 "-U".into(),
402 username.into(),
403 "-d".into(),
404 db_name.into(),
405 "--no-password".into(),
406 ],
407 };
408
409 let output = tokio::process::Command::new(&self.cli)
410 .args(&args)
411 .output()
412 .await
413 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
414
415 if !output.status.success() {
416 return Err(RuntimeError::ContainerStartFailed(format!(
417 "dump failed: {}",
418 String::from_utf8_lossy(&output.stderr).trim()
419 )));
420 }
421
422 Ok(output.stdout)
423 }
424
425 pub async fn restore_database(
426 &self,
427 db_instance_identifier: &str,
428 engine: &str,
429 username: &str,
430 password: &str,
431 db_name: &str,
432 dump_data: &[u8],
433 ) -> Result<(), RuntimeError> {
434 let container = self
435 .containers
436 .read()
437 .get(db_instance_identifier)
438 .cloned()
439 .ok_or(RuntimeError::Unavailable)?;
440
441 let args: Vec<String> = match engine {
442 "mysql" | "mariadb" => vec![
443 "exec".into(),
444 "-i".into(),
445 container.container_id.clone(),
446 "mysql".into(),
447 "-u".into(),
448 username.into(),
449 format!("-p{password}"),
450 db_name.into(),
451 ],
452 _ => vec![
453 "exec".into(),
454 "-i".into(),
455 container.container_id.clone(),
456 "psql".into(),
457 "-U".into(),
458 username.into(),
459 "-d".into(),
460 db_name.into(),
461 "--no-password".into(),
462 "-v".into(),
463 "ON_ERROR_STOP=1".into(),
464 ],
465 };
466
467 let mut child = tokio::process::Command::new(&self.cli)
468 .args(&args)
469 .stdin(std::process::Stdio::piped())
470 .stdout(std::process::Stdio::piped())
471 .stderr(std::process::Stdio::piped())
472 .spawn()
473 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
474
475 if let Some(mut stdin) = child.stdin.take() {
476 use tokio::io::AsyncWriteExt;
477 stdin
478 .write_all(dump_data)
479 .await
480 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
481 drop(stdin);
482 }
483
484 let output = child
485 .wait_with_output()
486 .await
487 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
488
489 if !output.status.success() {
490 return Err(RuntimeError::ContainerStartFailed(format!(
491 "restore failed: {}",
492 String::from_utf8_lossy(&output.stderr).trim()
493 )));
494 }
495
496 Ok(())
497 }
498}
499
500fn cli_available(cli: &str) -> bool {
501 std::process::Command::new(cli)
502 .arg("info")
503 .stdout(std::process::Stdio::null())
504 .stderr(std::process::Stdio::null())
505 .status()
506 .map(|status| status.success())
507 .unwrap_or(false)
508}