database_replicator/replication/
monitor.rs

1// ABOUTME: Replication monitoring utilities
2// ABOUTME: Queries replication status and lag from source and target databases
3
4use anyhow::{Context, Result};
5use tokio_postgres::Client;
6
7/// Replication statistics from the source database (publisher)
8#[derive(Debug, Clone)]
9pub struct SourceReplicationStats {
10    pub application_name: String,
11    pub state: String,
12    pub sent_lsn: String,
13    pub write_lsn: String,
14    pub flush_lsn: String,
15    pub replay_lsn: String,
16    pub write_lag_ms: Option<i64>,
17    pub flush_lag_ms: Option<i64>,
18    pub replay_lag_ms: Option<i64>,
19}
20
21/// Subscription statistics from the target database (subscriber)
22#[derive(Debug, Clone)]
23pub struct SubscriptionStats {
24    pub subscription_name: String,
25    pub pid: Option<i32>,
26    pub received_lsn: Option<String>,
27    pub latest_end_lsn: Option<String>,
28    pub state: String,
29}
30
31/// Get replication statistics from the source database
32/// Queries pg_stat_replication to see what's being replicated to subscribers
33pub async fn get_replication_lag(
34    client: &Client,
35    subscription_name: Option<&str>,
36) -> Result<Vec<SourceReplicationStats>> {
37    let query = if let Some(sub_name) = subscription_name {
38        format!(
39            "SELECT
40                application_name,
41                state,
42                sent_lsn::text,
43                write_lsn::text,
44                flush_lsn::text,
45                replay_lsn::text,
46                EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
47                EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
48                EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
49            FROM pg_stat_replication
50            WHERE application_name = '{}'",
51            sub_name
52        )
53    } else {
54        "SELECT
55            application_name,
56            state,
57            sent_lsn::text,
58            write_lsn::text,
59            flush_lsn::text,
60            replay_lsn::text,
61            EXTRACT(EPOCH FROM write_lag) * 1000 as write_lag_ms,
62            EXTRACT(EPOCH FROM flush_lag) * 1000 as flush_lag_ms,
63            EXTRACT(EPOCH FROM replay_lag) * 1000 as replay_lag_ms
64        FROM pg_stat_replication"
65            .to_string()
66    };
67
68    let rows = client
69        .query(&query, &[])
70        .await
71        .context("Failed to query replication statistics")?;
72
73    let mut stats = Vec::new();
74    for row in rows {
75        stats.push(SourceReplicationStats {
76            application_name: row.get(0),
77            state: row.get(1),
78            sent_lsn: row.get(2),
79            write_lsn: row.get(3),
80            flush_lsn: row.get(4),
81            replay_lsn: row.get(5),
82            write_lag_ms: row.get(6),
83            flush_lag_ms: row.get(7),
84            replay_lag_ms: row.get(8),
85        });
86    }
87
88    Ok(stats)
89}
90
91/// Get subscription status from the target database
92/// Queries pg_stat_subscription to see subscription state and progress
93pub async fn get_subscription_status(
94    client: &Client,
95    subscription_name: Option<&str>,
96) -> Result<Vec<SubscriptionStats>> {
97    let query = if let Some(sub_name) = subscription_name {
98        format!(
99            "SELECT
100                subname,
101                pid,
102                received_lsn::text,
103                latest_end_lsn::text,
104                srsubstate
105            FROM pg_stat_subscription
106            WHERE subname = '{}'",
107            sub_name
108        )
109    } else {
110        "SELECT
111            subname,
112            pid,
113            received_lsn::text,
114            latest_end_lsn::text,
115            srsubstate
116        FROM pg_stat_subscription"
117            .to_string()
118    };
119
120    let rows = client
121        .query(&query, &[])
122        .await
123        .context("Failed to query subscription statistics")?;
124
125    let mut stats = Vec::new();
126    for row in rows {
127        stats.push(SubscriptionStats {
128            subscription_name: row.get(0),
129            pid: row.get(1),
130            received_lsn: row.get(2),
131            latest_end_lsn: row.get(3),
132            state: row.get(4),
133        });
134    }
135
136    Ok(stats)
137}
138
139/// Check if replication is caught up (no lag)
140/// Returns true if all replication slots have < 1 second of replay lag
141pub async fn is_replication_caught_up(
142    client: &Client,
143    subscription_name: Option<&str>,
144) -> Result<bool> {
145    let stats = get_replication_lag(client, subscription_name).await?;
146
147    if stats.is_empty() {
148        // No active replication
149        return Ok(false);
150    }
151
152    for stat in stats {
153        // Check if replay lag is > 1000ms (1 second)
154        if let Some(lag_ms) = stat.replay_lag_ms {
155            if lag_ms > 1000 {
156                return Ok(false);
157            }
158        } else {
159            // If lag is NULL, it might be too far behind or not streaming yet
160            return Ok(false);
161        }
162    }
163
164    Ok(true)
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::postgres::connect;
171
172    #[tokio::test]
173    #[ignore]
174    async fn test_get_replication_lag() {
175        // This test requires a source database with active replication
176        let source_url = std::env::var("TEST_SOURCE_URL").unwrap();
177        let client = connect(&source_url).await.unwrap();
178
179        let result = get_replication_lag(&client, None).await;
180        match &result {
181            Ok(stats) => {
182                println!("✓ Replication lag query succeeded");
183                println!("Found {} replication slots", stats.len());
184                for stat in stats {
185                    println!(
186                        "  - {}: {} (replay lag: {:?}ms)",
187                        stat.application_name, stat.state, stat.replay_lag_ms
188                    );
189                }
190            }
191            Err(e) => {
192                println!("Error querying replication lag: {:?}", e);
193                // It's okay if no replication is active
194                if !e.to_string().contains("relation") && !e.to_string().contains("permission") {
195                    panic!("Unexpected error: {:?}", e);
196                }
197            }
198        }
199        assert!(result.is_ok());
200    }
201
202    #[tokio::test]
203    #[ignore]
204    async fn test_get_subscription_status() {
205        // This test requires a target database with active subscription
206        let target_url = std::env::var("TEST_TARGET_URL").unwrap();
207        let client = connect(&target_url).await.unwrap();
208
209        let result = get_subscription_status(&client, None).await;
210        match &result {
211            Ok(stats) => {
212                println!("✓ Subscription status query succeeded");
213                println!("Found {} subscriptions", stats.len());
214                for stat in stats {
215                    println!(
216                        "  - {}: state={} (pid: {:?})",
217                        stat.subscription_name, stat.state, stat.pid
218                    );
219                }
220            }
221            Err(e) => {
222                println!("Error querying subscription status: {:?}", e);
223                // It's okay if no subscriptions exist
224                if !e.to_string().contains("relation") && !e.to_string().contains("permission") {
225                    panic!("Unexpected error: {:?}", e);
226                }
227            }
228        }
229        assert!(result.is_ok());
230    }
231
232    #[tokio::test]
233    #[ignore]
234    async fn test_is_replication_caught_up() {
235        let source_url = std::env::var("TEST_SOURCE_URL").unwrap();
236        let client = connect(&source_url).await.unwrap();
237
238        let result = is_replication_caught_up(&client, None).await;
239        match &result {
240            Ok(caught_up) => {
241                println!("✓ Caught up check succeeded: {}", caught_up);
242            }
243            Err(e) => {
244                println!("Error checking if caught up: {:?}", e);
245                // It's okay if no replication is active
246                if !e.to_string().contains("relation") && !e.to_string().contains("permission") {
247                    panic!("Unexpected error: {:?}", e);
248                }
249            }
250        }
251        assert!(result.is_ok());
252    }
253
254    #[tokio::test]
255    #[ignore]
256    async fn test_get_replication_lag_with_name() {
257        let source_url = std::env::var("TEST_SOURCE_URL").unwrap();
258        let client = connect(&source_url).await.unwrap();
259
260        // Query for a specific subscription name
261        let result = get_replication_lag(&client, Some("seren_migration_sub")).await;
262        match &result {
263            Ok(stats) => {
264                println!("✓ Named replication lag query succeeded");
265                println!("Found {} matching replication slots", stats.len());
266            }
267            Err(e) => {
268                println!("Error querying named replication lag: {:?}", e);
269            }
270        }
271        assert!(result.is_ok());
272    }
273}