use crate::config::ReplicationConfig;
use crate::error::{PgWireError, Result};
use crate::lsn::Lsn;
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::sync::{mpsc, watch};
use tokio::task::JoinHandle;
use std::sync::Arc;
#[cfg(not(feature = "tls-rustls"))]
use crate::config::SslMode;
use super::worker::{ReplicationEvent, ReplicationEventReceiver, SharedProgress, WorkerState};
pub struct ReplicationClient {
rx: ReplicationEventReceiver,
progress: Arc<SharedProgress>,
stop_tx: watch::Sender<bool>,
join: Option<JoinHandle<std::result::Result<(), PgWireError>>>,
}
impl ReplicationClient {
pub async fn connect(cfg: ReplicationConfig) -> Result<Self> {
let (tx, rx) = mpsc::channel(cfg.buffer_events);
let progress = Arc::new(SharedProgress::new(cfg.start_lsn));
let (stop_tx, stop_rx) = watch::channel(false);
let progress_for_worker = Arc::clone(&progress);
let cfg_for_worker = cfg.clone();
let join = tokio::spawn(async move {
let mut worker = WorkerState::new(cfg_for_worker, progress_for_worker, stop_rx, tx);
let res = run_worker(&mut worker, &cfg).await;
if let Err(ref e) = res {
tracing::error!("replication worker terminated with error: {e}");
}
res
});
Ok(Self {
rx,
progress,
stop_tx,
join: Some(join),
})
}
pub async fn recv(&mut self) -> Result<Option<ReplicationEvent>> {
match self.rx.recv().await {
Some(Ok(ev)) => Ok(Some(ev)),
Some(Err(e)) => Err(e),
None => self.handle_worker_shutdown().await,
}
}
async fn handle_worker_shutdown(&mut self) -> Result<Option<ReplicationEvent>> {
let join = self
.join
.take()
.ok_or_else(|| PgWireError::Internal("replication worker already joined".into()))?;
match join.await {
Ok(Ok(())) => Ok(None),
Ok(Err(e)) => Err(e),
Err(join_err) => Err(PgWireError::Task(format!(
"replication worker panicked: {join_err}"
))),
}
}
#[inline]
pub fn update_applied_lsn(&self, lsn: Lsn) {
self.progress.update_applied(lsn);
}
#[inline]
pub fn stop(&self) {
let _ = self.stop_tx.send(true);
}
pub fn is_running(&self) -> bool {
self.join
.as_ref()
.map(|j| !j.is_finished())
.unwrap_or(false)
}
pub async fn join(mut self) -> Result<()> {
let join = self
.join
.take()
.ok_or_else(|| PgWireError::Task("worker already joined".into()))?;
match join.await {
Ok(inner) => inner,
Err(e) => Err(PgWireError::Task(format!("join error: {e}"))),
}
}
pub fn abort(&mut self) {
if let Some(join) = self.join.take() {
join.abort();
}
}
pub async fn shutdown(&mut self) -> Result<()> {
self.stop();
while let Some(msg) = self.rx.recv().await {
match msg {
Ok(_ev) => {} Err(e) => return Err(e),
}
}
self.join_mut().await
}
async fn join_mut(&mut self) -> Result<()> {
let join = self
.join
.take()
.ok_or_else(|| PgWireError::Task("worker already joined".into()))?;
match join.await {
Ok(inner) => inner,
Err(e) => Err(PgWireError::Task(format!("join error: {e}"))),
}
}
}
impl Drop for ReplicationClient {
fn drop(&mut self) {
let _ = self.stop_tx.send(true);
if let Some(join) = self.join.take() {
match tokio::runtime::Handle::try_current() {
Ok(handle) => {
handle.spawn(async move {
let _ = join.await;
});
}
Err(_) => {
tracing::debug!(
"dropping ReplicationClient outside a Tokio runtime; aborting worker task"
);
join.abort();
}
}
}
}
}
async fn run_worker(worker: &mut WorkerState, cfg: &ReplicationConfig) -> Result<()> {
#[cfg(unix)]
if cfg.is_unix_socket() {
if cfg.tls.mode.requires_tls() {
return Err(PgWireError::Tls(
"TLS is not supported over Unix domain sockets".into(),
));
}
let path = cfg.unix_socket_path();
let mut stream = UnixStream::connect(&path).await.map_err(|e| {
PgWireError::Io(std::sync::Arc::new(std::io::Error::new(
e.kind(),
format!("failed to connect to Unix socket {}: {e}", path.display()),
)))
})?;
return worker.run_on_stream(&mut stream).await;
}
let tcp = TcpStream::connect((cfg.host.as_str(), cfg.port)).await?;
tcp.set_nodelay(true)?;
#[cfg(feature = "tls-rustls")]
{
use crate::tls::rustls::{maybe_upgrade_to_tls, MaybeTlsStream};
let upgraded = maybe_upgrade_to_tls(tcp, &cfg.tls, &cfg.host).await?;
match upgraded {
MaybeTlsStream::Plain(mut s) => worker.run_on_stream(&mut s).await,
MaybeTlsStream::Tls(mut s) => worker.run_on_stream(s.as_mut()).await,
}
}
#[cfg(not(feature = "tls-rustls"))]
{
if !matches!(cfg.tls.mode, SslMode::Disable) {
return Err(PgWireError::Tls("tls-rustls feature not enabled".into()));
}
let mut s = tcp;
worker.run_on_stream(&mut s).await
}
}