Skip to main content

pgwire_replication/client/
tokio_client.rs

1use crate::config::ReplicationConfig;
2use crate::error::{PgWireError, Result};
3use crate::lsn::Lsn;
4
5use tokio::net::TcpStream;
6#[cfg(unix)]
7use tokio::net::UnixStream;
8
9use tokio::sync::{mpsc, watch};
10use tokio::task::JoinHandle;
11
12use std::sync::Arc;
13
14#[cfg(not(feature = "tls-rustls"))]
15use crate::config::SslMode;
16
17use super::worker::{ReplicationEvent, ReplicationEventReceiver, SharedProgress, WorkerState};
18
19/// PostgreSQL logical replication client.
20///
21/// This client spawns a background worker task that maintains the replication
22/// connection and streams events to the consumer via a bounded channel.
23///
24/// # Example
25///
26/// ```no_run
27/// use pgwire_replication::client::{ReplicationClient, ReplicationEvent};
28/// use pgwire_replication::config::ReplicationConfig;
29///
30/// #[tokio::main]
31/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
32///     let config = ReplicationConfig::new(
33///         "localhost",
34///         "postgres",
35///         "password",
36///         "mydb",
37///         "my_slot",
38///         "my_pub",
39///     );
40///
41///     let mut client = ReplicationClient::connect(config).await?;
42///
43///     while let Some(ev) = client.recv().await? {
44///         match ev {
45///             ReplicationEvent::XLogData { data, wal_end, .. } => {
46///                 process_change(&data);
47///                 client.update_applied_lsn(wal_end);
48///             }
49///             ReplicationEvent::KeepAlive { .. } => {}
50///             ReplicationEvent::StoppedAt { reached } => {
51///                 println!("Reached stop LSN: {reached}");
52///                 break;
53///             }
54///             _ => {}
55///         }
56///     }
57///
58///     Ok(())
59/// }
60///
61/// fn process_change(_data: &bytes::Bytes) {
62///     // user-defined
63/// }
64/// ```
65pub struct ReplicationClient {
66    rx: ReplicationEventReceiver,
67    progress: Arc<SharedProgress>,
68    stop_tx: watch::Sender<bool>,
69    join: Option<JoinHandle<std::result::Result<(), PgWireError>>>,
70}
71
72impl ReplicationClient {
73    /// Connect to PostgreSQL and start streaming replication events.
74    ///
75    /// This establishes a TCP connection (optionally upgrading to TLS),
76    /// authenticates, and starts the replication stream. Events are buffered
77    /// in a channel of size `config.buffer_events`.
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if:
82    /// - TCP connection fails
83    /// - TLS handshake fails (when enabled)
84    /// - Authentication fails
85    /// - Replication slot doesn't exist
86    /// - Publication doesn't exist
87    /// - Unix socket does not exist (when host starts with `/`)
88    /// - TLS requested with Unix socket connection
89    pub async fn connect(cfg: ReplicationConfig) -> Result<Self> {
90        let (tx, rx) = mpsc::channel(cfg.buffer_events);
91
92        // Progress is shared via atomics: cheap, monotonic, no async backpressure.
93        let progress = Arc::new(SharedProgress::new(cfg.start_lsn));
94
95        let (stop_tx, stop_rx) = watch::channel(false);
96
97        let progress_for_worker = Arc::clone(&progress);
98        let cfg_for_worker = cfg.clone();
99
100        let join = tokio::spawn(async move {
101            let mut worker = WorkerState::new(cfg_for_worker, progress_for_worker, stop_rx, tx);
102            let res = run_worker(&mut worker, &cfg).await;
103            if let Err(ref e) = res {
104                tracing::error!("replication worker terminated with error: {e}");
105            }
106            res
107        });
108
109        Ok(Self {
110            rx,
111            progress,
112            stop_tx,
113            join: Some(join),
114        })
115    }
116
117    /// Receive the next replication event.
118    ///
119    /// - `Ok(Some(event))` => received an event
120    /// - `Ok(None)`        => replication ended normally (stop requested or stop_at_lsn reached)
121    /// - `Err(e)`          => replication ended abnormally
122    pub async fn recv(&mut self) -> Result<Option<ReplicationEvent>> {
123        match self.rx.recv().await {
124            Some(Ok(ev)) => Ok(Some(ev)),
125            Some(Err(e)) => Err(e),
126            None => self.handle_worker_shutdown().await,
127        }
128    }
129
130    async fn handle_worker_shutdown(&mut self) -> Result<Option<ReplicationEvent>> {
131        let join = self
132            .join
133            .take()
134            .ok_or_else(|| PgWireError::Internal("replication worker already joined".into()))?;
135
136        match join.await {
137            Ok(Ok(())) => Ok(None),
138            Ok(Err(e)) => Err(e),
139            Err(join_err) => Err(PgWireError::Task(format!(
140                "replication worker panicked: {join_err}"
141            ))),
142        }
143    }
144
145    /// Update the applied/durable LSN reported to the server.
146    ///
147    /// Semantics: call this only once you have durably persisted all events up to `lsn`.
148    /// This update is monotonic and cheap; wire feedback is still governed by the worker’s
149    /// `status_interval` and keepalive reply requests.
150    #[inline]
151    pub fn update_applied_lsn(&self, lsn: Lsn) {
152        self.progress.update_applied(lsn);
153    }
154
155    /// Request the worker to stop gracefully.
156    ///
157    /// After calling this, [`recv()`](Self::recv) will return remaining buffered
158    /// events, then `Ok(None)` once the worker exits cleanly.
159    ///
160    /// This sends a CopyDone message to the server to cleanly terminate
161    /// the replication stream.
162    #[inline]
163    pub fn stop(&self) {
164        let _ = self.stop_tx.send(true);
165    }
166
167    pub fn is_running(&self) -> bool {
168        self.join
169            .as_ref()
170            .map(|j| !j.is_finished())
171            .unwrap_or(false)
172    }
173
174    /// Wait for the worker task to complete and return its result.
175    ///
176    /// This consumes the client. Use this for diagnostics or to ensure
177    /// clean shutdown after calling [`stop()`](Self::stop).
178    pub async fn join(mut self) -> Result<()> {
179        let join = self
180            .join
181            .take()
182            .ok_or_else(|| PgWireError::Task("worker already joined".into()))?;
183
184        match join.await {
185            Ok(inner) => inner,
186            Err(e) => Err(PgWireError::Task(format!("join error: {e}"))),
187        }
188    }
189
190    /// Abort the worker task immediately.
191    ///
192    /// This is a hard cancel and does not send CopyDone.
193    /// Prefer `stop()`/`shutdown()` for graceful termination.
194    pub fn abort(&mut self) {
195        if let Some(join) = self.join.take() {
196            join.abort();
197        }
198    }
199
200    /// Request a graceful stop and wait for the worker to exit.
201    pub async fn shutdown(&mut self) -> Result<()> {
202        self.stop();
203
204        // Drain events until the worker closes the channel.
205        while let Some(msg) = self.rx.recv().await {
206            match msg {
207                Ok(_ev) => {} //discard; caller can drain themselves if they need events
208                Err(e) => return Err(e),
209            }
210        }
211
212        self.join_mut().await
213    }
214
215    /// Wait for the worker task to complete and return its result.
216    async fn join_mut(&mut self) -> Result<()> {
217        let join = self
218            .join
219            .take()
220            .ok_or_else(|| PgWireError::Task("worker already joined".into()))?;
221
222        match join.await {
223            Ok(inner) => inner,
224            Err(e) => Err(PgWireError::Task(format!("join error: {e}"))),
225        }
226    }
227}
228
229impl Drop for ReplicationClient {
230    fn drop(&mut self) {
231        let _ = self.stop_tx.send(true);
232
233        // We cannot .await here. Prefer to detach a join in the background
234        // so the worker can exit cleanly without being aborted.
235        if let Some(join) = self.join.take() {
236            match tokio::runtime::Handle::try_current() {
237                Ok(handle) => {
238                    handle.spawn(async move {
239                        let _ = join.await;
240                    });
241                }
242                Err(_) => {
243                    // No Tokio runtime available (dropping outside async context).
244                    // Fall back to abort to avoid a potentially unbounded leaked task.
245                    tracing::debug!(
246                        "dropping ReplicationClient outside a Tokio runtime; aborting worker task"
247                    );
248                    join.abort();
249                }
250            }
251        }
252    }
253}
254
255async fn run_worker(worker: &mut WorkerState, cfg: &ReplicationConfig) -> Result<()> {
256    #[cfg(unix)]
257    if cfg.is_unix_socket() {
258        if cfg.tls.mode.requires_tls() {
259            return Err(PgWireError::Tls(
260                "TLS is not supported over Unix domain sockets".into(),
261            ));
262        }
263
264        let path = cfg.unix_socket_path();
265        let mut stream = UnixStream::connect(&path).await.map_err(|e| {
266            PgWireError::Io(format!(
267                "failed to connect to Unix socket {}: {e}",
268                path.display()
269            ))
270        })?;
271
272        return worker.run_on_stream(&mut stream).await;
273    }
274
275    let tcp = TcpStream::connect((cfg.host.as_str(), cfg.port)).await?;
276    tcp.set_nodelay(true)?;
277
278    #[cfg(feature = "tls-rustls")]
279    {
280        use crate::tls::rustls::{maybe_upgrade_to_tls, MaybeTlsStream};
281        let upgraded = maybe_upgrade_to_tls(tcp, &cfg.tls, &cfg.host).await?;
282        match upgraded {
283            MaybeTlsStream::Plain(mut s) => worker.run_on_stream(&mut s).await,
284            MaybeTlsStream::Tls(mut s) => worker.run_on_stream(s.as_mut()).await,
285        }
286    }
287
288    #[cfg(not(feature = "tls-rustls"))]
289    {
290        if !matches!(cfg.tls.mode, SslMode::Disable) {
291            return Err(PgWireError::Tls("tls-rustls feature not enabled".into()));
292        }
293        let mut s = tcp;
294        worker.run_on_stream(&mut s).await
295    }
296}