use alloc::string::String;
use core::{
fmt::{self, Display},
ops::Deref,
};
use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
use bytestring::ByteString;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Subject(ByteString);
impl Subject {
#[must_use]
pub fn from_static(value: &'static str) -> Self {
Self::try_from(ByteString::from_static(value)).expect("invalid Subject")
}
#[expect(
clippy::missing_panics_doc,
reason = "The subject validation is only made in debug"
)]
#[must_use]
pub fn from_dangerous_value(value: ByteString) -> Self {
if cfg!(debug_assertions) {
if let Err(err) = validate_subject(&value) {
panic!("Subject {value:?} isn't valid {err:?}");
}
}
Self(value)
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Display for Subject {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Display::fmt(&self.0, f)
}
}
impl TryFrom<ByteString> for Subject {
type Error = SubjectValidateError;
fn try_from(value: ByteString) -> Result<Self, Self::Error> {
validate_subject(&value)?;
Ok(Self::from_dangerous_value(value))
}
}
impl TryFrom<String> for Subject {
type Error = SubjectValidateError;
fn try_from(value: String) -> Result<Self, Self::Error> {
validate_subject(&value)?;
Ok(Self::from_dangerous_value(value.into()))
}
}
impl From<Subject> for ByteString {
fn from(value: Subject) -> Self {
value.0
}
}
impl AsRef<[u8]> for Subject {
fn as_ref(&self) -> &[u8] {
self.as_str().as_bytes()
}
}
impl AsRef<str> for Subject {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Deref for Subject {
type Target = str;
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl Serialize for Subject {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.as_str().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Subject {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = ByteString::deserialize(deserializer)?;
s.try_into().map_err(de::Error::custom)
}
}
#[derive(Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub enum SubjectValidateError {
#[error("Subject is empty")]
Empty,
#[error("Subject is too long")]
TooLong,
#[error("Subject contained an illegal whitespace character")]
IllegalCharacter,
#[error("Subject contained a broken token")]
BrokenToken,
#[error("Subject contained a broken wildcard")]
BrokenWildcard,
}
fn validate_subject(subject: &str) -> Result<(), SubjectValidateError> {
if subject.is_empty() {
return Err(SubjectValidateError::Empty);
}
if subject.len() > 256 {
return Err(SubjectValidateError::TooLong);
}
if subject.chars().any(char::is_whitespace) {
return Err(SubjectValidateError::IllegalCharacter);
}
let mut tokens = subject.split('.').peekable();
while let Some(token) = tokens.next() {
if token.is_empty() || token.contains("..") {
return Err(SubjectValidateError::BrokenToken);
}
if token.len() > 1 && (token.contains(['*', '>'])) {
return Err(SubjectValidateError::BrokenWildcard);
}
if token == ">" && tokens.peek().is_some() {
return Err(SubjectValidateError::BrokenWildcard);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use bytestring::ByteString;
use super::{Subject, SubjectValidateError};
#[test]
fn valid_subjects() {
let subjects = [
"cmd",
"cmd.endpoint",
"cmd.endpoint.detail",
"cmd.*.detail",
"cmd.*.*",
"cmd.endpoint.>",
];
for subject in subjects {
let s = Subject::try_from(ByteString::from_static(subject)).unwrap();
assert_eq!(subject, s.as_str());
}
}
#[test]
fn invalid_subjects() {
let subjects = [
("", SubjectValidateError::Empty),
(
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
SubjectValidateError::TooLong,
),
("cmd ", SubjectValidateError::IllegalCharacter),
("cmd .endpoint", SubjectValidateError::IllegalCharacter),
(" cmd", SubjectValidateError::IllegalCharacter),
("cmd.endpoint ", SubjectValidateError::IllegalCharacter),
(
"cmd.endpoint.detail ",
SubjectValidateError::IllegalCharacter,
),
(
"cmd.endpoint.detail\r",
SubjectValidateError::IllegalCharacter,
),
(
"cmd.endpoint.detail\n",
SubjectValidateError::IllegalCharacter,
),
(
"cmd.endpoint.detail\t",
SubjectValidateError::IllegalCharacter,
),
(
"cmd.endp oint.detail",
SubjectValidateError::IllegalCharacter,
),
(
"cmd.endp\roint.detail",
SubjectValidateError::IllegalCharacter,
),
(
"cmd.endp\noint.detail",
SubjectValidateError::IllegalCharacter,
),
(
"cmd.endp\toint.detail",
SubjectValidateError::IllegalCharacter,
),
(
"cmd.endpoint .detail",
SubjectValidateError::IllegalCharacter,
),
(
"cmd.endpoint\r.detail",
SubjectValidateError::IllegalCharacter,
),
(
"cmd.endpoint\n.detail",
SubjectValidateError::IllegalCharacter,
),
(
"cmd.endpoint\t.detail",
SubjectValidateError::IllegalCharacter,
),
(" ", SubjectValidateError::IllegalCharacter),
("\r", SubjectValidateError::IllegalCharacter),
("\n", SubjectValidateError::IllegalCharacter),
("\t", SubjectValidateError::IllegalCharacter),
("cmd..endpoint", SubjectValidateError::BrokenToken),
(".cmd.endpoint", SubjectValidateError::BrokenToken),
("cmd.endpoint.", SubjectValidateError::BrokenToken),
("cmd.**", SubjectValidateError::BrokenWildcard),
("cmd.**.endpoint", SubjectValidateError::BrokenWildcard),
("cmd.a*.endpoint", SubjectValidateError::BrokenWildcard),
("cmd.*a.endpoint", SubjectValidateError::BrokenWildcard),
("cmd.>.endpoint", SubjectValidateError::BrokenWildcard),
("cmd.a>.endpoint", SubjectValidateError::BrokenWildcard),
("cmd.endpoint.a>", SubjectValidateError::BrokenWildcard),
("cmd.endpoint.>a", SubjectValidateError::BrokenWildcard),
];
for (subject, expected_err) in subjects {
let err = Subject::try_from(ByteString::from_static(subject)).unwrap_err();
assert_eq!(expected_err, err);
}
}
}