Skip to main content

shelly_data/
adapter.rs

1use crate::error::{DataError, DataResult};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(rename_all = "snake_case")]
6pub enum AdapterKind {
7    None,
8    Postgres,
9    MySql,
10    Sqlite,
11}
12
13impl AdapterKind {
14    pub fn as_str(self) -> &'static str {
15        match self {
16            Self::None => "none",
17            Self::Postgres => "postgres",
18            Self::MySql => "mysql",
19            Self::Sqlite => "sqlite",
20        }
21    }
22
23    pub fn parse(raw: &str) -> DataResult<Self> {
24        match raw.trim().to_ascii_lowercase().as_str() {
25            "none" => Ok(Self::None),
26            "postgres" | "postgresql" | "pg" => Ok(Self::Postgres),
27            "mysql" => Ok(Self::MySql),
28            "sqlite" | "sqlite3" => Ok(Self::Sqlite),
29            value => Err(DataError::Config(format!(
30                "unsupported database adapter `{value}`; expected one of: none, postgres, mysql, sqlite"
31            ))),
32        }
33    }
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
37pub struct DatabaseConfig {
38    pub adapter: AdapterKind,
39    pub url: Option<String>,
40    pub url_env: Option<String>,
41}
42
43impl Default for DatabaseConfig {
44    fn default() -> Self {
45        Self {
46            adapter: AdapterKind::None,
47            url: None,
48            url_env: Some("DATABASE_URL".to_string()),
49        }
50    }
51}
52
53impl DatabaseConfig {
54    pub fn from_toml_like_str(content: &str) -> DataResult<Self> {
55        let mut config = Self::default();
56        let mut in_database_section = false;
57
58        for raw_line in content.lines() {
59            let line = raw_line.trim();
60            if line.is_empty() || line.starts_with('#') {
61                continue;
62            }
63            if line.starts_with('[') && line.ends_with(']') {
64                in_database_section = line == "[database]";
65                continue;
66            }
67            if !in_database_section {
68                continue;
69            }
70
71            let Some((key, value)) = line.split_once('=') else {
72                continue;
73            };
74
75            let key = key.trim();
76            let value = strip_quotes(value.trim());
77            match key {
78                "adapter" => config.adapter = AdapterKind::parse(value)?,
79                "url" => config.url = Some(value.to_string()),
80                "url_env" => config.url_env = Some(value.to_string()),
81                _ => {}
82            }
83        }
84
85        Ok(config)
86    }
87
88    pub fn resolve_url(&self) -> Option<String> {
89        if let Some(url) = &self.url {
90            return Some(url.clone());
91        }
92        self.url_env
93            .as_deref()
94            .and_then(|env_name| std::env::var(env_name).ok())
95    }
96}
97
98fn strip_quotes(value: &str) -> &str {
99    value
100        .strip_prefix('"')
101        .and_then(|rest| rest.strip_suffix('"'))
102        .unwrap_or(value)
103}
104
105#[cfg(test)]
106mod tests {
107    use super::{AdapterKind, DatabaseConfig};
108    use proptest::prelude::*;
109
110    #[test]
111    fn parse_database_config() {
112        let config = DatabaseConfig::from_toml_like_str(
113            r#"
114[database]
115adapter = "postgres"
116url_env = "APP_DB_URL"
117"#,
118        )
119        .unwrap();
120
121        assert_eq!(config.adapter, AdapterKind::Postgres);
122        assert_eq!(config.url_env.as_deref(), Some("APP_DB_URL"));
123    }
124
125    proptest! {
126        #[test]
127        fn adapter_parse_accepts_aliases_case_and_whitespace(
128            alias in prop_oneof![
129                Just("none"),
130                Just("postgres"),
131                Just("postgresql"),
132                Just("pg"),
133                Just("mysql"),
134                Just("sqlite"),
135                Just("sqlite3"),
136            ],
137            left_ws in 0usize..3,
138            right_ws in 0usize..3,
139            uppercase in any::<bool>(),
140        ) {
141            let alias = if uppercase {
142                alias.to_ascii_uppercase()
143            } else {
144                alias.to_string()
145            };
146            let input = format!("{}{}{}", " ".repeat(left_ws), alias, " ".repeat(right_ws));
147            let kind = AdapterKind::parse(&input).unwrap();
148            let expected = match alias.to_ascii_lowercase().as_str() {
149                "none" => AdapterKind::None,
150                "postgres" | "postgresql" | "pg" => AdapterKind::Postgres,
151                "mysql" => AdapterKind::MySql,
152                "sqlite" | "sqlite3" => AdapterKind::Sqlite,
153                _ => unreachable!("input generated from known aliases"),
154            };
155            prop_assert_eq!(kind, expected);
156        }
157
158        #[test]
159        fn adapter_parse_rejects_unknown_values(raw in "[a-zA-Z0-9_\\-]{1,24}") {
160            let normalized = raw.trim().to_ascii_lowercase();
161            prop_assume!(
162                normalized != "none" &&
163                normalized != "postgres" &&
164                normalized != "postgresql" &&
165                normalized != "pg" &&
166                normalized != "mysql" &&
167                normalized != "sqlite" &&
168                normalized != "sqlite3"
169            );
170            prop_assert!(AdapterKind::parse(&raw).is_err());
171        }
172    }
173}