dynomite/net/
dnode_server.rs1use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::sync::mpsc;
12use tracing::Instrument as _;
13
14use crate::core::types::MsgId;
15use crate::io::reactor::ConnRole;
16use crate::msg::{Msg, MsgParseResult, MsgType};
17use crate::net::conn::Conn;
18use crate::net::dispatcher::OutboundEnvelope;
19use crate::net::server::OutboundRequest;
20use crate::net::NetError;
21use crate::proto::dnode::{dmsg_write, DmsgType, DnodeParser, ParseStep};
22
23fn is_data_plane_ty(ty: DmsgType) -> bool {
28 matches!(
29 ty,
30 DmsgType::Req | DmsgType::ReqForward | DmsgType::Res | DmsgType::Unknown
31 )
32}
33
34pub struct DnodeServerConn {
36 conn: Conn,
37 requests: mpsc::Receiver<OutboundRequest>,
38 pending: std::collections::VecDeque<(MsgId, tracing::Span, Option<u32>)>,
39}
40
41impl DnodeServerConn {
42 #[must_use]
58 pub fn new(conn: Conn, requests: mpsc::Receiver<OutboundRequest>) -> Self {
59 debug_assert!(matches!(conn.role(), ConnRole::DnodePeerServer));
60 Self {
61 conn,
62 requests,
63 pending: std::collections::VecDeque::new(),
64 }
65 }
66
67 pub async fn run(mut self) -> Result<(), NetError> {
72 let mut requests = std::mem::replace(&mut self.requests, {
73 let (_tx, rx) = mpsc::channel::<OutboundRequest>(1);
74 rx
75 });
76 self.run_with(&mut requests).await
77 }
78
79 pub async fn run_with(
86 &mut self,
87 requests: &mut mpsc::Receiver<OutboundRequest>,
88 ) -> Result<(), NetError> {
89 let mut read_buf = vec![0u8; 4096];
90 let mut accumulated = Vec::<u8>::new();
91 let mut parser = DnodeParser::new();
92 let mut pending_responder: Option<mpsc::Sender<OutboundEnvelope>> = None;
93
94 loop {
95 if self.conn.is_eof() && self.pending.is_empty() {
96 self.conn.set_done();
97 return Ok(());
98 }
99
100 tokio::select! {
101 req = requests.recv() => {
102 let Some(req) = req else { continue; };
103 let send_span = tracing::info_span!(
104 parent: &req.span,
105 "peer.send",
106 req_id = req.req_id,
107 bytes = req.bytes.len(),
108 );
109 let req_span = req.span.clone();
110 let req_bytes = req.bytes;
111 let req_id = req.req_id;
112 let req_ty = req.ty;
113 let mut header_buf = self.conn.mbuf_pool().get();
114 dmsg_write(
115 &mut header_buf,
116 req_id,
117 if matches!(req_ty, DmsgType::Unknown) { DmsgType::Req } else { req_ty },
118 0,
119 true,
120 None,
121 u32::try_from(req_bytes.len()).unwrap_or(u32::MAX),
122 )?;
123 let header_len = header_buf.readable().len();
124 let transport = self.conn.transport_mut().ok_or(NetError::Closed)?;
125 let write_res = async {
126 transport.write_all(header_buf.readable()).await?;
127 transport.write_all(&req_bytes).await?;
128 Ok::<(), std::io::Error>(())
129 }
130 .instrument(send_span)
131 .await;
132 write_res?;
133 self.conn.record_send(header_len + req_bytes.len());
134 if is_data_plane_ty(req_ty) {
135 self.pending
136 .push_back((req_id, req_span, req.target_peer_idx));
137 pending_responder = Some(req.responder);
138 } else {
139 drop(req.responder);
144 }
145 }
146 read_res = async {
147 if let Some(t) = self.conn.transport_mut() {
148 t.read(&mut read_buf).await
149 } else {
150 Ok(0)
151 }
152 } => {
153 let n = read_res?;
154 if n == 0 {
155 self.conn.set_eof();
156 continue;
157 }
158 self.conn.record_recv(n);
159 accumulated.extend_from_slice(&read_buf[..n]);
160 self.drive_response(&mut accumulated, &mut parser, &mut pending_responder).await?;
161 }
162 }
163 }
164 }
165
166 async fn drive_response(
167 &mut self,
168 accumulated: &mut Vec<u8>,
169 parser: &mut DnodeParser,
170 responder: &mut Option<mpsc::Sender<OutboundEnvelope>>,
171 ) -> Result<(), NetError> {
172 loop {
173 if accumulated.is_empty() {
174 return Ok(());
175 }
176 let step = parser.step(accumulated.as_slice());
177 match step {
178 ParseStep::NeedMore { .. } => return Ok(()),
179 ParseStep::Error { consumed } => {
180 return Err(NetError::Dnode(format!(
181 "dnode peer-server parse error after {consumed} bytes"
182 )));
183 }
184 ParseStep::HeaderDone { consumed } => {
185 let dmsg = parser.take_dmsg();
186 let plen = dmsg.plen as usize;
187 let total = consumed + plen;
188 if accumulated.len() < total {
189 parser.reset();
190 return Ok(());
191 }
192 let payload = accumulated[consumed..total].to_vec();
193 accumulated.drain(0..total);
194 parser.reset();
195
196 let (req_id, req_span, source_peer_idx) = self
198 .pending
199 .pop_front()
200 .unwrap_or_else(|| (dmsg.id, tracing::Span::current(), None));
201 let parse_span = tracing::info_span!(
202 parent: &req_span,
203 "peer.parse",
204 req_id,
205 bytes = plen,
206 );
207 let env = parse_span.in_scope(|| {
208 let mut rsp = Msg::new(req_id, MsgType::Unknown, false);
209 let pool = self.conn.mbuf_pool().clone();
210 let mut buf = pool.get();
211 buf.recv(&payload);
212 rsp.mbufs_mut().push_back(buf);
213 rsp.recompute_mlen();
214 rsp.set_dmsg(dmsg);
215 rsp.set_parse_result(MsgParseResult::Ok);
218 OutboundEnvelope {
219 req_id,
220 rsp,
221 span: req_span,
222 source_peer_idx,
223 }
224 });
225 if let Some(sender) = responder.as_ref() {
226 let _ = sender.send(env).await;
227 }
228 }
229 }
230 }
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::io::reactor::TcpTransport;
238 use tokio::net::{TcpListener, TcpStream};
239
240 #[tokio::test]
241 async fn build_and_drop() {
242 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
243 let addr = listener.local_addr().unwrap();
244 let _accept = tokio::spawn(async move {
245 let (s, _) = listener.accept().await.unwrap();
246 drop(s);
247 });
248 let s = TcpStream::connect(addr).await.unwrap();
249 let conn = Conn::new(
250 Box::new(TcpTransport::new(s, ConnRole::DnodePeerServer)),
251 ConnRole::DnodePeerServer,
252 );
253 let (_tx, rx) = mpsc::channel(1);
254 let _server = DnodeServerConn::new(conn, rx);
255 }
256}