1use crate::{
21 error::{ClientError, ClientResult},
22 Params,
23};
24use bytes::{Buf, BufMut, Bytes, BytesMut};
25use std::{
26 borrow::Cow,
27 cmp::min,
28 collections::HashMap,
29 fmt::{self, Debug, Display},
30 mem::size_of,
31 ops::{Deref, DerefMut},
32};
33use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
34
35pub(crate) const VERSION_1: u8 = 1;
37pub(crate) const MAX_LENGTH: usize = 0xffff;
39pub(crate) const HEADER_LEN: usize = size_of::<Header>();
41
42#[derive(Debug, Clone, Copy)]
44#[repr(u8)]
45pub enum RequestType {
46 BeginRequest = 1,
48 AbortRequest = 2,
50 EndRequest = 3,
52 Params = 4,
54 Stdin = 5,
56 Stdout = 6,
58 Stderr = 7,
60 Data = 8,
62 GetValues = 9,
64 GetValuesResult = 10,
66 UnknownType = 11,
68}
69
70impl RequestType {
71 fn from_u8(u: u8) -> Self {
77 match u {
78 1 => RequestType::BeginRequest,
79 2 => RequestType::AbortRequest,
80 3 => RequestType::EndRequest,
81 4 => RequestType::Params,
82 5 => RequestType::Stdin,
83 6 => RequestType::Stdout,
84 7 => RequestType::Stderr,
85 8 => RequestType::Data,
86 9 => RequestType::GetValues,
87 10 => RequestType::GetValuesResult,
88 _ => RequestType::UnknownType,
89 }
90 }
91}
92
93impl Display for RequestType {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
95 write!(f, "{}", *self as u8)
96 }
97}
98
99#[derive(Debug, Clone)]
100pub(crate) struct Header {
101 pub(crate) version: u8,
103 pub(crate) r#type: RequestType,
105 pub(crate) request_id: u16,
107 pub(crate) content_length: u16,
109 pub(crate) padding_length: u8,
111 pub(crate) reserved: u8,
113}
114
115impl Header {
116 pub(crate) async fn write_to_stream_batches<F, R, W>(
126 r#type: RequestType, request_id: u16, writer: &mut W, content: &mut R,
127 before_write: Option<F>,
128 ) -> io::Result<()>
129 where
130 F: Fn(Header) -> Header,
131 R: AsyncRead + Unpin,
132 W: AsyncWrite + Unpin,
133 {
134 let mut buf = vec![0u8; MAX_LENGTH];
135 let mut had_written = false;
136
137 loop {
138 let read = content.read(&mut buf).await?;
139 if had_written && read == 0 {
140 break;
141 }
142
143 let buf = &buf[..read];
144 let mut header = Self::new(r#type.clone(), request_id, buf);
145 if let Some(ref f) = before_write {
146 header = f(header);
147 }
148 header.write_to_stream(writer, buf).await?;
149
150 had_written = true;
151 }
152 Ok(())
153 }
154
155 fn new(r#type: RequestType, request_id: u16, content: &[u8]) -> Self {
163 let content_length = min(content.len(), MAX_LENGTH) as u16;
164 Self {
165 version: VERSION_1,
166 r#type,
167 request_id,
168 content_length,
169 padding_length: (-(content_length as i16) & 7) as u8,
170 reserved: 0,
171 }
172 }
173
174 async fn write_to_stream<W: AsyncWrite + Unpin>(
181 self, writer: &mut W, content: &[u8],
182 ) -> io::Result<()> {
183 let mut buf: Bytes = (&self).into();
184
185 writer.write_all_buf(&mut buf).await?;
186 writer.write_all(content).await?;
187
188 if self.padding_length > 0 {
189 let padding = [0u8; 7]; writer
191 .write_all(&padding[..self.padding_length as usize])
192 .await?;
193 }
194 Ok(())
195 }
196
197 pub(crate) async fn new_from_stream<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Self> {
203 let mut buf = BytesMut::zeroed(HEADER_LEN);
204 reader.read_exact(&mut buf).await?;
205 Ok(Self::from(buf))
206 }
207
208 pub(crate) async fn read_content_from_stream<R: AsyncRead + Unpin>(
214 &self, reader: &mut R,
215 ) -> io::Result<BytesMut> {
216 let mut buf = BytesMut::zeroed(self.content_length as usize);
217 reader.read_exact(&mut buf).await?;
218 let mut padding_buf = BytesMut::zeroed(self.padding_length as usize);
219 reader.read_exact(&mut padding_buf).await?;
220 Ok(buf)
221 }
222}
223
224impl Into<Bytes> for &Header {
225 fn into(self) -> Bytes {
226 let mut buf = BytesMut::with_capacity(HEADER_LEN);
227 buf.put_u8(self.version);
228 buf.put_u8(self.r#type as u8);
229 buf.put_u16(self.request_id);
230 buf.put_u16(self.content_length);
231 buf.put_u8(self.padding_length);
232 buf.put_u8(self.reserved);
233 buf.freeze()
234 }
235}
236
237impl From<BytesMut> for Header {
238 fn from(mut buf: BytesMut) -> Self {
244 Self {
245 version: buf.get_u8(),
246 r#type: RequestType::from_u8(buf.get_u8()),
247 request_id: buf.get_u16(),
248 content_length: buf.get_u16(),
249 padding_length: buf.get_u8(),
250 reserved: buf.get_u8(),
251 }
252 }
253}
254
255#[derive(Debug, Clone, Copy)]
257#[repr(u16)]
258#[allow(dead_code)]
259pub enum Role {
260 Responder = 1,
262 Authorizer = 2,
264 Filter = 3,
266}
267
268#[derive(Debug)]
270pub(crate) struct BeginRequest {
271 pub(crate) role: Role,
273 pub(crate) flags: u8,
275 pub(crate) reserved: [u8; 5],
277}
278
279impl BeginRequest {
280 pub(crate) fn new(role: Role, keep_alive: bool) -> Self {
287 Self {
288 role,
289 flags: keep_alive as u8,
290 reserved: [0; 5],
291 }
292 }
293
294 pub(crate) fn to_content(&self) -> BytesMut {
296 let mut buf = BytesMut::with_capacity(8);
297 buf.put_u16(self.role as u16);
298 buf.put_u8(self.flags);
299 buf.put_slice(&self.reserved);
300 buf
301 }
302}
303
304#[derive(Debug)]
306pub(crate) struct BeginRequestRec {
307 pub(crate) header: Header,
309 pub(crate) begin_request: BeginRequest,
311 pub(crate) content: BytesMut,
313}
314
315impl BeginRequestRec {
316 pub(crate) fn new(request_id: u16, role: Role, keep_alive: bool) -> Self {
324 let begin_request = BeginRequest::new(role, keep_alive);
325 let content = begin_request.to_content();
326 let header = Header::new(RequestType::BeginRequest, request_id, &content);
327 Self {
328 header,
329 begin_request,
330 content,
331 }
332 }
333
334 pub(crate) async fn write_to_stream<W: AsyncWrite + Unpin>(
340 self, writer: &mut W,
341 ) -> io::Result<()> {
342 self.header.write_to_stream(writer, &self.content).await
343 }
344}
345
346#[derive(Debug, Clone, Copy)]
348pub enum ParamLength {
349 Short(u8),
351 Long(u32),
353}
354
355impl ParamLength {
356 pub fn new(length: usize) -> Self {
362 if length < 128 {
363 ParamLength::Short(length as u8)
364 } else {
365 let mut length = length;
366 length |= 1 << 31;
367 ParamLength::Long(length as u32)
368 }
369 }
370
371 pub fn content(self) -> BytesMut {
373 match self {
374 ParamLength::Short(l) => {
375 let mut buf = BytesMut::with_capacity(1);
376 buf.put_u8(l);
377 buf
378 }
379 ParamLength::Long(l) => {
380 let mut buf = BytesMut::with_capacity(4);
381 buf.put_u32(l);
382 buf
383 }
384 }
385 }
386}
387
388#[derive(Debug)]
390pub struct ParamPair<'a> {
391 name_length: ParamLength,
393 value_length: ParamLength,
395 name_data: Cow<'a, str>,
397 value_data: Cow<'a, str>,
399}
400
401impl<'a> ParamPair<'a> {
402 fn new(name: Cow<'a, str>, value: Cow<'a, str>) -> Self {
409 let name_length = ParamLength::new(name.len());
410 let value_length = ParamLength::new(value.len());
411 Self {
412 name_length,
413 value_length,
414 name_data: name,
415 value_data: value,
416 }
417 }
418
419 fn write_to_buf(&self, buf: &mut BytesMut) {
425 let name_len = self.name_length.content();
426 buf.extend_from_slice(&name_len);
427 let value_len = self.value_length.content();
428 buf.extend_from_slice(&value_len);
429 buf.extend_from_slice(self.name_data.as_bytes());
430 buf.extend_from_slice(self.value_data.as_bytes());
431 }
432}
433
434#[derive(Debug)]
436pub(crate) struct ParamPairs<'a>(Vec<ParamPair<'a>>);
437
438impl<'a> ParamPairs<'a> {
439 pub(crate) fn new(params: Params<'a>) -> Self {
445 let mut param_pairs = Vec::new();
446 let params: HashMap<Cow<'a, str>, Cow<'a, str>> = params.into();
447 for (name, value) in params.into_iter() {
448 let param_pair = ParamPair::new(name, value);
449 param_pairs.push(param_pair);
450 }
451
452 Self(param_pairs)
453 }
454
455 pub(crate) fn to_content(&self) -> Bytes {
457 let mut buf = BytesMut::new();
458
459 for param_pair in self.iter() {
460 param_pair.write_to_buf(&mut buf);
461 }
462
463 buf.freeze()
464 }
465}
466
467impl<'a> Deref for ParamPairs<'a> {
468 type Target = Vec<ParamPair<'a>>;
469
470 fn deref(&self) -> &Self::Target {
471 &self.0
472 }
473}
474
475impl<'a> DerefMut for ParamPairs<'a> {
476 fn deref_mut(&mut self) -> &mut Self::Target {
477 &mut self.0
478 }
479}
480
481#[derive(Debug)]
483#[repr(u8)]
484pub enum ProtocolStatus {
485 RequestComplete = 0,
487 CantMpxConn = 1,
489 Overloaded = 2,
491 UnknownRole = 3,
493}
494
495impl ProtocolStatus {
496 pub fn from_u8(u: u8) -> Self {
502 match u {
503 0 => ProtocolStatus::RequestComplete,
504 1 => ProtocolStatus::CantMpxConn,
505 2 => ProtocolStatus::Overloaded,
506 _ => ProtocolStatus::UnknownRole,
507 }
508 }
509
510 pub(crate) fn convert_to_client_result(self, app_status: u32) -> ClientResult<()> {
516 match self {
517 ProtocolStatus::RequestComplete => Ok(()),
518 _ => Err(ClientError::new_end_request_with_protocol_status(
519 self, app_status,
520 )),
521 }
522 }
523}
524
525#[derive(Debug)]
527pub struct EndRequest {
528 pub(crate) app_status: u32,
530 pub(crate) protocol_status: ProtocolStatus,
532 #[allow(dead_code)]
534 reserved: [u8; 3],
535}
536
537impl From<BytesMut> for EndRequest {
538 fn from(mut buf: BytesMut) -> Self {
539 let app_status = buf.get_u32();
540 let protocol_status = ProtocolStatus::from_u8(buf.get_u8());
541 let mut reserved = [0u8; 3];
542 buf.copy_to_slice(&mut reserved);
543
544 Self {
545 app_status,
546 protocol_status,
547 reserved,
548 }
549 }
550}
551
552#[derive(Debug)]
554pub(crate) struct EndRequestRec {
555 #[allow(dead_code)]
557 header: Header,
558 pub(crate) end_request: EndRequest,
560}
561
562impl EndRequestRec {
563 pub(crate) async fn from_header<R: AsyncRead + Unpin>(
570 header: &Header, reader: &mut R,
571 ) -> io::Result<Self> {
572 let header = header.clone();
573 let content = header.read_content_from_stream(reader).await?;
574 Ok(Self::new_from_buf(header, content))
575 }
576
577 pub(crate) fn new_from_buf(header: Header, buf: BytesMut) -> Self {
584 Self {
585 header,
586 end_request: EndRequest::from(buf),
587 }
588 }
589}