Skip to main content

fakecloud_rds/
runtime.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use parking_lot::RwLock;
5use tokio_postgres::NoTls;
6
7#[derive(Debug, Clone)]
8pub struct RunningDbContainer {
9    pub container_id: String,
10    pub host_port: u16,
11}
12
13pub struct RdsRuntime {
14    cli: String,
15    containers: RwLock<HashMap<String, RunningDbContainer>>,
16    instance_id: String,
17}
18
19#[derive(Debug, thiserror::Error)]
20pub enum RuntimeError {
21    #[error("container runtime is unavailable")]
22    Unavailable,
23    #[error("container failed to start: {0}")]
24    ContainerStartFailed(String),
25}
26
27impl RdsRuntime {
28    pub fn new() -> Option<Self> {
29        let cli = if let Ok(cli) = std::env::var("FAKECLOUD_CONTAINER_CLI") {
30            if cli_available(&cli) {
31                cli
32            } else {
33                return None;
34            }
35        } else if cli_available("docker") {
36            "docker".to_string()
37        } else if cli_available("podman") {
38            "podman".to_string()
39        } else {
40            return None;
41        };
42
43        Some(Self {
44            cli,
45            containers: RwLock::new(HashMap::new()),
46            instance_id: format!("fakecloud-{}", std::process::id()),
47        })
48    }
49
50    pub fn cli_name(&self) -> &str {
51        &self.cli
52    }
53
54    pub async fn ensure_postgres(
55        &self,
56        db_instance_identifier: &str,
57        engine: &str,
58        engine_version: &str,
59        username: &str,
60        password: &str,
61        db_name: &str,
62    ) -> Result<RunningDbContainer, RuntimeError> {
63        self.stop_container(db_instance_identifier).await;
64
65        // Determine Docker image and port based on engine
66        let (image, port, env_vars) = match engine {
67            "postgres" => {
68                let major_version = engine_version.split('.').next().unwrap_or("16");
69                let image = format!("postgres:{}-alpine", major_version);
70                let env_vars = vec![
71                    format!("POSTGRES_USER={username}"),
72                    format!("POSTGRES_PASSWORD={password}"),
73                    format!("POSTGRES_DB={db_name}"),
74                ];
75                (image, "5432", env_vars)
76            }
77            "mysql" => {
78                let major_version = if engine_version.starts_with("5.7") {
79                    "5.7"
80                } else {
81                    "8.0"
82                };
83                let image = format!("mysql:{}", major_version);
84                let env_vars = vec![
85                    format!("MYSQL_ROOT_PASSWORD={password}"),
86                    format!("MYSQL_USER={username}"),
87                    format!("MYSQL_PASSWORD={password}"),
88                    format!("MYSQL_DATABASE={db_name}"),
89                ];
90                (image, "3306", env_vars)
91            }
92            "mariadb" => {
93                let major_version = if engine_version.starts_with("10.11") {
94                    "10.11"
95                } else {
96                    "10.6"
97                };
98                let image = format!("mariadb:{}", major_version);
99                let env_vars = vec![
100                    format!("MARIADB_ROOT_PASSWORD={password}"),
101                    format!("MARIADB_USER={username}"),
102                    format!("MARIADB_PASSWORD={password}"),
103                    format!("MARIADB_DATABASE={db_name}"),
104                ];
105                (image, "3306", env_vars)
106            }
107            _ => {
108                return Err(RuntimeError::ContainerStartFailed(format!(
109                    "Unsupported engine: {}",
110                    engine
111                )))
112            }
113        };
114
115        // Build container create args
116        let mut args = vec![
117            "create".to_string(),
118            "-p".to_string(),
119            format!(":{}", port),
120            "--label".to_string(),
121            format!("fakecloud-rds={db_instance_identifier}"),
122            "--label".to_string(),
123            format!("fakecloud-instance={}", self.instance_id),
124        ];
125
126        for env_var in env_vars {
127            args.push("-e".to_string());
128            args.push(env_var);
129        }
130
131        args.push(image);
132
133        let output = tokio::process::Command::new(&self.cli)
134            .args(&args)
135            .output()
136            .await
137            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
138
139        if !output.status.success() {
140            return Err(RuntimeError::ContainerStartFailed(
141                String::from_utf8_lossy(&output.stderr).trim().to_string(),
142            ));
143        }
144
145        let container_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
146        let start_result = tokio::process::Command::new(&self.cli)
147            .args(["start", &container_id])
148            .output()
149            .await
150            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
151
152        if !start_result.status.success() {
153            self.remove_container(&container_id).await;
154            return Err(RuntimeError::ContainerStartFailed(format!(
155                "container start failed: {}",
156                String::from_utf8_lossy(&start_result.stderr).trim()
157            )));
158        }
159
160        let host_port = match self.lookup_port(&container_id, port).await {
161            Ok(host_port) => host_port,
162            Err(error) => {
163                self.remove_container(&container_id).await;
164                return Err(error);
165            }
166        };
167
168        // Wait for database to be ready
169        let wait_result = match engine {
170            "postgres" => {
171                self.wait_for_postgres(username, password, db_name, host_port)
172                    .await
173            }
174            "mysql" | "mariadb" => {
175                self.wait_for_mysql(username, password, db_name, host_port)
176                    .await
177            }
178            _ => unreachable!("engine already validated"),
179        };
180
181        if let Err(error) = wait_result {
182            self.remove_container(&container_id).await;
183            return Err(error);
184        }
185
186        let running = RunningDbContainer {
187            container_id,
188            host_port,
189        };
190        self.containers
191            .write()
192            .insert(db_instance_identifier.to_string(), running.clone());
193        Ok(running)
194    }
195
196    pub async fn stop_container(&self, db_instance_identifier: &str) {
197        let container = self.containers.write().remove(db_instance_identifier);
198        if let Some(container) = container {
199            self.remove_container(&container.container_id).await;
200        }
201    }
202
203    pub async fn restart_container(
204        &self,
205        db_instance_identifier: &str,
206        engine: &str,
207        username: &str,
208        password: &str,
209        db_name: &str,
210    ) -> Result<RunningDbContainer, RuntimeError> {
211        let running = self
212            .containers
213            .read()
214            .get(db_instance_identifier)
215            .cloned()
216            .ok_or(RuntimeError::Unavailable)?;
217
218        let output = tokio::process::Command::new(&self.cli)
219            .args(["restart", &running.container_id])
220            .output()
221            .await
222            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
223
224        if !output.status.success() {
225            return Err(RuntimeError::ContainerStartFailed(format!(
226                "container restart failed: {}",
227                String::from_utf8_lossy(&output.stderr).trim()
228            )));
229        }
230
231        let port = match engine {
232            "postgres" => "5432",
233            "mysql" | "mariadb" => "3306",
234            _ => "5432", // fallback
235        };
236
237        let host_port = self.lookup_port(&running.container_id, port).await?;
238
239        match engine {
240            "postgres" => {
241                self.wait_for_postgres(username, password, db_name, host_port)
242                    .await?
243            }
244            "mysql" | "mariadb" => {
245                self.wait_for_mysql(username, password, db_name, host_port)
246                    .await?
247            }
248            _ => {
249                self.wait_for_postgres(username, password, db_name, host_port)
250                    .await?
251            }
252        };
253        let running = RunningDbContainer {
254            container_id: running.container_id,
255            host_port,
256        };
257        self.containers
258            .write()
259            .insert(db_instance_identifier.to_string(), running.clone());
260        Ok(running)
261    }
262
263    pub async fn stop_all(&self) {
264        let containers: Vec<String> = {
265            let mut containers = self.containers.write();
266            containers
267                .drain()
268                .map(|(_, container)| container.container_id)
269                .collect()
270        };
271        for container_id in containers {
272            self.remove_container(&container_id).await;
273        }
274    }
275
276    async fn lookup_port(&self, container_id: &str, port: &str) -> Result<u16, RuntimeError> {
277        let port_output = tokio::process::Command::new(&self.cli)
278            .args(["port", container_id, port])
279            .output()
280            .await
281            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
282
283        let port_str = String::from_utf8_lossy(&port_output.stdout);
284        port_str
285            .trim()
286            .rsplit(':')
287            .next()
288            .and_then(|value| value.parse::<u16>().ok())
289            .ok_or_else(|| {
290                RuntimeError::ContainerStartFailed(format!(
291                    "could not determine container port from '{}'",
292                    port_str.trim()
293                ))
294            })
295    }
296
297    async fn wait_for_postgres(
298        &self,
299        username: &str,
300        password: &str,
301        db_name: &str,
302        host_port: u16,
303    ) -> Result<(), RuntimeError> {
304        for _ in 0..40 {
305            tokio::time::sleep(Duration::from_millis(500)).await;
306            let connection_string = format!(
307                "host=127.0.0.1 port={host_port} user={username} password={password} dbname={db_name}"
308            );
309            if let Ok((client, connection)) =
310                tokio_postgres::connect(&connection_string, NoTls).await
311            {
312                tokio::spawn(async move {
313                    let _ = connection.await;
314                });
315                if client.simple_query("SELECT 1").await.is_ok() {
316                    return Ok(());
317                }
318            }
319        }
320
321        Err(RuntimeError::ContainerStartFailed(
322            "postgres container did not become ready within 20 seconds".to_string(),
323        ))
324    }
325
326    async fn wait_for_mysql(
327        &self,
328        username: &str,
329        password: &str,
330        db_name: &str,
331        host_port: u16,
332    ) -> Result<(), RuntimeError> {
333        use mysql_async::prelude::*;
334        use mysql_async::OptsBuilder;
335
336        for attempt in 1..=40 {
337            let opts = OptsBuilder::default()
338                .ip_or_hostname("127.0.0.1")
339                .tcp_port(host_port)
340                .user(Some(username))
341                .pass(Some(password))
342                .db_name(Some(db_name));
343
344            match mysql_async::Conn::new(opts).await {
345                Ok(mut conn) => {
346                    if conn.query_drop("SELECT 1").await.is_ok() {
347                        let _ = conn.disconnect().await;
348                        return Ok(());
349                    }
350                }
351                Err(_) => {
352                    if attempt < 40 {
353                        tokio::time::sleep(Duration::from_millis(500)).await;
354                    }
355                    continue;
356                }
357            }
358        }
359
360        Err(RuntimeError::ContainerStartFailed(
361            "MySQL/MariaDB container did not become ready within 20 seconds".to_string(),
362        ))
363    }
364
365    async fn remove_container(&self, container_id: &str) {
366        let _ = tokio::process::Command::new(&self.cli)
367            .args(["rm", "-f", container_id])
368            .output()
369            .await;
370    }
371
372    pub async fn dump_database(
373        &self,
374        db_instance_identifier: &str,
375        engine: &str,
376        username: &str,
377        password: &str,
378        db_name: &str,
379    ) -> Result<Vec<u8>, RuntimeError> {
380        let container = self
381            .containers
382            .read()
383            .get(db_instance_identifier)
384            .cloned()
385            .ok_or(RuntimeError::Unavailable)?;
386
387        let args: Vec<String> = match engine {
388            "mysql" | "mariadb" => vec![
389                "exec".into(),
390                container.container_id.clone(),
391                "mysqldump".into(),
392                "-u".into(),
393                username.into(),
394                format!("-p{password}"),
395                db_name.into(),
396            ],
397            _ => vec![
398                "exec".into(),
399                container.container_id.clone(),
400                "pg_dump".into(),
401                "-U".into(),
402                username.into(),
403                "-d".into(),
404                db_name.into(),
405                "--no-password".into(),
406            ],
407        };
408
409        let output = tokio::process::Command::new(&self.cli)
410            .args(&args)
411            .output()
412            .await
413            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
414
415        if !output.status.success() {
416            return Err(RuntimeError::ContainerStartFailed(format!(
417                "dump failed: {}",
418                String::from_utf8_lossy(&output.stderr).trim()
419            )));
420        }
421
422        Ok(output.stdout)
423    }
424
425    pub async fn restore_database(
426        &self,
427        db_instance_identifier: &str,
428        engine: &str,
429        username: &str,
430        password: &str,
431        db_name: &str,
432        dump_data: &[u8],
433    ) -> Result<(), RuntimeError> {
434        let container = self
435            .containers
436            .read()
437            .get(db_instance_identifier)
438            .cloned()
439            .ok_or(RuntimeError::Unavailable)?;
440
441        let args: Vec<String> = match engine {
442            "mysql" | "mariadb" => vec![
443                "exec".into(),
444                "-i".into(),
445                container.container_id.clone(),
446                "mysql".into(),
447                "-u".into(),
448                username.into(),
449                format!("-p{password}"),
450                db_name.into(),
451            ],
452            _ => vec![
453                "exec".into(),
454                "-i".into(),
455                container.container_id.clone(),
456                "psql".into(),
457                "-U".into(),
458                username.into(),
459                "-d".into(),
460                db_name.into(),
461                "--no-password".into(),
462                "-v".into(),
463                "ON_ERROR_STOP=1".into(),
464            ],
465        };
466
467        let mut child = tokio::process::Command::new(&self.cli)
468            .args(&args)
469            .stdin(std::process::Stdio::piped())
470            .stdout(std::process::Stdio::piped())
471            .stderr(std::process::Stdio::piped())
472            .spawn()
473            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
474
475        if let Some(mut stdin) = child.stdin.take() {
476            use tokio::io::AsyncWriteExt;
477            stdin
478                .write_all(dump_data)
479                .await
480                .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
481            drop(stdin);
482        }
483
484        let output = child
485            .wait_with_output()
486            .await
487            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
488
489        if !output.status.success() {
490            return Err(RuntimeError::ContainerStartFailed(format!(
491                "restore failed: {}",
492                String::from_utf8_lossy(&output.stderr).trim()
493            )));
494        }
495
496        Ok(())
497    }
498}
499
500fn cli_available(cli: &str) -> bool {
501    std::process::Command::new(cli)
502        .arg("info")
503        .stdout(std::process::Stdio::null())
504        .stderr(std::process::Stdio::null())
505        .status()
506        .map(|status| status.success())
507        .unwrap_or(false)
508}