Skip to main content

fakecloud_rds/
runtime.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use parking_lot::RwLock;
5use tokio_postgres::NoTls;
6
7#[derive(Debug, Clone)]
8pub struct RunningDbContainer {
9    pub container_id: String,
10    pub host_port: u16,
11}
12
13pub struct RdsRuntime {
14    cli: String,
15    containers: RwLock<HashMap<String, RunningDbContainer>>,
16    instance_id: String,
17}
18
19#[derive(Debug, thiserror::Error)]
20pub enum RuntimeError {
21    #[error("container runtime is unavailable")]
22    Unavailable,
23    #[error("container failed to start: {0}")]
24    ContainerStartFailed(String),
25}
26
27impl RdsRuntime {
28    pub fn new() -> Option<Self> {
29        let cli = if let Ok(cli) = std::env::var("FAKECLOUD_CONTAINER_CLI") {
30            if cli_available(&cli) {
31                cli
32            } else {
33                return None;
34            }
35        } else if cli_available("docker") {
36            "docker".to_string()
37        } else if cli_available("podman") {
38            "podman".to_string()
39        } else {
40            return None;
41        };
42
43        Some(Self {
44            cli,
45            containers: RwLock::new(HashMap::new()),
46            instance_id: format!("fakecloud-{}", std::process::id()),
47        })
48    }
49
50    pub fn cli_name(&self) -> &str {
51        &self.cli
52    }
53
54    pub async fn ensure_postgres(
55        &self,
56        db_instance_identifier: &str,
57        engine: &str,
58        engine_version: &str,
59        username: &str,
60        password: &str,
61        db_name: &str,
62    ) -> Result<RunningDbContainer, RuntimeError> {
63        self.stop_container(db_instance_identifier).await;
64
65        // Determine Docker image and port based on engine
66        let (image, port, env_vars) = match engine {
67            "postgres" => {
68                let major_version = engine_version.split('.').next().unwrap_or("16");
69                let image = format!("postgres:{}-alpine", major_version);
70                let env_vars = vec![
71                    format!("POSTGRES_USER={username}"),
72                    format!("POSTGRES_PASSWORD={password}"),
73                    format!("POSTGRES_DB={db_name}"),
74                ];
75                (image, "5432", env_vars)
76            }
77            "mysql" => {
78                let major_version = if engine_version.starts_with("5.7") {
79                    "5.7"
80                } else {
81                    "8.0"
82                };
83                let image = format!("mysql:{}", major_version);
84                let env_vars = vec![
85                    format!("MYSQL_ROOT_PASSWORD={password}"),
86                    format!("MYSQL_USER={username}"),
87                    format!("MYSQL_PASSWORD={password}"),
88                    format!("MYSQL_DATABASE={db_name}"),
89                ];
90                (image, "3306", env_vars)
91            }
92            "mariadb" => {
93                let major_version = if engine_version.starts_with("10.11") {
94                    "10.11"
95                } else {
96                    "10.6"
97                };
98                let image = format!("mariadb:{}", major_version);
99                let env_vars = vec![
100                    format!("MARIADB_ROOT_PASSWORD={password}"),
101                    format!("MARIADB_USER={username}"),
102                    format!("MARIADB_PASSWORD={password}"),
103                    format!("MARIADB_DATABASE={db_name}"),
104                ];
105                (image, "3306", env_vars)
106            }
107            "oracle-ee" | "oracle-se2" | "oracle-ee-cdb" | "oracle-se2-cdb" => {
108                // Oracle Database Free is the no-cost dev edition shipped by
109                // Oracle. The container exposes a "FREEPDB1" pluggable
110                // database and creates the SYSTEM user with the password
111                // from ORACLE_PASSWORD.
112                let image = "gvenzl/oracle-free:23-slim".to_string();
113                let env_vars = vec![
114                    format!("ORACLE_PASSWORD={password}"),
115                    format!("APP_USER={username}"),
116                    format!("APP_USER_PASSWORD={password}"),
117                    format!("ORACLE_DATABASE={db_name}"),
118                ];
119                (image, "1521", env_vars)
120            }
121            "sqlserver-ee" | "sqlserver-se" | "sqlserver-ex" | "sqlserver-web" => {
122                // SQL Server Express is free for dev/test with no license
123                // ceiling. SA password must satisfy MSSQL's complexity
124                // requirements (>=8 chars, mix of classes); callers should
125                // supply one that does or the container will refuse to
126                // start.
127                let image = "mcr.microsoft.com/mssql/server:2022-latest".to_string();
128                let env_vars = vec![
129                    "ACCEPT_EULA=Y".to_string(),
130                    format!("MSSQL_SA_PASSWORD={password}"),
131                    "MSSQL_PID=Express".to_string(),
132                ];
133                (image, "1433", env_vars)
134            }
135            "db2-se" | "db2-ae" => {
136                // Db2 Community Edition is free under the standard IBM
137                // Community License. The container exposes a single
138                // database named after DBNAME, owned by db2inst1.
139                let image = "icr.io/db2_community/db2:latest".to_string();
140                let env_vars = vec![
141                    "LICENSE=accept".to_string(),
142                    "DB2INSTANCE=db2inst1".to_string(),
143                    format!("DB2INST1_PASSWORD={password}"),
144                    format!("DBNAME={db_name}"),
145                ];
146                (image, "50000", env_vars)
147            }
148            _ => {
149                return Err(RuntimeError::ContainerStartFailed(format!(
150                    "Unsupported engine: {}",
151                    engine
152                )))
153            }
154        };
155
156        // Db2 needs --privileged to set kernel parameters during startup.
157        let needs_privileged = matches!(engine, "db2-se" | "db2-ae");
158
159        // Build container create args
160        let mut args = vec![
161            "create".to_string(),
162            "-p".to_string(),
163            format!(":{}", port),
164            "--label".to_string(),
165            format!("fakecloud-rds={db_instance_identifier}"),
166            "--label".to_string(),
167            format!("fakecloud-instance={}", self.instance_id),
168        ];
169
170        if needs_privileged {
171            args.push("--privileged".to_string());
172        }
173
174        for env_var in env_vars {
175            args.push("-e".to_string());
176            args.push(env_var);
177        }
178
179        args.push(image);
180
181        let output = tokio::process::Command::new(&self.cli)
182            .args(&args)
183            .output()
184            .await
185            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
186
187        if !output.status.success() {
188            return Err(RuntimeError::ContainerStartFailed(
189                String::from_utf8_lossy(&output.stderr).trim().to_string(),
190            ));
191        }
192
193        let container_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
194        let start_result = tokio::process::Command::new(&self.cli)
195            .args(["start", &container_id])
196            .output()
197            .await
198            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
199
200        if !start_result.status.success() {
201            self.remove_container(&container_id).await;
202            return Err(RuntimeError::ContainerStartFailed(format!(
203                "container start failed: {}",
204                String::from_utf8_lossy(&start_result.stderr).trim()
205            )));
206        }
207
208        let host_port = match self.lookup_port(&container_id, port).await {
209            Ok(host_port) => host_port,
210            Err(error) => {
211                self.remove_container(&container_id).await;
212                return Err(error);
213            }
214        };
215
216        // Wait for database to be ready
217        let wait_result = match engine {
218            "postgres" => {
219                self.wait_for_postgres(username, password, db_name, host_port)
220                    .await
221            }
222            "mysql" | "mariadb" => {
223                self.wait_for_mysql(username, password, db_name, host_port)
224                    .await
225            }
226            "oracle-ee" | "oracle-se2" | "oracle-ee-cdb" | "oracle-se2-cdb" => {
227                self.wait_for_oracle(&container_id, host_port).await
228            }
229            "sqlserver-ee" | "sqlserver-se" | "sqlserver-ex" | "sqlserver-web" => {
230                self.wait_for_sqlserver(&container_id, host_port).await
231            }
232            "db2-se" | "db2-ae" => self.wait_for_db2(&container_id, host_port).await,
233            _ => unreachable!("engine already validated"),
234        };
235
236        if let Err(error) = wait_result {
237            self.remove_container(&container_id).await;
238            return Err(error);
239        }
240
241        let running = RunningDbContainer {
242            container_id,
243            host_port,
244        };
245        self.containers
246            .write()
247            .insert(db_instance_identifier.to_string(), running.clone());
248        Ok(running)
249    }
250
251    pub async fn stop_container(&self, db_instance_identifier: &str) {
252        let container = self.containers.write().remove(db_instance_identifier);
253        if let Some(container) = container {
254            self.remove_container(&container.container_id).await;
255        }
256    }
257
258    pub async fn restart_container(
259        &self,
260        db_instance_identifier: &str,
261        engine: &str,
262        username: &str,
263        password: &str,
264        db_name: &str,
265    ) -> Result<RunningDbContainer, RuntimeError> {
266        let running = self
267            .containers
268            .read()
269            .get(db_instance_identifier)
270            .cloned()
271            .ok_or(RuntimeError::Unavailable)?;
272
273        let output = tokio::process::Command::new(&self.cli)
274            .args(["restart", &running.container_id])
275            .output()
276            .await
277            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
278
279        if !output.status.success() {
280            return Err(RuntimeError::ContainerStartFailed(format!(
281                "container restart failed: {}",
282                String::from_utf8_lossy(&output.stderr).trim()
283            )));
284        }
285
286        let port = match engine {
287            "postgres" => "5432",
288            "mysql" | "mariadb" => "3306",
289            "oracle-ee" | "oracle-se2" | "oracle-ee-cdb" | "oracle-se2-cdb" => "1521",
290            "sqlserver-ee" | "sqlserver-se" | "sqlserver-ex" | "sqlserver-web" => "1433",
291            "db2-se" | "db2-ae" => "50000",
292            _ => "5432", // fallback
293        };
294
295        let host_port = self.lookup_port(&running.container_id, port).await?;
296
297        match engine {
298            "postgres" => {
299                self.wait_for_postgres(username, password, db_name, host_port)
300                    .await?
301            }
302            "mysql" | "mariadb" => {
303                self.wait_for_mysql(username, password, db_name, host_port)
304                    .await?
305            }
306            "oracle-ee" | "oracle-se2" | "oracle-ee-cdb" | "oracle-se2-cdb" => {
307                self.wait_for_oracle(&running.container_id, host_port)
308                    .await?
309            }
310            "sqlserver-ee" | "sqlserver-se" | "sqlserver-ex" | "sqlserver-web" => {
311                self.wait_for_sqlserver(&running.container_id, host_port)
312                    .await?
313            }
314            "db2-se" | "db2-ae" => self.wait_for_db2(&running.container_id, host_port).await?,
315            _ => {
316                self.wait_for_postgres(username, password, db_name, host_port)
317                    .await?
318            }
319        };
320        let running = RunningDbContainer {
321            container_id: running.container_id,
322            host_port,
323        };
324        self.containers
325            .write()
326            .insert(db_instance_identifier.to_string(), running.clone());
327        Ok(running)
328    }
329
330    pub async fn stop_all(&self) {
331        let containers: Vec<String> = {
332            let mut containers = self.containers.write();
333            containers
334                .drain()
335                .map(|(_, container)| container.container_id)
336                .collect()
337        };
338        for container_id in containers {
339            self.remove_container(&container_id).await;
340        }
341    }
342
343    async fn lookup_port(&self, container_id: &str, port: &str) -> Result<u16, RuntimeError> {
344        let port_output = tokio::process::Command::new(&self.cli)
345            .args(["port", container_id, port])
346            .output()
347            .await
348            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
349
350        let port_str = String::from_utf8_lossy(&port_output.stdout);
351        port_str
352            .trim()
353            .rsplit(':')
354            .next()
355            .and_then(|value| value.parse::<u16>().ok())
356            .ok_or_else(|| {
357                RuntimeError::ContainerStartFailed(format!(
358                    "could not determine container port from '{}'",
359                    port_str.trim()
360                ))
361            })
362    }
363
364    async fn wait_for_postgres(
365        &self,
366        username: &str,
367        password: &str,
368        db_name: &str,
369        host_port: u16,
370    ) -> Result<(), RuntimeError> {
371        for _ in 0..40 {
372            tokio::time::sleep(Duration::from_millis(500)).await;
373            let connection_string = format!(
374                "host=127.0.0.1 port={host_port} user={username} password={password} dbname={db_name}"
375            );
376            if let Ok((client, connection)) =
377                tokio_postgres::connect(&connection_string, NoTls).await
378            {
379                tokio::spawn(async move {
380                    let _ = connection.await;
381                });
382                if client.simple_query("SELECT 1").await.is_ok() {
383                    return Ok(());
384                }
385            }
386        }
387
388        Err(RuntimeError::ContainerStartFailed(
389            "postgres container did not become ready within 20 seconds".to_string(),
390        ))
391    }
392
393    async fn wait_for_mysql(
394        &self,
395        username: &str,
396        password: &str,
397        db_name: &str,
398        host_port: u16,
399    ) -> Result<(), RuntimeError> {
400        use mysql_async::prelude::*;
401        use mysql_async::OptsBuilder;
402
403        for attempt in 1..=40 {
404            let opts = OptsBuilder::default()
405                .ip_or_hostname("127.0.0.1")
406                .tcp_port(host_port)
407                .user(Some(username))
408                .pass(Some(password))
409                .db_name(Some(db_name));
410
411            match mysql_async::Conn::new(opts).await {
412                Ok(mut conn) => {
413                    if conn.query_drop("SELECT 1").await.is_ok() {
414                        let _ = conn.disconnect().await;
415                        return Ok(());
416                    }
417                }
418                Err(_) => {
419                    if attempt < 40 {
420                        tokio::time::sleep(Duration::from_millis(500)).await;
421                    }
422                    continue;
423                }
424            }
425        }
426
427        Err(RuntimeError::ContainerStartFailed(
428            "MySQL/MariaDB container did not become ready within 20 seconds".to_string(),
429        ))
430    }
431
432    /// Wait for Oracle Database Free to finish bootstrapping. The
433    /// `gvenzl/oracle-free` image prints `DATABASE IS READY TO USE!`
434    /// to stdout once the listener accepts connections, so we poll
435    /// `docker logs` until that marker appears (or the deadline elapses).
436    /// Oracle XE/Free typically takes 30-90 seconds on first start.
437    async fn wait_for_oracle(
438        &self,
439        container_id: &str,
440        host_port: u16,
441    ) -> Result<(), RuntimeError> {
442        self.wait_for_log_marker(container_id, "DATABASE IS READY TO USE!", 240)
443            .await?;
444        self.wait_for_tcp(host_port, 30).await
445    }
446
447    /// Wait for SQL Server to be ready. The official mssql/server image
448    /// emits `SQL Server is now ready for client connections.` once it
449    /// accepts TCP connections on 1433.
450    async fn wait_for_sqlserver(
451        &self,
452        container_id: &str,
453        host_port: u16,
454    ) -> Result<(), RuntimeError> {
455        self.wait_for_log_marker(
456            container_id,
457            "SQL Server is now ready for client connections",
458            180,
459        )
460        .await?;
461        self.wait_for_tcp(host_port, 30).await
462    }
463
464    /// Wait for Db2 Community Edition to finish setup. The
465    /// `icr.io/db2_community/db2` image prints `Setup has completed.`
466    /// once the instance is up and the database has been created.
467    async fn wait_for_db2(&self, container_id: &str, host_port: u16) -> Result<(), RuntimeError> {
468        self.wait_for_log_marker(container_id, "Setup has completed", 360)
469            .await?;
470        self.wait_for_tcp(host_port, 60).await
471    }
472
473    /// Poll `docker logs <container>` until the supplied marker appears
474    /// in stdout or stderr. `deadline_secs` caps total wait.
475    async fn wait_for_log_marker(
476        &self,
477        container_id: &str,
478        marker: &str,
479        deadline_secs: u64,
480    ) -> Result<(), RuntimeError> {
481        let deadline = std::time::Instant::now() + Duration::from_secs(deadline_secs);
482        while std::time::Instant::now() < deadline {
483            let output = tokio::process::Command::new(&self.cli)
484                .args(["logs", container_id])
485                .output()
486                .await
487                .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
488            let stdout = String::from_utf8_lossy(&output.stdout);
489            let stderr = String::from_utf8_lossy(&output.stderr);
490            if stdout.contains(marker) || stderr.contains(marker) {
491                return Ok(());
492            }
493            tokio::time::sleep(Duration::from_secs(2)).await;
494        }
495        Err(RuntimeError::ContainerStartFailed(format!(
496            "container did not log '{}' within {} seconds",
497            marker, deadline_secs
498        )))
499    }
500
501    /// TCP-probe the host port until it accepts a connection or the
502    /// deadline elapses. Use after `wait_for_log_marker` since the
503    /// listener may bind a moment after the readiness log line.
504    async fn wait_for_tcp(&self, host_port: u16, deadline_secs: u64) -> Result<(), RuntimeError> {
505        let deadline = std::time::Instant::now() + Duration::from_secs(deadline_secs);
506        while std::time::Instant::now() < deadline {
507            if tokio::net::TcpStream::connect(("127.0.0.1", host_port))
508                .await
509                .is_ok()
510            {
511                return Ok(());
512            }
513            tokio::time::sleep(Duration::from_millis(500)).await;
514        }
515        Err(RuntimeError::ContainerStartFailed(format!(
516            "TCP probe to 127.0.0.1:{} did not succeed within {}s",
517            host_port, deadline_secs
518        )))
519    }
520
521    async fn remove_container(&self, container_id: &str) {
522        let _ = tokio::process::Command::new(&self.cli)
523            .args(["rm", "-f", container_id])
524            .output()
525            .await;
526    }
527
528    pub async fn dump_database(
529        &self,
530        db_instance_identifier: &str,
531        engine: &str,
532        username: &str,
533        password: &str,
534        db_name: &str,
535    ) -> Result<Vec<u8>, RuntimeError> {
536        let container = self
537            .containers
538            .read()
539            .get(db_instance_identifier)
540            .cloned()
541            .ok_or(RuntimeError::Unavailable)?;
542
543        let args: Vec<String> = match engine {
544            "mysql" | "mariadb" => vec![
545                "exec".into(),
546                container.container_id.clone(),
547                "mysqldump".into(),
548                "-u".into(),
549                username.into(),
550                format!("-p{password}"),
551                db_name.into(),
552            ],
553            "postgres" => vec![
554                "exec".into(),
555                container.container_id.clone(),
556                "pg_dump".into(),
557                "-U".into(),
558                username.into(),
559                "-d".into(),
560                db_name.into(),
561                "--no-password".into(),
562            ],
563            // Heavy engines don't ship with a portable dump CLI we can
564            // shell out to from the host the same way pg_dump and
565            // mysqldump are guaranteed available. Surface a clear
566            // error so callers (snapshot/read-replica) don't silently
567            // run the wrong tool against an Oracle/SQL Server/Db2
568            // container.
569            "oracle-ee" | "oracle-se2" | "oracle-ee-cdb" | "oracle-se2-cdb" | "sqlserver-ee"
570            | "sqlserver-se" | "sqlserver-ex" | "sqlserver-web" | "db2-se" | "db2-ae" => {
571                return Err(RuntimeError::ContainerStartFailed(format!(
572                    "engine {engine} is not yet supported by the snapshot/read-replica path; \
573                     emulator stores the API state but cannot dump the underlying database"
574                )));
575            }
576            other => {
577                return Err(RuntimeError::ContainerStartFailed(format!(
578                    "engine {other} is not supported by dump_database"
579                )));
580            }
581        };
582
583        let output = tokio::process::Command::new(&self.cli)
584            .args(&args)
585            .output()
586            .await
587            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
588
589        if !output.status.success() {
590            return Err(RuntimeError::ContainerStartFailed(format!(
591                "dump failed: {}",
592                String::from_utf8_lossy(&output.stderr).trim()
593            )));
594        }
595
596        Ok(output.stdout)
597    }
598
599    pub async fn restore_database(
600        &self,
601        db_instance_identifier: &str,
602        engine: &str,
603        username: &str,
604        password: &str,
605        db_name: &str,
606        dump_data: &[u8],
607    ) -> Result<(), RuntimeError> {
608        let container = self
609            .containers
610            .read()
611            .get(db_instance_identifier)
612            .cloned()
613            .ok_or(RuntimeError::Unavailable)?;
614
615        let args: Vec<String> = match engine {
616            "mysql" | "mariadb" => vec![
617                "exec".into(),
618                "-i".into(),
619                container.container_id.clone(),
620                "mysql".into(),
621                "-u".into(),
622                username.into(),
623                format!("-p{password}"),
624                db_name.into(),
625            ],
626            "postgres" => vec![
627                "exec".into(),
628                "-i".into(),
629                container.container_id.clone(),
630                "psql".into(),
631                "-U".into(),
632                username.into(),
633                "-d".into(),
634                db_name.into(),
635                "--no-password".into(),
636                "-v".into(),
637                "ON_ERROR_STOP=1".into(),
638            ],
639            "oracle-ee" | "oracle-se2" | "oracle-ee-cdb" | "oracle-se2-cdb" | "sqlserver-ee"
640            | "sqlserver-se" | "sqlserver-ex" | "sqlserver-web" | "db2-se" | "db2-ae" => {
641                return Err(RuntimeError::ContainerStartFailed(format!(
642                    "engine {engine} is not yet supported by the snapshot-restore path"
643                )));
644            }
645            other => {
646                return Err(RuntimeError::ContainerStartFailed(format!(
647                    "engine {other} is not supported by restore_database"
648                )));
649            }
650        };
651
652        let mut child = tokio::process::Command::new(&self.cli)
653            .args(&args)
654            .stdin(std::process::Stdio::piped())
655            .stdout(std::process::Stdio::piped())
656            .stderr(std::process::Stdio::piped())
657            .spawn()
658            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
659
660        if let Some(mut stdin) = child.stdin.take() {
661            use tokio::io::AsyncWriteExt;
662            stdin
663                .write_all(dump_data)
664                .await
665                .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
666            drop(stdin);
667        }
668
669        let output = child
670            .wait_with_output()
671            .await
672            .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
673
674        if !output.status.success() {
675            return Err(RuntimeError::ContainerStartFailed(format!(
676                "restore failed: {}",
677                String::from_utf8_lossy(&output.stderr).trim()
678            )));
679        }
680
681        Ok(())
682    }
683}
684
685fn cli_available(cli: &str) -> bool {
686    std::process::Command::new(cli)
687        .arg("info")
688        .stdout(std::process::Stdio::null())
689        .stderr(std::process::Stdio::null())
690        .status()
691        .map(|status| status.success())
692        .unwrap_or(false)
693}