moq-lite 0.15.5

Media over QUIC - Transport (Lite)
Documentation
use std::borrow::Cow;

use crate::coding::{Decode, DecodeError, Encode, EncodeError};

use super::Message;

use super::Version;

#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct RequestId(pub u64);

impl RequestId {
	/// Returns the previous request ID and advances by 2.
	///
	/// IDs increment by 2 so peers keep parity separation:
	/// clients use even IDs and servers use odd IDs.
	pub fn increment(&mut self) -> RequestId {
		let prev = self.0;
		self.0 += 2;
		RequestId(prev)
	}
}

impl std::fmt::Display for RequestId {
	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
		write!(f, "{}", self.0)
	}
}

impl Encode<Version> for RequestId {
	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
		self.0.encode(w, version)?;
		Ok(())
	}
}

impl Decode<Version> for RequestId {
	fn decode<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
		let request_id = u64::decode(r, version)?;
		Ok(Self(request_id))
	}
}

#[derive(Clone, Debug)]
pub struct MaxRequestId {
	pub request_id: RequestId,
}

impl Message for MaxRequestId {
	const ID: u64 = 0x15;

	fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
		self.request_id.encode(w, version)?;
		Ok(())
	}

	fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
		let request_id = RequestId::decode(r, version)?;
		Ok(Self { request_id })
	}
}

#[derive(Clone, Debug)]
pub struct RequestsBlocked {
	pub request_id: RequestId,
}

impl Message for RequestsBlocked {
	const ID: u64 = 0x1a;

	fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
		self.request_id.encode(w, version)?;
		Ok(())
	}

	fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
		let request_id = RequestId::decode(r, version)?;
		Ok(Self { request_id })
	}
}

/// REQUEST_OK (0x07 in v15) - Generic success response for any request.
/// Replaces PublishNamespaceOk, SubscribeNamespaceOk in v15.
/// Also used as response to SubscribeUpdate and TrackStatus in v15.
#[derive(Clone, Debug)]
pub struct RequestOk {
	pub request_id: Option<RequestId>,
}

impl Message for RequestOk {
	const ID: u64 = 0x07;

	fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
		if version != Version::Draft17 {
			self.request_id
				.expect("request_id required for draft14-16")
				.encode(w, version)?;
		} else {
			assert!(self.request_id.is_none(), "request_id must be None for draft17");
		}
		encode_params!(w, version,);
		Ok(())
	}

	fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
		let request_id = if version == Version::Draft17 {
			None
		} else {
			Some(RequestId::decode(r, version)?)
		};
		decode_params!(r, version,);
		Ok(Self { request_id })
	}
}

/// REQUEST_ERROR (0x05 in v15) - Generic error response for any request.
/// Replaces SubscribeError, PublishError, PublishNamespaceError,
/// SubscribeNamespaceError, FetchError in v15.
#[derive(Clone, Debug)]
pub struct RequestError<'a> {
	pub request_id: Option<RequestId>,
	pub error_code: u64,
	pub reason_phrase: Cow<'a, str>,
	/// v16+: retry interval in milliseconds
	pub retry_interval: u64,
}

impl Message for RequestError<'_> {
	const ID: u64 = 0x05;

	fn encode_msg<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
		if version != Version::Draft17 {
			self.request_id
				.expect("request_id required for draft14-16")
				.encode(w, version)?;
		} else {
			assert!(self.request_id.is_none(), "request_id must be None for draft17");
		}
		self.error_code.encode(w, version)?;
		if version == Version::Draft16 || version == Version::Draft17 {
			self.retry_interval.encode(w, version)?;
		}
		self.reason_phrase.encode(w, version)?;
		Ok(())
	}

	fn decode_msg<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
		let request_id = if version == Version::Draft17 {
			None
		} else {
			Some(RequestId::decode(r, version)?)
		};
		let error_code = u64::decode(r, version)?;
		let retry_interval = match version {
			Version::Draft16 | Version::Draft17 => u64::decode(r, version)?,
			Version::Draft14 | Version::Draft15 => 0,
		};
		let reason_phrase = Cow::<str>::decode(r, version)?;
		Ok(Self {
			request_id,
			error_code,
			reason_phrase,
			retry_interval,
		})
	}
}

#[cfg(test)]
mod tests {
	use super::*;
	use bytes::BytesMut;

	fn encode_message<M: Message>(msg: &M, version: Version) -> Vec<u8> {
		let mut buf = BytesMut::new();
		msg.encode_msg(&mut buf, version).unwrap();
		buf.to_vec()
	}

	fn decode_message<M: Message>(bytes: &[u8], version: Version) -> Result<M, DecodeError> {
		let mut buf = bytes::Bytes::from(bytes.to_vec());
		M::decode_msg(&mut buf, version)
	}

	#[test]
	fn test_request_ok_round_trip() {
		let msg = RequestOk {
			request_id: Some(RequestId(42)),
		};

		let encoded = encode_message(&msg, Version::Draft15);
		let decoded: RequestOk = decode_message(&encoded, Version::Draft15).unwrap();

		assert_eq!(decoded.request_id, Some(RequestId(42)));
	}

	#[test]
	fn test_request_error_round_trip() {
		let msg = RequestError {
			request_id: Some(RequestId(99)),
			error_code: 500,
			reason_phrase: "Internal error".into(),
			retry_interval: 0,
		};

		let encoded = encode_message(&msg, Version::Draft15);
		let decoded: RequestError = decode_message(&encoded, Version::Draft15).unwrap();

		assert_eq!(decoded.request_id, Some(RequestId(99)));
		assert_eq!(decoded.error_code, 500);
		assert_eq!(decoded.reason_phrase, "Internal error");
		assert_eq!(decoded.retry_interval, 0);
	}

	#[test]
	fn test_request_error_v16_retry_interval() {
		let msg = RequestError {
			request_id: Some(RequestId(99)),
			error_code: 500,
			reason_phrase: "Internal error".into(),
			retry_interval: 5000,
		};

		let encoded = encode_message(&msg, Version::Draft16);
		let decoded: RequestError = decode_message(&encoded, Version::Draft16).unwrap();

		assert_eq!(decoded.request_id, Some(RequestId(99)));
		assert_eq!(decoded.error_code, 500);
		assert_eq!(decoded.reason_phrase, "Internal error");
		assert_eq!(decoded.retry_interval, 5000);
	}

	#[test]
	fn test_request_ok_v17_round_trip() {
		let msg = RequestOk { request_id: None };

		let encoded = encode_message(&msg, Version::Draft17);
		let decoded: RequestOk = decode_message(&encoded, Version::Draft17).unwrap();

		assert_eq!(decoded.request_id, None);
	}

	#[test]
	fn test_request_error_v17_round_trip() {
		let msg = RequestError {
			request_id: None,
			error_code: 500,
			reason_phrase: "Internal error".into(),
			retry_interval: 3000,
		};

		let encoded = encode_message(&msg, Version::Draft17);
		let decoded: RequestError = decode_message(&encoded, Version::Draft17).unwrap();

		assert_eq!(decoded.request_id, None);
		assert_eq!(decoded.error_code, 500);
		assert_eq!(decoded.reason_phrase, "Internal error");
		assert_eq!(decoded.retry_interval, 3000);
	}
}