use std::borrow::Cow;
use crate::coding::{Decode, DecodeError, Encode, EncodeError};
use super::Message;
use super::Version;
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct RequestId(pub u64);
impl RequestId {
pub fn increment(&mut self) -> RequestId {
let prev = self.0;
self.0 += 2;
RequestId(prev)
}
}
impl std::fmt::Display for RequestId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl Encode<Version> for RequestId {
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
self.0.encode(w, version)?;
Ok(())
}
}
impl Decode<Version> for RequestId {
fn decode<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
let request_id = u64::decode(r, version)?;
Ok(Self(request_id))
}
}
#[derive(Clone, Debug)]
pub struct MaxRequestId {
pub request_id: RequestId,
}
impl Message for MaxRequestId {
const ID: u64 = 0x15;
fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
self.request_id.encode(w, version)?;
Ok(())
}
fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
let request_id = RequestId::decode(r, version)?;
Ok(Self { request_id })
}
}
#[derive(Clone, Debug)]
pub struct RequestsBlocked {
pub request_id: RequestId,
}
impl Message for RequestsBlocked {
const ID: u64 = 0x1a;
fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
self.request_id.encode(w, version)?;
Ok(())
}
fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
let request_id = RequestId::decode(r, version)?;
Ok(Self { request_id })
}
}
#[derive(Clone, Debug)]
pub struct RequestOk {
pub request_id: Option<RequestId>,
}
impl Message for RequestOk {
const ID: u64 = 0x07;
fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
if version != Version::Draft17 {
self.request_id
.expect("request_id required for draft14-16")
.encode(w, version)?;
} else {
assert!(self.request_id.is_none(), "request_id must be None for draft17");
}
encode_params!(w, version,);
Ok(())
}
fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
let request_id = if version == Version::Draft17 {
None
} else {
Some(RequestId::decode(r, version)?)
};
decode_params!(r, version,);
Ok(Self { request_id })
}
}
#[derive(Clone, Debug)]
pub struct RequestError<'a> {
pub request_id: Option<RequestId>,
pub error_code: u64,
pub reason_phrase: Cow<'a, str>,
pub retry_interval: u64,
}
impl Message for RequestError<'_> {
const ID: u64 = 0x05;
fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
if version != Version::Draft17 {
self.request_id
.expect("request_id required for draft14-16")
.encode(w, version)?;
} else {
assert!(self.request_id.is_none(), "request_id must be None for draft17");
}
self.error_code.encode(w, version)?;
if version == Version::Draft16 || version == Version::Draft17 {
self.retry_interval.encode(w, version)?;
}
self.reason_phrase.encode(w, version)?;
Ok(())
}
fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
let request_id = if version == Version::Draft17 {
None
} else {
Some(RequestId::decode(r, version)?)
};
let error_code = u64::decode(r, version)?;
let retry_interval = match version {
Version::Draft16 | Version::Draft17 => u64::decode(r, version)?,
Version::Draft14 | Version::Draft15 => 0,
};
let reason_phrase = Cow::<str>::decode(r, version)?;
Ok(Self {
request_id,
error_code,
reason_phrase,
retry_interval,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BytesMut;
fn encode_message<M: Message>(msg: &M, version: Version) -> Vec<u8> {
let mut buf = BytesMut::new();
msg.encode_msg(&mut buf, version).unwrap();
buf.to_vec()
}
fn decode_message<M: Message>(bytes: &[u8], version: Version) -> Result<M, DecodeError> {
let mut buf = bytes::Bytes::from(bytes.to_vec());
M::decode_msg(&mut buf, version)
}
#[test]
fn test_request_ok_round_trip() {
let msg = RequestOk {
request_id: Some(RequestId(42)),
};
let encoded = encode_message(&msg, Version::Draft15);
let decoded: RequestOk = decode_message(&encoded, Version::Draft15).unwrap();
assert_eq!(decoded.request_id, Some(RequestId(42)));
}
#[test]
fn test_request_error_round_trip() {
let msg = RequestError {
request_id: Some(RequestId(99)),
error_code: 500,
reason_phrase: "Internal error".into(),
retry_interval: 0,
};
let encoded = encode_message(&msg, Version::Draft15);
let decoded: RequestError = decode_message(&encoded, Version::Draft15).unwrap();
assert_eq!(decoded.request_id, Some(RequestId(99)));
assert_eq!(decoded.error_code, 500);
assert_eq!(decoded.reason_phrase, "Internal error");
assert_eq!(decoded.retry_interval, 0);
}
#[test]
fn test_request_error_v16_retry_interval() {
let msg = RequestError {
request_id: Some(RequestId(99)),
error_code: 500,
reason_phrase: "Internal error".into(),
retry_interval: 5000,
};
let encoded = encode_message(&msg, Version::Draft16);
let decoded: RequestError = decode_message(&encoded, Version::Draft16).unwrap();
assert_eq!(decoded.request_id, Some(RequestId(99)));
assert_eq!(decoded.error_code, 500);
assert_eq!(decoded.reason_phrase, "Internal error");
assert_eq!(decoded.retry_interval, 5000);
}
#[test]
fn test_request_ok_v17_round_trip() {
let msg = RequestOk { request_id: None };
let encoded = encode_message(&msg, Version::Draft17);
let decoded: RequestOk = decode_message(&encoded, Version::Draft17).unwrap();
assert_eq!(decoded.request_id, None);
}
#[test]
fn test_request_error_v17_round_trip() {
let msg = RequestError {
request_id: None,
error_code: 500,
reason_phrase: "Internal error".into(),
retry_interval: 3000,
};
let encoded = encode_message(&msg, Version::Draft17);
let decoded: RequestError = decode_message(&encoded, Version::Draft17).unwrap();
assert_eq!(decoded.request_id, None);
assert_eq!(decoded.error_code, 500);
assert_eq!(decoded.reason_phrase, "Internal error");
assert_eq!(decoded.retry_interval, 3000);
}
}