database_replicator/replication/
publication.rs

1// ABOUTME: Publication management for logical replication on source database
2// ABOUTME: Creates and manages PostgreSQL publications for table replication
3
4use anyhow::{bail, Context, Result};
5use tokio_postgres::Client;
6
7use crate::filters::ReplicationFilter;
8use crate::table_rules::TableRuleKind;
9
10/// Create a publication for tables with optional filtering
11///
12/// When table filters are specified, creates a publication for only the filtered tables.
13/// Without filters, creates a publication for all tables.
14///
15/// # Arguments
16///
17/// * `client` - Connected client to the database
18/// * `db_name` - Name of the database (for filtering context)
19/// * `publication_name` - Name of the publication to create
20/// * `filter` - Replication filter for table inclusion/exclusion
21///
22/// # Returns
23///
24/// Returns `Ok(())` if publication is created or already exists
25pub async fn create_publication(
26    client: &Client,
27    db_name: &str,
28    publication_name: &str,
29    filter: &ReplicationFilter,
30) -> Result<()> {
31    // Validate publication name to prevent SQL injection
32    crate::utils::validate_postgres_identifier(publication_name).with_context(|| {
33        format!(
34            "Invalid publication name '{}': must be a valid PostgreSQL identifier",
35            publication_name
36        )
37    })?;
38
39    tracing::info!("Creating publication '{}'...", publication_name);
40
41    if filter.is_empty() {
42        let query = format!("CREATE PUBLICATION \"{}\" FOR ALL TABLES", publication_name);
43        return execute_publication_query(client, publication_name, &query).await;
44    }
45
46    let tables = crate::migration::list_tables(client).await?;
47
48    let mut plain_tables = Vec::new();
49    let mut predicate_tables = Vec::new();
50
51    for table in tables {
52        // Build "schema.table" identifier for include/exclude logic
53        let table_identifier = if table.schema == "public" {
54            table.name.clone()
55        } else {
56            format!("{}.{}", table.schema, table.name)
57        };
58
59        if !filter.should_replicate_table(db_name, &table_identifier) {
60            continue;
61        }
62
63        // Validate schema/table names
64        crate::utils::validate_postgres_identifier(&table.schema).with_context(|| {
65            format!(
66                "Invalid schema name '{}' for table '{}': must be a valid PostgreSQL identifier",
67                table.schema, table.name
68            )
69        })?;
70        crate::utils::validate_postgres_identifier(&table.name).with_context(|| {
71            format!(
72                "Invalid table name '{}' in schema '{}': must be a valid PostgreSQL identifier",
73                table.name, table.schema
74            )
75        })?;
76
77        let fq_table = format!("\"{}\".\"{}\"", table.schema, table.name);
78
79        match filter
80            .table_rules()
81            .rule_for_table(db_name, &table.schema, &table.name)
82        {
83            Some(TableRuleKind::SchemaOnly) => {
84                tracing::debug!(
85                    "Excluding table '{}' from publication (schema-only)",
86                    table_identifier
87                );
88            }
89            Some(TableRuleKind::Predicate(pred)) => {
90                predicate_tables.push((fq_table, pred));
91            }
92            None => {
93                plain_tables.push(fq_table);
94            }
95        }
96    }
97
98    if plain_tables.is_empty() && predicate_tables.is_empty() {
99        bail!(
100            "No tables available for publication '{}' after applying filters and schema-only rules",
101            publication_name
102        );
103    }
104
105    let has_predicates = !predicate_tables.is_empty();
106    let server_version = get_server_version(client).await?;
107    if has_predicates && server_version < 150000 {
108        bail!(
109            "Table-level predicates require PostgreSQL 15+. Detected server version {}.\n\
110             Upgrade the source database or remove --table-filter/--time-filter for logical replication.",
111            server_version
112        );
113    }
114
115    let mut clauses = Vec::new();
116    clauses.extend(plain_tables);
117    clauses.extend(
118        predicate_tables
119            .iter()
120            .map(|(table, predicate)| format!("{} WHERE ({})", table, predicate)),
121    );
122
123    let query = format!(
124        "CREATE PUBLICATION \"{}\" FOR TABLE {}",
125        publication_name,
126        clauses.join(", ")
127    );
128
129    execute_publication_query(client, publication_name, &query).await
130}
131
132async fn execute_publication_query(
133    client: &Client,
134    publication_name: &str,
135    query: &str,
136) -> Result<()> {
137    match client.execute(query, &[]).await {
138        Ok(_) => {
139            tracing::info!("✓ Publication '{}' created successfully", publication_name);
140            Ok(())
141        }
142        Err(e) => {
143            let err_str = e.to_string();
144            // Publication might already exist - that's okay
145            if err_str.contains("already exists") {
146                tracing::info!("✓ Publication '{}' already exists", publication_name);
147                Ok(())
148            } else if err_str.contains("permission denied") || err_str.contains("must be owner") {
149                anyhow::bail!(
150                    "Permission denied: Cannot create publication '{}'.\n\
151                     You need superuser or owner privileges on the database.\n\
152                     Grant with: GRANT CREATE ON DATABASE <dbname> TO <user>;\n\
153                     Error: {}",
154                    publication_name,
155                    err_str
156                )
157            } else if err_str.contains("wal_level") || err_str.contains("logical replication") {
158                anyhow::bail!(
159                    "Logical replication not enabled: Cannot create publication '{}'.\n\
160                     The database parameter 'wal_level' must be set to 'logical'.\n\
161                     Contact your database administrator to update postgresql.conf:\n\
162                     wal_level = logical\n\
163                     Error: {}",
164                    publication_name,
165                    err_str
166                )
167            } else {
168                anyhow::bail!(
169                    "Failed to create publication '{}': {}\n\
170                     \n\
171                     Common causes:\n\
172                     - Insufficient privileges (need CREATE privilege on database)\n\
173                     - Logical replication not enabled (wal_level must be 'logical')\n\
174                     - Database does not support publications",
175                    publication_name,
176                    err_str
177                )
178            }
179        }
180    }
181}
182
183async fn get_server_version(client: &Client) -> Result<i32> {
184    let row = client
185        .query_one("SHOW server_version_num", &[])
186        .await
187        .context("Failed to query server version")?;
188    let version_str: String = row.get(0);
189    version_str.parse::<i32>().with_context(|| {
190        format!(
191            "Invalid server_version_num '{}'. Expected integer.",
192            version_str
193        )
194    })
195}
196
197/// List all publications in the database
198pub async fn list_publications(client: &Client) -> Result<Vec<String>> {
199    let rows = client
200        .query("SELECT pubname FROM pg_publication ORDER BY pubname", &[])
201        .await
202        .context("Failed to list publications")?;
203
204    let publications: Vec<String> = rows.iter().map(|row| row.get(0)).collect();
205
206    Ok(publications)
207}
208
209/// Drop a publication
210pub async fn drop_publication(client: &Client, publication_name: &str) -> Result<()> {
211    // Validate publication name to prevent SQL injection
212    crate::utils::validate_postgres_identifier(publication_name).with_context(|| {
213        format!(
214            "Invalid publication name '{}': must be a valid PostgreSQL identifier",
215            publication_name
216        )
217    })?;
218
219    tracing::info!("Dropping publication '{}'...", publication_name);
220
221    let query = format!("DROP PUBLICATION IF EXISTS \"{}\"", publication_name);
222
223    client
224        .execute(&query, &[])
225        .await
226        .context(format!("Failed to drop publication '{}'", publication_name))?;
227
228    tracing::info!("✓ Publication '{}' dropped", publication_name);
229    Ok(())
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::postgres::connect;
236
237    #[tokio::test]
238    #[ignore]
239    async fn test_create_and_list_publications() {
240        let url = std::env::var("TEST_SOURCE_URL").unwrap();
241        let client = connect(&url).await.unwrap();
242
243        let pub_name = "test_publication";
244        let db_name = "postgres"; // Assume testing on postgres database
245        let filter = ReplicationFilter::empty();
246
247        // Clean up if exists
248        let _ = drop_publication(&client, pub_name).await;
249
250        // Create publication
251        let result = create_publication(&client, db_name, pub_name, &filter).await;
252        match &result {
253            Ok(_) => println!("✓ Publication created successfully"),
254            Err(e) => {
255                println!("Error creating publication: {:?}", e);
256                // If Neon doesn't support publications, skip rest of test
257                if e.to_string().contains("not supported") || e.to_string().contains("permission") {
258                    println!("Skipping test - Neon might not support publications on pooler");
259                    return;
260                }
261            }
262        }
263        assert!(result.is_ok(), "Failed to create publication");
264
265        // List publications
266        let pubs = list_publications(&client).await.unwrap();
267        println!("Publications: {:?}", pubs);
268        assert!(pubs.contains(&pub_name.to_string()));
269
270        // Clean up
271        drop_publication(&client, pub_name).await.unwrap();
272    }
273
274    #[tokio::test]
275    #[ignore]
276    async fn test_drop_publication() {
277        let url = std::env::var("TEST_SOURCE_URL").unwrap();
278        let client = connect(&url).await.unwrap();
279
280        let pub_name = "test_drop_publication";
281        let db_name = "postgres";
282        let filter = ReplicationFilter::empty();
283
284        // Create publication
285        create_publication(&client, db_name, pub_name, &filter)
286            .await
287            .unwrap();
288
289        // Drop it
290        let result = drop_publication(&client, pub_name).await;
291        assert!(result.is_ok());
292
293        // Verify it's gone
294        let pubs = list_publications(&client).await.unwrap();
295        assert!(!pubs.contains(&pub_name.to_string()));
296    }
297}