database_replicator/replication/
subscription.rs

1// ABOUTME: Subscription management for logical replication on target database
2// ABOUTME: Creates and manages PostgreSQL subscriptions to receive replicated data
3
4use anyhow::{Context, Result};
5use std::time::Duration;
6use tokio_postgres::Client;
7
8/// Create a subscription to a publication on the source database
9pub async fn create_subscription(
10    client: &Client,
11    subscription_name: &str,
12    source_connection_string: &str,
13    publication_name: &str,
14) -> Result<()> {
15    // Validate subscription name to prevent SQL injection
16    crate::utils::validate_postgres_identifier(subscription_name).with_context(|| {
17        format!(
18            "Invalid subscription name '{}': must be a valid PostgreSQL identifier",
19            subscription_name
20        )
21    })?;
22
23    // Validate publication name to prevent SQL injection
24    crate::utils::validate_postgres_identifier(publication_name).with_context(|| {
25        format!(
26            "Invalid publication name '{}': must be a valid PostgreSQL identifier",
27            publication_name
28        )
29    })?;
30
31    tracing::info!("Creating subscription '{}'...", subscription_name);
32
33    // SECURITY NOTE: PostgreSQL subscriptions store connection strings (including passwords)
34    // in the pg_subscription system catalog, visible to users with access to that table.
35    //
36    // To avoid storing passwords in the catalog:
37    // 1. Configure .pgpass file on the TARGET PostgreSQL server
38    // 2. Use password-less connection string (omit password from URL)
39    // 3. The subscription will read credentials from .pgpass
40    //
41    // See: https://www.postgresql.org/docs/current/libpq-pgpass.html
42    //
43    // For now, we use the provided connection string as-is for compatibility.
44    // Users concerned about password exposure should configure .pgpass on the target server.
45
46    tracing::warn!(
47        "⚠ Security Note: Subscription connection strings are stored in pg_subscription catalog"
48    );
49    tracing::warn!(
50        "  To avoid storing passwords, configure .pgpass on the target PostgreSQL server"
51    );
52
53    let query = format!(
54        "CREATE SUBSCRIPTION \"{}\" CONNECTION '{}' PUBLICATION \"{}\"",
55        subscription_name, source_connection_string, publication_name
56    );
57
58    match client.execute(&query, &[]).await {
59        Ok(_) => {
60            tracing::info!(
61                "✓ Subscription '{}' created successfully",
62                subscription_name
63            );
64            Ok(())
65        }
66        Err(e) => {
67            let err_str = e.to_string();
68            // Subscription might already exist - that's okay
69            if err_str.contains("already exists") {
70                tracing::info!("✓ Subscription '{}' already exists", subscription_name);
71                Ok(())
72            } else if err_str.contains("permission denied") || err_str.contains("must be superuser")
73            {
74                anyhow::bail!(
75                    "Permission denied: Cannot create subscription '{}'.\n\
76                     Only superusers can create subscriptions in PostgreSQL.\n\
77                     Contact your database administrator to:\n\
78                     1. Grant superuser: ALTER ROLE <user> WITH SUPERUSER;\n\
79                     2. Or create the subscription on your behalf\n\
80                     Error: {}",
81                    subscription_name,
82                    err_str
83                )
84            } else if err_str.contains("publication") && err_str.contains("does not exist") {
85                anyhow::bail!(
86                    "Publication does not exist: Cannot create subscription '{}'.\n\
87                     The publication '{}' was not found on the source database.\n\
88                     Make sure the publication exists before creating the subscription.\n\
89                     Error: {}",
90                    subscription_name,
91                    publication_name,
92                    err_str
93                )
94            } else if err_str.contains("could not connect to the publisher")
95                || err_str.contains("connection")
96            {
97                anyhow::bail!(
98                    "Connection failed: Cannot connect to source database for subscription '{}'.\n\
99                     Please verify:\n\
100                     - The source database is accessible from the target\n\
101                     - The connection string is correct\n\
102                     - Firewall rules allow connections\n\
103                     - The source user has REPLICATION privilege\n\
104                     Error: {}",
105                    subscription_name,
106                    err_str
107                )
108            } else if err_str.contains("replication slot") {
109                anyhow::bail!(
110                    "Replication slot error: Cannot create subscription '{}'.\n\
111                     The source database may have reached the maximum number of replication slots.\n\
112                     Check 'max_replication_slots' on the source database.\n\
113                     Error: {}",
114                    subscription_name,
115                    err_str
116                )
117            } else {
118                anyhow::bail!(
119                    "Failed to create subscription '{}': {}\n\
120                     \n\
121                     Common causes:\n\
122                     - Insufficient privileges (need SUPERUSER on target)\n\
123                     - Publication does not exist on source\n\
124                     - Cannot connect to source database\n\
125                     - max_replication_slots limit reached on source",
126                    subscription_name,
127                    err_str
128                )
129            }
130        }
131    }
132}
133
134/// List all subscriptions in the database
135pub async fn list_subscriptions(client: &Client) -> Result<Vec<String>> {
136    let rows = client
137        .query("SELECT subname FROM pg_subscription ORDER BY subname", &[])
138        .await
139        .context("Failed to list subscriptions")?;
140
141    let subscriptions: Vec<String> = rows.iter().map(|row| row.get(0)).collect();
142
143    Ok(subscriptions)
144}
145
146/// Drop a subscription
147pub async fn drop_subscription(client: &Client, subscription_name: &str) -> Result<()> {
148    // Validate subscription name to prevent SQL injection
149    crate::utils::validate_postgres_identifier(subscription_name).with_context(|| {
150        format!(
151            "Invalid subscription name '{}': must be a valid PostgreSQL identifier",
152            subscription_name
153        )
154    })?;
155
156    tracing::info!("Dropping subscription '{}'...", subscription_name);
157
158    let query = format!("DROP SUBSCRIPTION IF EXISTS \"{}\"", subscription_name);
159
160    client.execute(&query, &[]).await.context(format!(
161        "Failed to drop subscription '{}'",
162        subscription_name
163    ))?;
164
165    tracing::info!("✓ Subscription '{}' dropped", subscription_name);
166    Ok(())
167}
168
169/// Subscription state enum
170#[derive(Debug, Clone, PartialEq)]
171pub enum SubscriptionState {
172    /// Subscription is streaming changes ('r' state)
173    Streaming,
174    /// Subscription is initializing ('i' state)
175    Initializing,
176    /// Subscription is copying data ('d' state)
177    Copying,
178    /// Subscription is syncing ('s' state)
179    Syncing,
180    /// Subscription has an error or is in unknown state
181    Error(String),
182    /// Subscription does not exist
183    NotFound,
184}
185
186/// Detect the current state of a subscription
187pub async fn detect_subscription_state(
188    client: &Client,
189    subscription_name: &str,
190) -> Result<SubscriptionState> {
191    // Query pg_stat_subscription to get subscription state
192    let rows = client
193        .query(
194            "SELECT srsubstate FROM pg_stat_subscription WHERE subname = $1",
195            &[&subscription_name],
196        )
197        .await
198        .context(format!(
199            "Failed to query subscription status for '{}'",
200            subscription_name
201        ))?;
202
203    if rows.is_empty() {
204        return Ok(SubscriptionState::NotFound);
205    }
206
207    let state: String = rows[0].get(0);
208
209    match state.as_str() {
210        "r" => Ok(SubscriptionState::Streaming),
211        "i" => Ok(SubscriptionState::Initializing),
212        "d" => Ok(SubscriptionState::Copying),
213        "s" => Ok(SubscriptionState::Syncing),
214        other => Ok(SubscriptionState::Error(other.to_string())),
215    }
216}
217
218/// Wait for subscription to complete initial sync and enter streaming state
219/// Returns when subscription reaches 'r' (ready/streaming) state
220pub async fn wait_for_sync(
221    client: &Client,
222    subscription_name: &str,
223    timeout_secs: u64,
224) -> Result<()> {
225    tracing::info!(
226        "Waiting for subscription '{}' to sync...",
227        subscription_name
228    );
229
230    let start = std::time::Instant::now();
231    let timeout = Duration::from_secs(timeout_secs);
232
233    loop {
234        let row = client
235            .query_one(
236                "SELECT srsubstate FROM pg_stat_subscription WHERE subname = $1",
237                &[&subscription_name],
238            )
239            .await
240            .context(format!(
241                "Failed to query subscription status for '{}'",
242                subscription_name
243            ))?;
244
245        let state: String = row.get(0);
246
247        match state.as_str() {
248            "r" => {
249                tracing::info!(
250                    "✓ Subscription '{}' is ready and streaming",
251                    subscription_name
252                );
253                return Ok(());
254            }
255            "i" => {
256                tracing::info!("Subscription '{}' is initializing...", subscription_name);
257            }
258            "d" => {
259                tracing::info!("Subscription '{}' is copying data...", subscription_name);
260            }
261            "s" => {
262                tracing::info!("Subscription '{}' is syncing...", subscription_name);
263            }
264            _ => {
265                tracing::warn!(
266                    "Subscription '{}' in unexpected state: {}",
267                    subscription_name,
268                    state
269                );
270            }
271        }
272
273        if start.elapsed() > timeout {
274            anyhow::bail!(
275                "Timeout waiting for subscription '{}' to sync after {} seconds.\n\
276                 The subscription is in state '{}' and has not reached 'ready' (streaming) state.\n\
277                 \n\
278                 Possible causes:\n\
279                 - Large database taking longer than expected to copy\n\
280                 - Network issues slowing down data transfer\n\
281                 - Source database under heavy load\n\
282                 \n\
283                 Suggestions:\n\
284                 - Increase the timeout value and try again\n\
285                 - Check replication status with 'status' command\n\
286                 - Monitor source database load and network connectivity",
287                subscription_name,
288                timeout_secs,
289                state
290            );
291        }
292
293        tokio::time::sleep(Duration::from_secs(2)).await;
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use crate::postgres::connect;
301
302    #[tokio::test]
303    #[ignore]
304    async fn test_create_and_list_subscriptions() {
305        // This test requires two databases: source and target
306        let source_url = std::env::var("TEST_SOURCE_URL").unwrap();
307        let target_url = std::env::var("TEST_TARGET_URL").unwrap();
308
309        let source_client = connect(&source_url).await.unwrap();
310        let target_client = connect(&target_url).await.unwrap();
311
312        let sub_name = "test_subscription";
313        let pub_name = "test_publication";
314        let db_name = "postgres"; // Assume testing on postgres database
315        let filter = crate::filters::ReplicationFilter::empty();
316
317        // Create publication on source
318        crate::replication::create_publication(&source_client, db_name, pub_name, &filter)
319            .await
320            .unwrap();
321
322        // Clean up subscription if exists
323        let _ = drop_subscription(&target_client, sub_name).await;
324
325        // Create subscription on target
326        let result = create_subscription(&target_client, sub_name, &source_url, pub_name).await;
327        match &result {
328            Ok(_) => println!("✓ Subscription created successfully"),
329            Err(e) => {
330                println!("Error creating subscription: {:?}", e);
331                // If target doesn't support subscriptions, skip rest of test
332                if e.to_string().contains("not supported") || e.to_string().contains("permission") {
333                    println!("Skipping test - target might not support subscriptions");
334                    return;
335                }
336            }
337        }
338        assert!(result.is_ok(), "Failed to create subscription");
339
340        // List subscriptions
341        let subs = list_subscriptions(&target_client).await.unwrap();
342        println!("Subscriptions: {:?}", subs);
343        assert!(subs.contains(&sub_name.to_string()));
344
345        // Clean up
346        drop_subscription(&target_client, sub_name).await.unwrap();
347        crate::replication::drop_publication(&source_client, pub_name)
348            .await
349            .unwrap();
350    }
351
352    #[tokio::test]
353    #[ignore]
354    async fn test_drop_subscription() {
355        let source_url = std::env::var("TEST_SOURCE_URL").unwrap();
356        let target_url = std::env::var("TEST_TARGET_URL").unwrap();
357
358        let source_client = connect(&source_url).await.unwrap();
359        let target_client = connect(&target_url).await.unwrap();
360
361        let sub_name = "test_drop_subscription";
362        let pub_name = "test_drop_publication";
363        let db_name = "postgres";
364        let filter = crate::filters::ReplicationFilter::empty();
365
366        // Create publication on source
367        crate::replication::create_publication(&source_client, db_name, pub_name, &filter)
368            .await
369            .unwrap();
370
371        // Create subscription on target
372        create_subscription(&target_client, sub_name, &source_url, pub_name)
373            .await
374            .unwrap();
375
376        // Drop it
377        let result = drop_subscription(&target_client, sub_name).await;
378        assert!(result.is_ok());
379
380        // Verify it's gone
381        let subs = list_subscriptions(&target_client).await.unwrap();
382        assert!(!subs.contains(&sub_name.to_string()));
383
384        // Clean up publication
385        crate::replication::drop_publication(&source_client, pub_name)
386            .await
387            .unwrap();
388    }
389
390    #[tokio::test]
391    #[ignore]
392    async fn test_wait_for_sync() {
393        let source_url = std::env::var("TEST_SOURCE_URL").unwrap();
394        let target_url = std::env::var("TEST_TARGET_URL").unwrap();
395
396        let source_client = connect(&source_url).await.unwrap();
397        let target_client = connect(&target_url).await.unwrap();
398
399        let sub_name = "test_wait_subscription";
400        let pub_name = "test_wait_publication";
401        let db_name = "postgres";
402        let filter = crate::filters::ReplicationFilter::empty();
403
404        // Create publication on source
405        crate::replication::create_publication(&source_client, db_name, pub_name, &filter)
406            .await
407            .unwrap();
408
409        // Clean up subscription if exists
410        let _ = drop_subscription(&target_client, sub_name).await;
411
412        // Create subscription on target
413        create_subscription(&target_client, sub_name, &source_url, pub_name)
414            .await
415            .unwrap();
416
417        // Wait for sync (30 second timeout)
418        let result = wait_for_sync(&target_client, sub_name, 30).await;
419        assert!(result.is_ok(), "Failed to wait for sync: {:?}", result);
420
421        // Clean up
422        drop_subscription(&target_client, sub_name).await.unwrap();
423        crate::replication::drop_publication(&source_client, pub_name)
424            .await
425            .unwrap();
426    }
427}