1use ociman::label;
10
11use crate::config::SeedConfig;
12use crate::seed::{SeedHash, SeedName};
13
14#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
17pub struct SeedEntry {
18 pub name: SeedName,
19 #[serde(flatten)]
20 pub config: SeedConfig,
21 pub hash: Option<SeedHash>,
22}
23
24pub const IMAGE_KEY: label::Key = label::Key::from_static_or_panic("pg-ephemeral.image");
25pub const INSTANCE_KEY: label::Key = label::Key::from_static_or_panic("pg-ephemeral.instance");
26pub const SEEDS_KEY: label::Key = label::Key::from_static_or_panic("pg-ephemeral.seeds");
27pub const SESSION_KEY: label::Key = label::Key::from_static_or_panic("pg-ephemeral.session");
28pub const SSL_CA_CERT_PEM_KEY: label::Key =
29 label::Key::from_static_or_panic("pg-ephemeral.ssl.ca-cert-pem");
30pub const SSL_HOSTNAME_KEY: label::Key =
31 label::Key::from_static_or_panic("pg-ephemeral.ssl.hostname");
32pub const SUPERUSER_APPLICATION_KEY: label::Key =
33 label::Key::from_static_or_panic("pg-ephemeral.superuser.application");
34pub const SUPERUSER_DATABASE_KEY: label::Key =
35 label::Key::from_static_or_panic("pg-ephemeral.superuser.database");
36pub const SUPERUSER_PASSWORD_KEY: label::Key =
37 label::Key::from_static_or_panic("pg-ephemeral.superuser.password");
38pub const SUPERUSER_USER_KEY: label::Key =
39 label::Key::from_static_or_panic("pg-ephemeral.superuser.user");
40pub const VERSION_KEY: label::Key = label::Key::from_static_or_panic("pg-ephemeral.version");
41
42#[derive(Debug, thiserror::Error)]
45pub enum ApplyError {
46 #[error("label {key} value exceeds limits")]
47 OversizedValue {
48 key: label::Key,
49 #[source]
50 source: label::Error,
51 },
52 #[error("failed to serialize seeds as JSON")]
53 SeedsJson(#[source] serde_json::Error),
54}
55
56#[derive(Debug, Clone, PartialEq)]
60pub struct Metadata {
61 pub version: semver::Version,
62 pub instance: crate::InstanceName,
63 pub image: ociman::image::Reference,
64 pub superuser: SuperuserMetadata,
65 pub seeds: Vec<SeedEntry>,
66 pub ssl: Option<SslMetadata>,
67}
68
69#[derive(Debug, Clone, PartialEq)]
70pub struct SuperuserMetadata {
71 pub user: pg_client::User,
72 pub database: pg_client::Database,
73 pub password: pg_client::config::Password,
74 pub application: Option<pg_client::config::ApplicationName>,
75}
76
77#[derive(Debug, Clone, PartialEq)]
78pub struct SslMetadata {
79 pub hostname: pg_client::config::HostName,
80 pub ca_cert_pem: String,
81}
82
83#[derive(Debug, thiserror::Error)]
86pub enum PrepareConfigError {
87 #[error("failed to materialize CA certificate")]
88 WriteCaCert(#[from] crate::certificate::WriteCaPemError),
89}
90
91impl Metadata {
92 pub fn prepare_config(
99 self,
100 host: pg_client::config::Host,
101 host_addr: Option<pg_client::config::HostAddr>,
102 port: pg_client::config::Port,
103 ) -> Result<pg_client::Config, PrepareConfigError> {
104 let (ssl_mode, ssl_root_cert) = match self.ssl {
105 Some(ssl) => {
106 let path = crate::certificate::write_ca_pem_to_temp(ssl.ca_cert_pem.as_bytes())?;
107 (
108 pg_client::config::SslMode::VerifyFull,
109 Some(pg_client::config::SslRootCert::File(path)),
110 )
111 }
112 None => (pg_client::config::SslMode::Disable, None),
113 };
114
115 Ok(pg_client::Config {
116 endpoint: pg_client::config::Endpoint::Network {
117 host,
118 channel_binding: None,
119 host_addr,
120 port: Some(port),
121 },
122 session: pg_client::config::Session {
123 application_name: self.superuser.application,
124 database: self.superuser.database,
125 password: Some(self.superuser.password),
126 user: self.superuser.user,
127 },
128 ssl_mode,
129 ssl_root_cert,
130 sqlx: Default::default(),
131 })
132 }
133}
134
135#[derive(Debug, thiserror::Error)]
138pub enum ReadError {
139 #[error("required label {0} is missing")]
140 Missing(label::Key),
141 #[error("label {key} value could not be parsed: {message}")]
142 ValueParse { key: label::Key, message: String },
143 #[error("label {key} JSON could not be decoded")]
144 Json {
145 key: label::Key,
146 #[source]
147 source: serde_json::Error,
148 },
149 #[error(
150 "ssl labels are inconsistent: {present} is set but {missing} is not — both must be \
151 present together"
152 )]
153 SslLabelsInconsistent {
154 present: label::Key,
155 missing: label::Key,
156 },
157}
158
159pub fn read_image(labels: &ociman::label::ImageLabels) -> Result<Metadata, ReadError> {
161 read(labels)
162}
163
164pub fn read_container(labels: &ociman::label::ContainerLabels) -> Result<Metadata, ReadError> {
166 read(labels)
167}
168
169fn read<S: ociman::label::Scope>(
170 labels: &ociman::label::ReadLabels<S>,
171) -> Result<Metadata, ReadError> {
172 let version = parse_required(labels, &VERSION_KEY)?;
173 let instance = parse_required(labels, &INSTANCE_KEY)?;
174 let image = parse_required_string_err(labels, &IMAGE_KEY)?;
175
176 let superuser = SuperuserMetadata {
177 user: parse_required(labels, &SUPERUSER_USER_KEY)?,
178 database: parse_required(labels, &SUPERUSER_DATABASE_KEY)?,
179 password: parse_required(labels, &SUPERUSER_PASSWORD_KEY)?,
180 application: parse_optional(labels, &SUPERUSER_APPLICATION_KEY)?,
181 };
182
183 let seeds_json = required(labels, &SEEDS_KEY)?;
184 let seeds: Vec<SeedEntry> =
185 serde_json::from_str(seeds_json).map_err(|source| ReadError::Json {
186 key: SEEDS_KEY.clone(),
187 source,
188 })?;
189
190 let ssl_hostname: Option<pg_client::config::HostName> =
191 parse_optional(labels, &SSL_HOSTNAME_KEY)?;
192 let ssl_ca_cert_pem = optional(labels, &SSL_CA_CERT_PEM_KEY).map(str::to_owned);
193
194 let ssl = match (ssl_hostname, ssl_ca_cert_pem) {
195 (Some(hostname), Some(ca_cert_pem)) => Some(SslMetadata {
196 hostname,
197 ca_cert_pem,
198 }),
199 (None, None) => None,
200 (Some(_), None) => {
201 return Err(ReadError::SslLabelsInconsistent {
202 present: SSL_HOSTNAME_KEY.clone(),
203 missing: SSL_CA_CERT_PEM_KEY.clone(),
204 });
205 }
206 (None, Some(_)) => {
207 return Err(ReadError::SslLabelsInconsistent {
208 present: SSL_CA_CERT_PEM_KEY.clone(),
209 missing: SSL_HOSTNAME_KEY.clone(),
210 });
211 }
212 };
213
214 Ok(Metadata {
215 version,
216 instance,
217 image,
218 superuser,
219 seeds,
220 ssl,
221 })
222}
223
224fn optional<'a, S: ociman::label::Scope>(
225 labels: &'a ociman::label::ReadLabels<S>,
226 key: &label::Key,
227) -> Option<&'a str> {
228 labels.get(key).map(ociman::label::ReadValue::as_str)
229}
230
231fn required<'a, S: ociman::label::Scope>(
232 labels: &'a ociman::label::ReadLabels<S>,
233 key: &label::Key,
234) -> Result<&'a str, ReadError> {
235 optional(labels, key).ok_or_else(|| ReadError::Missing(key.clone()))
236}
237
238fn parse_required<T, S: ociman::label::Scope>(
239 labels: &ociman::label::ReadLabels<S>,
240 key: &label::Key,
241) -> Result<T, ReadError>
242where
243 T: std::str::FromStr,
244 T::Err: std::fmt::Display,
245{
246 let raw = required(labels, key)?;
247 raw.parse().map_err(|error: T::Err| ReadError::ValueParse {
248 key: key.clone(),
249 message: error.to_string(),
250 })
251}
252
253fn parse_optional<T, S: ociman::label::Scope>(
254 labels: &ociman::label::ReadLabels<S>,
255 key: &label::Key,
256) -> Result<Option<T>, ReadError>
257where
258 T: std::str::FromStr,
259 T::Err: std::fmt::Display,
260{
261 match optional(labels, key) {
262 Some(raw) => raw
263 .parse()
264 .map(Some)
265 .map_err(|error: T::Err| ReadError::ValueParse {
266 key: key.clone(),
267 message: error.to_string(),
268 }),
269 None => Ok(None),
270 }
271}
272
273fn parse_required_string_err<T, S: ociman::label::Scope>(
276 labels: &ociman::label::ReadLabels<S>,
277 key: &label::Key,
278) -> Result<T, ReadError>
279where
280 T: std::str::FromStr<Err = String>,
281{
282 let raw = required(labels, key)?;
283 raw.parse()
284 .map_err(|message: String| ReadError::ValueParse {
285 key: key.clone(),
286 message,
287 })
288}
289
290pub(crate) fn apply(
292 ociman_definition: ociman::Definition,
293 definition: &crate::Definition,
294 password: &pg_client::config::Password,
295 ssl_bundle: Option<&crate::certificate::Bundle>,
296 seeds: &[SeedEntry],
297) -> Result<ociman::Definition, ApplyError> {
298 let image_reference = ociman::image::Reference::from(&definition.image).to_string();
299 let seeds_json = serde_json::to_string(seeds).map_err(ApplyError::SeedsJson)?;
300
301 let mut pairs: Vec<(label::Key, label::Value)> = vec![
302 (
303 VERSION_KEY.clone(),
304 to_value(&VERSION_KEY, crate::VERSION_STR)?,
305 ),
306 (
307 INSTANCE_KEY.clone(),
308 to_value(&INSTANCE_KEY, definition.instance_name.as_str())?,
309 ),
310 (IMAGE_KEY.clone(), to_value(&IMAGE_KEY, &image_reference)?),
311 (
312 SUPERUSER_USER_KEY.clone(),
313 to_value(&SUPERUSER_USER_KEY, definition.superuser.as_ref())?,
314 ),
315 (
316 SUPERUSER_DATABASE_KEY.clone(),
317 to_value(&SUPERUSER_DATABASE_KEY, definition.database.as_ref())?,
318 ),
319 (
320 SUPERUSER_PASSWORD_KEY.clone(),
321 to_value(&SUPERUSER_PASSWORD_KEY, password.as_ref())?,
322 ),
323 (SEEDS_KEY.clone(), to_value(&SEEDS_KEY, &seeds_json)?),
324 ];
325
326 if let Some(application_name) = &definition.application_name {
327 pairs.push((
328 SUPERUSER_APPLICATION_KEY.clone(),
329 to_value(&SUPERUSER_APPLICATION_KEY, application_name.as_ref())?,
330 ));
331 }
332
333 if let Some(crate::definition::SslConfig::Generated { hostname }) = &definition.ssl_config {
334 pairs.push((
335 SSL_HOSTNAME_KEY.clone(),
336 to_value(&SSL_HOSTNAME_KEY, hostname.as_str())?,
337 ));
338 }
339
340 if let Some(bundle) = ssl_bundle {
341 pairs.push((
342 SSL_CA_CERT_PEM_KEY.clone(),
343 to_value(&SSL_CA_CERT_PEM_KEY, &bundle.ca_cert_pem)?,
344 ));
345 }
346
347 if let Some(session_name) = &definition.session_name {
348 pairs.push((
349 SESSION_KEY.clone(),
350 to_value(&SESSION_KEY, session_name.as_str())?,
351 ));
352 }
353
354 Ok(ociman_definition.labels(pairs.iter().map(|(key, value)| (key, value))))
355}
356
357fn to_value(key: &label::Key, raw: &str) -> Result<label::Value, ApplyError> {
358 label::Value::try_from(raw.to_string()).map_err(|source| ApplyError::OversizedValue {
359 key: key.clone(),
360 source,
361 })
362}
363
364pub(crate) fn build_seed_entries(
367 definition: &crate::Definition,
368 loaded_seeds: &crate::seed::LoadedSeeds<'_>,
369) -> Vec<SeedEntry> {
370 let mut entries = Vec::with_capacity(definition.seeds.len());
371 for loaded_seed in loaded_seeds.iter_seeds() {
372 let name = loaded_seed.name().clone();
373 let seed = match definition.seeds.get(loaded_seed.name()) {
374 Some(seed) => seed,
375 None => unreachable!(
376 "loaded seed {name} must exist in definition.seeds; \
377 load_seeds populates from this map",
378 ),
379 };
380 entries.push(SeedEntry {
381 name,
382 config: seed.into(),
383 hash: loaded_seed.cache_status().hash().cloned(),
384 });
385 }
386 entries
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::seed::SeedCacheConfig;
393
394 #[test]
395 fn seed_entry_json_round_trip_compliant_hash() {
396 let entry = SeedEntry {
397 name: "schema".parse().unwrap(),
398 config: SeedConfig::SqlFile {
399 path: "schema.sql".into(),
400 git_revision: None,
401 },
402 hash: Some(
403 "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
404 .parse()
405 .unwrap(),
406 ),
407 };
408
409 let json = serde_json::to_string(&entry).unwrap();
410 assert_eq!(
411 json,
412 r#"{"name":"schema","type":"sql-file","path":"schema.sql","git_revision":null,"hash":"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}"#
413 );
414
415 let parsed: SeedEntry = serde_json::from_str(&json).unwrap();
416 assert_eq!(parsed, entry);
417 }
418
419 #[test]
420 fn seed_entry_json_round_trip_uncacheable() {
421 let entry = SeedEntry {
422 name: "dynamic".parse().unwrap(),
423 config: SeedConfig::Command {
424 command: "psql".to_string(),
425 arguments: vec!["-c".to_string(), "SELECT 1".to_string()],
426 cache: SeedCacheConfig::None,
427 },
428 hash: None,
429 };
430
431 let json = serde_json::to_string(&entry).unwrap();
432 let parsed: SeedEntry = serde_json::from_str(&json).unwrap();
433 assert_eq!(parsed, entry);
434
435 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
437 assert_eq!(value["hash"], serde_json::Value::Null);
438 }
439
440 #[test]
441 fn seed_list_json_round_trip() {
442 let entries = vec![
443 SeedEntry {
444 name: "a".parse().unwrap(),
445 config: SeedConfig::SqlStatement {
446 statement: "CREATE TABLE t (id INT)".to_string(),
447 },
448 hash: Some(
449 "1111111111111111111111111111111111111111111111111111111111111111"
450 .parse()
451 .unwrap(),
452 ),
453 },
454 SeedEntry {
455 name: "b".parse().unwrap(),
456 config: SeedConfig::ContainerScript {
457 script: "apt-get install -y foo".to_string(),
458 },
459 hash: Some(
460 "2222222222222222222222222222222222222222222222222222222222222222"
461 .parse()
462 .unwrap(),
463 ),
464 },
465 ];
466
467 let json = serde_json::to_string(&entries).unwrap();
468 let parsed: Vec<SeedEntry> = serde_json::from_str(&json).unwrap();
469 assert_eq!(parsed, entries);
470 }
471}