use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::mpsc;
use crate::msg::Msg;
use crate::msg::MsgParseResult;
use crate::msg::MsgType;
use crate::net::client::ClientHandler;
use crate::net::conn::Conn;
use crate::net::dispatcher::OutboundEnvelope;
use crate::net::NetError;
use crate::proto::dnode::{DmsgType, DnodeParser, ParseStep};
pub type DnodeClientHandler = ClientHandler;
pub async fn dnode_client_loop(
mut conn: Conn,
handler: ClientHandler,
mut rx: mpsc::Receiver<OutboundEnvelope>,
) -> Result<(), NetError> {
let mut read_buf = vec![0u8; 4096];
let mut accumulated = Vec::<u8>::new();
let mut parser = DnodeParser::new();
loop {
if conn.is_eof() && conn.imsg_q().is_empty() && conn.omsg_q().is_empty() {
conn.set_done();
return Ok(());
}
tokio::select! {
res = async {
if let Some(t) = conn.transport_mut() {
t.read(&mut read_buf).await
} else {
Ok(0)
}
} => {
let n = res?;
if n == 0 {
conn.set_eof();
continue;
}
conn.record_recv(n);
accumulated.extend_from_slice(&read_buf[..n]);
drive_dnode_parser(&mut conn, &handler, &mut accumulated, &mut parser).await?;
}
Some(env) = rx.recv() => {
let bytes: Vec<u8> = env
.rsp
.mbufs()
.iter()
.flat_map(|b| b.readable().to_vec())
.collect();
if !bytes.is_empty() {
let mut header_buf = conn.mbuf_pool().get();
crate::proto::dnode::dmsg_write(
&mut header_buf,
env.req_id,
crate::proto::dnode::DmsgType::Res,
0,
true,
None,
u32::try_from(bytes.len()).unwrap_or(u32::MAX),
)
.map_err(|e| NetError::Dnode(format!("{e:?}")))?;
let header_len = header_buf.readable().len();
if let Some(t) = conn.transport_mut() {
t.write_all(header_buf.readable()).await?;
t.write_all(&bytes).await?;
conn.record_send(header_len + bytes.len());
}
}
conn.outstanding_mut().remove(&env.req_id);
if let Some(front) = conn.omsg_q_mut().front() {
if front.id() == env.req_id {
let _ = conn.omsg_q_mut().pop_front();
}
}
}
}
}
}
async fn drive_dnode_parser(
conn: &mut Conn,
handler: &ClientHandler,
accumulated: &mut Vec<u8>,
parser: &mut DnodeParser,
) -> Result<(), NetError> {
loop {
if accumulated.is_empty() {
return Ok(());
}
let step = parser.step(accumulated.as_slice());
match step {
ParseStep::NeedMore { .. } => return Ok(()),
ParseStep::Error { consumed } => {
return Err(NetError::Dnode(format!(
"dnode header parse error after {consumed} bytes"
)));
}
ParseStep::HeaderDone { consumed } => {
let header_end = consumed;
let dmsg = parser.take_dmsg();
let plen = dmsg.plen as usize;
let total = header_end + plen;
if accumulated.len() < total {
parser.reset();
return Ok(());
}
let payload = accumulated[header_end..total].to_vec();
accumulated.drain(0..total);
parser.reset();
if is_gossip_ty(dmsg.ty) {
handle_gossip_frame(handler, dmsg.ty, &payload);
continue;
}
let decoded = if dmsg.is_encrypted() {
let Some(key) = conn.aes_key() else {
return Err(NetError::Dnode(
"dnode payload marked encrypted but no aes key bound".into(),
));
};
decrypt_dnode_payload(key, &payload)?
} else {
payload
};
let mut msg = Msg::new(dmsg.id, MsgType::Unknown, true);
let dmsg_ty = dmsg.ty;
msg.set_dmsg(dmsg);
let parse_result = match handler.data_store() {
crate::conf::DataStore::Redis | crate::conf::DataStore::Noxu => {
crate::proto::redis::redis_parse_req(&mut msg, &decoded)
}
crate::conf::DataStore::Memcache => {
crate::proto::memcache::memcache_parse_req(&mut msg, &decoded)
}
};
if matches!(dmsg_ty, DmsgType::ReqForward) {
msg.set_routing(crate::msg::MsgRouting::LocalNodeOnly);
}
match parse_result {
MsgParseResult::Ok | MsgParseResult::Noop => {
let pool = conn.mbuf_pool().clone();
let mut buf = pool.get();
buf.recv(&decoded);
msg.mbufs_mut().push_back(buf);
msg.recompute_mlen();
conn.outstanding_mut().insert(msg.id(), msg.id());
conn.enqueue_out(Msg::new(msg.id(), msg.ty(), true))?;
let outcome = handler
.dispatcher()
.dispatch(msg, handler.response_tx().clone());
match outcome {
crate::net::dispatcher::DispatchOutcome::Pending
| crate::net::dispatcher::DispatchOutcome::Drop => {}
crate::net::dispatcher::DispatchOutcome::Inline(rsp)
| crate::net::dispatcher::DispatchOutcome::Error(rsp) => {
let env = OutboundEnvelope {
req_id: rsp.id(),
rsp,
span: tracing::Span::current(),
source_peer_idx: None,
};
let _ = handler.response_tx().send(env).await;
}
}
}
MsgParseResult::Again => return Ok(()),
other => {
return Err(NetError::Parse(format!("dnode payload parse: {other:?}")));
}
}
}
}
}
}
fn decrypt_dnode_payload(
key: &[u8; crate::crypto::AES_KEYLEN],
payload: &[u8],
) -> Result<Vec<u8>, NetError> {
crate::crypto::Crypto::aes_decrypt(payload, key)
.map_err(|_| NetError::Dnode("dnode payload decrypt failed".into()))
}
fn is_gossip_ty(ty: DmsgType) -> bool {
matches!(
ty,
DmsgType::GossipSyn
| DmsgType::GossipSynReply
| DmsgType::GossipAck
| DmsgType::GossipDigestSyn
| DmsgType::GossipDigestAck
| DmsgType::GossipDigestAck2
| DmsgType::GossipShutdown
)
}
fn handle_gossip_frame(handler: &ClientHandler, ty: DmsgType, payload: &[u8]) {
let Some(gossip) = handler.gossip() else {
return;
};
let Ok(pname) = std::str::from_utf8(payload) else {
return;
};
let pname = pname.trim();
if pname.is_empty() {
return;
}
let now = std::time::Instant::now();
match ty {
DmsgType::GossipShutdown => {
gossip.mark_down_pname(pname);
}
_ => {
gossip.record_heartbeat_pname(pname, now);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::reactor::{ConnRole, TcpTransport};
use tokio::net::{TcpListener, TcpStream};
#[tokio::test]
async fn build_and_drop() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let _accept = tokio::spawn(async move {
let (s, _) = listener.accept().await.unwrap();
drop(s);
});
let s = TcpStream::connect(addr).await.unwrap();
let _conn = Conn::new(
Box::new(TcpTransport::new(s, ConnRole::DnodePeerClient)),
ConnRole::DnodePeerClient,
);
}
}