pg_replica 0.2.0

Consensus-driven failover for PostgreSQL (Raft control plane)
use std::sync::mpsc::Sender as StdSender;
use std::sync::Arc;

use openraft::error::{
    InstallSnapshotError, NetworkError, RPCError, RaftError, RemoteError, Unreachable,
};
use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory};
use openraft::raft::{
    AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
    VoteRequest, VoteResponse,
};
use openraft::{BasicNode, Raft};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use tokio::sync::OnceCell;

use crate::rtype::{NodeId, TypeConfig};

pub struct Peer {
    pub id: u64,
    pub addr: String,
}

pub fn parse_peers(spec: &str) -> Vec<Peer> {
    spec.split(',')
        .filter_map(|entry| {
            let entry = entry.trim();
            if entry.is_empty() {
                return None;
            }
            let (id_part, addr_part) = entry.split_once('@')?;
            let id = id_part.trim().parse::<u64>().ok()?;
            Some(Peer {
                id,
                addr: addr_part.trim().to_string(),
            })
        })
        .collect()
}

pub const KIND_RAFT: u8 = 0;
pub const KIND_GOSSIP: u8 = 1;
const RPC_APPEND: u8 = 0;
const RPC_VOTE: u8 = 1;
const RPC_SNAPSHOT: u8 = 2;
const MAX_FRAME: usize = 256 * 1024 * 1024;

pub type RaftSlot = Arc<OnceCell<Raft<TypeConfig>>>;
type RpcError<E = RaftError<NodeId>> = RPCError<NodeId, BasicNode, E>;

async fn write_frame<S: AsyncWrite + Unpin>(stream: &mut S, bytes: &[u8]) -> std::io::Result<()> {
    stream.write_all(&(bytes.len() as u32).to_be_bytes()).await?;
    stream.write_all(bytes).await?;
    stream.flush().await
}

async fn read_frame<S: AsyncRead + Unpin>(stream: &mut S) -> std::io::Result<Vec<u8>> {
    let mut len = [0u8; 4];
    stream.read_exact(&mut len).await?;
    let n = u32::from_be_bytes(len) as usize;
    if n > MAX_FRAME {
        return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "frame too large"));
    }
    let mut buf = vec![0u8; n];
    stream.read_exact(&mut buf).await?;
    Ok(buf)
}

pub fn spawn_server(
    listener: std::net::TcpListener,
    slot: RaftSlot,
    gossip_in: StdSender<(u64, Vec<u8>)>,
) {
    tokio::spawn(async move {
        let listener = match TcpListener::from_std(listener) {
            Ok(l) => l,
            Err(_) => return,
        };
        loop {
            match listener.accept().await {
                Ok((stream, _)) => {
                    let slot = slot.clone();
                    let gossip_in = gossip_in.clone();
                    tokio::spawn(handle_conn(stream, slot, gossip_in));
                }
                Err(_) => tokio::time::sleep(std::time::Duration::from_millis(100)).await,
            }
        }
    });
}

async fn handle_conn(mut stream: TcpStream, slot: RaftSlot, gossip_in: StdSender<(u64, Vec<u8>)>) {
    loop {
        let frame = match read_frame(&mut stream).await {
            Ok(frame) => frame,
            Err(_) => return,
        };
        if frame.is_empty() {
            return;
        }
        match frame[0] {
            KIND_GOSSIP => {
                if frame.len() >= 9 {
                    let from = u64::from_be_bytes(frame[1..9].try_into().unwrap());
                    let _ = gossip_in.send((from, frame[9..].to_vec()));
                }
            }
            KIND_RAFT => match dispatch_raft(&frame[1..], &slot).await {
                Some(resp) => {
                    if write_frame(&mut stream, &resp).await.is_err() {
                        return;
                    }
                }
                None => return,
            },
            _ => return,
        }
    }
}

async fn dispatch_raft(payload: &[u8], slot: &RaftSlot) -> Option<Vec<u8>> {
    if payload.is_empty() {
        return None;
    }
    let raft = slot.get()?;
    let body = &payload[1..];
    match payload[0] {
        RPC_APPEND => {
            let req: AppendEntriesRequest<TypeConfig> = serde_json::from_slice(body).ok()?;
            let res = raft.append_entries(req).await;
            serde_json::to_vec(&res).ok()
        }
        RPC_VOTE => {
            let req: VoteRequest<NodeId> = serde_json::from_slice(body).ok()?;
            let res = raft.vote(req).await;
            serde_json::to_vec(&res).ok()
        }
        RPC_SNAPSHOT => {
            let req: InstallSnapshotRequest<TypeConfig> = serde_json::from_slice(body).ok()?;
            let res = raft.install_snapshot(req).await;
            serde_json::to_vec(&res).ok()
        }
        _ => None,
    }
}

#[derive(Clone)]
pub struct NetworkFactory;

impl RaftNetworkFactory<TypeConfig> for NetworkFactory {
    type Network = Connection;

    async fn new_client(&mut self, target: NodeId, node: &BasicNode) -> Self::Network {
        Connection {
            target,
            addr: node.addr.clone(),
        }
    }
}

pub struct Connection {
    target: NodeId,
    addr: String,
}

impl Connection {
    async fn call<Req, Resp, Err>(&self, rpc: u8, req: &Req) -> Result<Resp, RpcError<Err>>
    where
        Req: Serialize,
        Resp: DeserializeOwned,
        Err: std::error::Error + DeserializeOwned,
    {
        let mut payload = Vec::with_capacity(2);
        payload.push(KIND_RAFT);
        payload.push(rpc);
        payload.extend_from_slice(
            &serde_json::to_vec(req).map_err(|e| RPCError::Network(NetworkError::new(&e)))?,
        );

        let mut stream = TcpStream::connect(&self.addr)
            .await
            .map_err(|e| RPCError::Unreachable(Unreachable::new(&e)))?;
        write_frame(&mut stream, &payload)
            .await
            .map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
        let resp = read_frame(&mut stream)
            .await
            .map_err(|e| RPCError::Network(NetworkError::new(&e)))?;

        let res: Result<Resp, Err> =
            serde_json::from_slice(&resp).map_err(|e| RPCError::Network(NetworkError::new(&e)))?;
        res.map_err(|e| RPCError::RemoteError(RemoteError::new(self.target, e)))
    }
}

impl RaftNetwork<TypeConfig> for Connection {
    async fn append_entries(
        &mut self,
        req: AppendEntriesRequest<TypeConfig>,
        _option: RPCOption,
    ) -> Result<AppendEntriesResponse<NodeId>, RpcError> {
        self.call(RPC_APPEND, &req).await
    }

    async fn install_snapshot(
        &mut self,
        req: InstallSnapshotRequest<TypeConfig>,
        _option: RPCOption,
    ) -> Result<InstallSnapshotResponse<NodeId>, RpcError<RaftError<NodeId, InstallSnapshotError>>> {
        self.call(RPC_SNAPSHOT, &req).await
    }

    async fn vote(
        &mut self,
        req: VoteRequest<NodeId>,
        _option: RPCOption,
    ) -> Result<VoteResponse<NodeId>, RpcError> {
        self.call(RPC_VOTE, &req).await
    }
}

#[derive(Clone)]
pub struct GossipHandle {
    tx: UnboundedSender<Vec<u8>>,
}

impl GossipHandle {
    pub fn broadcast(&self, from: u64, payload: &[u8]) {
        let mut frame = Vec::with_capacity(9 + payload.len());
        frame.push(KIND_GOSSIP);
        frame.extend_from_slice(&from.to_be_bytes());
        frame.extend_from_slice(payload);
        let _ = self.tx.send(frame);
    }
}

pub fn spawn_gossip_sender(my_id: u64, peers: Vec<(u64, String)>) -> GossipHandle {
    let (tx, mut rx) = unbounded_channel::<Vec<u8>>();
    tokio::spawn(async move {
        while let Some(frame) = rx.recv().await {
            for (id, addr) in &peers {
                if *id == my_id {
                    continue;
                }
                if let Ok(mut stream) = TcpStream::connect(addr).await {
                    let _ = write_frame(&mut stream, &frame).await;
                }
            }
        }
    });
    GossipHandle { tx }
}