Skip to main content

fakecloud_rds/
state.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use chrono::{DateTime, Utc};
6use fakecloud_aws::arn::Arn;
7use parking_lot::RwLock;
8use uuid::Uuid;
9
10pub type SharedRdsState = Arc<RwLock<RdsState>>;
11
12/// Supported DB instance classes — single source of truth.
13pub const SUPPORTED_INSTANCE_CLASSES: &[&str] = &[
14    "db.t3.micro",
15    "db.t3.small",
16    "db.t3.medium",
17    "db.t3.large",
18    "db.t4g.micro",
19    "db.t4g.small",
20    "db.m5.large",
21];
22
23#[derive(Clone)]
24pub struct DbInstance {
25    pub db_instance_identifier: String,
26    pub db_instance_arn: String,
27    pub db_instance_class: String,
28    pub engine: String,
29    pub engine_version: String,
30    pub db_instance_status: String,
31    pub master_username: String,
32    pub db_name: Option<String>,
33    pub endpoint_address: String,
34    pub port: i32,
35    pub allocated_storage: i32,
36    pub publicly_accessible: bool,
37    pub deletion_protection: bool,
38    pub created_at: DateTime<Utc>,
39    pub dbi_resource_id: String,
40    pub master_user_password: String,
41    pub container_id: String,
42    pub host_port: u16,
43    pub tags: Vec<RdsTag>,
44    pub read_replica_source_db_instance_identifier: Option<String>,
45    pub read_replica_db_instance_identifiers: Vec<String>,
46    pub vpc_security_group_ids: Vec<String>,
47    pub db_parameter_group_name: Option<String>,
48    pub backup_retention_period: i32,
49    pub preferred_backup_window: String,
50    pub latest_restorable_time: Option<DateTime<Utc>>,
51    pub option_group_name: Option<String>,
52    pub multi_az: bool,
53    pub pending_modified_values: Option<PendingModifiedValues>,
54}
55
56#[derive(Clone, Default)]
57pub struct PendingModifiedValues {
58    pub db_instance_class: Option<String>,
59    pub allocated_storage: Option<i32>,
60    pub backup_retention_period: Option<i32>,
61    pub multi_az: Option<bool>,
62    pub engine_version: Option<String>,
63    pub master_user_password: Option<String>,
64}
65
66impl fmt::Debug for PendingModifiedValues {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        f.debug_struct("PendingModifiedValues")
69            .field("db_instance_class", &self.db_instance_class)
70            .field("allocated_storage", &self.allocated_storage)
71            .field("backup_retention_period", &self.backup_retention_period)
72            .field("multi_az", &self.multi_az)
73            .field("engine_version", &self.engine_version)
74            .field(
75                "master_user_password",
76                &self.master_user_password.as_ref().map(|_| "<redacted>"),
77            )
78            .finish()
79    }
80}
81
82impl fmt::Debug for DbInstance {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        f.debug_struct("DbInstance")
85            .field("db_instance_identifier", &self.db_instance_identifier)
86            .field("db_instance_arn", &self.db_instance_arn)
87            .field("db_instance_class", &self.db_instance_class)
88            .field("engine", &self.engine)
89            .field("engine_version", &self.engine_version)
90            .field("db_instance_status", &self.db_instance_status)
91            .field("master_username", &self.master_username)
92            .field("db_name", &self.db_name)
93            .field("endpoint_address", &self.endpoint_address)
94            .field("port", &self.port)
95            .field("allocated_storage", &self.allocated_storage)
96            .field("publicly_accessible", &self.publicly_accessible)
97            .field("deletion_protection", &self.deletion_protection)
98            .field("created_at", &self.created_at)
99            .field("dbi_resource_id", &self.dbi_resource_id)
100            .field("master_user_password", &"<redacted>")
101            .field("container_id", &self.container_id)
102            .field("host_port", &self.host_port)
103            .field("tags", &self.tags)
104            .field(
105                "read_replica_source_db_instance_identifier",
106                &self.read_replica_source_db_instance_identifier,
107            )
108            .field(
109                "read_replica_db_instance_identifiers",
110                &self.read_replica_db_instance_identifiers,
111            )
112            .field("vpc_security_group_ids", &self.vpc_security_group_ids)
113            .field("db_parameter_group_name", &self.db_parameter_group_name)
114            .field("backup_retention_period", &self.backup_retention_period)
115            .field("preferred_backup_window", &self.preferred_backup_window)
116            .field("latest_restorable_time", &self.latest_restorable_time)
117            .field("option_group_name", &self.option_group_name)
118            .field("multi_az", &self.multi_az)
119            .field("pending_modified_values", &self.pending_modified_values)
120            .finish()
121    }
122}
123
124#[derive(Debug, Clone, PartialEq, Eq)]
125pub struct RdsTag {
126    pub key: String,
127    pub value: String,
128}
129
130#[derive(Clone)]
131pub struct DbSnapshot {
132    pub db_snapshot_identifier: String,
133    pub db_snapshot_arn: String,
134    pub db_instance_identifier: String,
135    pub snapshot_create_time: DateTime<Utc>,
136    pub engine: String,
137    pub engine_version: String,
138    pub allocated_storage: i32,
139    pub status: String,
140    pub port: i32,
141    pub master_username: String,
142    pub db_name: Option<String>,
143    pub dbi_resource_id: String,
144    pub snapshot_type: String,
145    pub master_user_password: String,
146    pub tags: Vec<RdsTag>,
147    pub dump_data: Vec<u8>,
148}
149
150impl fmt::Debug for DbSnapshot {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        f.debug_struct("DbSnapshot")
153            .field("db_snapshot_identifier", &self.db_snapshot_identifier)
154            .field("db_snapshot_arn", &self.db_snapshot_arn)
155            .field("db_instance_identifier", &self.db_instance_identifier)
156            .field("snapshot_create_time", &self.snapshot_create_time)
157            .field("engine", &self.engine)
158            .field("engine_version", &self.engine_version)
159            .field("allocated_storage", &self.allocated_storage)
160            .field("status", &self.status)
161            .field("port", &self.port)
162            .field("master_username", &self.master_username)
163            .field("db_name", &self.db_name)
164            .field("dbi_resource_id", &self.dbi_resource_id)
165            .field("snapshot_type", &self.snapshot_type)
166            .field("master_user_password", &"<redacted>")
167            .field("tags", &self.tags)
168            .field("dump_data", &format!("<{} bytes>", self.dump_data.len()))
169            .finish()
170    }
171}
172
173#[derive(Debug)]
174pub struct RdsState {
175    pub account_id: String,
176    pub region: String,
177    pub instances: HashMap<String, DbInstance>,
178    pub in_progress_instance_ids: HashSet<String>,
179    pub snapshots: HashMap<String, DbSnapshot>,
180    pub subnet_groups: HashMap<String, DbSubnetGroup>,
181    pub parameter_groups: HashMap<String, DbParameterGroup>,
182}
183
184#[derive(Debug, Clone, PartialEq, Eq)]
185pub struct EngineVersionInfo {
186    pub engine: String,
187    pub engine_version: String,
188    pub db_parameter_group_family: String,
189    pub db_engine_description: String,
190    pub db_engine_version_description: String,
191    pub status: String,
192}
193
194#[derive(Debug, Clone, PartialEq, Eq)]
195pub struct OrderableDbInstanceOption {
196    pub engine: String,
197    pub engine_version: String,
198    pub db_instance_class: String,
199    pub license_model: String,
200    pub storage_type: String,
201    pub min_storage_size: i32,
202    pub max_storage_size: i32,
203}
204
205#[derive(Debug, Clone)]
206pub struct DbSubnetGroup {
207    pub db_subnet_group_name: String,
208    pub db_subnet_group_arn: String,
209    pub db_subnet_group_description: String,
210    pub vpc_id: String,
211    pub subnet_ids: Vec<String>,
212    pub subnet_availability_zones: Vec<String>,
213    pub tags: Vec<RdsTag>,
214}
215
216#[derive(Debug, Clone)]
217pub struct DbParameterGroup {
218    pub db_parameter_group_name: String,
219    pub db_parameter_group_arn: String,
220    pub db_parameter_group_family: String,
221    pub description: String,
222    pub parameters: HashMap<String, String>,
223    pub tags: Vec<RdsTag>,
224}
225
226impl RdsState {
227    pub fn new(account_id: &str, region: &str) -> Self {
228        Self {
229            account_id: account_id.to_string(),
230            region: region.to_string(),
231            instances: HashMap::new(),
232            in_progress_instance_ids: HashSet::new(),
233            snapshots: HashMap::new(),
234            subnet_groups: HashMap::new(),
235            parameter_groups: default_parameter_groups(account_id, region),
236        }
237    }
238
239    pub fn reset(&mut self) {
240        self.instances.clear();
241        self.in_progress_instance_ids.clear();
242        self.snapshots.clear();
243        self.subnet_groups.clear();
244        self.parameter_groups = default_parameter_groups(&self.account_id, &self.region);
245    }
246
247    pub fn db_instance_arn(&self, db_instance_identifier: &str) -> String {
248        Arn::new(
249            "rds",
250            &self.region,
251            &self.account_id,
252            &format!("db:{db_instance_identifier}"),
253        )
254        .to_string()
255    }
256
257    pub fn db_snapshot_arn(&self, db_snapshot_identifier: &str) -> String {
258        Arn::new(
259            "rds",
260            &self.region,
261            &self.account_id,
262            &format!("snapshot:{db_snapshot_identifier}"),
263        )
264        .to_string()
265    }
266
267    pub fn db_subnet_group_arn(&self, db_subnet_group_name: &str) -> String {
268        Arn::new(
269            "rds",
270            &self.region,
271            &self.account_id,
272            &format!("subgrp:{db_subnet_group_name}"),
273        )
274        .to_string()
275    }
276
277    pub fn db_parameter_group_arn(&self, db_parameter_group_name: &str) -> String {
278        Arn::new(
279            "rds",
280            &self.region,
281            &self.account_id,
282            &format!("pg:{db_parameter_group_name}"),
283        )
284        .to_string()
285    }
286
287    pub fn next_dbi_resource_id(&self) -> String {
288        format!("db-{}", Uuid::new_v4().simple())
289    }
290
291    pub fn begin_instance_creation(&mut self, db_instance_identifier: &str) -> bool {
292        if self.instances.contains_key(db_instance_identifier)
293            || self
294                .in_progress_instance_ids
295                .contains(db_instance_identifier)
296        {
297            return false;
298        }
299
300        self.in_progress_instance_ids
301            .insert(db_instance_identifier.to_string());
302        true
303    }
304
305    pub fn finish_instance_creation(&mut self, instance: DbInstance) {
306        self.in_progress_instance_ids
307            .remove(&instance.db_instance_identifier);
308        self.instances
309            .insert(instance.db_instance_identifier.clone(), instance);
310    }
311
312    pub fn cancel_instance_creation(&mut self, db_instance_identifier: &str) {
313        self.in_progress_instance_ids.remove(db_instance_identifier);
314    }
315}
316
317pub fn default_engine_versions() -> Vec<EngineVersionInfo> {
318    vec![
319        // PostgreSQL versions
320        EngineVersionInfo {
321            engine: "postgres".to_string(),
322            engine_version: "16.3".to_string(),
323            db_parameter_group_family: "postgres16".to_string(),
324            db_engine_description: "PostgreSQL".to_string(),
325            db_engine_version_description: "PostgreSQL 16.3".to_string(),
326            status: "available".to_string(),
327        },
328        EngineVersionInfo {
329            engine: "postgres".to_string(),
330            engine_version: "15.5".to_string(),
331            db_parameter_group_family: "postgres15".to_string(),
332            db_engine_description: "PostgreSQL".to_string(),
333            db_engine_version_description: "PostgreSQL 15.5".to_string(),
334            status: "available".to_string(),
335        },
336        EngineVersionInfo {
337            engine: "postgres".to_string(),
338            engine_version: "14.10".to_string(),
339            db_parameter_group_family: "postgres14".to_string(),
340            db_engine_description: "PostgreSQL".to_string(),
341            db_engine_version_description: "PostgreSQL 14.10".to_string(),
342            status: "available".to_string(),
343        },
344        EngineVersionInfo {
345            engine: "postgres".to_string(),
346            engine_version: "13.13".to_string(),
347            db_parameter_group_family: "postgres13".to_string(),
348            db_engine_description: "PostgreSQL".to_string(),
349            db_engine_version_description: "PostgreSQL 13.13".to_string(),
350            status: "available".to_string(),
351        },
352        // MySQL versions
353        EngineVersionInfo {
354            engine: "mysql".to_string(),
355            engine_version: "8.0.35".to_string(),
356            db_parameter_group_family: "mysql8.0".to_string(),
357            db_engine_description: "MySQL Community Edition".to_string(),
358            db_engine_version_description: "MySQL 8.0.35".to_string(),
359            status: "available".to_string(),
360        },
361        EngineVersionInfo {
362            engine: "mysql".to_string(),
363            engine_version: "8.0.28".to_string(),
364            db_parameter_group_family: "mysql8.0".to_string(),
365            db_engine_description: "MySQL Community Edition".to_string(),
366            db_engine_version_description: "MySQL 8.0.28".to_string(),
367            status: "available".to_string(),
368        },
369        EngineVersionInfo {
370            engine: "mysql".to_string(),
371            engine_version: "5.7.44".to_string(),
372            db_parameter_group_family: "mysql5.7".to_string(),
373            db_engine_description: "MySQL Community Edition".to_string(),
374            db_engine_version_description: "MySQL 5.7.44".to_string(),
375            status: "available".to_string(),
376        },
377        // MariaDB versions
378        EngineVersionInfo {
379            engine: "mariadb".to_string(),
380            engine_version: "10.11.6".to_string(),
381            db_parameter_group_family: "mariadb10.11".to_string(),
382            db_engine_description: "MariaDB Community Edition".to_string(),
383            db_engine_version_description: "MariaDB 10.11.6".to_string(),
384            status: "available".to_string(),
385        },
386        EngineVersionInfo {
387            engine: "mariadb".to_string(),
388            engine_version: "10.6.16".to_string(),
389            db_parameter_group_family: "mariadb10.6".to_string(),
390            db_engine_description: "MariaDB Community Edition".to_string(),
391            db_engine_version_description: "MariaDB 10.6.16".to_string(),
392            status: "available".to_string(),
393        },
394    ]
395}
396
397pub fn default_orderable_options() -> Vec<OrderableDbInstanceOption> {
398    let mut options = Vec::new();
399    let engines_and_versions = vec![
400        ("postgres", "16.3", "postgresql-license"),
401        ("postgres", "15.5", "postgresql-license"),
402        ("postgres", "14.10", "postgresql-license"),
403        ("postgres", "13.13", "postgresql-license"),
404        ("mysql", "8.0.35", "general-public-license"),
405        ("mysql", "8.0.28", "general-public-license"),
406        ("mysql", "5.7.44", "general-public-license"),
407        ("mariadb", "10.11.6", "general-public-license"),
408        ("mariadb", "10.6.16", "general-public-license"),
409    ];
410
411    for (engine, version, license) in engines_and_versions {
412        for class in SUPPORTED_INSTANCE_CLASSES {
413            options.push(OrderableDbInstanceOption {
414                engine: engine.to_string(),
415                engine_version: version.to_string(),
416                db_instance_class: class.to_string(),
417                license_model: license.to_string(),
418                storage_type: "gp2".to_string(),
419                min_storage_size: 20,
420                max_storage_size: 16384,
421            });
422        }
423    }
424
425    options
426}
427
428pub fn default_parameter_groups(
429    account_id: &str,
430    region: &str,
431) -> HashMap<String, DbParameterGroup> {
432    let mut groups = HashMap::new();
433
434    let families = vec![
435        ("postgres16", "Default parameter group for postgres16"),
436        ("postgres15", "Default parameter group for postgres15"),
437        ("postgres14", "Default parameter group for postgres14"),
438        ("postgres13", "Default parameter group for postgres13"),
439        ("mysql8.0", "Default parameter group for mysql8.0"),
440        ("mysql5.7", "Default parameter group for mysql5.7"),
441        ("mariadb10.11", "Default parameter group for mariadb10.11"),
442        ("mariadb10.6", "Default parameter group for mariadb10.6"),
443    ];
444
445    for (family, description) in families {
446        let group_name = format!("default.{}", family);
447        let group = DbParameterGroup {
448            db_parameter_group_name: group_name.clone(),
449            db_parameter_group_arn: format!("arn:aws:rds:{region}:{account_id}:pg:{group_name}"),
450            db_parameter_group_family: family.to_string(),
451            description: description.to_string(),
452            parameters: HashMap::new(),
453            tags: Vec::new(),
454        };
455        groups.insert(group_name, group);
456    }
457
458    groups
459}
460
461#[cfg(test)]
462mod tests {
463    use chrono::Utc;
464
465    use super::{default_engine_versions, default_orderable_options, DbInstance, RdsState};
466
467    #[test]
468    fn new_initializes_account_and_region() {
469        let state = RdsState::new("123456789012", "us-east-1");
470
471        assert_eq!(state.account_id, "123456789012");
472        assert_eq!(state.region, "us-east-1");
473        assert!(state.instances.is_empty());
474        assert!(state.in_progress_instance_ids.is_empty());
475    }
476
477    #[test]
478    fn reset_clears_instances() {
479        let mut state = RdsState::new("123456789012", "us-east-1");
480        let created_at = Utc::now();
481        state.instances.insert(
482            "db-1".to_string(),
483            DbInstance {
484                db_instance_identifier: "db-1".to_string(),
485                db_instance_arn: "arn:aws:rds:us-east-1:123456789012:db:db-1".to_string(),
486                db_instance_class: "db.t3.micro".to_string(),
487                engine: "postgres".to_string(),
488                engine_version: "16.3".to_string(),
489                db_instance_status: "available".to_string(),
490                master_username: "admin".to_string(),
491                db_name: Some("postgres".to_string()),
492                endpoint_address: "127.0.0.1".to_string(),
493                port: 5432,
494                allocated_storage: 20,
495                publicly_accessible: true,
496                deletion_protection: false,
497                created_at,
498                dbi_resource_id: "db-test".to_string(),
499                master_user_password: "secret123".to_string(),
500                container_id: "container-id".to_string(),
501                host_port: 15432,
502                tags: Vec::new(),
503                read_replica_source_db_instance_identifier: None,
504                read_replica_db_instance_identifiers: Vec::new(),
505                vpc_security_group_ids: Vec::new(),
506                db_parameter_group_name: None,
507                backup_retention_period: 1,
508                preferred_backup_window: "03:00-04:00".to_string(),
509                latest_restorable_time: Some(created_at),
510                option_group_name: None,
511                multi_az: false,
512                pending_modified_values: None,
513            },
514        );
515
516        state.reset();
517
518        assert!(state.instances.is_empty());
519        assert!(state.in_progress_instance_ids.is_empty());
520    }
521
522    #[test]
523    fn default_engine_versions_are_postgres_metadata() {
524        let versions = default_engine_versions();
525
526        assert_eq!(versions.len(), 9); // 4 postgres + 3 mysql + 2 mariadb
527                                       // Check first postgres version
528        assert_eq!(versions[0].engine, "postgres");
529        assert_eq!(versions[0].engine_version, "16.3");
530        assert_eq!(versions[0].db_parameter_group_family, "postgres16");
531    }
532
533    #[test]
534    fn default_orderable_options_match_engine_versions() {
535        let versions = default_engine_versions();
536        let options = default_orderable_options();
537
538        assert_eq!(options.len(), 63); // 9 versions * 7 instance classes
539                                       // Verify all engines and versions have orderable options
540        for version in &versions {
541            assert!(options.iter().any(|opt| {
542                opt.engine == version.engine && opt.engine_version == version.engine_version
543            }));
544        }
545    }
546
547    #[test]
548    fn begin_instance_creation_rejects_duplicate_identifiers() {
549        let mut state = RdsState::new("123456789012", "us-east-1");
550
551        assert!(state.begin_instance_creation("db-1"));
552        assert!(!state.begin_instance_creation("db-1"));
553
554        state.cancel_instance_creation("db-1");
555        assert!(state.begin_instance_creation("db-1"));
556    }
557}