Skip to main content

fakecloud_rds/
runtime.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use parking_lot::RwLock;
5use tokio_postgres::NoTls;
6
7#[derive(Debug, Clone)]
8pub struct RunningDbContainer {
9    pub container_id: String,
10    pub host_port: u16,
11}
12
13pub struct RdsRuntime {
14    cli: String,
15    containers: RwLock<HashMap<String, RunningDbContainer>>,
16    instance_id: String,
17}
18
19#[derive(Debug, thiserror::Error)]
20pub enum RuntimeError {
21    #[error("container runtime is unavailable")]
22    Unavailable,
23    #[error("container failed to start: {0}")]
24    ContainerStartFailed(String),
25}
26
27impl RdsRuntime {
28    pub fn new() -> Option<Self> {
29        let cli = if let Ok(cli) = std::env::var("FAKECLOUD_CONTAINER_CLI") {
30            if cli_available(&cli) {
31                cli
32            } else {
33                return None;
34            }
35        } else if cli_available("docker") {
36            "docker".to_string()
37        } else if cli_available("podman") {
38            "podman".to_string()
39        } else {
40            return None;
41        };
42
43        Some(Self {
44            cli,
45            containers: RwLock::new(HashMap::new()),
46            instance_id: format!("fakecloud-{}", std::process::id()),
47        })
48    }
49
50    pub fn cli_name(&self) -> &str {
51        &self.cli
52    }
53
54    pub async fn ensure_postgres(
55        &self,
56        db_instance_identifier: &str,
57        engine: &str,
58        engine_version: &str,
59        username: &str,
60        password: &str,
61        db_name: &str,
62    ) -> Result<RunningDbContainer, RuntimeError> {
63        self.stop_container(db_instance_identifier).await;
64
65        // Determine Docker image and port based on engine
66        let (image, port, env_vars) = match engine {
67            "postgres" => {
68                let major_version = engine_version.split('.').next().unwrap_or("16");
69                let image = format!("postgres:{}-alpine", major_version);
70                let env_vars = vec![
71                    format!("POSTGRES_USER={username}"),
72                    format!("POSTGRES_PASSWORD={password}"),
73                    format!("POSTGRES_DB={db_name}"),
74                ];
75                (image, "5432", env_vars)
76            }
77            "mysql" => {
78                let major_version = if engine_version.starts_with("5.7") {
79                    "5.7"
80                } else {
81                    "8.0"
82                };
83                let image = format!("mysql:{}", major_version);
84                let env_vars = vec![
85                    format!("MYSQL_ROOT_PASSWORD={password}"),
86                    format!("MYSQL_USER={username}"),
87                    format!("MYSQL_PASSWORD={password}"),
88                    format!("MYSQL_DATABASE={db_name}"),
89                ];
90                (image, "3306", env_vars)
91            }
92            "mariadb" => {
93                let major_version = if engine_version.starts_with("10.11") {
94                    "10.11"
95                } else {
96                    "10.6"
97                };
98                let image = format!("mariadb:{}", major_version);
99                let env_vars = vec![
100                    format!("MARIADB_ROOT_PASSWORD={password}"),
101                    format!("MARIADB_USER={username}"),
102                    format!("MARIADB_PASSWORD={password}"),
103                    format!("MARIADB_DATABASE={db_name}"),
104                ];
105                (image, "3306", env_vars)
106            }
107            _ => {
108                return Err(RuntimeError::ContainerStartFailed(format!(
109                    "Unsupported engine: {}",
110                    engine
111                )))
112            }
113        };
114
115        // Build container create args
116        let mut args = vec![
117            "create".to_string(),
118            "-p".to_string(),
119            format!(":{}", port),
120            "--label".to_string(),
121            format!("fakecloud-rds={db_instance_identifier}"),
122            "--label".to_string(),
123            format!("fakecloud-instance={}", self.instance_id),
124        ];
125
126        for env_var in env_vars {
127            args.push("-e".to_string());
128            args.push(env_var);
129        }
130
131        args.push(image);
132
133        let output = tokio::process::Command::new(&self.cli)
134            .args(&args)
135            .output()
136            .await
137            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
138
139        if !output.status.success() {
140            return Err(RuntimeError::ContainerStartFailed(
141                String::from_utf8_lossy(&output.stderr).trim().to_string(),
142            ));
143        }
144
145        let container_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
146        let start_result = tokio::process::Command::new(&self.cli)
147            .args(["start", &container_id])
148            .output()
149            .await
150            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
151
152        if !start_result.status.success() {
153            self.remove_container(&container_id).await;
154            return Err(RuntimeError::ContainerStartFailed(format!(
155                "container start failed: {}",
156                String::from_utf8_lossy(&start_result.stderr).trim()
157            )));
158        }
159
160        let host_port = match self.lookup_port(&container_id, port).await {
161            Ok(host_port) => host_port,
162            Err(error) => {
163                self.remove_container(&container_id).await;
164                return Err(error);
165            }
166        };
167
168        // Wait for database to be ready
169        let wait_result = match engine {
170            "postgres" => {
171                self.wait_for_postgres(username, password, db_name, host_port)
172                    .await
173            }
174            "mysql" | "mariadb" => {
175                self.wait_for_mysql(username, password, db_name, host_port)
176                    .await
177            }
178            _ => unreachable!("engine already validated"),
179        };
180
181        if let Err(error) = wait_result {
182            self.remove_container(&container_id).await;
183            return Err(error);
184        }
185
186        let running = RunningDbContainer {
187            container_id,
188            host_port,
189        };
190        self.containers
191            .write()
192            .insert(db_instance_identifier.to_string(), running.clone());
193        Ok(running)
194    }
195
196    pub async fn stop_container(&self, db_instance_identifier: &str) {
197        let container = self.containers.write().remove(db_instance_identifier);
198        if let Some(container) = container {
199            self.remove_container(&container.container_id).await;
200        }
201    }
202
203    pub async fn restart_container(
204        &self,
205        db_instance_identifier: &str,
206        engine: &str,
207        username: &str,
208        password: &str,
209        db_name: &str,
210    ) -> Result<RunningDbContainer, RuntimeError> {
211        let running = self
212            .containers
213            .read()
214            .get(db_instance_identifier)
215            .cloned()
216            .ok_or(RuntimeError::Unavailable)?;
217
218        let output = tokio::process::Command::new(&self.cli)
219            .args(["restart", &running.container_id])
220            .output()
221            .await
222            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
223
224        if !output.status.success() {
225            return Err(RuntimeError::ContainerStartFailed(format!(
226                "container restart failed: {}",
227                String::from_utf8_lossy(&output.stderr).trim()
228            )));
229        }
230
231        let port = match engine {
232            "postgres" => "5432",
233            "mysql" | "mariadb" => "3306",
234            _ => "5432", // fallback
235        };
236
237        let host_port = self.lookup_port(&running.container_id, port).await?;
238
239        match engine {
240            "postgres" => {
241                self.wait_for_postgres(username, password, db_name, host_port)
242                    .await?
243            }
244            "mysql" | "mariadb" => {
245                self.wait_for_mysql(username, password, db_name, host_port)
246                    .await?
247            }
248            _ => {
249                self.wait_for_postgres(username, password, db_name, host_port)
250                    .await?
251            }
252        };
253        let running = RunningDbContainer {
254            container_id: running.container_id,
255            host_port,
256        };
257        self.containers
258            .write()
259            .insert(db_instance_identifier.to_string(), running.clone());
260        Ok(running)
261    }
262
263    pub async fn stop_all(&self) {
264        let containers: Vec<String> = {
265            let mut containers = self.containers.write();
266            containers
267                .drain()
268                .map(|(_, container)| container.container_id)
269                .collect()
270        };
271        for container_id in containers {
272            self.remove_container(&container_id).await;
273        }
274    }
275
276    async fn lookup_port(&self, container_id: &str, port: &str) -> Result<u16, RuntimeError> {
277        let port_output = tokio::process::Command::new(&self.cli)
278            .args(["port", container_id, port])
279            .output()
280            .await
281            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
282
283        let port_str = String::from_utf8_lossy(&port_output.stdout);
284        port_str
285            .trim()
286            .rsplit(':')
287            .next()
288            .and_then(|value| value.parse::<u16>().ok())
289            .ok_or_else(|| {
290                RuntimeError::ContainerStartFailed(format!(
291                    "could not determine container port from '{}'",
292                    port_str.trim()
293                ))
294            })
295    }
296
297    async fn wait_for_postgres(
298        &self,
299        username: &str,
300        password: &str,
301        db_name: &str,
302        host_port: u16,
303    ) -> Result<(), RuntimeError> {
304        for _ in 0..40 {
305            tokio::time::sleep(Duration::from_millis(500)).await;
306            let connection_string = format!(
307                "host=127.0.0.1 port={host_port} user={username} password={password} dbname={db_name}"
308            );
309            if let Ok((client, connection)) =
310                tokio_postgres::connect(&connection_string, NoTls).await
311            {
312                tokio::spawn(async move {
313                    let _ = connection.await;
314                });
315                if client.simple_query("SELECT 1").await.is_ok() {
316                    return Ok(());
317                }
318            }
319        }
320
321        Err(RuntimeError::ContainerStartFailed(
322            "postgres container did not become ready within 20 seconds".to_string(),
323        ))
324    }
325
326    async fn wait_for_mysql(
327        &self,
328        username: &str,
329        password: &str,
330        db_name: &str,
331        host_port: u16,
332    ) -> Result<(), RuntimeError> {
333        use mysql_async::prelude::*;
334        use mysql_async::OptsBuilder;
335
336        for attempt in 1..=40 {
337            let opts = OptsBuilder::default()
338                .ip_or_hostname("127.0.0.1")
339                .tcp_port(host_port)
340                .user(Some(username))
341                .pass(Some(password))
342                .db_name(Some(db_name));
343
344            match mysql_async::Conn::new(opts).await {
345                Ok(mut conn) => {
346                    if conn.query_drop("SELECT 1").await.is_ok() {
347                        let _ = conn.disconnect().await;
348                        return Ok(());
349                    }
350                }
351                Err(_) => {
352                    if attempt < 40 {
353                        tokio::time::sleep(Duration::from_millis(500)).await;
354                    }
355                    continue;
356                }
357            }
358        }
359
360        Err(RuntimeError::ContainerStartFailed(
361            "MySQL/MariaDB container did not become ready within 20 seconds".to_string(),
362        ))
363    }
364
365    async fn remove_container(&self, container_id: &str) {
366        let _ = tokio::process::Command::new(&self.cli)
367            .args(["rm", "-f", container_id])
368            .output()
369            .await;
370    }
371
372    pub async fn dump_database(
373        &self,
374        db_instance_identifier: &str,
375        username: &str,
376        db_name: &str,
377    ) -> Result<Vec<u8>, RuntimeError> {
378        let container = self
379            .containers
380            .read()
381            .get(db_instance_identifier)
382            .cloned()
383            .ok_or(RuntimeError::Unavailable)?;
384
385        let output = tokio::process::Command::new(&self.cli)
386            .args([
387                "exec",
388                &container.container_id,
389                "pg_dump",
390                "-U",
391                username,
392                "-d",
393                db_name,
394                "--no-password",
395            ])
396            .output()
397            .await
398            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
399
400        if !output.status.success() {
401            return Err(RuntimeError::ContainerStartFailed(format!(
402                "pg_dump failed: {}",
403                String::from_utf8_lossy(&output.stderr).trim()
404            )));
405        }
406
407        Ok(output.stdout)
408    }
409
410    pub async fn restore_database(
411        &self,
412        db_instance_identifier: &str,
413        username: &str,
414        db_name: &str,
415        dump_data: &[u8],
416    ) -> Result<(), RuntimeError> {
417        let container = self
418            .containers
419            .read()
420            .get(db_instance_identifier)
421            .cloned()
422            .ok_or(RuntimeError::Unavailable)?;
423
424        let mut child = tokio::process::Command::new(&self.cli)
425            .args([
426                "exec",
427                "-i",
428                &container.container_id,
429                "psql",
430                "-U",
431                username,
432                "-d",
433                db_name,
434                "--no-password",
435                "-v",
436                "ON_ERROR_STOP=1",
437            ])
438            .stdin(std::process::Stdio::piped())
439            .stdout(std::process::Stdio::piped())
440            .stderr(std::process::Stdio::piped())
441            .spawn()
442            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
443
444        if let Some(mut stdin) = child.stdin.take() {
445            use tokio::io::AsyncWriteExt;
446            stdin
447                .write_all(dump_data)
448                .await
449                .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
450            drop(stdin);
451        }
452
453        let output = child
454            .wait_with_output()
455            .await
456            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
457
458        if !output.status.success() {
459            return Err(RuntimeError::ContainerStartFailed(format!(
460                "psql restore failed: {}",
461                String::from_utf8_lossy(&output.stderr).trim()
462            )));
463        }
464
465        Ok(())
466    }
467}
468
469fn cli_available(cli: &str) -> bool {
470    std::process::Command::new(cli)
471        .arg("info")
472        .stdout(std::process::Stdio::null())
473        .stderr(std::process::Stdio::null())
474        .status()
475        .map(|status| status.success())
476        .unwrap_or(false)
477}