Skip to main content

drasi_source_postgres/
connection.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use anyhow::{anyhow, Result};
16use bytes::{Buf, BytesMut};
17use log::{debug, info, trace, warn};
18use std::collections::HashMap;
19use tokio::io::{AsyncReadExt, AsyncWriteExt};
20use tokio::net::TcpStream;
21
22use super::protocol::{
23    parse_backend_message, AuthenticationMessage, BackendMessage, FrontendMessage, StartupMessage,
24    TransactionStatus,
25};
26use super::scram::ScramClient;
27use super::types::{ReplicationSlotInfo, StandbyStatusUpdate};
28
29pub struct ReplicationConnection {
30    stream: TcpStream,
31    read_buffer: BytesMut,
32    write_buffer: BytesMut,
33    parameters: HashMap<String, String>,
34    process_id: Option<i32>,
35    secret_key: Option<i32>,
36    transaction_status: TransactionStatus,
37    in_copy_mode: bool,
38}
39
40impl ReplicationConnection {
41    pub async fn connect(
42        host: &str,
43        port: u16,
44        database: &str,
45        user: &str,
46        password: &str,
47    ) -> Result<Self> {
48        info!("Connecting to PostgreSQL at {host}:{port}");
49
50        let stream = TcpStream::connect((host, port)).await?;
51        stream.set_nodelay(true)?;
52
53        let mut conn = Self {
54            stream,
55            read_buffer: BytesMut::with_capacity(8192),
56            write_buffer: BytesMut::with_capacity(8192),
57            parameters: HashMap::new(),
58            process_id: None,
59            secret_key: None,
60            transaction_status: TransactionStatus::Idle,
61            in_copy_mode: false,
62        };
63
64        conn.startup_replication(database, user, password).await?;
65
66        Ok(conn)
67    }
68
69    async fn startup_replication(
70        &mut self,
71        database: &str,
72        user: &str,
73        password: &str,
74    ) -> Result<()> {
75        debug!("Starting replication protocol handshake");
76
77        // Send startup message
78        let startup = StartupMessage::new_replication(database, user);
79        self.send_message(FrontendMessage::StartupMessage(startup))
80            .await?;
81
82        // Handle authentication
83        loop {
84            let msg = self.read_message().await?;
85            match msg {
86                BackendMessage::Authentication(auth) => {
87                    match auth {
88                        AuthenticationMessage::Ok => {
89                            debug!("Authentication successful");
90                            break;
91                        }
92                        AuthenticationMessage::CleartextPassword => {
93                            debug!("Server requested cleartext password");
94                            self.send_message(FrontendMessage::PasswordMessage(
95                                password.to_string(),
96                            ))
97                            .await?;
98                        }
99                        AuthenticationMessage::MD5Password(_) => {
100                            return Err(anyhow!(
101                                "MD5 authentication is not supported (insecure). \
102                                 Please configure PostgreSQL to use scram-sha-256 in pg_hba.conf"
103                            ));
104                        }
105                        AuthenticationMessage::SASL(mechanisms) => {
106                            if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
107                                debug!("Server requested SCRAM-SHA-256 authentication");
108                                let mut scram_client = ScramClient::new(user, password);
109
110                                // Send SASLInitialResponse
111                                let client_first = scram_client.client_first_message();
112                                self.send_sasl_initial_response("SCRAM-SHA-256", &client_first)
113                                    .await?;
114
115                                // Continue SASL exchange
116                                loop {
117                                    let sasl_msg = self.read_message().await?;
118                                    match sasl_msg {
119                                        BackendMessage::Authentication(
120                                            AuthenticationMessage::SASLContinue(data),
121                                        ) => {
122                                            let server_first = String::from_utf8_lossy(&data);
123                                            scram_client
124                                                .process_server_first_message(&server_first)?;
125
126                                            let client_final =
127                                                scram_client.client_final_message()?;
128                                            self.send_sasl_response(&client_final).await?;
129                                        }
130                                        BackendMessage::Authentication(
131                                            AuthenticationMessage::SASLFinal(data),
132                                        ) => {
133                                            let server_final = String::from_utf8_lossy(&data);
134                                            scram_client.verify_server_final(&server_final)?;
135                                            debug!("SCRAM-SHA-256 authentication successful");
136                                            break;
137                                        }
138                                        BackendMessage::ErrorResponse(err) => {
139                                            return Err(anyhow!(
140                                                "SASL authentication failed: {}",
141                                                err.message
142                                            ));
143                                        }
144                                        _ => {
145                                            warn!("Unexpected message during SASL: {sasl_msg:?}");
146                                        }
147                                    }
148                                }
149                            } else {
150                                return Err(anyhow!("No supported SASL mechanisms"));
151                            }
152                        }
153                        _ => {
154                            return Err(anyhow!("Unsupported authentication method"));
155                        }
156                    }
157                }
158                BackendMessage::ErrorResponse(err) => {
159                    return Err(anyhow!("Authentication failed: {}", err.message));
160                }
161                _ => {
162                    warn!("Unexpected message during authentication: {msg:?}");
163                }
164            }
165        }
166
167        // Wait for ReadyForQuery
168        loop {
169            let msg = self.read_message().await?;
170            match msg {
171                BackendMessage::BackendKeyData {
172                    process_id,
173                    secret_key,
174                } => {
175                    self.process_id = Some(process_id);
176                    self.secret_key = Some(secret_key);
177                    debug!("Received backend key data: pid={process_id}");
178                }
179                BackendMessage::ParameterStatus { name, value } => {
180                    debug!("Parameter: {name} = {value}");
181                    self.parameters.insert(name, value);
182                }
183                BackendMessage::ReadyForQuery(status) => {
184                    self.transaction_status = status;
185                    debug!("Connection ready, status: {status:?}");
186                    break;
187                }
188                BackendMessage::ErrorResponse(err) => {
189                    return Err(anyhow!("Startup failed: {}", err.message));
190                }
191                BackendMessage::NoticeResponse(notice) => {
192                    info!("Notice: {}", notice.message);
193                }
194                _ => {
195                    warn!("Unexpected message during startup: {msg:?}");
196                }
197            }
198        }
199
200        Ok(())
201    }
202
203    pub async fn identify_system(&mut self) -> Result<HashMap<String, String>> {
204        debug!("Sending IDENTIFY_SYSTEM command");
205
206        self.send_message(FrontendMessage::Query("IDENTIFY_SYSTEM".to_string()))
207            .await?;
208
209        let mut system_info = HashMap::new();
210
211        loop {
212            let msg = self.read_message().await?;
213            match msg {
214                BackendMessage::RowDescription(_) => {
215                    // Skip row description
216                }
217                BackendMessage::DataRow(row) => {
218                    // Parse system identification
219                    if row.len() >= 4 {
220                        if let Some(Some(systemid)) = row.first() {
221                            system_info.insert(
222                                "systemid".to_string(),
223                                String::from_utf8_lossy(systemid).to_string(),
224                            );
225                        }
226                        if let Some(Some(timeline)) = row.get(1) {
227                            system_info.insert(
228                                "timeline".to_string(),
229                                String::from_utf8_lossy(timeline).to_string(),
230                            );
231                        }
232                        if let Some(Some(xlogpos)) = row.get(2) {
233                            system_info.insert(
234                                "xlogpos".to_string(),
235                                String::from_utf8_lossy(xlogpos).to_string(),
236                            );
237                        }
238                        if let Some(Some(dbname)) = row.get(3) {
239                            system_info.insert(
240                                "dbname".to_string(),
241                                String::from_utf8_lossy(dbname).to_string(),
242                            );
243                        }
244                    }
245                }
246                BackendMessage::CommandComplete(_) => {
247                    // Command completed
248                }
249                BackendMessage::ReadyForQuery(status) => {
250                    self.transaction_status = status;
251                    break;
252                }
253                BackendMessage::ErrorResponse(err) => {
254                    return Err(anyhow!("IDENTIFY_SYSTEM failed: {}", err.message));
255                }
256                _ => {
257                    warn!("Unexpected message during IDENTIFY_SYSTEM: {msg:?}");
258                }
259            }
260        }
261
262        Ok(system_info)
263    }
264
265    pub async fn create_replication_slot(
266        &mut self,
267        slot_name: &str,
268        temporary: bool,
269    ) -> Result<ReplicationSlotInfo> {
270        debug!("Creating replication slot: {slot_name}");
271
272        let query = if temporary {
273            format!("CREATE_REPLICATION_SLOT {slot_name} TEMPORARY LOGICAL pgoutput")
274        } else {
275            format!("CREATE_REPLICATION_SLOT {slot_name} LOGICAL pgoutput")
276        };
277
278        self.send_message(FrontendMessage::Query(query)).await?;
279
280        let mut slot_info = ReplicationSlotInfo {
281            slot_name: slot_name.to_string(),
282            consistent_point: String::new(),
283            snapshot_name: None,
284            output_plugin: "pgoutput".to_string(),
285            restart_lsn: None,
286        };
287
288        loop {
289            let msg = self.read_message().await?;
290            match msg {
291                BackendMessage::RowDescription(_) => {
292                    // Skip row description
293                }
294                BackendMessage::DataRow(row) => {
295                    // Parse slot creation result
296                    if row.len() >= 4 {
297                        if let Some(Some(consistent_point)) = row.get(1) {
298                            slot_info.consistent_point =
299                                String::from_utf8_lossy(consistent_point).to_string();
300                        }
301                        if let Some(Some(snapshot_name)) = row.get(2) {
302                            slot_info.snapshot_name =
303                                Some(String::from_utf8_lossy(snapshot_name).to_string());
304                        }
305                    }
306                }
307                BackendMessage::CommandComplete(_) => {
308                    // Command completed
309                }
310                BackendMessage::ReadyForQuery(status) => {
311                    self.transaction_status = status;
312                    break;
313                }
314                BackendMessage::ErrorResponse(err) => {
315                    if err.message.contains("already exists") {
316                        debug!("Replication slot already exists: {slot_name}");
317                        // Drain the ReadyForQuery that PostgreSQL sends after ErrorResponse
318                        loop {
319                            let drain_msg = self.read_message().await?;
320                            if let BackendMessage::ReadyForQuery(status) = drain_msg {
321                                self.transaction_status = status;
322                                break;
323                            }
324                        }
325                        return self.get_replication_slot_info(slot_name).await;
326                    }
327                    return Err(anyhow!("CREATE_REPLICATION_SLOT failed: {}", err.message));
328                }
329                _ => {
330                    warn!("Unexpected message during CREATE_REPLICATION_SLOT: {msg:?}");
331                }
332            }
333        }
334
335        Ok(slot_info)
336    }
337
338    pub async fn get_replication_slot_info(
339        &mut self,
340        slot_name: &str,
341    ) -> Result<ReplicationSlotInfo> {
342        debug!("Querying existing replication slot: {slot_name}");
343
344        let slot_name_escaped = slot_name.replace('\'', "''");
345        let query = format!(
346            "SELECT slot_name, confirmed_flush_lsn, restart_lsn, plugin FROM pg_replication_slots WHERE slot_name = '{slot_name_escaped}'"
347        );
348
349        self.send_message(FrontendMessage::Query(query)).await?;
350
351        let mut slot_info = ReplicationSlotInfo {
352            slot_name: slot_name.to_string(),
353            consistent_point: "0/0".to_string(),
354            snapshot_name: None,
355            output_plugin: "pgoutput".to_string(),
356            restart_lsn: None,
357        };
358        let mut found_row = false;
359
360        loop {
361            let msg = self.read_message().await?;
362            match msg {
363                BackendMessage::RowDescription(_) => {
364                    // Skip row description
365                }
366                BackendMessage::DataRow(row) => {
367                    found_row = true;
368                    if row.len() >= 4 {
369                        if let Some(Some(confirmed_flush_lsn)) = row.get(1) {
370                            let lsn = String::from_utf8_lossy(confirmed_flush_lsn).to_string();
371                            if !lsn.is_empty() {
372                                slot_info.consistent_point = lsn;
373                            }
374                        }
375                        if let Some(Some(restart_lsn_val)) = row.get(2) {
376                            let lsn = String::from_utf8_lossy(restart_lsn_val).to_string();
377                            if !lsn.is_empty() {
378                                slot_info.restart_lsn = Some(lsn.clone());
379                                // Fall back to restart_lsn for consistent_point if confirmed_flush is unset
380                                if slot_info.consistent_point == "0/0" {
381                                    slot_info.consistent_point = lsn;
382                                }
383                            }
384                        }
385                        if let Some(Some(plugin)) = row.get(3) {
386                            slot_info.output_plugin = String::from_utf8_lossy(plugin).to_string();
387                        }
388                    }
389                }
390                BackendMessage::CommandComplete(_) => {
391                    // Command completed
392                }
393                BackendMessage::ReadyForQuery(status) => {
394                    self.transaction_status = status;
395                    break;
396                }
397                BackendMessage::ErrorResponse(err) => {
398                    return Err(anyhow!("Failed to query replication slot: {}", err.message));
399                }
400                _ => {
401                    warn!("Unexpected message during slot query: {msg:?}");
402                }
403            }
404        }
405
406        if !found_row {
407            return Err(anyhow!("Replication slot not found: {slot_name}"));
408        }
409
410        info!(
411            "Using existing replication slot: {slot_name} at LSN {}",
412            slot_info.consistent_point
413        );
414        Ok(slot_info)
415    }
416
417    pub async fn start_replication(
418        &mut self,
419        slot_name: &str,
420        start_lsn: Option<u64>,
421        options: HashMap<String, String>,
422    ) -> Result<()> {
423        debug!("Starting replication from slot: {slot_name}");
424
425        let mut query = format!("START_REPLICATION SLOT {slot_name} LOGICAL");
426
427        if let Some(lsn) = start_lsn {
428            query.push_str(&format!(" {}", format_lsn(lsn)));
429        } else {
430            query.push_str(" 0/0");
431        }
432
433        if !options.is_empty() {
434            query.push_str(" (");
435            let opts: Vec<String> = options.iter().map(|(k, v)| format!("{k} '{v}'")).collect();
436            query.push_str(&opts.join(", "));
437            query.push(')');
438        }
439
440        self.send_message(FrontendMessage::Query(query)).await?;
441
442        // Wait for CopyBothResponse
443        loop {
444            let msg = self.read_message().await?;
445            match msg {
446                BackendMessage::CopyBothResponse => {
447                    debug!("Entered COPY BOTH mode for replication");
448                    self.in_copy_mode = true;
449                    break;
450                }
451                BackendMessage::ErrorResponse(err) => {
452                    return Err(anyhow!("START_REPLICATION failed: {}", err.message));
453                }
454                BackendMessage::ReadyForQuery(_) => {
455                    // This is normal - PostgreSQL sends ReadyForQuery before entering COPY mode
456                    debug!("Received ReadyForQuery before entering COPY mode");
457                }
458                _ => {
459                    debug!("Message during START_REPLICATION: {msg:?}");
460                }
461            }
462        }
463
464        Ok(())
465    }
466
467    pub async fn read_replication_message(&mut self) -> Result<BackendMessage> {
468        if !self.in_copy_mode {
469            return Err(anyhow!("Not in COPY mode"));
470        }
471
472        self.read_message().await
473    }
474
475    pub async fn send_standby_status(&mut self, status: StandbyStatusUpdate) -> Result<()> {
476        if !self.in_copy_mode {
477            return Err(anyhow!("Not in COPY mode"));
478        }
479
480        let timestamp = chrono::Utc::now().timestamp_micros() - 946684800000000; // PostgreSQL epoch
481
482        self.send_message(FrontendMessage::StandbyStatusUpdate {
483            write_lsn: status.write_lsn,
484            flush_lsn: status.flush_lsn,
485            apply_lsn: status.apply_lsn,
486            timestamp,
487            reply: if status.reply_requested { 1 } else { 0 },
488        })
489        .await
490    }
491
492    async fn send_message(&mut self, msg: FrontendMessage) -> Result<()> {
493        self.write_buffer.clear();
494        msg.encode(&mut self.write_buffer)?;
495
496        self.stream.write_all(&self.write_buffer).await?;
497        self.stream.flush().await?;
498
499        trace!("Sent message: {msg:?}");
500        Ok(())
501    }
502
503    async fn send_sasl_initial_response(&mut self, mechanism: &str, response: &str) -> Result<()> {
504        self.send_message(FrontendMessage::SASLInitialResponse {
505            mechanism: mechanism.to_string(),
506            data: response.as_bytes().to_vec(),
507        })
508        .await
509    }
510
511    async fn send_sasl_response(&mut self, response: &str) -> Result<()> {
512        self.send_message(FrontendMessage::SASLResponse(response.as_bytes().to_vec()))
513            .await
514    }
515
516    async fn read_message(&mut self) -> Result<BackendMessage> {
517        loop {
518            // Try to parse a message from the buffer
519            if let Some(msg) = self.try_parse_message()? {
520                trace!("Received message: {msg:?}");
521                return Ok(msg);
522            }
523
524            // Read more data
525            let mut temp_buf = vec![0u8; 4096];
526            let n = self.stream.read(&mut temp_buf).await?;
527            if n == 0 {
528                return Err(anyhow!("Connection closed by server"));
529            }
530
531            self.read_buffer.extend_from_slice(&temp_buf[..n]);
532        }
533    }
534
535    fn try_parse_message(&mut self) -> Result<Option<BackendMessage>> {
536        if self.read_buffer.len() < 5 {
537            return Ok(None); // Need at least type + length
538        }
539
540        let msg_type = self.read_buffer[0];
541        let length = u32::from_be_bytes([
542            self.read_buffer[1],
543            self.read_buffer[2],
544            self.read_buffer[3],
545            self.read_buffer[4],
546        ]) as usize;
547
548        if length < 4 {
549            return Err(anyhow!("Invalid message length: {length}"));
550        }
551
552        let total_length = 1 + length; // Type byte + length (includes self)
553
554        if self.read_buffer.len() < total_length {
555            return Ok(None); // Need more data
556        }
557
558        // Extract message
559        let body = self.read_buffer[5..total_length].to_vec();
560        self.read_buffer.advance(total_length);
561
562        // Parse message
563        let msg = parse_backend_message(msg_type, &body)?;
564        Ok(Some(msg))
565    }
566
567    pub async fn close(mut self) -> Result<()> {
568        if self.in_copy_mode {
569            let _ = self.send_message(FrontendMessage::CopyDone).await;
570        }
571        let _ = self.send_message(FrontendMessage::Terminate).await;
572        let _ = self.stream.shutdown().await;
573        Ok(())
574    }
575}
576
577/// Formats a WAL LSN `u64` as the PostgreSQL `"high/low"` hex notation (e.g. `"0/1A3F00"`).
578pub(crate) fn format_lsn(lsn: u64) -> String {
579    format!("{:X}/{:X}", lsn >> 32, lsn & 0xFFFFFFFF)
580}
581
582/// Parses a PostgreSQL LSN string in `"high/low"` hex notation into a `u64`.
583/// Returns an error if the string is not in `"X/Y"` format.
584pub(crate) fn parse_lsn(lsn_str: &str) -> Result<u64> {
585    let parts: Vec<&str> = lsn_str.split('/').collect();
586    if parts.len() != 2 {
587        return Err(anyhow!("Invalid LSN format: {lsn_str}"));
588    }
589
590    let high = u64::from_str_radix(parts[0], 16)?;
591    let low = u64::from_str_radix(parts[1], 16)?;
592
593    Ok((high << 32) | low)
594}