use kevy_replicate::handshake::{HandshakeError, encode_ack, parse_replicate_from};
use kevy_resp::{Argv, parse_command_into};
use kevy_sys::Socket;
pub struct ReplicaConn {
pub sock: Socket,
pub fd: i32,
pub input: Vec<u8>,
pub output: Vec<u8>,
pub write_off: usize,
pub state: ReplicaState,
pub peer: (std::net::Ipv4Addr, u16),
}
#[derive(Debug)]
pub enum ReplicaState {
HandshakePending,
AckSent {
replica_id: String,
from_offset: u64,
},
Streaming {
replica_id: String,
sent_offset: u64,
},
SnapshotShipping {
replica_id: String,
ack_offset: u64,
serializing: Option<std::sync::mpsc::Receiver<Vec<u8>>>,
snapshot_buf: Vec<u8>,
snapshot_off: usize,
},
Closed {
replica_id: Option<String>,
sent_offset: u64,
},
}
impl ReplicaConn {
#[allow(dead_code)] pub fn new(sock: Socket) -> Self {
Self::with_peer(sock, (std::net::Ipv4Addr::UNSPECIFIED, 0))
}
pub fn with_peer(sock: Socket, peer: (std::net::Ipv4Addr, u16)) -> Self {
let fd = sock.raw();
Self {
sock,
fd,
input: Vec::with_capacity(256),
output: Vec::with_capacity(64),
write_off: 0,
state: ReplicaState::HandshakePending,
peer,
}
}
pub fn close(&mut self) {
let (id, off) = match &self.state {
ReplicaState::HandshakePending => (None, 0),
ReplicaState::AckSent { replica_id, from_offset } => {
(Some(replica_id.clone()), *from_offset)
}
ReplicaState::Streaming { replica_id, sent_offset } => {
(Some(replica_id.clone()), *sent_offset)
}
ReplicaState::SnapshotShipping { replica_id, ack_offset, .. } => {
(Some(replica_id.clone()), *ack_offset)
}
ReplicaState::Closed { .. } => return,
};
self.state = ReplicaState::Closed {
replica_id: id,
sent_offset: off,
};
}
}
pub(crate) fn advance_handshake(conn: &mut ReplicaConn) -> Result<(), HandshakeError> {
if !matches!(conn.state, ReplicaState::HandshakePending) {
return Ok(());
}
let mut argv = Argv::default();
let parsed = parse_command_into(&conn.input, &mut argv)
.map_err(|_| HandshakeError::BadCommand)?;
let consumed = match parsed {
Some(n) => n,
None => return Ok(()), };
let req = parse_replicate_from(&argv)?;
conn.input.drain(..consumed);
conn.output.extend_from_slice(&encode_ack(req.from_offset));
conn.write_off = 0;
conn.state = ReplicaState::AckSent {
replica_id: req.replica_id,
from_offset: req.from_offset,
};
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn fake_conn() -> ReplicaConn {
let sock = unsafe { Socket::from_raw_fd(-1) };
ReplicaConn {
sock,
fd: -1,
input: Vec::new(),
output: Vec::new(),
write_off: 0,
state: ReplicaState::HandshakePending,
peer: (std::net::Ipv4Addr::UNSPECIFIED, 0),
}
}
fn resp_replicate_from(offset: &str, id: &str) -> Vec<u8> {
let mut v = Vec::new();
v.extend_from_slice(b"*5\r\n");
for arg in [b"REPLICATE".as_slice(), b"FROM", offset.as_bytes(), b"ID", id.as_bytes()] {
v.extend_from_slice(format!("${}\r\n", arg.len()).as_bytes());
v.extend_from_slice(arg);
v.extend_from_slice(b"\r\n");
}
v
}
#[test]
fn close_from_handshake_pending_carries_no_id() {
let mut conn = fake_conn();
conn.close();
match conn.state {
ReplicaState::Closed { replica_id, sent_offset } => {
assert_eq!(replica_id, None);
assert_eq!(sent_offset, 0);
}
other => panic!("expected Closed, got {other:?}"),
}
}
#[test]
fn close_from_ack_sent_preserves_id_and_offset() {
let mut conn = fake_conn();
conn.input = resp_replicate_from("17", "replica-x");
advance_handshake(&mut conn).expect("handshake ok");
conn.close();
match conn.state {
ReplicaState::Closed { replica_id, sent_offset } => {
assert_eq!(replica_id.as_deref(), Some("replica-x"));
assert_eq!(sent_offset, 17);
}
other => panic!("expected Closed, got {other:?}"),
}
}
#[test]
fn close_from_streaming_preserves_id_and_offset() {
let mut conn = fake_conn();
conn.state = ReplicaState::Streaming {
replica_id: "replica-z".into(),
sent_offset: 99,
};
conn.close();
match conn.state {
ReplicaState::Closed { replica_id, sent_offset } => {
assert_eq!(replica_id.as_deref(), Some("replica-z"));
assert_eq!(sent_offset, 99);
}
other => panic!("expected Closed, got {other:?}"),
}
}
#[test]
fn close_is_idempotent() {
let mut conn = fake_conn();
conn.state = ReplicaState::Streaming {
replica_id: "r".into(),
sent_offset: 5,
};
conn.close();
let snapshot = format!("{:?}", conn.state);
conn.close(); assert_eq!(format!("{:?}", conn.state), snapshot);
}
#[test]
fn handshake_pending_to_ack_sent_on_complete_command() {
let mut conn = fake_conn();
conn.input = resp_replicate_from("42", "replica-a");
advance_handshake(&mut conn).expect("ok");
match &conn.state {
ReplicaState::AckSent { replica_id, from_offset } => {
assert_eq!(replica_id, "replica-a");
assert_eq!(*from_offset, 42);
}
other => panic!("expected AckSent, got {other:?}"),
}
assert_eq!(conn.output, b"+ACK 42\r\n");
assert!(conn.input.is_empty());
}
#[test]
fn partial_handshake_stays_pending_and_waits_for_more_bytes() {
let mut conn = fake_conn();
let full = resp_replicate_from("0", "replica-a");
conn.input = full[..full.len() / 2].to_vec();
advance_handshake(&mut conn).expect("ok");
assert!(matches!(conn.state, ReplicaState::HandshakePending));
assert!(conn.output.is_empty());
conn.input.extend_from_slice(&full[full.len() / 2..]);
advance_handshake(&mut conn).expect("ok");
assert!(matches!(conn.state, ReplicaState::AckSent { .. }));
}
#[test]
fn wrong_command_is_rejected_at_handshake() {
let mut conn = fake_conn();
conn.input = b"*1\r\n$4\r\nPING\r\n".to_vec();
let err = advance_handshake(&mut conn).unwrap_err();
assert!(matches!(err, HandshakeError::WrongArity(_) | HandshakeError::BadCommand));
assert!(matches!(conn.state, ReplicaState::HandshakePending));
}
#[test]
fn inline_form_parses_then_handshake_rejects_arity() {
let mut conn = fake_conn();
conn.input = b"!garbage\r\n".to_vec();
let err = advance_handshake(&mut conn).unwrap_err();
assert_eq!(err, HandshakeError::WrongArity(1));
}
#[test]
fn resp_level_malformed_input_returns_bad_command() {
let mut conn = fake_conn();
conn.input = b"*1\r\n!nope\r\n".to_vec();
let err = advance_handshake(&mut conn).unwrap_err();
assert_eq!(err, HandshakeError::BadCommand);
}
#[test]
fn second_call_after_ack_is_noop() {
let mut conn = fake_conn();
conn.input = resp_replicate_from("7", "r");
advance_handshake(&mut conn).unwrap();
let out_before = conn.output.clone();
advance_handshake(&mut conn).unwrap();
assert_eq!(conn.output, out_before);
}
}