1use crate::{
21 error::{ClientError, ClientResult},
22 Params,
23};
24use std::{
25 borrow::Cow,
26 cmp::min,
27 collections::HashMap,
28 fmt::{self, Debug, Display},
29 mem::size_of,
30 ops::{Deref, DerefMut},
31};
32use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
33
34pub(crate) const VERSION_1: u8 = 1;
36pub(crate) const MAX_LENGTH: usize = 0xffff;
38pub(crate) const HEADER_LEN: usize = size_of::<Header>();
40
41#[derive(Debug, Clone)]
43#[repr(u8)]
44pub enum RequestType {
45 BeginRequest = 1,
47 AbortRequest = 2,
49 EndRequest = 3,
51 Params = 4,
53 Stdin = 5,
55 Stdout = 6,
57 Stderr = 7,
59 Data = 8,
61 GetValues = 9,
63 GetValuesResult = 10,
65 UnknownType = 11,
67}
68
69impl RequestType {
70 fn from_u8(u: u8) -> Self {
76 match u {
77 1 => RequestType::BeginRequest,
78 2 => RequestType::AbortRequest,
79 3 => RequestType::EndRequest,
80 4 => RequestType::Params,
81 5 => RequestType::Stdin,
82 6 => RequestType::Stdout,
83 7 => RequestType::Stderr,
84 8 => RequestType::Data,
85 9 => RequestType::GetValues,
86 10 => RequestType::GetValuesResult,
87 _ => RequestType::UnknownType,
88 }
89 }
90}
91
92impl Display for RequestType {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
94 Display::fmt(&(self.clone() as u8), f)
95 }
96}
97
98#[derive(Debug, Clone)]
99pub(crate) struct Header {
100 pub(crate) version: u8,
102 pub(crate) r#type: RequestType,
104 pub(crate) request_id: u16,
106 pub(crate) content_length: u16,
108 pub(crate) padding_length: u8,
110 pub(crate) reserved: u8,
112}
113
114impl Header {
115 pub(crate) async fn write_to_stream_batches<F, R, W>(
125 r#type: RequestType, request_id: u16, writer: &mut W, content: &mut R,
126 before_write: Option<F>,
127 ) -> io::Result<()>
128 where
129 F: Fn(Header) -> Header,
130 R: AsyncRead + Unpin,
131 W: AsyncWrite + Unpin,
132 {
133 let mut buf: [u8; MAX_LENGTH] = [0; MAX_LENGTH];
134 let mut had_written = false;
135
136 loop {
137 let read = content.read(&mut buf).await?;
138 if had_written && read == 0 {
139 break;
140 }
141
142 let buf = &buf[..read];
143 let mut header = Self::new(r#type.clone(), request_id, buf);
144 if let Some(ref f) = before_write {
145 header = f(header);
146 }
147 header.write_to_stream(writer, buf).await?;
148
149 had_written = true;
150 }
151 Ok(())
152 }
153
154 fn new(r#type: RequestType, request_id: u16, content: &[u8]) -> Self {
162 let content_length = min(content.len(), MAX_LENGTH) as u16;
163 Self {
164 version: VERSION_1,
165 r#type,
166 request_id,
167 content_length,
168 padding_length: (-(content_length as i16) & 7) as u8,
169 reserved: 0,
170 }
171 }
172
173 async fn write_to_stream<W: AsyncWrite + Unpin>(
180 self, writer: &mut W, content: &[u8],
181 ) -> io::Result<()> {
182 let mut buf: Vec<u8> = Vec::new();
183 buf.push(self.version);
184 buf.push(self.r#type as u8);
185 buf.write_u16(self.request_id).await?;
186 buf.write_u16(self.content_length).await?;
187 buf.push(self.padding_length);
188 buf.push(self.reserved);
189
190 writer.write_all(&buf).await?;
191 writer.write_all(content).await?;
192 writer
193 .write_all(&vec![0; self.padding_length as usize])
194 .await?;
195
196 Ok(())
197 }
198
199 pub(crate) async fn new_from_stream<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Self> {
205 let mut buf: [u8; HEADER_LEN] = [0; HEADER_LEN];
206 reader.read_exact(&mut buf).await?;
207
208 Ok(Self::new_from_buf(&buf))
209 }
210
211 #[inline]
217 pub(crate) fn new_from_buf(buf: &[u8; HEADER_LEN]) -> Self {
218 Self {
219 version: buf[0],
220 r#type: RequestType::from_u8(buf[1]),
221 request_id: be_buf_to_u16(&buf[2..4]),
222 content_length: be_buf_to_u16(&buf[4..6]),
223 padding_length: buf[6],
224 reserved: buf[7],
225 }
226 }
227
228 pub(crate) async fn read_content_from_stream<R: AsyncRead + Unpin>(
234 &self, reader: &mut R,
235 ) -> io::Result<Vec<u8>> {
236 let mut buf = vec![0; self.content_length as usize];
237 reader.read_exact(&mut buf).await?;
238 let mut padding_buf = vec![0; self.padding_length as usize];
239 reader.read_exact(&mut padding_buf).await?;
240 Ok(buf)
241 }
242}
243
244#[derive(Debug, Clone, Copy)]
246#[repr(u16)]
247#[allow(dead_code)]
248pub enum Role {
249 Responder = 1,
251 Authorizer = 2,
253 Filter = 3,
255}
256
257#[derive(Debug)]
259pub(crate) struct BeginRequest {
260 pub(crate) role: Role,
262 pub(crate) flags: u8,
264 pub(crate) reserved: [u8; 5],
266}
267
268impl BeginRequest {
269 pub(crate) fn new(role: Role, keep_alive: bool) -> Self {
276 Self {
277 role,
278 flags: keep_alive as u8,
279 reserved: [0; 5],
280 }
281 }
282
283 pub(crate) async fn to_content(&self) -> io::Result<Vec<u8>> {
285 let mut buf: Vec<u8> = Vec::new();
286 buf.write_u16(self.role as u16).await?;
287 buf.push(self.flags);
288 buf.extend_from_slice(&self.reserved);
289 Ok(buf)
290 }
291}
292
293pub(crate) struct BeginRequestRec {
295 pub(crate) header: Header,
297 pub(crate) begin_request: BeginRequest,
299 pub(crate) content: Vec<u8>,
301}
302
303impl BeginRequestRec {
304 pub(crate) async fn new(request_id: u16, role: Role, keep_alive: bool) -> io::Result<Self> {
312 let begin_request = BeginRequest::new(role, keep_alive);
313 let content = begin_request.to_content().await?;
314 let header = Header::new(RequestType::BeginRequest, request_id, &content);
315 Ok(Self {
316 header,
317 begin_request,
318 content,
319 })
320 }
321
322 pub(crate) async fn write_to_stream<W: AsyncWrite + Unpin>(
328 self, writer: &mut W,
329 ) -> io::Result<()> {
330 self.header.write_to_stream(writer, &self.content).await
331 }
332}
333
334impl Debug for BeginRequestRec {
335 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
336 Debug::fmt(
337 &format!(
338 "BeginRequestRec {{header: {:?}, begin_request: {:?}}}",
339 self.header, self.begin_request
340 ),
341 f,
342 )
343 }
344}
345
346#[derive(Debug, Clone, Copy)]
348pub enum ParamLength {
349 Short(u8),
351 Long(u32),
353}
354
355impl ParamLength {
356 pub fn new(length: usize) -> Self {
362 if length < 128 {
363 ParamLength::Short(length as u8)
364 } else {
365 let mut length = length;
366 length |= 1 << 31;
367 ParamLength::Long(length as u32)
368 }
369 }
370
371 pub async fn content(self) -> io::Result<Vec<u8>> {
373 let mut buf: Vec<u8> = Vec::new();
374 match self {
375 ParamLength::Short(l) => buf.push(l),
376 ParamLength::Long(l) => buf.write_u32(l).await?,
377 }
378 Ok(buf)
379 }
380}
381
382#[derive(Debug)]
384pub struct ParamPair<'a> {
385 name_length: ParamLength,
387 value_length: ParamLength,
389 name_data: Cow<'a, str>,
391 value_data: Cow<'a, str>,
393}
394
395impl<'a> ParamPair<'a> {
396 fn new(name: Cow<'a, str>, value: Cow<'a, str>) -> Self {
403 let name_length = ParamLength::new(name.len());
404 let value_length = ParamLength::new(value.len());
405 Self {
406 name_length,
407 value_length,
408 name_data: name,
409 value_data: value,
410 }
411 }
412
413 async fn write_to_stream<W: AsyncWrite + Unpin>(&self, writer: &mut W) -> io::Result<()> {
419 writer.write_all(&self.name_length.content().await?).await?;
420 writer
421 .write_all(&self.value_length.content().await?)
422 .await?;
423 writer.write_all(self.name_data.as_bytes()).await?;
424 writer.write_all(self.value_data.as_bytes()).await?;
425 Ok(())
426 }
427}
428
429#[derive(Debug)]
431pub(crate) struct ParamPairs<'a>(Vec<ParamPair<'a>>);
432
433impl<'a> ParamPairs<'a> {
434 pub(crate) fn new(params: Params<'a>) -> Self {
440 let mut param_pairs = Vec::new();
441 let params: HashMap<Cow<'a, str>, Cow<'a, str>> = params.into();
442 for (name, value) in params.into_iter() {
443 let param_pair = ParamPair::new(name, value);
444 param_pairs.push(param_pair);
445 }
446
447 Self(param_pairs)
448 }
449
450 pub(crate) async fn to_content(&self) -> io::Result<Vec<u8>> {
452 let mut buf: Vec<u8> = Vec::new();
453
454 for param_pair in self.iter() {
455 param_pair.write_to_stream(&mut buf).await?;
456 }
457
458 Ok(buf)
459 }
460}
461
462impl<'a> Deref for ParamPairs<'a> {
463 type Target = Vec<ParamPair<'a>>;
464
465 fn deref(&self) -> &Self::Target {
466 &self.0
467 }
468}
469
470impl<'a> DerefMut for ParamPairs<'a> {
471 fn deref_mut(&mut self) -> &mut Self::Target {
472 &mut self.0
473 }
474}
475
476#[derive(Debug)]
478#[repr(u8)]
479pub enum ProtocolStatus {
480 RequestComplete = 0,
482 CantMpxConn = 1,
484 Overloaded = 2,
486 UnknownRole = 3,
488}
489
490impl ProtocolStatus {
491 pub fn from_u8(u: u8) -> Self {
497 match u {
498 0 => ProtocolStatus::RequestComplete,
499 1 => ProtocolStatus::CantMpxConn,
500 2 => ProtocolStatus::Overloaded,
501 _ => ProtocolStatus::UnknownRole,
502 }
503 }
504
505 pub(crate) fn convert_to_client_result(self, app_status: u32) -> ClientResult<()> {
511 match self {
512 ProtocolStatus::RequestComplete => Ok(()),
513 _ => Err(ClientError::new_end_request_with_protocol_status(
514 self, app_status,
515 )),
516 }
517 }
518}
519
520#[derive(Debug)]
522pub struct EndRequest {
523 pub(crate) app_status: u32,
525 pub(crate) protocol_status: ProtocolStatus,
527 #[allow(dead_code)]
529 reserved: [u8; 3],
530}
531
532#[derive(Debug)]
534pub(crate) struct EndRequestRec {
535 #[allow(dead_code)]
537 header: Header,
538 pub(crate) end_request: EndRequest,
540}
541
542impl EndRequestRec {
543 pub(crate) async fn from_header<R: AsyncRead + Unpin>(
550 header: &Header, reader: &mut R,
551 ) -> io::Result<Self> {
552 let header = header.clone();
553 let content = &*header.read_content_from_stream(reader).await?;
554 Ok(Self::new_from_buf(header, content))
555 }
556
557 pub(crate) fn new_from_buf(header: Header, buf: &[u8]) -> Self {
564 let app_status = u32::from_be_bytes(<[u8; 4]>::try_from(&buf[0..4]).unwrap());
565 let protocol_status =
566 ProtocolStatus::from_u8(u8::from_be_bytes(<[u8; 1]>::try_from(&buf[4..5]).unwrap()));
567 let reserved = <[u8; 3]>::try_from(&buf[5..8]).unwrap();
568 Self {
569 header,
570 end_request: EndRequest {
571 app_status,
572 protocol_status,
573 reserved,
574 },
575 }
576 }
577}
578
579fn be_buf_to_u16(buf: &[u8]) -> u16 {
585 u16::from_be_bytes(<[u8; 2]>::try_from(buf).unwrap())
586}