pgwire_replication/client/
tokio_client.rs

1use crate::config::ReplicationConfig;
2use crate::error::{PgWireError, Result};
3use crate::lsn::Lsn;
4use tokio::net::TcpStream;
5use tokio::sync::{mpsc, watch};
6use tokio::task::JoinHandle;
7
8use std::sync::Arc;
9
10#[cfg(not(feature = "tls-rustls"))]
11use crate::config::SslMode;
12
13use super::worker::{ReplicationEvent, ReplicationEventReceiver, SharedProgress, WorkerState};
14
15/// PostgreSQL logical replication client.
16///
17/// This client spawns a background worker task that maintains the replication
18/// connection and streams events to the consumer via a bounded channel.
19///
20/// # Example
21///
22/// ```no_run
23/// use pgwire_replication::client::{ReplicationClient, ReplicationEvent};
24/// use pgwire_replication::config::ReplicationConfig;
25///
26/// #[tokio::main]
27/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
28///     let config = ReplicationConfig::new(
29///         "localhost",
30///         "postgres",
31///         "password",
32///         "mydb",
33///         "my_slot",
34///         "my_pub",
35///     );
36///
37///     let mut client = ReplicationClient::connect(config).await?;
38///
39///     while let Some(ev) = client.recv().await? {
40///         match ev {
41///             ReplicationEvent::XLogData { data, wal_end, .. } => {
42///                 process_change(&data);
43///                 client.update_applied_lsn(wal_end);
44///             }
45///             ReplicationEvent::KeepAlive { .. } => {}
46///             ReplicationEvent::StoppedAt { reached } => {
47///                 println!("Reached stop LSN: {reached}");
48///                 break;
49///             }
50///             _ => {}
51///         }
52///     }
53///
54///     Ok(())
55/// }
56///
57/// fn process_change(_data: &bytes::Bytes) {
58///     // user-defined
59/// }
60/// ```
61pub struct ReplicationClient {
62    rx: ReplicationEventReceiver,
63    progress: Arc<SharedProgress>,
64    stop_tx: watch::Sender<bool>,
65    join: Option<JoinHandle<std::result::Result<(), PgWireError>>>,
66}
67
68impl ReplicationClient {
69    /// Connect to PostgreSQL and start streaming replication events.
70    ///
71    /// This establishes a TCP connection (optionally upgrading to TLS),
72    /// authenticates, and starts the replication stream. Events are buffered
73    /// in a channel of size `config.buffer_events`.
74    ///
75    /// # Errors
76    ///
77    /// Returns an error if:
78    /// - TCP connection fails
79    /// - TLS handshake fails (when enabled)
80    /// - Authentication fails
81    /// - Replication slot doesn't exist
82    /// - Publication doesn't exist
83    pub async fn connect(cfg: ReplicationConfig) -> Result<Self> {
84        let (tx, rx) = mpsc::channel(cfg.buffer_events);
85
86        // Progress is shared via atomics: cheap, monotonic, no async backpressure.
87        let progress = Arc::new(SharedProgress::new(cfg.start_lsn));
88
89        let (stop_tx, stop_rx) = watch::channel(false);
90
91        let progress_for_worker = Arc::clone(&progress);
92        let cfg_for_worker = cfg.clone();
93
94        let join = tokio::spawn(async move {
95            let mut worker = WorkerState::new(cfg_for_worker, progress_for_worker, stop_rx, tx);
96            let res = run_worker(&mut worker, &cfg).await;
97            if let Err(ref e) = res {
98                tracing::error!("replication worker terminated with error: {e}");
99            }
100            res
101        });
102
103        Ok(Self {
104            rx,
105            progress,
106            stop_tx,
107            join: Some(join),
108        })
109    }
110
111    /// Receive the next replication event.
112    ///
113    /// - `Ok(Some(event))` => received an event
114    /// - `Ok(None)`        => replication ended normally (stop requested or stop_at_lsn reached)
115    /// - `Err(e)`          => replication ended abnormally
116    pub async fn recv(&mut self) -> Result<Option<ReplicationEvent>> {
117        match self.rx.recv().await {
118            Some(Ok(ev)) => Ok(Some(ev)),
119            Some(Err(e)) => Err(e),
120            None => self.handle_worker_shutdown().await,
121        }
122    }
123
124    async fn handle_worker_shutdown(&mut self) -> Result<Option<ReplicationEvent>> {
125        let join = self
126            .join
127            .take()
128            .ok_or_else(|| PgWireError::Internal("replication worker already joined".into()))?;
129
130        match join.await {
131            Ok(Ok(())) => Ok(None),
132            Ok(Err(e)) => Err(e),
133            Err(join_err) => Err(PgWireError::Task(format!(
134                "replication worker panicked: {join_err}"
135            ))),
136        }
137    }
138
139    /// Update the applied/durable LSN reported to the server.
140    ///
141    /// Semantics: call this only once you have durably persisted all events up to `lsn`.
142    /// This update is monotonic and cheap; wire feedback is still governed by the worker’s
143    /// `status_interval` and keepalive reply requests.
144    #[inline]
145    pub fn update_applied_lsn(&self, lsn: Lsn) {
146        self.progress.update_applied(lsn);
147    }
148
149    /// Request the worker to stop gracefully.
150    ///
151    /// After calling this, [`recv()`](Self::recv) will return remaining buffered
152    /// events, then `Ok(None)` once the worker exits cleanly.
153    ///
154    /// This sends a CopyDone message to the server to cleanly terminate
155    /// the replication stream.
156    #[inline]
157    pub fn stop(&self) {
158        let _ = self.stop_tx.send(true);
159    }
160
161    pub fn is_running(&self) -> bool {
162        self.join
163            .as_ref()
164            .map(|j| !j.is_finished())
165            .unwrap_or(false)
166    }
167
168    /// Wait for the worker task to complete and return its result.
169    ///
170    /// This consumes the client. Use this for diagnostics or to ensure
171    /// clean shutdown after calling [`stop()`](Self::stop).
172    pub async fn join(mut self) -> Result<()> {
173        let join = self
174            .join
175            .take()
176            .ok_or_else(|| PgWireError::Task("worker already joined".into()))?;
177
178        match join.await {
179            Ok(inner) => inner,
180            Err(e) => Err(PgWireError::Task(format!("join error: {e}"))),
181        }
182    }
183
184    /// Abort the worker task immediately.
185    ///
186    /// This is a hard cancel and does not send CopyDone.
187    /// Prefer `stop()`/`shutdown()` for graceful termination.
188    pub fn abort(&mut self) {
189        if let Some(join) = self.join.take() {
190            join.abort();
191        }
192    }
193
194    /// Request a graceful stop and wait for the worker to exit.
195    pub async fn shutdown(&mut self) -> Result<()> {
196        self.stop();
197
198        // Drain events until the worker closes the channel.
199        while let Some(msg) = self.rx.recv().await {
200            match msg {
201                Ok(_ev) => {} //discard; caller can drain themselves if they need events
202                Err(e) => return Err(e),
203            }
204        }
205
206        self.join_mut().await
207    }
208
209    /// Wait for the worker task to complete and return its result.
210    async fn join_mut(&mut self) -> Result<()> {
211        let join = self
212            .join
213            .take()
214            .ok_or_else(|| PgWireError::Task("worker already joined".into()))?;
215
216        match join.await {
217            Ok(inner) => inner,
218            Err(e) => Err(PgWireError::Task(format!("join error: {e}"))),
219        }
220    }
221}
222
223impl Drop for ReplicationClient {
224    fn drop(&mut self) {
225        let _ = self.stop_tx.send(true);
226
227        // We cannot .await here. Prefer to detach a join in the background
228        // so the worker can exit cleanly without being aborted.
229        if let Some(join) = self.join.take() {
230            match tokio::runtime::Handle::try_current() {
231                Ok(handle) => {
232                    handle.spawn(async move {
233                        let _ = join.await;
234                    });
235                }
236                Err(_) => {
237                    // No Tokio runtime available (dropping outside async context).
238                    // Fall back to abort to avoid a potentially unbounded leaked task.
239                    tracing::debug!(
240                        "dropping ReplicationClient outside a Tokio runtime; aborting worker task"
241                    );
242                    join.abort();
243                }
244            }
245        }
246    }
247}
248
249async fn run_worker(worker: &mut WorkerState, cfg: &ReplicationConfig) -> Result<()> {
250    let tcp = TcpStream::connect((cfg.host.as_str(), cfg.port)).await?;
251    tcp.set_nodelay(true)?;
252
253    #[cfg(feature = "tls-rustls")]
254    {
255        use crate::tls::rustls::{maybe_upgrade_to_tls, MaybeTlsStream};
256        let upgraded = maybe_upgrade_to_tls(tcp, &cfg.tls, &cfg.host).await?;
257        match upgraded {
258            MaybeTlsStream::Plain(mut s) => worker.run_on_stream(&mut s).await,
259            MaybeTlsStream::Tls(mut s) => worker.run_on_stream(s.as_mut()).await,
260        }
261    }
262
263    #[cfg(not(feature = "tls-rustls"))]
264    {
265        if !matches!(cfg.tls.mode, SslMode::Disable) {
266            return Err(PgWireError::Tls("tls-rustls feature not enabled".into()));
267        }
268        let mut s = tcp;
269        worker.run_on_stream(&mut s).await
270    }
271}