use std::io::{Read, Write};
use std::net::{Shutdown, TcpListener, TcpStream, ToSocketAddrs};
use std::sync::mpsc::{Receiver, Sender, channel, RecvTimeoutError};
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread::JoinHandle;
use std::time::{Duration, Instant};
use crate::elector::{Elector, Outbound};
use crate::message::Message;
use crate::wire::{DecodeError, decode, encode};
const READ_BUF_CAP: usize = 16 * 1024;
const READ_RETRY_BACKOFF: Duration = Duration::from_millis(100);
pub enum InboundEvent {
Message(String, Message),
InboundConnFailed(String),
}
struct Shared {
elector: Mutex<Elector>,
out_queues: Mutex<std::collections::HashMap<String, std::collections::VecDeque<Message>>>,
}
const MAX_PENDING_PER_PEER: usize = 256;
#[derive(Debug, Clone)]
pub struct PeerAddr {
pub node_id: String,
pub host: String,
pub port: u16,
}
pub struct Transport {
stop: Arc<AtomicBool>,
handles: Vec<JoinHandle<()>>,
shared: Arc<Shared>,
state_view: Arc<Shared>,
}
impl Transport {
pub fn spawn(
elector: Elector,
hb_interval: Duration,
listen_addr: (std::net::IpAddr, u16),
peers: Vec<PeerAddr>,
) -> std::io::Result<Self> {
let shared = Arc::new(Shared {
elector: Mutex::new(elector),
out_queues: Mutex::new(std::collections::HashMap::new()),
});
let stop = Arc::new(AtomicBool::new(false));
let mut handles = Vec::new();
let (inbound_tx, inbound_rx) = channel::<InboundEvent>();
let listener = TcpListener::bind(listen_addr)?;
listener.set_nonblocking(false)?;
let listener_stop = stop.clone();
let listener_tx = inbound_tx.clone();
handles.push(
std::thread::Builder::new()
.name("kevy-elect-listener".to_string())
.spawn(move || {
accept_loop(listener, listener_tx, listener_stop);
})?,
);
for peer in &peers {
let peer_stop = stop.clone();
let peer_shared = shared.clone();
let peer_clone = peer.clone();
handles.push(
std::thread::Builder::new()
.name(format!("kevy-elect-out-{}", peer.node_id))
.spawn(move || {
outbound_loop(peer_clone, peer_shared, peer_stop);
})?,
);
}
let orch_stop = stop.clone();
let orch_shared = shared.clone();
handles.push(
std::thread::Builder::new()
.name("kevy-elect-orchestrator".to_string())
.spawn(move || {
orchestrator_loop(orch_shared, inbound_rx, hb_interval, orch_stop);
})?,
);
Ok(Self {
stop,
handles,
state_view: shared.clone(),
shared,
})
}
pub fn state_snapshot(&self) -> ElectorSnapshot {
let e = self.state_view.elector.lock().expect("elector lock");
let now = std::time::Instant::now();
let down_peers: Vec<String> = e
.peer_ids
.iter()
.filter(|id| id.as_str() != e.node_id.as_str())
.filter(|id| e.is_peer_down(id, now))
.cloned()
.collect();
ElectorSnapshot {
role: e.role(),
epoch: e.epoch(),
current_primary: e.current_primary().map(str::to_string),
down_peers,
}
}
pub fn set_repl_offset(&self, offset: u64) {
self.shared
.elector
.lock()
.expect("elector lock")
.set_repl_offset(offset);
}
pub fn shutdown(mut self) {
self.stop.store(true, Ordering::Relaxed);
for h in self.handles.drain(..) {
let _ = h.join();
}
}
}
impl Drop for Transport {
fn drop(&mut self) {
self.stop.store(true, Ordering::Relaxed);
}
}
#[derive(Debug, Clone)]
pub struct ElectorSnapshot {
pub role: crate::message::Role,
pub epoch: u64,
pub current_primary: Option<String>,
pub down_peers: Vec<String>,
}
fn accept_loop(listener: TcpListener, tx: Sender<InboundEvent>, stop: Arc<AtomicBool>) {
listener
.set_nonblocking(true)
.expect("listener set_nonblocking(true)");
while !stop.load(Ordering::Relaxed) {
match listener.accept() {
Ok((stream, addr)) => {
let _ = stream.set_nonblocking(false); let tx_clone = tx.clone();
let stop_clone = stop.clone();
let addr_str = addr.to_string();
let _ = std::thread::Builder::new()
.name(format!("kevy-elect-in-{addr_str}"))
.spawn(move || {
inbound_read_loop(stream, addr_str, tx_clone, stop_clone);
});
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::sleep(Duration::from_millis(20));
}
Err(_) => {
std::thread::sleep(READ_RETRY_BACKOFF);
}
}
}
}
fn inbound_read_loop(
mut stream: TcpStream,
peer_addr: String,
tx: Sender<InboundEvent>,
stop: Arc<AtomicBool>,
) {
let _ = stream.set_nodelay(true);
let _ = stream.set_read_timeout(Some(Duration::from_millis(200)));
let mut buf: Vec<u8> = Vec::with_capacity(READ_BUF_CAP);
let mut chunk = [0u8; 1024];
while !stop.load(Ordering::Relaxed) {
match stream.read(&mut chunk) {
Ok(0) => {
let _ = tx.send(InboundEvent::InboundConnFailed(peer_addr.clone()));
return;
}
Ok(n) => {
buf.extend_from_slice(&chunk[..n]);
if buf.len() > READ_BUF_CAP {
let _ = tx.send(InboundEvent::InboundConnFailed(peer_addr.clone()));
return;
}
while !buf.is_empty() {
match decode(&buf) {
Ok((msg, used)) => {
let from = message_sender(&msg);
let _ = tx.send(InboundEvent::Message(from, msg));
buf.drain(..used);
}
Err(DecodeError::Truncated) => break,
Err(_) => {
let _ = tx.send(InboundEvent::InboundConnFailed(peer_addr.clone()));
return;
}
}
}
}
Err(e)
if e.kind() == std::io::ErrorKind::WouldBlock
|| e.kind() == std::io::ErrorKind::TimedOut =>
{
continue;
}
Err(_) => {
let _ = tx.send(InboundEvent::InboundConnFailed(peer_addr.clone()));
return;
}
}
}
}
fn message_sender(msg: &Message) -> String {
match msg {
Message::Hb { node_id, .. } => node_id.clone(),
Message::Offer { candidate_id, .. } => candidate_id.clone(),
Message::Accept { accepter_id, .. } => accepter_id.clone(),
Message::Announce { new_primary_id, .. } => new_primary_id.clone(),
}
}
fn outbound_loop(peer: PeerAddr, shared: Arc<Shared>, stop: Arc<AtomicBool>) {
let mut stream: Option<TcpStream> = None;
while !stop.load(Ordering::Relaxed) {
if stream.is_none() {
stream = dial(&peer);
if stream.is_none() {
std::thread::sleep(READ_RETRY_BACKOFF);
continue;
}
}
let next_msg = {
let mut qs = shared.out_queues.lock().expect("out_queues lock");
qs.get_mut(&peer.node_id).and_then(|q| q.pop_front())
};
let Some(msg) = next_msg else {
std::thread::sleep(Duration::from_millis(1));
continue;
};
let bytes = encode(&msg);
let Some(s) = stream.as_mut() else {
continue;
};
if s.write_all(&bytes).is_err() {
let _ = s.shutdown(Shutdown::Both);
stream = None;
let mut qs = shared.out_queues.lock().expect("out_queues lock");
qs.entry(peer.node_id.clone()).or_default().push_front(msg);
}
}
}
fn dial(peer: &PeerAddr) -> Option<TcpStream> {
let target = (peer.host.as_str(), peer.port);
let addr_iter = target.to_socket_addrs().ok()?;
for sa in addr_iter {
match TcpStream::connect_timeout(&sa, Duration::from_millis(500)) {
Ok(s) => {
let _ = s.set_nodelay(true);
return Some(s);
}
Err(_) => continue,
}
}
None
}
fn orchestrator_loop(
shared: Arc<Shared>,
inbound_rx: Receiver<InboundEvent>,
hb_interval: Duration,
stop: Arc<AtomicBool>,
) {
while !stop.load(Ordering::Relaxed) {
let mut outs: Vec<Outbound> = Vec::new();
match inbound_rx.recv_timeout(hb_interval) {
Ok(InboundEvent::Message(from, msg)) => {
let now = Instant::now();
let mut e = shared.elector.lock().expect("elector lock");
outs.extend(e.on_message(&from, msg, now));
outs.extend(e.tick(now));
}
Ok(InboundEvent::InboundConnFailed(_)) => {
}
Err(RecvTimeoutError::Timeout) => {
let now = Instant::now();
let mut e = shared.elector.lock().expect("elector lock");
outs.extend(e.tick(now));
}
Err(RecvTimeoutError::Disconnected) => return,
}
if !outs.is_empty() {
let mut qs = shared.out_queues.lock().expect("out_queues lock");
for out in outs {
let targets: Vec<String> = if out.to == Outbound::BROADCAST {
qs.keys().cloned().collect()
} else {
vec![out.to]
};
for target in targets {
let q = qs.entry(target).or_default();
if q.len() < MAX_PENDING_PER_PEER {
q.push_back(out.msg.clone());
}
}
}
}
}
}