1use std::collections::HashMap;
2use std::time::Duration;
3
4use parking_lot::RwLock;
5use tokio_postgres::NoTls;
6
7#[derive(Debug, Clone)]
8pub struct RunningDbContainer {
9 pub container_id: String,
10 pub host_port: u16,
11}
12
13pub struct RdsRuntime {
14 cli: String,
15 containers: RwLock<HashMap<String, RunningDbContainer>>,
16 instance_id: String,
17}
18
19#[derive(Debug, thiserror::Error)]
20pub enum RuntimeError {
21 #[error("container runtime is unavailable")]
22 Unavailable,
23 #[error("container failed to start: {0}")]
24 ContainerStartFailed(String),
25}
26
27impl RdsRuntime {
28 pub fn new() -> Option<Self> {
29 let cli = if let Ok(cli) = std::env::var("FAKECLOUD_CONTAINER_CLI") {
30 if cli_available(&cli) {
31 cli
32 } else {
33 return None;
34 }
35 } else if cli_available("docker") {
36 "docker".to_string()
37 } else if cli_available("podman") {
38 "podman".to_string()
39 } else {
40 return None;
41 };
42
43 Some(Self {
44 cli,
45 containers: RwLock::new(HashMap::new()),
46 instance_id: format!("fakecloud-{}", std::process::id()),
47 })
48 }
49
50 pub fn cli_name(&self) -> &str {
51 &self.cli
52 }
53
54 pub async fn ensure_postgres(
55 &self,
56 db_instance_identifier: &str,
57 engine: &str,
58 engine_version: &str,
59 username: &str,
60 password: &str,
61 db_name: &str,
62 ) -> Result<RunningDbContainer, RuntimeError> {
63 self.stop_container(db_instance_identifier).await;
64
65 let (image, port, env_vars) = match engine {
67 "postgres" => {
68 let major_version = engine_version.split('.').next().unwrap_or("16");
69 let image = format!("postgres:{}-alpine", major_version);
70 let env_vars = vec![
71 format!("POSTGRES_USER={username}"),
72 format!("POSTGRES_PASSWORD={password}"),
73 format!("POSTGRES_DB={db_name}"),
74 ];
75 (image, "5432", env_vars)
76 }
77 "mysql" => {
78 let major_version = if engine_version.starts_with("5.7") {
79 "5.7"
80 } else {
81 "8.0"
82 };
83 let image = format!("mysql:{}", major_version);
84 let env_vars = vec![
85 format!("MYSQL_ROOT_PASSWORD={password}"),
86 format!("MYSQL_USER={username}"),
87 format!("MYSQL_PASSWORD={password}"),
88 format!("MYSQL_DATABASE={db_name}"),
89 ];
90 (image, "3306", env_vars)
91 }
92 "mariadb" => {
93 let major_version = if engine_version.starts_with("10.11") {
94 "10.11"
95 } else {
96 "10.6"
97 };
98 let image = format!("mariadb:{}", major_version);
99 let env_vars = vec![
100 format!("MARIADB_ROOT_PASSWORD={password}"),
101 format!("MARIADB_USER={username}"),
102 format!("MARIADB_PASSWORD={password}"),
103 format!("MARIADB_DATABASE={db_name}"),
104 ];
105 (image, "3306", env_vars)
106 }
107 "oracle-ee" | "oracle-se2" | "oracle-ee-cdb" | "oracle-se2-cdb" => {
108 let image = "gvenzl/oracle-free:23-slim".to_string();
113 let env_vars = vec![
114 format!("ORACLE_PASSWORD={password}"),
115 format!("APP_USER={username}"),
116 format!("APP_USER_PASSWORD={password}"),
117 format!("ORACLE_DATABASE={db_name}"),
118 ];
119 (image, "1521", env_vars)
120 }
121 "sqlserver-ee" | "sqlserver-se" | "sqlserver-ex" | "sqlserver-web" => {
122 let image = "mcr.microsoft.com/mssql/server:2022-latest".to_string();
128 let env_vars = vec![
129 "ACCEPT_EULA=Y".to_string(),
130 format!("MSSQL_SA_PASSWORD={password}"),
131 "MSSQL_PID=Express".to_string(),
132 ];
133 (image, "1433", env_vars)
134 }
135 "db2-se" | "db2-ae" => {
136 let image = "icr.io/db2_community/db2:latest".to_string();
140 let env_vars = vec![
141 "LICENSE=accept".to_string(),
142 "DB2INSTANCE=db2inst1".to_string(),
143 format!("DB2INST1_PASSWORD={password}"),
144 format!("DBNAME={db_name}"),
145 ];
146 (image, "50000", env_vars)
147 }
148 _ => {
149 return Err(RuntimeError::ContainerStartFailed(format!(
150 "Unsupported engine: {}",
151 engine
152 )))
153 }
154 };
155
156 let needs_privileged = matches!(engine, "db2-se" | "db2-ae");
158
159 let mut args = vec![
161 "create".to_string(),
162 "-p".to_string(),
163 format!(":{}", port),
164 "--label".to_string(),
165 format!("fakecloud-rds={db_instance_identifier}"),
166 "--label".to_string(),
167 format!("fakecloud-instance={}", self.instance_id),
168 ];
169
170 if needs_privileged {
171 args.push("--privileged".to_string());
172 }
173
174 for env_var in env_vars {
175 args.push("-e".to_string());
176 args.push(env_var);
177 }
178
179 args.push(image);
180
181 let output = tokio::process::Command::new(&self.cli)
182 .args(&args)
183 .output()
184 .await
185 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
186
187 if !output.status.success() {
188 return Err(RuntimeError::ContainerStartFailed(
189 String::from_utf8_lossy(&output.stderr).trim().to_string(),
190 ));
191 }
192
193 let container_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
194 let start_result = tokio::process::Command::new(&self.cli)
195 .args(["start", &container_id])
196 .output()
197 .await
198 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
199
200 if !start_result.status.success() {
201 self.remove_container(&container_id).await;
202 return Err(RuntimeError::ContainerStartFailed(format!(
203 "container start failed: {}",
204 String::from_utf8_lossy(&start_result.stderr).trim()
205 )));
206 }
207
208 let host_port = match self.lookup_port(&container_id, port).await {
209 Ok(host_port) => host_port,
210 Err(error) => {
211 self.remove_container(&container_id).await;
212 return Err(error);
213 }
214 };
215
216 let wait_result = match engine {
218 "postgres" => {
219 self.wait_for_postgres(username, password, db_name, host_port)
220 .await
221 }
222 "mysql" | "mariadb" => {
223 self.wait_for_mysql(username, password, db_name, host_port)
224 .await
225 }
226 "oracle-ee" | "oracle-se2" | "oracle-ee-cdb" | "oracle-se2-cdb" => {
227 self.wait_for_oracle(&container_id, host_port).await
228 }
229 "sqlserver-ee" | "sqlserver-se" | "sqlserver-ex" | "sqlserver-web" => {
230 self.wait_for_sqlserver(&container_id, host_port).await
231 }
232 "db2-se" | "db2-ae" => self.wait_for_db2(&container_id, host_port).await,
233 _ => unreachable!("engine already validated"),
234 };
235
236 if let Err(error) = wait_result {
237 self.remove_container(&container_id).await;
238 return Err(error);
239 }
240
241 let running = RunningDbContainer {
242 container_id,
243 host_port,
244 };
245 self.containers
246 .write()
247 .insert(db_instance_identifier.to_string(), running.clone());
248 Ok(running)
249 }
250
251 pub async fn stop_container(&self, db_instance_identifier: &str) {
252 let container = self.containers.write().remove(db_instance_identifier);
253 if let Some(container) = container {
254 self.remove_container(&container.container_id).await;
255 }
256 }
257
258 pub async fn restart_container(
259 &self,
260 db_instance_identifier: &str,
261 engine: &str,
262 username: &str,
263 password: &str,
264 db_name: &str,
265 ) -> Result<RunningDbContainer, RuntimeError> {
266 let running = self
267 .containers
268 .read()
269 .get(db_instance_identifier)
270 .cloned()
271 .ok_or(RuntimeError::Unavailable)?;
272
273 let output = tokio::process::Command::new(&self.cli)
274 .args(["restart", &running.container_id])
275 .output()
276 .await
277 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
278
279 if !output.status.success() {
280 return Err(RuntimeError::ContainerStartFailed(format!(
281 "container restart failed: {}",
282 String::from_utf8_lossy(&output.stderr).trim()
283 )));
284 }
285
286 let port = match engine {
287 "postgres" => "5432",
288 "mysql" | "mariadb" => "3306",
289 "oracle-ee" | "oracle-se2" | "oracle-ee-cdb" | "oracle-se2-cdb" => "1521",
290 "sqlserver-ee" | "sqlserver-se" | "sqlserver-ex" | "sqlserver-web" => "1433",
291 "db2-se" | "db2-ae" => "50000",
292 _ => "5432", };
294
295 let host_port = self.lookup_port(&running.container_id, port).await?;
296
297 match engine {
298 "postgres" => {
299 self.wait_for_postgres(username, password, db_name, host_port)
300 .await?
301 }
302 "mysql" | "mariadb" => {
303 self.wait_for_mysql(username, password, db_name, host_port)
304 .await?
305 }
306 "oracle-ee" | "oracle-se2" | "oracle-ee-cdb" | "oracle-se2-cdb" => {
307 self.wait_for_oracle(&running.container_id, host_port)
308 .await?
309 }
310 "sqlserver-ee" | "sqlserver-se" | "sqlserver-ex" | "sqlserver-web" => {
311 self.wait_for_sqlserver(&running.container_id, host_port)
312 .await?
313 }
314 "db2-se" | "db2-ae" => self.wait_for_db2(&running.container_id, host_port).await?,
315 _ => {
316 self.wait_for_postgres(username, password, db_name, host_port)
317 .await?
318 }
319 };
320 let running = RunningDbContainer {
321 container_id: running.container_id,
322 host_port,
323 };
324 self.containers
325 .write()
326 .insert(db_instance_identifier.to_string(), running.clone());
327 Ok(running)
328 }
329
330 pub async fn stop_all(&self) {
331 let containers: Vec<String> = {
332 let mut containers = self.containers.write();
333 containers
334 .drain()
335 .map(|(_, container)| container.container_id)
336 .collect()
337 };
338 for container_id in containers {
339 self.remove_container(&container_id).await;
340 }
341 }
342
343 async fn lookup_port(&self, container_id: &str, port: &str) -> Result<u16, RuntimeError> {
344 let port_output = tokio::process::Command::new(&self.cli)
345 .args(["port", container_id, port])
346 .output()
347 .await
348 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
349
350 let port_str = String::from_utf8_lossy(&port_output.stdout);
351 port_str
352 .trim()
353 .rsplit(':')
354 .next()
355 .and_then(|value| value.parse::<u16>().ok())
356 .ok_or_else(|| {
357 RuntimeError::ContainerStartFailed(format!(
358 "could not determine container port from '{}'",
359 port_str.trim()
360 ))
361 })
362 }
363
364 async fn wait_for_postgres(
365 &self,
366 username: &str,
367 password: &str,
368 db_name: &str,
369 host_port: u16,
370 ) -> Result<(), RuntimeError> {
371 for _ in 0..40 {
372 tokio::time::sleep(Duration::from_millis(500)).await;
373 let connection_string = format!(
374 "host=127.0.0.1 port={host_port} user={username} password={password} dbname={db_name}"
375 );
376 if let Ok((client, connection)) =
377 tokio_postgres::connect(&connection_string, NoTls).await
378 {
379 tokio::spawn(async move {
380 let _ = connection.await;
381 });
382 if client.simple_query("SELECT 1").await.is_ok() {
383 return Ok(());
384 }
385 }
386 }
387
388 Err(RuntimeError::ContainerStartFailed(
389 "postgres container did not become ready within 20 seconds".to_string(),
390 ))
391 }
392
393 async fn wait_for_mysql(
394 &self,
395 username: &str,
396 password: &str,
397 db_name: &str,
398 host_port: u16,
399 ) -> Result<(), RuntimeError> {
400 use mysql_async::prelude::*;
401 use mysql_async::OptsBuilder;
402
403 for attempt in 1..=40 {
404 let opts = OptsBuilder::default()
405 .ip_or_hostname("127.0.0.1")
406 .tcp_port(host_port)
407 .user(Some(username))
408 .pass(Some(password))
409 .db_name(Some(db_name));
410
411 match mysql_async::Conn::new(opts).await {
412 Ok(mut conn) => {
413 if conn.query_drop("SELECT 1").await.is_ok() {
414 let _ = conn.disconnect().await;
415 return Ok(());
416 }
417 }
418 Err(_) => {
419 if attempt < 40 {
420 tokio::time::sleep(Duration::from_millis(500)).await;
421 }
422 continue;
423 }
424 }
425 }
426
427 Err(RuntimeError::ContainerStartFailed(
428 "MySQL/MariaDB container did not become ready within 20 seconds".to_string(),
429 ))
430 }
431
432 async fn wait_for_oracle(
438 &self,
439 container_id: &str,
440 host_port: u16,
441 ) -> Result<(), RuntimeError> {
442 self.wait_for_log_marker(container_id, "DATABASE IS READY TO USE!", 240)
443 .await?;
444 self.wait_for_tcp(host_port, 30).await
445 }
446
447 async fn wait_for_sqlserver(
451 &self,
452 container_id: &str,
453 host_port: u16,
454 ) -> Result<(), RuntimeError> {
455 self.wait_for_log_marker(
456 container_id,
457 "SQL Server is now ready for client connections",
458 180,
459 )
460 .await?;
461 self.wait_for_tcp(host_port, 30).await
462 }
463
464 async fn wait_for_db2(&self, container_id: &str, host_port: u16) -> Result<(), RuntimeError> {
468 self.wait_for_log_marker(container_id, "Setup has completed", 360)
469 .await?;
470 self.wait_for_tcp(host_port, 60).await
471 }
472
473 async fn wait_for_log_marker(
476 &self,
477 container_id: &str,
478 marker: &str,
479 deadline_secs: u64,
480 ) -> Result<(), RuntimeError> {
481 let deadline = std::time::Instant::now() + Duration::from_secs(deadline_secs);
482 while std::time::Instant::now() < deadline {
483 let output = tokio::process::Command::new(&self.cli)
484 .args(["logs", container_id])
485 .output()
486 .await
487 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
488 let stdout = String::from_utf8_lossy(&output.stdout);
489 let stderr = String::from_utf8_lossy(&output.stderr);
490 if stdout.contains(marker) || stderr.contains(marker) {
491 return Ok(());
492 }
493 tokio::time::sleep(Duration::from_secs(2)).await;
494 }
495 Err(RuntimeError::ContainerStartFailed(format!(
496 "container did not log '{}' within {} seconds",
497 marker, deadline_secs
498 )))
499 }
500
501 async fn wait_for_tcp(&self, host_port: u16, deadline_secs: u64) -> Result<(), RuntimeError> {
505 let deadline = std::time::Instant::now() + Duration::from_secs(deadline_secs);
506 while std::time::Instant::now() < deadline {
507 if tokio::net::TcpStream::connect(("127.0.0.1", host_port))
508 .await
509 .is_ok()
510 {
511 return Ok(());
512 }
513 tokio::time::sleep(Duration::from_millis(500)).await;
514 }
515 Err(RuntimeError::ContainerStartFailed(format!(
516 "TCP probe to 127.0.0.1:{} did not succeed within {}s",
517 host_port, deadline_secs
518 )))
519 }
520
521 async fn remove_container(&self, container_id: &str) {
522 let _ = tokio::process::Command::new(&self.cli)
523 .args(["rm", "-f", container_id])
524 .output()
525 .await;
526 }
527
528 pub async fn dump_database(
529 &self,
530 db_instance_identifier: &str,
531 engine: &str,
532 username: &str,
533 password: &str,
534 db_name: &str,
535 ) -> Result<Vec<u8>, RuntimeError> {
536 let container = self
537 .containers
538 .read()
539 .get(db_instance_identifier)
540 .cloned()
541 .ok_or(RuntimeError::Unavailable)?;
542
543 let args: Vec<String> = match engine {
544 "mysql" | "mariadb" => vec![
545 "exec".into(),
546 container.container_id.clone(),
547 "mysqldump".into(),
548 "-u".into(),
549 username.into(),
550 format!("-p{password}"),
551 db_name.into(),
552 ],
553 "postgres" => vec![
554 "exec".into(),
555 container.container_id.clone(),
556 "pg_dump".into(),
557 "-U".into(),
558 username.into(),
559 "-d".into(),
560 db_name.into(),
561 "--no-password".into(),
562 ],
563 "oracle-ee" | "oracle-se2" | "oracle-ee-cdb" | "oracle-se2-cdb" | "sqlserver-ee"
570 | "sqlserver-se" | "sqlserver-ex" | "sqlserver-web" | "db2-se" | "db2-ae" => {
571 return Err(RuntimeError::ContainerStartFailed(format!(
572 "engine {engine} is not yet supported by the snapshot/read-replica path; \
573 emulator stores the API state but cannot dump the underlying database"
574 )));
575 }
576 other => {
577 return Err(RuntimeError::ContainerStartFailed(format!(
578 "engine {other} is not supported by dump_database"
579 )));
580 }
581 };
582
583 let output = tokio::process::Command::new(&self.cli)
584 .args(&args)
585 .output()
586 .await
587 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
588
589 if !output.status.success() {
590 return Err(RuntimeError::ContainerStartFailed(format!(
591 "dump failed: {}",
592 String::from_utf8_lossy(&output.stderr).trim()
593 )));
594 }
595
596 Ok(output.stdout)
597 }
598
599 pub async fn restore_database(
600 &self,
601 db_instance_identifier: &str,
602 engine: &str,
603 username: &str,
604 password: &str,
605 db_name: &str,
606 dump_data: &[u8],
607 ) -> Result<(), RuntimeError> {
608 let container = self
609 .containers
610 .read()
611 .get(db_instance_identifier)
612 .cloned()
613 .ok_or(RuntimeError::Unavailable)?;
614
615 let args: Vec<String> = match engine {
616 "mysql" | "mariadb" => vec![
617 "exec".into(),
618 "-i".into(),
619 container.container_id.clone(),
620 "mysql".into(),
621 "-u".into(),
622 username.into(),
623 format!("-p{password}"),
624 db_name.into(),
625 ],
626 "postgres" => vec![
627 "exec".into(),
628 "-i".into(),
629 container.container_id.clone(),
630 "psql".into(),
631 "-U".into(),
632 username.into(),
633 "-d".into(),
634 db_name.into(),
635 "--no-password".into(),
636 "-v".into(),
637 "ON_ERROR_STOP=1".into(),
638 ],
639 "oracle-ee" | "oracle-se2" | "oracle-ee-cdb" | "oracle-se2-cdb" | "sqlserver-ee"
640 | "sqlserver-se" | "sqlserver-ex" | "sqlserver-web" | "db2-se" | "db2-ae" => {
641 return Err(RuntimeError::ContainerStartFailed(format!(
642 "engine {engine} is not yet supported by the snapshot-restore path"
643 )));
644 }
645 other => {
646 return Err(RuntimeError::ContainerStartFailed(format!(
647 "engine {other} is not supported by restore_database"
648 )));
649 }
650 };
651
652 let mut child = tokio::process::Command::new(&self.cli)
653 .args(&args)
654 .stdin(std::process::Stdio::piped())
655 .stdout(std::process::Stdio::piped())
656 .stderr(std::process::Stdio::piped())
657 .spawn()
658 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
659
660 if let Some(mut stdin) = child.stdin.take() {
661 use tokio::io::AsyncWriteExt;
662 stdin
663 .write_all(dump_data)
664 .await
665 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
666 drop(stdin);
667 }
668
669 let output = child
670 .wait_with_output()
671 .await
672 .map_err(|e| RuntimeError::ContainerStartFailed(e.to_string()))?;
673
674 if !output.status.success() {
675 return Err(RuntimeError::ContainerStartFailed(format!(
676 "restore failed: {}",
677 String::from_utf8_lossy(&output.stderr).trim()
678 )));
679 }
680
681 Ok(())
682 }
683}
684
685fn cli_available(cli: &str) -> bool {
686 std::process::Command::new(cli)
687 .arg("info")
688 .stdout(std::process::Stdio::null())
689 .stderr(std::process::Stdio::null())
690 .status()
691 .map(|status| status.success())
692 .unwrap_or(false)
693}