1use std::{num::NonZeroU16, str::FromStr};
2
3use serde::{Deserialize, Serialize};
4
5use super::{
6 error::*, Param, Params, TlsSecurity, DEFAULT_BRANCH_NAME_CONNECT, DEFAULT_DATABASE_NAME,
7 DEFAULT_HOST, DEFAULT_PORT,
8};
9
10#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
14#[serde(deny_unknown_fields)]
15pub struct CredentialsFile {
16 pub user: Option<String>,
17 pub host: Option<String>,
18 pub port: Option<NonZeroU16>,
19 pub password: Option<String>,
20 pub secret_key: Option<String>,
21 pub database: Option<String>,
22 pub branch: Option<String>,
23 pub tls_ca: Option<String>,
24 #[serde(default)]
25 pub tls_security: TlsSecurity,
26 pub tls_server_name: Option<String>,
27
28 #[serde(skip)]
29 pub(crate) warnings: Vec<Warning>,
30}
31
32impl From<&CredentialsFile> for Params {
33 fn from(credentials: &CredentialsFile) -> Self {
34 let host = if let Some(host) = credentials.host.clone() {
35 Param::Unparsed(host)
36 } else {
37 Param::Parsed(DEFAULT_HOST.clone())
38 };
39 let port = if let Some(port) = credentials.port {
40 Param::Parsed(port.into())
41 } else {
42 Param::Parsed(DEFAULT_PORT)
43 };
44
45 Params {
46 host,
47 port,
48 user: Param::from_unparsed(credentials.user.clone()),
49 password: Param::from_unparsed(credentials.password.clone()),
50 secret_key: Param::from_unparsed(credentials.secret_key.clone()),
51 database: Param::from_unparsed(credentials.database.clone()),
52 branch: Param::from_unparsed(credentials.branch.clone()),
53 tls_ca: Param::from_unparsed(credentials.tls_ca.clone()),
54 tls_security: Param::Parsed(credentials.tls_security),
55 tls_server_name: Param::from_unparsed(credentials.tls_server_name.clone()),
56 ..Default::default()
57 }
58 }
59}
60
61impl From<CredentialsFile> for Params {
62 fn from(credentials: CredentialsFile) -> Self {
63 Self::from(&credentials)
64 }
65}
66
67impl CredentialsFile {
68 pub fn warnings(&self) -> &[Warning] {
69 &self.warnings
70 }
71}
72
73impl FromStr for CredentialsFile {
74 type Err = ParseError;
75
76 fn from_str(s: &str) -> Result<Self, Self::Err> {
77 if let Ok(mut res) = serde_json::from_str::<CredentialsFile>(s) {
78 if (
80 Some(DEFAULT_DATABASE_NAME),
81 Some(DEFAULT_BRANCH_NAME_CONNECT),
82 ) == (res.database.as_deref(), res.branch.as_deref())
83 {
84 res.database = None;
85 res.branch = None;
86 }
87
88 if let (Some(database), Some(branch)) = (&res.database, &res.branch) {
90 if database != branch {
91 return Err(ParseError::InvalidCredentialsFile(
92 InvalidCredentialsFileError::ConflictingSettings(
93 ("database".to_string(), database.clone()),
94 ("branch".to_string(), branch.clone()),
95 ),
96 ));
97 }
98 }
99
100 return Ok(res);
101 }
102
103 let res = serde_json::from_str::<CredentialsFileCompat>(s).map_err(|e| {
104 ParseError::InvalidCredentialsFile(InvalidCredentialsFileError::SerializationError(
105 e.to_string(),
106 ))
107 })?;
108
109 res.try_into()
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114struct CredentialsFileCompat {
115 #[serde(default, skip_serializing_if = "Option::is_none")]
116 host: Option<String>,
117 #[serde(default, skip_serializing_if = "Option::is_none")]
118 port: Option<NonZeroU16>,
119 #[serde(default, skip_serializing_if = "Option::is_none")]
120 user: Option<String>,
121 #[serde(default, skip_serializing_if = "Option::is_none")]
122 password: Option<String>,
123 #[serde(default, skip_serializing_if = "Option::is_none")]
124 secret_key: Option<String>,
125 #[serde(default, skip_serializing_if = "Option::is_none")]
126 database: Option<String>,
127 #[serde(default, skip_serializing_if = "Option::is_none")]
128 branch: Option<String>,
129 #[serde(default, skip_serializing_if = "Option::is_none")]
130 tls_cert_data: Option<String>, #[serde(default, skip_serializing_if = "Option::is_none")]
132 tls_ca: Option<String>,
133 #[serde(default, skip_serializing_if = "Option::is_none")]
134 tls_server_name: Option<String>,
135 #[serde(default, skip_serializing_if = "Option::is_none")]
136 tls_verify_hostname: Option<bool>, tls_security: Option<TlsSecurity>,
138}
139
140impl CredentialsFileCompat {
141 fn validate(&self) -> Vec<Warning> {
142 let mut warnings = Vec::new();
143 if self.database.as_deref() == Some(DEFAULT_DATABASE_NAME)
144 && self.branch.as_deref() == Some(DEFAULT_BRANCH_NAME_CONNECT)
145 {
146 warnings.push(Warning::DefaultDatabaseAndBranch);
147 }
148 if self.tls_verify_hostname.is_some() {
149 warnings.push(Warning::DeprecatedCredentialProperty(
150 "tls_verify_hostname".to_string(),
151 ));
152 }
153 if self.tls_cert_data.is_some() {
154 warnings.push(Warning::DeprecatedCredentialProperty(
155 "tls_cert_data".to_string(),
156 ));
157 }
158 warnings
159 }
160}
161
162impl TryInto<CredentialsFile> for CredentialsFileCompat {
163 type Error = ParseError;
164
165 fn try_into(self) -> Result<CredentialsFile, Self::Error> {
166 let expected_verify = match self.tls_security {
167 Some(TlsSecurity::Strict) => Some(true),
168 Some(TlsSecurity::NoHostVerification) => Some(false),
169 Some(TlsSecurity::Insecure) => Some(false),
170 _ => None,
171 };
172 if self.tls_verify_hostname.is_some()
173 && self.tls_security.is_some()
174 && expected_verify
175 .zip(self.tls_verify_hostname)
176 .map(|(actual, expected)| actual != expected)
177 .unwrap_or(false)
178 {
179 Err(ParseError::InvalidCredentialsFile(
180 InvalidCredentialsFileError::ConflictingSettings(
181 (
182 "tls_security".to_string(),
183 self.tls_security.unwrap().to_string(),
184 ),
185 (
186 "tls_verify_hostname".to_string(),
187 self.tls_verify_hostname.unwrap().to_string(),
188 ),
189 ),
190 ))
191 } else if self.tls_ca.is_some()
192 && self.tls_cert_data.is_some()
193 && self.tls_ca != self.tls_cert_data
194 {
195 return Err(ParseError::InvalidCredentialsFile(
196 InvalidCredentialsFileError::ConflictingSettings(
197 ("tls_ca".to_string(), self.tls_ca.unwrap().to_string()),
198 (
199 "tls_cert_data".to_string(),
200 self.tls_cert_data.unwrap().to_string(),
201 ),
202 ),
203 ));
204 } else {
205 let warnings = self.validate();
206
207 let mut database = self.database;
208 let mut branch = self.branch;
209
210 if (
212 Some(DEFAULT_DATABASE_NAME),
213 Some(DEFAULT_BRANCH_NAME_CONNECT),
214 ) == (database.as_deref(), branch.as_deref())
215 {
216 database = None;
217 branch = None;
218 }
219
220 if let (Some(database), Some(branch)) = (&database, &branch) {
222 if database != branch {
223 return Err(ParseError::InvalidCredentialsFile(
224 InvalidCredentialsFileError::ConflictingSettings(
225 ("database".to_string(), database.to_string()),
226 ("branch".to_string(), branch.to_string()),
227 ),
228 ));
229 }
230 }
231
232 Ok(CredentialsFile {
233 host: self.host,
234 port: self.port,
235 user: self.user,
236 password: self.password,
237 secret_key: self.secret_key,
238 database,
239 branch,
240 tls_ca: self.tls_ca.or(self.tls_cert_data.clone()),
241 tls_server_name: self.tls_server_name,
242 tls_security: self.tls_security.unwrap_or(match self.tls_verify_hostname {
243 None => TlsSecurity::Default,
244 Some(true) => TlsSecurity::Strict,
245 Some(false) => TlsSecurity::NoHostVerification,
246 }),
247 warnings,
248 })
249 }
250 }
251}
252
253#[derive(Debug, Clone, Deserialize)]
257pub struct CloudCredentialsFile {
258 pub(crate) secret_key: String,
259}
260
261impl FromStr for CloudCredentialsFile {
262 type Err = ParseError;
263
264 fn from_str(s: &str) -> Result<Self, Self::Err> {
265 serde_json::from_str(s).map_err(|e| {
266 ParseError::InvalidCredentialsFile(InvalidCredentialsFileError::SerializationError(
267 e.to_string(),
268 ))
269 })
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_credentials_file() {
279 let credentials = CredentialsFile::from_str("{\"branch\": \"edgedb\"}").unwrap();
280 assert_eq!(credentials.branch, Some("edgedb".to_string()));
281 assert_eq!(credentials.database, None);
282 }
283
284 #[test]
285 fn test_credentials_file_default_database_and_branch() {
286 let credentials =
287 CredentialsFile::from_str("{\"database\": \"edgedb\", \"branch\": \"__default__\"}")
288 .unwrap();
289 assert_eq!(credentials.database, None);
290 assert_eq!(credentials.branch, None);
291 }
292}