pgwire_replication/client/tokio_client.rs
1use crate::config::ReplicationConfig;
2use crate::error::{PgWireError, Result};
3use crate::lsn::Lsn;
4
5use tokio::net::TcpStream;
6#[cfg(unix)]
7use tokio::net::UnixStream;
8
9use tokio::sync::{mpsc, watch};
10use tokio::task::JoinHandle;
11
12use std::sync::Arc;
13
14#[cfg(not(feature = "tls-rustls"))]
15use crate::config::SslMode;
16
17use super::worker::{ReplicationEvent, ReplicationEventReceiver, SharedProgress, WorkerState};
18
19/// PostgreSQL logical replication client.
20///
21/// This client spawns a background worker task that maintains the replication
22/// connection and streams events to the consumer via a bounded channel.
23///
24/// # Example
25///
26/// ```no_run
27/// use pgwire_replication::client::{ReplicationClient, ReplicationEvent};
28/// use pgwire_replication::config::ReplicationConfig;
29///
30/// #[tokio::main]
31/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
32/// let config = ReplicationConfig::new(
33/// "localhost",
34/// "postgres",
35/// "password",
36/// "mydb",
37/// "my_slot",
38/// "my_pub",
39/// );
40///
41/// let mut client = ReplicationClient::connect(config).await?;
42///
43/// while let Some(ev) = client.recv().await? {
44/// match ev {
45/// ReplicationEvent::XLogData { data, wal_end, .. } => {
46/// process_change(&data);
47/// client.update_applied_lsn(wal_end);
48/// }
49/// ReplicationEvent::KeepAlive { .. } => {}
50/// ReplicationEvent::StoppedAt { reached } => {
51/// println!("Reached stop LSN: {reached}");
52/// break;
53/// }
54/// _ => {}
55/// }
56/// }
57///
58/// Ok(())
59/// }
60///
61/// fn process_change(_data: &bytes::Bytes) {
62/// // user-defined
63/// }
64/// ```
65pub struct ReplicationClient {
66 rx: ReplicationEventReceiver,
67 progress: Arc<SharedProgress>,
68 stop_tx: watch::Sender<bool>,
69 join: Option<JoinHandle<std::result::Result<(), PgWireError>>>,
70}
71
72impl ReplicationClient {
73 /// Connect to PostgreSQL and start streaming replication events.
74 ///
75 /// This establishes a TCP connection (optionally upgrading to TLS),
76 /// authenticates, and starts the replication stream. Events are buffered
77 /// in a channel of size `config.buffer_events`.
78 ///
79 /// # Errors
80 ///
81 /// Returns an error if:
82 /// - TCP connection fails
83 /// - TLS handshake fails (when enabled)
84 /// - Authentication fails
85 /// - Replication slot doesn't exist
86 /// - Publication doesn't exist
87 /// - Unix socket does not exist (when host starts with `/`)
88 /// - TLS requested with Unix socket connection
89 pub async fn connect(cfg: ReplicationConfig) -> Result<Self> {
90 let (tx, rx) = mpsc::channel(cfg.buffer_events);
91
92 // Progress is shared via atomics: cheap, monotonic, no async backpressure.
93 let progress = Arc::new(SharedProgress::new(cfg.start_lsn));
94
95 let (stop_tx, stop_rx) = watch::channel(false);
96
97 let progress_for_worker = Arc::clone(&progress);
98 let cfg_for_worker = cfg.clone();
99
100 let join = tokio::spawn(async move {
101 let mut worker = WorkerState::new(cfg_for_worker, progress_for_worker, stop_rx, tx);
102 let res = run_worker(&mut worker, &cfg).await;
103 if let Err(ref e) = res {
104 tracing::error!("replication worker terminated with error: {e}");
105 }
106 res
107 });
108
109 Ok(Self {
110 rx,
111 progress,
112 stop_tx,
113 join: Some(join),
114 })
115 }
116
117 /// Receive the next replication event.
118 ///
119 /// - `Ok(Some(event))` => received an event
120 /// - `Ok(None)` => replication ended normally (stop requested or stop_at_lsn reached)
121 /// - `Err(e)` => replication ended abnormally
122 pub async fn recv(&mut self) -> Result<Option<ReplicationEvent>> {
123 match self.rx.recv().await {
124 Some(Ok(ev)) => Ok(Some(ev)),
125 Some(Err(e)) => Err(e),
126 None => self.handle_worker_shutdown().await,
127 }
128 }
129
130 async fn handle_worker_shutdown(&mut self) -> Result<Option<ReplicationEvent>> {
131 let join = self
132 .join
133 .take()
134 .ok_or_else(|| PgWireError::Internal("replication worker already joined".into()))?;
135
136 match join.await {
137 Ok(Ok(())) => Ok(None),
138 Ok(Err(e)) => Err(e),
139 Err(join_err) => Err(PgWireError::Task(format!(
140 "replication worker panicked: {join_err}"
141 ))),
142 }
143 }
144
145 /// Update the applied/durable LSN reported to the server.
146 ///
147 /// Semantics: call this only once you have durably persisted all events up to `lsn`.
148 /// This update is monotonic and cheap; wire feedback is still governed by the worker’s
149 /// `status_interval` and keepalive reply requests.
150 #[inline]
151 pub fn update_applied_lsn(&self, lsn: Lsn) {
152 self.progress.update_applied(lsn);
153 }
154
155 /// Request the worker to stop gracefully.
156 ///
157 /// After calling this, [`recv()`](Self::recv) will return remaining buffered
158 /// events, then `Ok(None)` once the worker exits cleanly.
159 ///
160 /// This sends a CopyDone message to the server to cleanly terminate
161 /// the replication stream.
162 #[inline]
163 pub fn stop(&self) {
164 let _ = self.stop_tx.send(true);
165 }
166
167 pub fn is_running(&self) -> bool {
168 self.join
169 .as_ref()
170 .map(|j| !j.is_finished())
171 .unwrap_or(false)
172 }
173
174 /// Wait for the worker task to complete and return its result.
175 ///
176 /// This consumes the client. Use this for diagnostics or to ensure
177 /// clean shutdown after calling [`stop()`](Self::stop).
178 pub async fn join(mut self) -> Result<()> {
179 let join = self
180 .join
181 .take()
182 .ok_or_else(|| PgWireError::Task("worker already joined".into()))?;
183
184 match join.await {
185 Ok(inner) => inner,
186 Err(e) => Err(PgWireError::Task(format!("join error: {e}"))),
187 }
188 }
189
190 /// Abort the worker task immediately.
191 ///
192 /// This is a hard cancel and does not send CopyDone.
193 /// Prefer `stop()`/`shutdown()` for graceful termination.
194 pub fn abort(&mut self) {
195 if let Some(join) = self.join.take() {
196 join.abort();
197 }
198 }
199
200 /// Request a graceful stop and wait for the worker to exit.
201 pub async fn shutdown(&mut self) -> Result<()> {
202 self.stop();
203
204 // Drain events until the worker closes the channel.
205 while let Some(msg) = self.rx.recv().await {
206 match msg {
207 Ok(_ev) => {} //discard; caller can drain themselves if they need events
208 Err(e) => return Err(e),
209 }
210 }
211
212 self.join_mut().await
213 }
214
215 /// Wait for the worker task to complete and return its result.
216 async fn join_mut(&mut self) -> Result<()> {
217 let join = self
218 .join
219 .take()
220 .ok_or_else(|| PgWireError::Task("worker already joined".into()))?;
221
222 match join.await {
223 Ok(inner) => inner,
224 Err(e) => Err(PgWireError::Task(format!("join error: {e}"))),
225 }
226 }
227}
228
229impl Drop for ReplicationClient {
230 fn drop(&mut self) {
231 let _ = self.stop_tx.send(true);
232
233 // We cannot .await here. Prefer to detach a join in the background
234 // so the worker can exit cleanly without being aborted.
235 if let Some(join) = self.join.take() {
236 match tokio::runtime::Handle::try_current() {
237 Ok(handle) => {
238 handle.spawn(async move {
239 let _ = join.await;
240 });
241 }
242 Err(_) => {
243 // No Tokio runtime available (dropping outside async context).
244 // Fall back to abort to avoid a potentially unbounded leaked task.
245 tracing::debug!(
246 "dropping ReplicationClient outside a Tokio runtime; aborting worker task"
247 );
248 join.abort();
249 }
250 }
251 }
252 }
253}
254
255async fn run_worker(worker: &mut WorkerState, cfg: &ReplicationConfig) -> Result<()> {
256 #[cfg(unix)]
257 if cfg.is_unix_socket() {
258 if cfg.tls.mode.requires_tls() {
259 return Err(PgWireError::Tls(
260 "TLS is not supported over Unix domain sockets".into(),
261 ));
262 }
263
264 let path = cfg.unix_socket_path();
265 let mut stream = UnixStream::connect(&path).await.map_err(|e| {
266 PgWireError::Io(format!(
267 "failed to connect to Unix socket {}: {e}",
268 path.display()
269 ))
270 })?;
271
272 return worker.run_on_stream(&mut stream).await;
273 }
274
275 let tcp = TcpStream::connect((cfg.host.as_str(), cfg.port)).await?;
276 tcp.set_nodelay(true)?;
277
278 #[cfg(feature = "tls-rustls")]
279 {
280 use crate::tls::rustls::{maybe_upgrade_to_tls, MaybeTlsStream};
281 let upgraded = maybe_upgrade_to_tls(tcp, &cfg.tls, &cfg.host).await?;
282 match upgraded {
283 MaybeTlsStream::Plain(mut s) => worker.run_on_stream(&mut s).await,
284 MaybeTlsStream::Tls(mut s) => worker.run_on_stream(s.as_mut()).await,
285 }
286 }
287
288 #[cfg(not(feature = "tls-rustls"))]
289 {
290 if !matches!(cfg.tls.mode, SslMode::Disable) {
291 return Err(PgWireError::Tls("tls-rustls feature not enabled".into()));
292 }
293 let mut s = tcp;
294 worker.run_on_stream(&mut s).await
295 }
296}