pgwire_replication/client/tokio_client.rs
1use crate::config::ReplicationConfig;
2use crate::error::{PgWireError, Result};
3use crate::lsn::Lsn;
4use tokio::net::TcpStream;
5use tokio::sync::{mpsc, watch};
6use tokio::task::JoinHandle;
7
8use std::sync::Arc;
9
10#[cfg(not(feature = "tls-rustls"))]
11use crate::config::SslMode;
12
13use super::worker::{ReplicationEvent, ReplicationEventReceiver, SharedProgress, WorkerState};
14
15/// PostgreSQL logical replication client.
16///
17/// This client spawns a background worker task that maintains the replication
18/// connection and streams events to the consumer via a bounded channel.
19///
20/// # Example
21///
22/// ```no_run
23/// use pgwire_replication::client::{ReplicationClient, ReplicationEvent};
24/// use pgwire_replication::config::ReplicationConfig;
25///
26/// #[tokio::main]
27/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
28/// let config = ReplicationConfig::new(
29/// "localhost",
30/// "postgres",
31/// "password",
32/// "mydb",
33/// "my_slot",
34/// "my_pub",
35/// );
36///
37/// let mut client = ReplicationClient::connect(config).await?;
38///
39/// while let Some(ev) = client.recv().await? {
40/// match ev {
41/// ReplicationEvent::XLogData { data, wal_end, .. } => {
42/// process_change(&data);
43/// client.update_applied_lsn(wal_end);
44/// }
45/// ReplicationEvent::KeepAlive { .. } => {}
46/// ReplicationEvent::StoppedAt { reached } => {
47/// println!("Reached stop LSN: {reached}");
48/// break;
49/// }
50/// _ => {}
51/// }
52/// }
53///
54/// Ok(())
55/// }
56///
57/// fn process_change(_data: &bytes::Bytes) {
58/// // user-defined
59/// }
60/// ```
61pub struct ReplicationClient {
62 rx: ReplicationEventReceiver,
63 progress: Arc<SharedProgress>,
64 stop_tx: watch::Sender<bool>,
65 join: Option<JoinHandle<std::result::Result<(), PgWireError>>>,
66}
67
68impl ReplicationClient {
69 /// Connect to PostgreSQL and start streaming replication events.
70 ///
71 /// This establishes a TCP connection (optionally upgrading to TLS),
72 /// authenticates, and starts the replication stream. Events are buffered
73 /// in a channel of size `config.buffer_events`.
74 ///
75 /// # Errors
76 ///
77 /// Returns an error if:
78 /// - TCP connection fails
79 /// - TLS handshake fails (when enabled)
80 /// - Authentication fails
81 /// - Replication slot doesn't exist
82 /// - Publication doesn't exist
83 pub async fn connect(cfg: ReplicationConfig) -> Result<Self> {
84 let (tx, rx) = mpsc::channel(cfg.buffer_events);
85
86 // Progress is shared via atomics: cheap, monotonic, no async backpressure.
87 let progress = Arc::new(SharedProgress::new(cfg.start_lsn));
88
89 let (stop_tx, stop_rx) = watch::channel(false);
90
91 let progress_for_worker = Arc::clone(&progress);
92 let cfg_for_worker = cfg.clone();
93
94 let join = tokio::spawn(async move {
95 let mut worker = WorkerState::new(cfg_for_worker, progress_for_worker, stop_rx, tx);
96 let res = run_worker(&mut worker, &cfg).await;
97 if let Err(ref e) = res {
98 tracing::error!("replication worker terminated with error: {e}");
99 }
100 res
101 });
102
103 Ok(Self {
104 rx,
105 progress,
106 stop_tx,
107 join: Some(join),
108 })
109 }
110
111 /// Receive the next replication event.
112 ///
113 /// - `Ok(Some(event))` => received an event
114 /// - `Ok(None)` => replication ended normally (stop requested or stop_at_lsn reached)
115 /// - `Err(e)` => replication ended abnormally
116 pub async fn recv(&mut self) -> Result<Option<ReplicationEvent>> {
117 match self.rx.recv().await {
118 Some(Ok(ev)) => Ok(Some(ev)),
119 Some(Err(e)) => Err(e),
120 None => self.handle_worker_shutdown().await,
121 }
122 }
123
124 async fn handle_worker_shutdown(&mut self) -> Result<Option<ReplicationEvent>> {
125 let join = self
126 .join
127 .take()
128 .ok_or_else(|| PgWireError::Internal("replication worker already joined".into()))?;
129
130 match join.await {
131 Ok(Ok(())) => Ok(None),
132 Ok(Err(e)) => Err(e),
133 Err(join_err) => Err(PgWireError::Task(format!(
134 "replication worker panicked: {join_err}"
135 ))),
136 }
137 }
138
139 /// Update the applied/durable LSN reported to the server.
140 ///
141 /// Semantics: call this only once you have durably persisted all events up to `lsn`.
142 /// This update is monotonic and cheap; wire feedback is still governed by the worker’s
143 /// `status_interval` and keepalive reply requests.
144 #[inline]
145 pub fn update_applied_lsn(&self, lsn: Lsn) {
146 self.progress.update_applied(lsn);
147 }
148
149 /// Request the worker to stop gracefully.
150 ///
151 /// After calling this, [`recv()`](Self::recv) will return remaining buffered
152 /// events, then `Ok(None)` once the worker exits cleanly.
153 ///
154 /// This sends a CopyDone message to the server to cleanly terminate
155 /// the replication stream.
156 #[inline]
157 pub fn stop(&self) {
158 let _ = self.stop_tx.send(true);
159 }
160
161 pub fn is_running(&self) -> bool {
162 self.join
163 .as_ref()
164 .map(|j| !j.is_finished())
165 .unwrap_or(false)
166 }
167
168 /// Wait for the worker task to complete and return its result.
169 ///
170 /// This consumes the client. Use this for diagnostics or to ensure
171 /// clean shutdown after calling [`stop()`](Self::stop).
172 pub async fn join(mut self) -> Result<()> {
173 let join = self
174 .join
175 .take()
176 .ok_or_else(|| PgWireError::Task("worker already joined".into()))?;
177
178 match join.await {
179 Ok(inner) => inner,
180 Err(e) => Err(PgWireError::Task(format!("join error: {e}"))),
181 }
182 }
183
184 /// Abort the worker task immediately.
185 ///
186 /// This is a hard cancel and does not send CopyDone.
187 /// Prefer `stop()`/`shutdown()` for graceful termination.
188 pub fn abort(&mut self) {
189 if let Some(join) = self.join.take() {
190 join.abort();
191 }
192 }
193
194 /// Request a graceful stop and wait for the worker to exit.
195 pub async fn shutdown(&mut self) -> Result<()> {
196 self.stop();
197
198 // Drain events until the worker closes the channel.
199 while let Some(msg) = self.rx.recv().await {
200 match msg {
201 Ok(_ev) => {} //discard; caller can drain themselves if they need events
202 Err(e) => return Err(e),
203 }
204 }
205
206 self.join_mut().await
207 }
208
209 /// Wait for the worker task to complete and return its result.
210 async fn join_mut(&mut self) -> Result<()> {
211 let join = self
212 .join
213 .take()
214 .ok_or_else(|| PgWireError::Task("worker already joined".into()))?;
215
216 match join.await {
217 Ok(inner) => inner,
218 Err(e) => Err(PgWireError::Task(format!("join error: {e}"))),
219 }
220 }
221}
222
223impl Drop for ReplicationClient {
224 fn drop(&mut self) {
225 let _ = self.stop_tx.send(true);
226
227 // We cannot .await here. Prefer to detach a join in the background
228 // so the worker can exit cleanly without being aborted.
229 if let Some(join) = self.join.take() {
230 match tokio::runtime::Handle::try_current() {
231 Ok(handle) => {
232 handle.spawn(async move {
233 let _ = join.await;
234 });
235 }
236 Err(_) => {
237 // No Tokio runtime available (dropping outside async context).
238 // Fall back to abort to avoid a potentially unbounded leaked task.
239 tracing::debug!(
240 "dropping ReplicationClient outside a Tokio runtime; aborting worker task"
241 );
242 join.abort();
243 }
244 }
245 }
246 }
247}
248
249async fn run_worker(worker: &mut WorkerState, cfg: &ReplicationConfig) -> Result<()> {
250 let tcp = TcpStream::connect((cfg.host.as_str(), cfg.port)).await?;
251 tcp.set_nodelay(true)?;
252
253 #[cfg(feature = "tls-rustls")]
254 {
255 use crate::tls::rustls::{maybe_upgrade_to_tls, MaybeTlsStream};
256 let upgraded = maybe_upgrade_to_tls(tcp, &cfg.tls, &cfg.host).await?;
257 match upgraded {
258 MaybeTlsStream::Plain(mut s) => worker.run_on_stream(&mut s).await,
259 MaybeTlsStream::Tls(mut s) => worker.run_on_stream(s.as_mut()).await,
260 }
261 }
262
263 #[cfg(not(feature = "tls-rustls"))]
264 {
265 if !matches!(cfg.tls.mode, SslMode::Disable) {
266 return Err(PgWireError::Tls("tls-rustls feature not enabled".into()));
267 }
268 let mut s = tcp;
269 worker.run_on_stream(&mut s).await
270 }
271}