use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::core::types::MsgId;
use crate::io::mbuf::{MbufPool, MbufQueue};
use crate::io::reactor::{ConnRole, Transport};
use crate::msg::{ConsistencyLevel, Msg, MsgQueue};
use super::NetError;
pub const MAX_CONN_QUEUE_SIZE: usize = 20_000;
#[derive(Debug, Default, Clone)]
pub struct ConnStats {
pub recv_bytes: u64,
pub send_bytes: u64,
pub recv_msgs: u64,
pub send_msgs: u64,
pub recv_events: u64,
pub send_events: u64,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub struct ConnHandle(u64);
impl ConnHandle {
#[must_use]
pub fn raw(self) -> u64 {
self.0
}
}
static NEXT_HANDLE: AtomicU64 = AtomicU64::new(1);
fn next_handle() -> ConnHandle {
ConnHandle(NEXT_HANDLE.fetch_add(1, Ordering::Relaxed))
}
#[allow(clippy::struct_excessive_bools)]
pub struct Conn {
handle: ConnHandle,
role: ConnRole,
transport: Option<Box<dyn Transport>>,
peer_addr: Option<SocketAddr>,
recv: MbufQueue,
send: MbufQueue,
imsg_q: MsgQueue,
omsg_q: MsgQueue,
rmsg: Option<Msg>,
smsg: Option<Msg>,
stats: ConnStats,
eof: bool,
done: bool,
err: Option<String>,
read_consistency: ConsistencyLevel,
write_consistency: ConsistencyLevel,
same_dc: bool,
dyn_mode: bool,
dnode_secured: bool,
crypto_key_sent: bool,
aes_key: Option<[u8; 32]>,
outstanding: HashMap<MsgId, MsgId>,
pool: MbufPool,
}
impl std::fmt::Debug for Conn {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let _ = (
&self.transport,
&self.read_consistency,
&self.write_consistency,
&self.same_dc,
&self.dyn_mode,
&self.dnode_secured,
&self.crypto_key_sent,
&self.aes_key,
&self.outstanding,
&self.pool,
&self.stats,
&self.rmsg,
&self.smsg,
);
f.debug_struct("Conn")
.field("handle", &self.handle)
.field("role", &self.role)
.field("peer_addr", &self.peer_addr)
.field("recv_chain", &self.recv.len())
.field("send_chain", &self.send.len())
.field("imsg_q", &self.imsg_q.len())
.field("omsg_q", &self.omsg_q.len())
.field("recv_bytes", &self.stats.recv_bytes)
.field("send_bytes", &self.stats.send_bytes)
.field("eof", &self.eof)
.field("done", &self.done)
.field("err", &self.err)
.field("read_consistency", &self.read_consistency)
.field("write_consistency", &self.write_consistency)
.field("same_dc", &self.same_dc)
.field("dyn_mode", &self.dyn_mode)
.field("dnode_secured", &self.dnode_secured)
.field("crypto_key_sent", &self.crypto_key_sent)
.field("aes_key_set", &self.aes_key.is_some())
.field("outstanding", &self.outstanding.len())
.finish()
}
}
impl Conn {
pub fn new(transport: Box<dyn Transport>, role: ConnRole) -> Self {
let peer_addr = transport.peer_addr();
Self {
handle: next_handle(),
role,
transport: Some(transport),
peer_addr,
recv: MbufQueue::new(),
send: MbufQueue::new(),
imsg_q: MsgQueue::new(),
omsg_q: MsgQueue::new(),
rmsg: None,
smsg: None,
stats: ConnStats::default(),
eof: false,
done: false,
err: None,
read_consistency: ConsistencyLevel::DcOne,
write_consistency: ConsistencyLevel::DcOne,
same_dc: true,
dyn_mode: matches!(
role,
ConnRole::DnodePeerProxy | ConnRole::DnodePeerClient | ConnRole::DnodePeerServer
),
dnode_secured: false,
crypto_key_sent: false,
aes_key: None,
outstanding: HashMap::new(),
pool: MbufPool::default(),
}
}
#[must_use]
pub fn handle(&self) -> ConnHandle {
self.handle
}
#[must_use]
pub fn role(&self) -> ConnRole {
self.role
}
#[must_use]
pub fn peer_addr(&self) -> Option<SocketAddr> {
self.peer_addr
}
#[must_use]
pub fn stats(&self) -> &ConnStats {
&self.stats
}
#[must_use]
pub fn recv_chain(&self) -> &MbufQueue {
&self.recv
}
pub fn recv_chain_mut(&mut self) -> &mut MbufQueue {
&mut self.recv
}
#[must_use]
pub fn send_chain(&self) -> &MbufQueue {
&self.send
}
pub fn send_chain_mut(&mut self) -> &mut MbufQueue {
&mut self.send
}
#[must_use]
pub fn imsg_q(&self) -> &MsgQueue {
&self.imsg_q
}
pub fn imsg_q_mut(&mut self) -> &mut MsgQueue {
&mut self.imsg_q
}
#[must_use]
pub fn omsg_q(&self) -> &MsgQueue {
&self.omsg_q
}
pub fn omsg_q_mut(&mut self) -> &mut MsgQueue {
&mut self.omsg_q
}
#[must_use]
pub fn rmsg(&self) -> Option<&Msg> {
self.rmsg.as_ref()
}
pub fn rmsg_mut(&mut self) -> Option<&mut Msg> {
self.rmsg.as_mut()
}
pub fn take_rmsg(&mut self) -> Option<Msg> {
self.rmsg.take()
}
pub fn set_rmsg(&mut self, msg: Option<Msg>) {
self.rmsg = msg;
}
#[must_use]
pub fn smsg(&self) -> Option<&Msg> {
self.smsg.as_ref()
}
pub fn take_smsg(&mut self) -> Option<Msg> {
self.smsg.take()
}
pub fn set_smsg(&mut self, msg: Option<Msg>) {
self.smsg = msg;
}
#[must_use]
pub fn is_eof(&self) -> bool {
self.eof
}
pub fn set_eof(&mut self) {
self.eof = true;
}
#[must_use]
pub fn is_done(&self) -> bool {
self.done
}
pub fn set_done(&mut self) {
self.done = true;
}
#[must_use]
pub fn err(&self) -> Option<&str> {
self.err.as_deref()
}
pub fn set_err<S: Into<String>>(&mut self, msg: S) {
self.err = Some(msg.into());
}
#[must_use]
pub fn read_consistency(&self) -> ConsistencyLevel {
self.read_consistency
}
#[must_use]
pub fn write_consistency(&self) -> ConsistencyLevel {
self.write_consistency
}
pub fn set_read_consistency(&mut self, c: ConsistencyLevel) {
self.read_consistency = c;
}
pub fn set_write_consistency(&mut self, c: ConsistencyLevel) {
self.write_consistency = c;
}
#[must_use]
pub fn same_dc(&self) -> bool {
self.same_dc
}
pub fn set_same_dc(&mut self, on: bool) {
self.same_dc = on;
}
#[must_use]
pub fn dyn_mode(&self) -> bool {
self.dyn_mode
}
#[must_use]
pub fn dnode_secured(&self) -> bool {
self.dnode_secured
}
pub fn set_dnode_secured(&mut self, on: bool) {
self.dnode_secured = on;
}
#[must_use]
pub fn crypto_key_sent(&self) -> bool {
self.crypto_key_sent
}
pub fn set_crypto_key_sent(&mut self, on: bool) {
self.crypto_key_sent = on;
}
#[must_use]
pub fn aes_key(&self) -> Option<&[u8; 32]> {
self.aes_key.as_ref()
}
pub fn set_aes_key(&mut self, key: [u8; 32]) {
self.aes_key = Some(key);
}
#[must_use]
pub fn outstanding(&self) -> &HashMap<MsgId, MsgId> {
&self.outstanding
}
pub fn outstanding_mut(&mut self) -> &mut HashMap<MsgId, MsgId> {
&mut self.outstanding
}
#[must_use]
pub fn mbuf_pool(&self) -> &MbufPool {
&self.pool
}
pub fn set_mbuf_pool(&mut self, pool: MbufPool) {
self.pool = pool;
}
pub fn take_transport(&mut self) -> Option<Box<dyn Transport>> {
self.transport.take()
}
pub fn set_transport(&mut self, transport: Box<dyn Transport>) {
self.peer_addr = transport.peer_addr();
self.transport = Some(transport);
}
#[must_use]
pub fn has_transport(&self) -> bool {
self.transport.is_some()
}
pub fn transport_mut(&mut self) -> Option<&mut Box<dyn Transport>> {
self.transport.as_mut()
}
pub fn enqueue_in(&mut self, msg: Msg) -> Result<(), NetError> {
if self.imsg_q.len() >= MAX_CONN_QUEUE_SIZE {
return Err(NetError::PoolExhausted);
}
self.imsg_q.push_back(msg);
self.stats.recv_msgs += 1;
Ok(())
}
pub fn enqueue_out(&mut self, msg: Msg) -> Result<(), NetError> {
if self.omsg_q.len() >= MAX_CONN_QUEUE_SIZE {
return Err(NetError::PoolExhausted);
}
self.omsg_q.push_back(msg);
self.stats.send_msgs += 1;
Ok(())
}
pub fn close(&mut self) {
self.transport = None;
self.done = true;
}
pub fn run(&mut self) -> Result<(), NetError> {
if self.transport.is_none() {
return Err(NetError::Closed);
}
if self.done {
return Ok(());
}
Ok(())
}
pub fn record_recv(&mut self, bytes: usize) {
self.stats.recv_bytes += bytes as u64;
self.stats.recv_events += 1;
}
pub fn record_send(&mut self, bytes: usize) {
self.stats.send_bytes += bytes as u64;
self.stats.send_events += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::reactor::TcpTransport;
use crate::msg::MsgType;
use tokio::net::{TcpListener, TcpStream};
async fn pair() -> (Conn, Conn) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let accept = tokio::spawn(async move {
let (s, _) = listener.accept().await.unwrap();
s
});
let client = TcpStream::connect(addr).await.unwrap();
let server = accept.await.unwrap();
let c = Conn::new(
Box::new(TcpTransport::new(client, ConnRole::Client)),
ConnRole::Client,
);
let s = Conn::new(
Box::new(TcpTransport::new(server, ConnRole::Server)),
ConnRole::Server,
);
(c, s)
}
#[tokio::test]
async fn enqueue_in_and_out() {
let (mut c, _s) = pair().await;
c.enqueue_in(Msg::new(1, MsgType::ReqRedisGet, true))
.unwrap();
c.enqueue_out(Msg::new(2, MsgType::RspRedisStatus, false))
.unwrap();
assert_eq!(c.imsg_q().len(), 1);
assert_eq!(c.omsg_q().len(), 1);
assert_eq!(c.stats().recv_msgs, 1);
assert_eq!(c.stats().send_msgs, 1);
}
#[tokio::test]
async fn close_drops_transport() {
let (mut c, _s) = pair().await;
assert!(c.has_transport());
c.close();
assert!(!c.has_transport());
assert!(c.is_done());
}
#[tokio::test]
async fn handle_is_unique() {
let (a, b) = pair().await;
assert_ne!(a.handle(), b.handle());
}
#[tokio::test]
async fn role_seed_drives_dyn_mode() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let _accept = tokio::spawn(async move {
let (s, _) = listener.accept().await.unwrap();
drop(s);
});
let s = TcpStream::connect(addr).await.unwrap();
let c = Conn::new(
Box::new(TcpTransport::new(s, ConnRole::DnodePeerServer)),
ConnRole::DnodePeerServer,
);
assert!(c.dyn_mode());
assert!(c.same_dc());
}
}