use crate::wire::WireError;
use kevy_resp::Argv;
use std::io::{self, Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::time::Duration;
#[derive(Debug)]
pub struct DecodedFrame {
pub offset: u64,
pub argv: Argv,
}
#[derive(Debug)]
pub enum ReplicaEvent {
Frame(DecodedFrame),
SnapshotBegin,
SnapshotChunk(Vec<u8>),
SnapshotEnd {
ack_offset: u64,
},
}
#[derive(Debug)]
pub enum ReplicaError {
HandshakeRejected,
AckMalformed,
Truncated,
Frame(WireError),
OffsetGap {
expected: u64,
got: u64,
},
UnexpectedInSnapshot,
SnapshotInProgress,
Io(io::Error),
}
impl std::fmt::Display for ReplicaError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::HandshakeRejected => write!(f, "primary rejected replication handshake"),
Self::AckMalformed => write!(f, "primary sent malformed +ACK"),
Self::Truncated => write!(f, "replication stream truncated by peer"),
Self::Frame(e) => write!(f, "replication frame decode error: {e}"),
Self::OffsetGap { expected, got } => {
write!(f, "replication offset gap: expected {expected}, got {got}")
}
Self::UnexpectedInSnapshot => {
write!(f, "primary sent non-chunk bytes mid-snapshot")
}
Self::SnapshotInProgress => {
write!(f, "snapshot in progress; use next_event() to consume")
}
Self::Io(e) => write!(f, "replication socket I/O error: {e}"),
}
}
}
impl std::error::Error for ReplicaError {}
impl From<io::Error> for ReplicaError {
fn from(e: io::Error) -> Self {
ReplicaError::Io(e)
}
}
impl From<WireError> for ReplicaError {
fn from(e: WireError) -> Self {
match e {
WireError::Truncated => ReplicaError::Truncated,
other => ReplicaError::Frame(other),
}
}
}
pub struct ReplicaClient {
pub(crate) sock: TcpStream,
pub(crate) buf: Vec<u8>,
pub(crate) cursor: usize,
pub(crate) primary_offset_at_handshake: u64,
pub(crate) expected_offset: u64,
pub(crate) in_snapshot: bool,
}
impl ReplicaClient {
pub fn connect<A: ToSocketAddrs>(
addr: A,
replica_id: &str,
from_offset: u64,
) -> Result<Self, ReplicaError> {
Self::connect_with_timeout(addr, replica_id, from_offset, Duration::from_secs(5))
}
pub fn connect_with_timeout<A: ToSocketAddrs>(
addr: A,
replica_id: &str,
from_offset: u64,
connect_timeout: Duration,
) -> Result<Self, ReplicaError> {
let mut last_err: Option<io::Error> = None;
let mut sock: Option<TcpStream> = None;
for sa in addr.to_socket_addrs().map_err(ReplicaError::Io)? {
match TcpStream::connect_timeout(&sa, connect_timeout) {
Ok(s) => {
sock = Some(s);
break;
}
Err(e) => last_err = Some(e),
}
}
let mut sock = sock.ok_or_else(|| {
ReplicaError::Io(last_err.unwrap_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "no socket address resolved")
}))
})?;
let req = encode_replicate_from(from_offset, replica_id);
sock.write_all(&req)?;
sock.set_read_timeout(Some(connect_timeout))?;
let primary_offset = read_ack(&mut sock)?;
sock.set_read_timeout(None)?;
sock.set_nonblocking(false)?;
Ok(ReplicaClient {
sock,
buf: Vec::with_capacity(8 * 1024),
cursor: 0,
primary_offset_at_handshake: primary_offset,
expected_offset: from_offset,
in_snapshot: false,
})
}
pub fn primary_offset_at_handshake(&self) -> u64 {
self.primary_offset_at_handshake
}
pub fn socket_handle(&self) -> io::Result<TcpStream> {
self.sock.try_clone()
}
pub fn expected_offset(&self) -> u64 {
self.expected_offset
}
pub fn next_frame(&mut self) -> Option<Result<DecodedFrame, ReplicaError>> {
match self.next_event()? {
Ok(ReplicaEvent::Frame(f)) => Some(Ok(f)),
Ok(_) => Some(Err(ReplicaError::SnapshotInProgress)),
Err(e) => Some(Err(e)),
}
}
pub(crate) fn maybe_compact_buf(&mut self) {
if self.cursor >= 4 * 1024 {
self.buf.drain(..self.cursor);
self.cursor = 0;
}
}
}
impl Iterator for ReplicaClient {
type Item = Result<DecodedFrame, ReplicaError>;
fn next(&mut self) -> Option<Self::Item> {
self.next_frame()
}
}
fn encode_replicate_from(from_offset: u64, replica_id: &str) -> Vec<u8> {
let mut v = Vec::with_capacity(64 + replica_id.len());
v.extend_from_slice(b"*5\r\n");
let offset_str = from_offset.to_string();
for arg in [
b"REPLICATE".as_slice(),
b"FROM",
offset_str.as_bytes(),
b"ID",
replica_id.as_bytes(),
] {
let header = format!("${}\r\n", arg.len());
v.extend_from_slice(header.as_bytes());
v.extend_from_slice(arg);
v.extend_from_slice(b"\r\n");
}
v
}
fn read_ack(sock: &mut TcpStream) -> Result<u64, ReplicaError> {
let mut line = Vec::with_capacity(32);
let mut b = [0u8; 1];
loop {
match sock.read(&mut b) {
Ok(0) => return Err(ReplicaError::HandshakeRejected),
Ok(_) => {
line.push(b[0]);
if line.len() >= 2 && line.ends_with(b"\r\n") {
break;
}
if line.len() > 256 {
return Err(ReplicaError::AckMalformed);
}
}
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(ReplicaError::Io(e)),
}
}
parse_ack_line(&line)
}
fn parse_ack_line(line: &[u8]) -> Result<u64, ReplicaError> {
let body = line.strip_suffix(b"\r\n").ok_or(ReplicaError::AckMalformed)?;
let body = body.strip_prefix(b"+ACK ").ok_or(ReplicaError::AckMalformed)?;
let s = std::str::from_utf8(body).map_err(|_| ReplicaError::AckMalformed)?;
s.parse::<u64>().map_err(|_| ReplicaError::AckMalformed)
}
#[cfg(test)]
impl ReplicaClient {
pub(crate) fn from_socket_for_test(sock: TcpStream, expected_offset: u64) -> Self {
Self {
sock,
buf: Vec::with_capacity(8 * 1024),
cursor: 0,
primary_offset_at_handshake: expected_offset,
expected_offset,
in_snapshot: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encoded_replicate_from_matches_what_primary_parses() {
let bytes = encode_replicate_from(42, "replica-a");
let mut argv = Argv::default();
let consumed = kevy_resp::parse_command_into(&bytes, &mut argv)
.expect("parse ok")
.expect("complete");
assert_eq!(consumed, bytes.len());
let req = crate::handshake::parse_replicate_from(&argv).expect("handshake ok");
assert_eq!(req.from_offset, 42);
assert_eq!(req.replica_id, "replica-a");
}
#[test]
fn ack_line_parses_offsets() {
assert_eq!(parse_ack_line(b"+ACK 0\r\n").unwrap(), 0);
assert_eq!(parse_ack_line(b"+ACK 42\r\n").unwrap(), 42);
assert_eq!(parse_ack_line(b"+ACK 12345678\r\n").unwrap(), 12_345_678);
}
#[test]
fn ack_line_rejects_malformed() {
assert!(matches!(
parse_ack_line(b"+PONG\r\n"),
Err(ReplicaError::AckMalformed)
));
assert!(matches!(
parse_ack_line(b"+ACK abc\r\n"),
Err(ReplicaError::AckMalformed)
));
assert!(matches!(
parse_ack_line(b"-ERR nope\r\n"),
Err(ReplicaError::AckMalformed)
));
assert!(matches!(
parse_ack_line(b"+ACK 1"),
Err(ReplicaError::AckMalformed)
));
}
#[test]
fn ack_line_rejects_offset_overflow() {
assert!(matches!(
parse_ack_line(b"+ACK 99999999999999999999999\r\n"),
Err(ReplicaError::AckMalformed)
));
}
#[test]
fn from_io_error_wraps_into_io_variant() {
let e: ReplicaError = io::Error::new(io::ErrorKind::ConnectionRefused, "x").into();
assert!(matches!(e, ReplicaError::Io(_)));
}
#[test]
fn from_wire_error_truncated_maps_to_truncated() {
let e: ReplicaError = WireError::Truncated.into();
assert!(matches!(e, ReplicaError::Truncated));
}
#[test]
fn from_wire_error_other_maps_to_frame() {
let e: ReplicaError = WireError::BadEnvelope.into();
assert!(matches!(e, ReplicaError::Frame(_)));
}
}