1use crate::backend::{BackendClient, BackendConfig};
29use crate::switchover_buffer::SwitchoverBuffer;
30use crate::{ProxyError, Result};
31use chrono::{DateTime, Utc};
32use parking_lot::RwLock;
33use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::sync::Arc;
36use uuid::Uuid;
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
40#[serde(rename_all = "snake_case")]
41pub enum UpgradeState {
42 Pending,
44 StandbyCatchingUp,
47 ShadowExecuting,
49 Validated,
51 Cutover,
53 Draining,
55 Complete,
57 Failed,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct UpgradeJob {
64 pub id: Uuid,
65 pub from_version: u32,
66 pub to_version: u32,
67 pub from_address: String,
68 pub to_address: String,
69 pub state: UpgradeState,
70 pub started_at: DateTime<Utc>,
71 pub updated_at: DateTime<Utc>,
72 pub completed_at: Option<DateTime<Utc>>,
73 pub error: Option<String>,
74 pub shadow_statements: u64,
76 pub validated_rows: u64,
78}
79
80impl UpgradeJob {
81 fn new(req: &PlanRequest) -> Self {
82 let now = Utc::now();
83 Self {
84 id: Uuid::new_v4(),
85 from_version: req.from_version,
86 to_version: req.to_version,
87 from_address: req.from_address.clone(),
88 to_address: req.to_address.clone(),
89 state: UpgradeState::Pending,
90 started_at: now,
91 updated_at: now,
92 completed_at: None,
93 error: None,
94 shadow_statements: 0,
95 validated_rows: 0,
96 }
97 }
98
99 fn advance(&mut self, next: UpgradeState) {
100 self.state = next;
101 self.updated_at = Utc::now();
102 if matches!(next, UpgradeState::Complete | UpgradeState::Failed) {
103 self.completed_at = Some(self.updated_at);
104 }
105 }
106
107 fn fail(&mut self, reason: impl Into<String>) {
108 self.error = Some(reason.into());
109 self.advance(UpgradeState::Failed);
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct PlanRequest {
116 pub from_version: u32,
118 pub to_version: u32,
120 #[serde(default)]
123 pub from_address: String,
124 pub to_address: String,
127}
128
129impl PlanRequest {
130 pub fn validate(&self) -> Result<()> {
131 if self.to_version <= self.from_version {
132 return Err(ProxyError::Configuration(format!(
133 "to_version ({}) must be greater than from_version ({})",
134 self.to_version, self.from_version
135 )));
136 }
137 if self.to_address.is_empty() {
138 return Err(ProxyError::Configuration(
139 "to_address must be provided".into(),
140 ));
141 }
142 Ok(())
143 }
144}
145
146pub struct UpgradeOrchestrator {
154 jobs: Arc<RwLock<HashMap<Uuid, UpgradeJob>>>,
155 switchover: Arc<SwitchoverBuffer>,
159 backend_template: BackendConfig,
162}
163
164impl UpgradeOrchestrator {
165 pub fn new(
166 switchover: Arc<SwitchoverBuffer>,
167 backend_template: BackendConfig,
168 ) -> Self {
169 Self {
170 jobs: Arc::new(RwLock::new(HashMap::new())),
171 switchover,
172 backend_template,
173 }
174 }
175
176 pub fn start(&self, req: PlanRequest) -> Result<Uuid> {
180 req.validate()?;
181 let job = UpgradeJob::new(&req);
182 let id = job.id;
183 self.jobs.write().insert(id, job);
184 tracing::info!(
185 job = %id,
186 from = req.from_version,
187 to = req.to_version,
188 "upgrade job created"
189 );
190 Ok(id)
191 }
192
193 pub fn get(&self, id: Uuid) -> Option<UpgradeJob> {
195 self.jobs.read().get(&id).cloned()
196 }
197
198 pub fn list(&self) -> Vec<UpgradeJob> {
200 self.jobs.read().values().cloned().collect()
201 }
202
203 pub async fn tick(&self, id: Uuid) -> Result<UpgradeJob> {
229 let mut snap = self
230 .get(id)
231 .ok_or_else(|| ProxyError::Internal(format!("upgrade job {} not found", id)))?;
232
233 let outcome = match snap.state {
234 UpgradeState::Pending => self.stage_create_replication(&snap).await,
235 UpgradeState::StandbyCatchingUp => self.stage_wait_catchup(&snap).await,
236 UpgradeState::ShadowExecuting => self.stage_settle_shadow(&snap).await,
237 UpgradeState::Validated => self.stage_cutover(&snap).await,
238 UpgradeState::Cutover => self.stage_drain(&snap).await,
239 UpgradeState::Draining => self.stage_retire(&snap).await,
240 UpgradeState::Complete | UpgradeState::Failed => Ok(snap.state),
241 };
242
243 match outcome {
244 Ok(next) => snap.advance(next),
245 Err(e) => snap.fail(e.to_string()),
246 }
247
248 self.jobs.write().insert(id, snap.clone());
249 Ok(snap)
250 }
251
252 async fn stage_create_replication(&self, job: &UpgradeJob) -> Result<UpgradeState> {
260 let pub_name = publication_name(job.id);
261
262 let source_cfg = self.backend_for(&job.from_address)?;
264 let mut source = BackendClient::connect(&source_cfg).await.map_err(|e| {
265 ProxyError::FailoverFailed(format!("connect source: {}", e))
266 })?;
267 let _ = source
269 .execute(&format!(
270 "DROP PUBLICATION IF EXISTS {}",
271 quote_ident(&pub_name)
272 ))
273 .await;
274 source
275 .execute(&format!(
276 "CREATE PUBLICATION {} FOR ALL TABLES",
277 quote_ident(&pub_name)
278 ))
279 .await
280 .map_err(|e| {
281 ProxyError::FailoverFailed(format!("CREATE PUBLICATION: {}", e))
282 })?;
283 source.close().await;
284
285 let target_cfg = self.backend_for(&job.to_address)?;
292 let conninfo = source_conninfo(&source_cfg);
293 let mut target = BackendClient::connect(&target_cfg).await.map_err(|e| {
294 ProxyError::FailoverFailed(format!("connect target: {}", e))
295 })?;
296 let _ = target
297 .execute(&format!(
298 "DROP SUBSCRIPTION IF EXISTS {}",
299 quote_ident(&pub_name)
300 ))
301 .await;
302 target
303 .execute(&format!(
304 "CREATE SUBSCRIPTION {} CONNECTION '{}' PUBLICATION {}",
305 quote_ident(&pub_name),
306 conninfo.replace('\'', "''"),
307 quote_ident(&pub_name)
308 ))
309 .await
310 .map_err(|e| {
311 ProxyError::FailoverFailed(format!("CREATE SUBSCRIPTION: {}", e))
312 })?;
313 target.close().await;
314
315 tracing::info!(job = %job.id, pub_name = %pub_name, "stage 1: replication created");
316 Ok(UpgradeState::StandbyCatchingUp)
317 }
318
319 async fn stage_wait_catchup(&self, job: &UpgradeJob) -> Result<UpgradeState> {
326 let target_cfg = self.backend_for(&job.to_address)?;
327 let mut target = BackendClient::connect(&target_cfg).await.map_err(|e| {
328 ProxyError::FailoverFailed(format!("connect target: {}", e))
329 })?;
330 let pub_name = publication_name(job.id);
331 let row = target
332 .query_scalar(&format!(
333 "SELECT subenabled FROM pg_subscription WHERE subname = '{}'",
334 pub_name.replace('\'', "''")
335 ))
336 .await
337 .map_err(|e| {
338 ProxyError::FailoverFailed(format!("subscription probe: {}", e))
339 })?;
340 target.close().await;
341
342 let enabled = row
343 .as_bool("subenabled")
344 .map_err(|e| {
345 ProxyError::FailoverFailed(format!("subenabled value: {}", e))
346 })?
347 .unwrap_or(false);
348 if !enabled {
349 return Err(ProxyError::FailoverFailed(format!(
350 "subscription {} not enabled on target",
351 pub_name
352 )));
353 }
354 tracing::info!(job = %job.id, "stage 2: subscription active");
355 Ok(UpgradeState::ShadowExecuting)
356 }
357
358 async fn stage_settle_shadow(&self, job: &UpgradeJob) -> Result<UpgradeState> {
366 tracing::info!(job = %job.id, "stage 3: shadow window settle");
367 tokio::time::sleep(std::time::Duration::from_millis(250)).await;
368 Ok(UpgradeState::Validated)
369 }
370
371 async fn stage_cutover(&self, job: &UpgradeJob) -> Result<UpgradeState> {
374 self.switchover.start_buffering();
375 tracing::info!(job = %job.id, "stage 4: switchover_buffer engaged; promoting target");
376
377 let target_cfg = self.backend_for(&job.to_address)?;
378 let mut target = BackendClient::connect(&target_cfg).await.map_err(|e| {
379 self.switchover.stop_buffering();
382 ProxyError::FailoverFailed(format!("connect target for promote: {}", e))
383 })?;
384
385 let result = target
387 .query_scalar("SELECT pg_promote(true, 60)")
388 .await
389 .map_err(|e| ProxyError::FailoverFailed(format!("pg_promote: {}", e)));
390 target.close().await;
391
392 let row = match result {
393 Ok(r) => r,
394 Err(e) => {
395 self.switchover.stop_buffering();
396 return Err(e);
397 }
398 };
399 let promoted = row
400 .as_bool("pg_promote")
401 .map_err(|e| {
402 self.switchover.stop_buffering();
403 ProxyError::FailoverFailed(format!("pg_promote result: {}", e))
404 })?
405 .unwrap_or(false);
406 if !promoted {
407 self.switchover.stop_buffering();
408 return Err(ProxyError::FailoverFailed(
409 "pg_promote returned false".into(),
410 ));
411 }
412
413 tracing::info!(job = %job.id, "stage 4: target promoted");
414 Ok(UpgradeState::Cutover)
415 }
416
417 async fn stage_drain(&self, job: &UpgradeJob) -> Result<UpgradeState> {
421 tracing::info!(job = %job.id, "stage 5: draining buffered writes");
422 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
423 Ok(UpgradeState::Draining)
424 }
425
426 async fn stage_retire(&self, job: &UpgradeJob) -> Result<UpgradeState> {
432 self.switchover.stop_buffering();
433 tracing::info!(job = %job.id, "stage 6: switchover_buffer released");
434
435 let pub_name = publication_name(job.id);
436
437 if let Ok(target_cfg) = self.backend_for(&job.to_address) {
439 if let Ok(mut target) = BackendClient::connect(&target_cfg).await {
440 let _ = target
441 .execute(&format!(
442 "DROP SUBSCRIPTION IF EXISTS {}",
443 quote_ident(&pub_name)
444 ))
445 .await;
446 target.close().await;
447 }
448 }
449
450 if let Ok(source_cfg) = self.backend_for(&job.from_address) {
452 if let Ok(mut source) = BackendClient::connect(&source_cfg).await {
453 let _ = source
454 .execute(&format!(
455 "DROP PUBLICATION IF EXISTS {}",
456 quote_ident(&pub_name)
457 ))
458 .await;
459 source.close().await;
460 }
461 }
462
463 Ok(UpgradeState::Complete)
464 }
465
466 fn backend_for(&self, addr: &str) -> Result<BackendConfig> {
469 let (host, port) = parse_addr(addr)?;
470 let mut c = self.backend_template.clone();
471 c.host = host;
472 c.port = port;
473 Ok(c)
474 }
475
476 pub fn cancel(&self, id: Uuid, reason: &str) -> Result<UpgradeJob> {
480 let mut jobs = self.jobs.write();
481 let job = jobs
482 .get_mut(&id)
483 .ok_or_else(|| ProxyError::Internal(format!("upgrade job {} not found", id)))?;
484 if matches!(
485 job.state,
486 UpgradeState::Complete | UpgradeState::Failed
487 ) {
488 return Err(ProxyError::Internal(format!(
489 "job {} already terminal: {:?}",
490 id, job.state
491 )));
492 }
493 self.switchover.stop_buffering();
495 job.fail(format!("cancelled: {}", reason));
496 Ok(job.clone())
497 }
498}
499
500fn quote_ident(name: &str) -> String {
505 let mut out = String::with_capacity(name.len() + 2);
506 out.push('"');
507 for c in name.chars() {
508 if c == '"' {
509 out.push_str("\"\"");
510 } else {
511 out.push(c);
512 }
513 }
514 out.push('"');
515 out
516}
517
518fn publication_name(id: Uuid) -> String {
522 format!("helios_upgrade_{}", id.simple())
523}
524
525fn source_conninfo(cfg: &BackendConfig) -> String {
529 let mut parts = vec![
530 format!("host={}", cfg.host),
531 format!("port={}", cfg.port),
532 format!("user={}", cfg.user),
533 ];
534 if let Some(pw) = &cfg.password {
535 parts.push(format!("password={}", pw));
536 }
537 if let Some(db) = &cfg.database {
538 parts.push(format!("dbname={}", db));
539 }
540 parts.join(" ")
541}
542
543fn parse_addr(addr: &str) -> Result<(String, u16)> {
547 let (host, port) = addr.rsplit_once(':').ok_or_else(|| {
548 ProxyError::Configuration(format!("expected host:port, got {:?}", addr))
549 })?;
550 let port: u16 = port.parse().map_err(|_| {
551 ProxyError::Configuration(format!("invalid port in {:?}", addr))
552 })?;
553 Ok((host.to_string(), port))
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use crate::backend::tls::default_client_config;
560 use crate::backend::TlsMode;
561 use crate::switchover_buffer::BufferConfig;
562 use std::time::Duration;
563
564 fn template() -> BackendConfig {
565 BackendConfig {
566 host: "placeholder".into(),
567 port: 0,
568 user: "postgres".into(),
569 password: None,
570 database: None,
571 application_name: Some("helios-upgrade".into()),
572 tls_mode: TlsMode::Disable,
573 connect_timeout: Duration::from_millis(200),
574 query_timeout: Duration::from_millis(200),
575 tls_config: default_client_config(),
576 }
577 }
578
579 fn switchover() -> Arc<SwitchoverBuffer> {
580 Arc::new(SwitchoverBuffer::new(BufferConfig::default()))
581 }
582
583 #[test]
584 fn validate_rejects_downgrade() {
585 let req = PlanRequest {
586 from_version: 17,
587 to_version: 14,
588 from_address: "pg-17:5432".into(),
589 to_address: "pg-14:5432".into(),
590 };
591 assert!(matches!(req.validate(), Err(ProxyError::Configuration(_))));
592 }
593
594 #[test]
595 fn validate_rejects_same_version() {
596 let req = PlanRequest {
597 from_version: 16,
598 to_version: 16,
599 from_address: "a".into(),
600 to_address: "b".into(),
601 };
602 assert!(req.validate().is_err());
603 }
604
605 #[test]
606 fn validate_rejects_empty_target_address() {
607 let req = PlanRequest {
608 from_version: 14,
609 to_version: 17,
610 from_address: "a".into(),
611 to_address: "".into(),
612 };
613 assert!(req.validate().is_err());
614 }
615
616 #[test]
617 fn validate_accepts_proper_upgrade() {
618 let req = PlanRequest {
619 from_version: 14,
620 to_version: 17,
621 from_address: "pg-14:5432".into(),
622 to_address: "pg-17:5432".into(),
623 };
624 assert!(req.validate().is_ok());
625 }
626
627 #[tokio::test]
632 async fn tick_fails_job_on_unreachable_source() {
633 let orch = UpgradeOrchestrator::new(switchover(), template());
634 let id = orch
635 .start(PlanRequest {
636 from_version: 14,
637 to_version: 17,
638 from_address: "127.0.0.1:1".into(),
640 to_address: "127.0.0.1:2".into(),
641 })
642 .unwrap();
643
644 let result = orch.tick(id).await.unwrap();
645 assert_eq!(result.state, UpgradeState::Failed);
646 let err = result.error.expect("failure carries an error message");
647 assert!(
650 err.contains("connect") || err.contains("FailoverFailed") || err.contains("PUBLICATION"),
651 "expected connect/SQL error, got {}",
652 err
653 );
654 }
655
656 #[tokio::test]
658 async fn tick_on_terminal_job_is_noop() {
659 let orch = UpgradeOrchestrator::new(switchover(), template());
660 let id = orch
661 .start(PlanRequest {
662 from_version: 14,
663 to_version: 17,
664 from_address: "127.0.0.1:1".into(),
665 to_address: "127.0.0.1:2".into(),
666 })
667 .unwrap();
668
669 let r1 = orch.tick(id).await.unwrap();
671 assert_eq!(r1.state, UpgradeState::Failed);
672
673 let r2 = orch.tick(id).await.unwrap();
675 assert_eq!(r2.state, UpgradeState::Failed);
676 }
677
678 #[tokio::test]
679 async fn cancel_marks_failed_with_reason() {
680 let orch = UpgradeOrchestrator::new(switchover(), template());
681 let id = orch
682 .start(PlanRequest {
683 from_version: 14,
684 to_version: 17,
685 from_address: "a:1".into(),
686 to_address: "b:2".into(),
687 })
688 .unwrap();
689
690 let cancelled = orch.cancel(id, "operator request").unwrap();
693 assert_eq!(cancelled.state, UpgradeState::Failed);
694 assert!(cancelled.error.unwrap().contains("operator request"));
695 }
696
697 #[test]
698 fn cancel_errors_on_terminal_job() {
699 let orch = UpgradeOrchestrator::new(switchover(), template());
700 let id = orch
701 .start(PlanRequest {
702 from_version: 14,
703 to_version: 17,
704 from_address: "a:1".into(),
705 to_address: "b:2".into(),
706 })
707 .unwrap();
708 orch.cancel(id, "first cancel").unwrap();
710 assert!(orch.cancel(id, "second cancel").is_err());
711 }
712
713 #[test]
714 fn list_returns_every_known_job() {
715 let orch = UpgradeOrchestrator::new(switchover(), template());
716 for to in [15, 16, 17] {
717 orch.start(PlanRequest {
718 from_version: 14,
719 to_version: to,
720 from_address: "a:1".into(),
721 to_address: "b:2".into(),
722 })
723 .unwrap();
724 }
725 assert_eq!(orch.list().len(), 3);
726 }
727
728 #[test]
729 fn parse_addr_round_trip() {
730 let (h, p) = parse_addr("pg-primary.svc:5432").unwrap();
731 assert_eq!(h, "pg-primary.svc");
732 assert_eq!(p, 5432);
733 }
734
735 #[test]
736 fn parse_addr_supports_ipv6_style_host() {
737 let (h, p) = parse_addr("[::1]:5432").unwrap();
740 assert_eq!(h, "[::1]");
741 assert_eq!(p, 5432);
742 }
743
744 #[test]
745 fn parse_addr_rejects_missing_port() {
746 assert!(parse_addr("pg-primary.svc").is_err());
747 assert!(parse_addr("pg-primary.svc:").is_err());
748 assert!(parse_addr("pg-primary.svc:not-a-port").is_err());
749 }
750
751 #[test]
752 fn quote_ident_doubles_embedded_quotes() {
753 assert_eq!(quote_ident("simple"), "\"simple\"");
754 assert_eq!(quote_ident("a\"b"), "\"a\"\"b\"");
755 }
756
757 #[test]
758 fn publication_name_uses_simple_uuid_form() {
759 let id = Uuid::nil();
760 let name = publication_name(id);
761 assert_eq!(name, "helios_upgrade_00000000000000000000000000000000");
762 assert!(name.len() < 63);
764 }
765
766 #[test]
767 fn source_conninfo_includes_credentials() {
768 let cfg = template();
769 let s = source_conninfo(&cfg);
770 assert!(s.contains("host=placeholder"));
771 assert!(s.contains("port=0"));
772 assert!(s.contains("user=postgres"));
773 assert!(!s.contains("password="));
775 assert!(!s.contains("dbname="));
776 }
777}