Skip to main content

fakecloud_rds/
state.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3use std::sync::Arc;
4
5use chrono::{DateTime, Utc};
6use fakecloud_aws::arn::Arn;
7use parking_lot::RwLock;
8use uuid::Uuid;
9
10pub type SharedRdsState = Arc<RwLock<RdsState>>;
11
12#[derive(Clone)]
13pub struct DbInstance {
14    pub db_instance_identifier: String,
15    pub db_instance_arn: String,
16    pub db_instance_class: String,
17    pub engine: String,
18    pub engine_version: String,
19    pub db_instance_status: String,
20    pub master_username: String,
21    pub db_name: Option<String>,
22    pub endpoint_address: String,
23    pub port: i32,
24    pub allocated_storage: i32,
25    pub publicly_accessible: bool,
26    pub deletion_protection: bool,
27    pub created_at: DateTime<Utc>,
28    pub dbi_resource_id: String,
29    pub master_user_password: String,
30    pub container_id: String,
31    pub host_port: u16,
32    pub tags: Vec<RdsTag>,
33    pub read_replica_source_db_instance_identifier: Option<String>,
34    pub read_replica_db_instance_identifiers: Vec<String>,
35    pub vpc_security_group_ids: Vec<String>,
36    pub db_parameter_group_name: Option<String>,
37    pub backup_retention_period: i32,
38    pub preferred_backup_window: String,
39    pub latest_restorable_time: DateTime<Utc>,
40    pub option_group_name: Option<String>,
41    pub multi_az: bool,
42    pub pending_modified_values: Option<PendingModifiedValues>,
43}
44
45#[derive(Clone, Default)]
46pub struct PendingModifiedValues {
47    pub db_instance_class: Option<String>,
48    pub allocated_storage: Option<i32>,
49    pub backup_retention_period: Option<i32>,
50    pub multi_az: Option<bool>,
51    pub engine_version: Option<String>,
52    pub master_user_password: Option<String>,
53}
54
55impl fmt::Debug for PendingModifiedValues {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        f.debug_struct("PendingModifiedValues")
58            .field("db_instance_class", &self.db_instance_class)
59            .field("allocated_storage", &self.allocated_storage)
60            .field("backup_retention_period", &self.backup_retention_period)
61            .field("multi_az", &self.multi_az)
62            .field("engine_version", &self.engine_version)
63            .field(
64                "master_user_password",
65                &self.master_user_password.as_ref().map(|_| "<redacted>"),
66            )
67            .finish()
68    }
69}
70
71impl fmt::Debug for DbInstance {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        f.debug_struct("DbInstance")
74            .field("db_instance_identifier", &self.db_instance_identifier)
75            .field("db_instance_arn", &self.db_instance_arn)
76            .field("db_instance_class", &self.db_instance_class)
77            .field("engine", &self.engine)
78            .field("engine_version", &self.engine_version)
79            .field("db_instance_status", &self.db_instance_status)
80            .field("master_username", &self.master_username)
81            .field("db_name", &self.db_name)
82            .field("endpoint_address", &self.endpoint_address)
83            .field("port", &self.port)
84            .field("allocated_storage", &self.allocated_storage)
85            .field("publicly_accessible", &self.publicly_accessible)
86            .field("deletion_protection", &self.deletion_protection)
87            .field("created_at", &self.created_at)
88            .field("dbi_resource_id", &self.dbi_resource_id)
89            .field("master_user_password", &"<redacted>")
90            .field("container_id", &self.container_id)
91            .field("host_port", &self.host_port)
92            .field("tags", &self.tags)
93            .field(
94                "read_replica_source_db_instance_identifier",
95                &self.read_replica_source_db_instance_identifier,
96            )
97            .field(
98                "read_replica_db_instance_identifiers",
99                &self.read_replica_db_instance_identifiers,
100            )
101            .field("vpc_security_group_ids", &self.vpc_security_group_ids)
102            .field("db_parameter_group_name", &self.db_parameter_group_name)
103            .field("backup_retention_period", &self.backup_retention_period)
104            .field("preferred_backup_window", &self.preferred_backup_window)
105            .field("latest_restorable_time", &self.latest_restorable_time)
106            .field("option_group_name", &self.option_group_name)
107            .field("multi_az", &self.multi_az)
108            .field("pending_modified_values", &self.pending_modified_values)
109            .finish()
110    }
111}
112
113#[derive(Debug, Clone, PartialEq, Eq)]
114pub struct RdsTag {
115    pub key: String,
116    pub value: String,
117}
118
119#[derive(Clone)]
120pub struct DbSnapshot {
121    pub db_snapshot_identifier: String,
122    pub db_snapshot_arn: String,
123    pub db_instance_identifier: String,
124    pub snapshot_create_time: DateTime<Utc>,
125    pub engine: String,
126    pub engine_version: String,
127    pub allocated_storage: i32,
128    pub status: String,
129    pub port: i32,
130    pub master_username: String,
131    pub db_name: Option<String>,
132    pub dbi_resource_id: String,
133    pub snapshot_type: String,
134    pub master_user_password: String,
135    pub tags: Vec<RdsTag>,
136    pub dump_data: Vec<u8>,
137}
138
139impl fmt::Debug for DbSnapshot {
140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141        f.debug_struct("DbSnapshot")
142            .field("db_snapshot_identifier", &self.db_snapshot_identifier)
143            .field("db_snapshot_arn", &self.db_snapshot_arn)
144            .field("db_instance_identifier", &self.db_instance_identifier)
145            .field("snapshot_create_time", &self.snapshot_create_time)
146            .field("engine", &self.engine)
147            .field("engine_version", &self.engine_version)
148            .field("allocated_storage", &self.allocated_storage)
149            .field("status", &self.status)
150            .field("port", &self.port)
151            .field("master_username", &self.master_username)
152            .field("db_name", &self.db_name)
153            .field("dbi_resource_id", &self.dbi_resource_id)
154            .field("snapshot_type", &self.snapshot_type)
155            .field("master_user_password", &"<redacted>")
156            .field("tags", &self.tags)
157            .field("dump_data", &format!("<{} bytes>", self.dump_data.len()))
158            .finish()
159    }
160}
161
162#[derive(Debug)]
163pub struct RdsState {
164    pub account_id: String,
165    pub region: String,
166    pub instances: HashMap<String, DbInstance>,
167    pub in_progress_instance_ids: HashSet<String>,
168    pub snapshots: HashMap<String, DbSnapshot>,
169    pub subnet_groups: HashMap<String, DbSubnetGroup>,
170    pub parameter_groups: HashMap<String, DbParameterGroup>,
171}
172
173#[derive(Debug, Clone, PartialEq, Eq)]
174pub struct EngineVersionInfo {
175    pub engine: String,
176    pub engine_version: String,
177    pub db_parameter_group_family: String,
178    pub db_engine_description: String,
179    pub db_engine_version_description: String,
180    pub status: String,
181}
182
183#[derive(Debug, Clone, PartialEq, Eq)]
184pub struct OrderableDbInstanceOption {
185    pub engine: String,
186    pub engine_version: String,
187    pub db_instance_class: String,
188    pub license_model: String,
189    pub storage_type: String,
190    pub min_storage_size: i32,
191    pub max_storage_size: i32,
192}
193
194#[derive(Debug, Clone)]
195pub struct DbSubnetGroup {
196    pub db_subnet_group_name: String,
197    pub db_subnet_group_arn: String,
198    pub db_subnet_group_description: String,
199    pub vpc_id: String,
200    pub subnet_ids: Vec<String>,
201    pub subnet_availability_zones: Vec<String>,
202    pub tags: Vec<RdsTag>,
203}
204
205#[derive(Debug, Clone)]
206pub struct DbParameterGroup {
207    pub db_parameter_group_name: String,
208    pub db_parameter_group_arn: String,
209    pub db_parameter_group_family: String,
210    pub description: String,
211    pub parameters: HashMap<String, String>,
212    pub tags: Vec<RdsTag>,
213}
214
215impl RdsState {
216    pub fn new(account_id: &str, region: &str) -> Self {
217        Self {
218            account_id: account_id.to_string(),
219            region: region.to_string(),
220            instances: HashMap::new(),
221            in_progress_instance_ids: HashSet::new(),
222            snapshots: HashMap::new(),
223            subnet_groups: HashMap::new(),
224            parameter_groups: default_parameter_groups(account_id, region),
225        }
226    }
227
228    pub fn reset(&mut self) {
229        self.instances.clear();
230        self.in_progress_instance_ids.clear();
231        self.snapshots.clear();
232        self.subnet_groups.clear();
233        self.parameter_groups = default_parameter_groups(&self.account_id, &self.region);
234    }
235
236    pub fn db_instance_arn(&self, db_instance_identifier: &str) -> String {
237        Arn::new(
238            "rds",
239            &self.region,
240            &self.account_id,
241            &format!("db:{db_instance_identifier}"),
242        )
243        .to_string()
244    }
245
246    pub fn db_snapshot_arn(&self, db_snapshot_identifier: &str) -> String {
247        Arn::new(
248            "rds",
249            &self.region,
250            &self.account_id,
251            &format!("snapshot:{db_snapshot_identifier}"),
252        )
253        .to_string()
254    }
255
256    pub fn db_subnet_group_arn(&self, db_subnet_group_name: &str) -> String {
257        Arn::new(
258            "rds",
259            &self.region,
260            &self.account_id,
261            &format!("subgrp:{db_subnet_group_name}"),
262        )
263        .to_string()
264    }
265
266    pub fn db_parameter_group_arn(&self, db_parameter_group_name: &str) -> String {
267        Arn::new(
268            "rds",
269            &self.region,
270            &self.account_id,
271            &format!("pg:{db_parameter_group_name}"),
272        )
273        .to_string()
274    }
275
276    pub fn next_dbi_resource_id(&self) -> String {
277        format!("db-{}", Uuid::new_v4().simple())
278    }
279
280    pub fn begin_instance_creation(&mut self, db_instance_identifier: &str) -> bool {
281        if self.instances.contains_key(db_instance_identifier)
282            || self
283                .in_progress_instance_ids
284                .contains(db_instance_identifier)
285        {
286            return false;
287        }
288
289        self.in_progress_instance_ids
290            .insert(db_instance_identifier.to_string());
291        true
292    }
293
294    pub fn finish_instance_creation(&mut self, instance: DbInstance) {
295        self.in_progress_instance_ids
296            .remove(&instance.db_instance_identifier);
297        self.instances
298            .insert(instance.db_instance_identifier.clone(), instance);
299    }
300
301    pub fn cancel_instance_creation(&mut self, db_instance_identifier: &str) {
302        self.in_progress_instance_ids.remove(db_instance_identifier);
303    }
304}
305
306pub fn default_engine_versions() -> Vec<EngineVersionInfo> {
307    vec![
308        // PostgreSQL versions
309        EngineVersionInfo {
310            engine: "postgres".to_string(),
311            engine_version: "16.3".to_string(),
312            db_parameter_group_family: "postgres16".to_string(),
313            db_engine_description: "PostgreSQL".to_string(),
314            db_engine_version_description: "PostgreSQL 16.3".to_string(),
315            status: "available".to_string(),
316        },
317        EngineVersionInfo {
318            engine: "postgres".to_string(),
319            engine_version: "15.5".to_string(),
320            db_parameter_group_family: "postgres15".to_string(),
321            db_engine_description: "PostgreSQL".to_string(),
322            db_engine_version_description: "PostgreSQL 15.5".to_string(),
323            status: "available".to_string(),
324        },
325        EngineVersionInfo {
326            engine: "postgres".to_string(),
327            engine_version: "14.10".to_string(),
328            db_parameter_group_family: "postgres14".to_string(),
329            db_engine_description: "PostgreSQL".to_string(),
330            db_engine_version_description: "PostgreSQL 14.10".to_string(),
331            status: "available".to_string(),
332        },
333        EngineVersionInfo {
334            engine: "postgres".to_string(),
335            engine_version: "13.13".to_string(),
336            db_parameter_group_family: "postgres13".to_string(),
337            db_engine_description: "PostgreSQL".to_string(),
338            db_engine_version_description: "PostgreSQL 13.13".to_string(),
339            status: "available".to_string(),
340        },
341        // MySQL versions
342        EngineVersionInfo {
343            engine: "mysql".to_string(),
344            engine_version: "8.0.35".to_string(),
345            db_parameter_group_family: "mysql8.0".to_string(),
346            db_engine_description: "MySQL Community Edition".to_string(),
347            db_engine_version_description: "MySQL 8.0.35".to_string(),
348            status: "available".to_string(),
349        },
350        EngineVersionInfo {
351            engine: "mysql".to_string(),
352            engine_version: "8.0.28".to_string(),
353            db_parameter_group_family: "mysql8.0".to_string(),
354            db_engine_description: "MySQL Community Edition".to_string(),
355            db_engine_version_description: "MySQL 8.0.28".to_string(),
356            status: "available".to_string(),
357        },
358        EngineVersionInfo {
359            engine: "mysql".to_string(),
360            engine_version: "5.7.44".to_string(),
361            db_parameter_group_family: "mysql5.7".to_string(),
362            db_engine_description: "MySQL Community Edition".to_string(),
363            db_engine_version_description: "MySQL 5.7.44".to_string(),
364            status: "available".to_string(),
365        },
366        // MariaDB versions
367        EngineVersionInfo {
368            engine: "mariadb".to_string(),
369            engine_version: "10.11.6".to_string(),
370            db_parameter_group_family: "mariadb10.11".to_string(),
371            db_engine_description: "MariaDB Community Edition".to_string(),
372            db_engine_version_description: "MariaDB 10.11.6".to_string(),
373            status: "available".to_string(),
374        },
375        EngineVersionInfo {
376            engine: "mariadb".to_string(),
377            engine_version: "10.6.16".to_string(),
378            db_parameter_group_family: "mariadb10.6".to_string(),
379            db_engine_description: "MariaDB Community Edition".to_string(),
380            db_engine_version_description: "MariaDB 10.6.16".to_string(),
381            status: "available".to_string(),
382        },
383    ]
384}
385
386pub fn default_orderable_options() -> Vec<OrderableDbInstanceOption> {
387    let mut options = Vec::new();
388    let engines_and_versions = vec![
389        ("postgres", "16.3", "postgresql-license"),
390        ("postgres", "15.5", "postgresql-license"),
391        ("postgres", "14.10", "postgresql-license"),
392        ("postgres", "13.13", "postgresql-license"),
393        ("mysql", "8.0.35", "general-public-license"),
394        ("mysql", "8.0.28", "general-public-license"),
395        ("mysql", "5.7.44", "general-public-license"),
396        ("mariadb", "10.11.6", "general-public-license"),
397        ("mariadb", "10.6.16", "general-public-license"),
398    ];
399
400    let instance_classes = vec![
401        "db.t3.micro",
402        "db.t3.small",
403        "db.t3.medium",
404        "db.t3.large",
405        "db.t4g.micro",
406        "db.t4g.small",
407        "db.m5.large",
408    ];
409
410    for (engine, version, license) in engines_and_versions {
411        for class in &instance_classes {
412            options.push(OrderableDbInstanceOption {
413                engine: engine.to_string(),
414                engine_version: version.to_string(),
415                db_instance_class: class.to_string(),
416                license_model: license.to_string(),
417                storage_type: "gp2".to_string(),
418                min_storage_size: 20,
419                max_storage_size: 16384,
420            });
421        }
422    }
423
424    options
425}
426
427pub fn default_parameter_groups(
428    account_id: &str,
429    region: &str,
430) -> HashMap<String, DbParameterGroup> {
431    let mut groups = HashMap::new();
432
433    let families = vec![
434        ("postgres16", "Default parameter group for postgres16"),
435        ("postgres15", "Default parameter group for postgres15"),
436        ("postgres14", "Default parameter group for postgres14"),
437        ("postgres13", "Default parameter group for postgres13"),
438        ("mysql8.0", "Default parameter group for mysql8.0"),
439        ("mysql5.7", "Default parameter group for mysql5.7"),
440        ("mariadb10.11", "Default parameter group for mariadb10.11"),
441        ("mariadb10.6", "Default parameter group for mariadb10.6"),
442    ];
443
444    for (family, description) in families {
445        let group_name = format!("default.{}", family);
446        let group = DbParameterGroup {
447            db_parameter_group_name: group_name.clone(),
448            db_parameter_group_arn: format!("arn:aws:rds:{region}:{account_id}:pg:{group_name}"),
449            db_parameter_group_family: family.to_string(),
450            description: description.to_string(),
451            parameters: HashMap::new(),
452            tags: Vec::new(),
453        };
454        groups.insert(group_name, group);
455    }
456
457    groups
458}
459
460#[cfg(test)]
461mod tests {
462    use chrono::Utc;
463
464    use super::{default_engine_versions, default_orderable_options, DbInstance, RdsState};
465
466    #[test]
467    fn new_initializes_account_and_region() {
468        let state = RdsState::new("123456789012", "us-east-1");
469
470        assert_eq!(state.account_id, "123456789012");
471        assert_eq!(state.region, "us-east-1");
472        assert!(state.instances.is_empty());
473        assert!(state.in_progress_instance_ids.is_empty());
474    }
475
476    #[test]
477    fn reset_clears_instances() {
478        let mut state = RdsState::new("123456789012", "us-east-1");
479        let created_at = Utc::now();
480        state.instances.insert(
481            "db-1".to_string(),
482            DbInstance {
483                db_instance_identifier: "db-1".to_string(),
484                db_instance_arn: "arn:aws:rds:us-east-1:123456789012:db:db-1".to_string(),
485                db_instance_class: "db.t3.micro".to_string(),
486                engine: "postgres".to_string(),
487                engine_version: "16.3".to_string(),
488                db_instance_status: "available".to_string(),
489                master_username: "admin".to_string(),
490                db_name: Some("postgres".to_string()),
491                endpoint_address: "127.0.0.1".to_string(),
492                port: 5432,
493                allocated_storage: 20,
494                publicly_accessible: true,
495                deletion_protection: false,
496                created_at,
497                dbi_resource_id: "db-test".to_string(),
498                master_user_password: "secret123".to_string(),
499                container_id: "container-id".to_string(),
500                host_port: 15432,
501                tags: Vec::new(),
502                read_replica_source_db_instance_identifier: None,
503                read_replica_db_instance_identifiers: Vec::new(),
504                vpc_security_group_ids: Vec::new(),
505                db_parameter_group_name: None,
506                backup_retention_period: 1,
507                preferred_backup_window: "03:00-04:00".to_string(),
508                latest_restorable_time: created_at,
509                option_group_name: None,
510                multi_az: false,
511                pending_modified_values: None,
512            },
513        );
514
515        state.reset();
516
517        assert!(state.instances.is_empty());
518        assert!(state.in_progress_instance_ids.is_empty());
519    }
520
521    #[test]
522    fn default_engine_versions_are_postgres_metadata() {
523        let versions = default_engine_versions();
524
525        assert_eq!(versions.len(), 9); // 4 postgres + 3 mysql + 2 mariadb
526                                       // Check first postgres version
527        assert_eq!(versions[0].engine, "postgres");
528        assert_eq!(versions[0].engine_version, "16.3");
529        assert_eq!(versions[0].db_parameter_group_family, "postgres16");
530    }
531
532    #[test]
533    fn default_orderable_options_match_engine_versions() {
534        let versions = default_engine_versions();
535        let options = default_orderable_options();
536
537        assert_eq!(options.len(), 63); // 9 versions * 7 instance classes
538                                       // Verify all engines and versions have orderable options
539        for version in &versions {
540            assert!(options.iter().any(|opt| {
541                opt.engine == version.engine && opt.engine_version == version.engine_version
542            }));
543        }
544    }
545
546    #[test]
547    fn begin_instance_creation_rejects_duplicate_identifiers() {
548        let mut state = RdsState::new("123456789012", "us-east-1");
549
550        assert!(state.begin_instance_creation("db-1"));
551        assert!(!state.begin_instance_creation("db-1"));
552
553        state.cancel_instance_creation("db-1");
554        assert!(state.begin_instance_creation("db-1"));
555    }
556}