gel_dsn/gel/
credentials.rs

1use std::{num::NonZeroU16, str::FromStr};
2
3use serde::{Deserialize, Serialize};
4
5use super::{
6    error::*, Param, Params, TlsSecurity, DEFAULT_BRANCH_NAME_CONNECT, DEFAULT_DATABASE_NAME,
7    DEFAULT_HOST, DEFAULT_PORT,
8};
9
10/// An opaque type representing a credentials file.
11///
12/// Use [`std::str::FromStr`] to parse a credentials file from a string.
13#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(deny_unknown_fields)]
15pub struct CredentialsFile {
16    pub user: Option<String>,
17    pub host: Option<String>,
18    pub port: Option<NonZeroU16>,
19    pub password: Option<String>,
20    pub secret_key: Option<String>,
21    pub database: Option<String>,
22    pub branch: Option<String>,
23    pub tls_ca: Option<String>,
24    #[serde(default)]
25    pub tls_security: TlsSecurity,
26    pub tls_server_name: Option<String>,
27
28    #[serde(skip)]
29    pub(crate) warnings: Vec<Warning>,
30}
31
32impl From<&CredentialsFile> for Params {
33    fn from(credentials: &CredentialsFile) -> Self {
34        let host = if let Some(host) = credentials.host.clone() {
35            Param::Unparsed(host)
36        } else {
37            Param::Parsed(DEFAULT_HOST.clone())
38        };
39        let port = if let Some(port) = credentials.port {
40            Param::Parsed(port.into())
41        } else {
42            Param::Parsed(DEFAULT_PORT)
43        };
44
45        Params {
46            host,
47            port,
48            user: Param::from_unparsed(credentials.user.clone()),
49            password: Param::from_unparsed(credentials.password.clone()),
50            secret_key: Param::from_unparsed(credentials.secret_key.clone()),
51            database: Param::from_unparsed(credentials.database.clone()),
52            branch: Param::from_unparsed(credentials.branch.clone()),
53            tls_ca: Param::from_unparsed(credentials.tls_ca.clone()),
54            tls_security: Param::Parsed(credentials.tls_security),
55            tls_server_name: Param::from_unparsed(credentials.tls_server_name.clone()),
56            ..Default::default()
57        }
58    }
59}
60
61impl From<CredentialsFile> for Params {
62    fn from(credentials: CredentialsFile) -> Self {
63        Self::from(&credentials)
64    }
65}
66
67impl CredentialsFile {
68    pub fn warnings(&self) -> &[Warning] {
69        &self.warnings
70    }
71}
72
73impl FromStr for CredentialsFile {
74    type Err = ParseError;
75
76    fn from_str(s: &str) -> Result<Self, Self::Err> {
77        if let Ok(mut res) = serde_json::from_str::<CredentialsFile>(s) {
78            // Special case: treat database=__default__ and branch=edgedb as not set
79            if (
80                Some(DEFAULT_DATABASE_NAME),
81                Some(DEFAULT_BRANCH_NAME_CONNECT),
82            ) == (res.database.as_deref(), res.branch.as_deref())
83            {
84                res.database = None;
85                res.branch = None;
86            }
87
88            // Special case: don't allow database and branch to be set at the same time
89            if let (Some(database), Some(branch)) = (&res.database, &res.branch) {
90                if database != branch {
91                    return Err(ParseError::InvalidCredentialsFile(
92                        InvalidCredentialsFileError::ConflictingSettings(
93                            ("database".to_string(), database.clone()),
94                            ("branch".to_string(), branch.clone()),
95                        ),
96                    ));
97                }
98            }
99
100            return Ok(res);
101        }
102
103        let res = serde_json::from_str::<CredentialsFileCompat>(s).map_err(|e| {
104            ParseError::InvalidCredentialsFile(InvalidCredentialsFileError::SerializationError(
105                e.to_string(),
106            ))
107        })?;
108
109        res.try_into()
110    }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114struct CredentialsFileCompat {
115    #[serde(default, skip_serializing_if = "Option::is_none")]
116    host: Option<String>,
117    #[serde(default, skip_serializing_if = "Option::is_none")]
118    port: Option<NonZeroU16>,
119    #[serde(default, skip_serializing_if = "Option::is_none")]
120    user: Option<String>,
121    #[serde(default, skip_serializing_if = "Option::is_none")]
122    password: Option<String>,
123    #[serde(default, skip_serializing_if = "Option::is_none")]
124    secret_key: Option<String>,
125    #[serde(default, skip_serializing_if = "Option::is_none")]
126    database: Option<String>,
127    #[serde(default, skip_serializing_if = "Option::is_none")]
128    branch: Option<String>,
129    #[serde(default, skip_serializing_if = "Option::is_none")]
130    tls_cert_data: Option<String>, // deprecated
131    #[serde(default, skip_serializing_if = "Option::is_none")]
132    tls_ca: Option<String>,
133    #[serde(default, skip_serializing_if = "Option::is_none")]
134    tls_server_name: Option<String>,
135    #[serde(default, skip_serializing_if = "Option::is_none")]
136    tls_verify_hostname: Option<bool>, // deprecated
137    tls_security: Option<TlsSecurity>,
138}
139
140impl CredentialsFileCompat {
141    fn validate(&self) -> Vec<Warning> {
142        let mut warnings = Vec::new();
143        if self.database.as_deref() == Some(DEFAULT_DATABASE_NAME)
144            && self.branch.as_deref() == Some(DEFAULT_BRANCH_NAME_CONNECT)
145        {
146            warnings.push(Warning::DefaultDatabaseAndBranch);
147        }
148        if self.tls_verify_hostname.is_some() {
149            warnings.push(Warning::DeprecatedCredentialProperty(
150                "tls_verify_hostname".to_string(),
151            ));
152        }
153        if self.tls_cert_data.is_some() {
154            warnings.push(Warning::DeprecatedCredentialProperty(
155                "tls_cert_data".to_string(),
156            ));
157        }
158        warnings
159    }
160}
161
162impl TryInto<CredentialsFile> for CredentialsFileCompat {
163    type Error = ParseError;
164
165    fn try_into(self) -> Result<CredentialsFile, Self::Error> {
166        let expected_verify = match self.tls_security {
167            Some(TlsSecurity::Strict) => Some(true),
168            Some(TlsSecurity::NoHostVerification) => Some(false),
169            Some(TlsSecurity::Insecure) => Some(false),
170            _ => None,
171        };
172        if self.tls_verify_hostname.is_some()
173            && self.tls_security.is_some()
174            && expected_verify
175                .zip(self.tls_verify_hostname)
176                .map(|(actual, expected)| actual != expected)
177                .unwrap_or(false)
178        {
179            Err(ParseError::InvalidCredentialsFile(
180                InvalidCredentialsFileError::ConflictingSettings(
181                    (
182                        "tls_security".to_string(),
183                        self.tls_security.unwrap().to_string(),
184                    ),
185                    (
186                        "tls_verify_hostname".to_string(),
187                        self.tls_verify_hostname.unwrap().to_string(),
188                    ),
189                ),
190            ))
191        } else if self.tls_ca.is_some()
192            && self.tls_cert_data.is_some()
193            && self.tls_ca != self.tls_cert_data
194        {
195            return Err(ParseError::InvalidCredentialsFile(
196                InvalidCredentialsFileError::ConflictingSettings(
197                    ("tls_ca".to_string(), self.tls_ca.unwrap().to_string()),
198                    (
199                        "tls_cert_data".to_string(),
200                        self.tls_cert_data.unwrap().to_string(),
201                    ),
202                ),
203            ));
204        } else {
205            let warnings = self.validate();
206
207            let mut database = self.database;
208            let mut branch = self.branch;
209
210            // Special case: treat database=__default__ and branch=edgedb as not set
211            if (
212                Some(DEFAULT_DATABASE_NAME),
213                Some(DEFAULT_BRANCH_NAME_CONNECT),
214            ) == (database.as_deref(), branch.as_deref())
215            {
216                database = None;
217                branch = None;
218            }
219
220            // Special case: don't allow database and branch to be set at the same time
221            if let (Some(database), Some(branch)) = (&database, &branch) {
222                if database != branch {
223                    return Err(ParseError::InvalidCredentialsFile(
224                        InvalidCredentialsFileError::ConflictingSettings(
225                            ("database".to_string(), database.to_string()),
226                            ("branch".to_string(), branch.to_string()),
227                        ),
228                    ));
229                }
230            }
231
232            Ok(CredentialsFile {
233                host: self.host,
234                port: self.port,
235                user: self.user,
236                password: self.password,
237                secret_key: self.secret_key,
238                database,
239                branch,
240                tls_ca: self.tls_ca.or(self.tls_cert_data.clone()),
241                tls_server_name: self.tls_server_name,
242                tls_security: self.tls_security.unwrap_or(match self.tls_verify_hostname {
243                    None => TlsSecurity::Default,
244                    Some(true) => TlsSecurity::Strict,
245                    Some(false) => TlsSecurity::NoHostVerification,
246                }),
247                warnings,
248            })
249        }
250    }
251}
252
253/// An opaque type representing a cloud credentials file.
254///
255/// Use [`std::str::FromStr`] to parse a cloud credentials file from a string.
256#[derive(Debug, Clone, Deserialize)]
257pub struct CloudCredentialsFile {
258    pub(crate) secret_key: String,
259}
260
261impl FromStr for CloudCredentialsFile {
262    type Err = ParseError;
263
264    fn from_str(s: &str) -> Result<Self, Self::Err> {
265        serde_json::from_str(s).map_err(|e| {
266            ParseError::InvalidCredentialsFile(InvalidCredentialsFileError::SerializationError(
267                e.to_string(),
268            ))
269        })
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_credentials_file() {
279        let credentials = CredentialsFile::from_str("{\"branch\": \"edgedb\"}").unwrap();
280        assert_eq!(credentials.branch, Some("edgedb".to_string()));
281        assert_eq!(credentials.database, None);
282    }
283
284    #[test]
285    fn test_credentials_file_default_database_and_branch() {
286        let credentials =
287            CredentialsFile::from_str("{\"database\": \"edgedb\", \"branch\": \"__default__\"}")
288                .unwrap();
289        assert_eq!(credentials.database, None);
290        assert_eq!(credentials.branch, None);
291    }
292}