use serde::{Deserialize, Serialize};
#[derive(Clone, Copy)]
pub(crate) enum BuiltinSegment {
Literal(&'static [u8]),
Variable {
charset: CharsetName,
min: usize,
max: usize,
},
#[allow(dead_code)]
Opaque {
value: &'static [u8],
charset: CharsetName,
},
}
#[derive(Clone, Debug)]
pub(crate) enum Segment {
Literal(Vec<u8>),
Variable {
charset: CharsetName,
min: usize,
max: usize,
},
Opaque {
value: Vec<u8>,
charset: CharsetName,
},
}
impl From<&BuiltinSegment> for Segment {
fn from(b: &BuiltinSegment) -> Self {
match b {
BuiltinSegment::Literal(bytes) => Segment::Literal(bytes.to_vec()),
BuiltinSegment::Variable { charset, min, max } => Segment::Variable {
charset: *charset,
min: *min,
max: *max,
},
BuiltinSegment::Opaque { value, charset } => Segment::Opaque {
value: value.to_vec(),
charset: *charset,
},
}
}
}
impl Segment {
pub(crate) fn from_def(def: &SegmentDef) -> Result<Self, SegmentDefError> {
match def {
SegmentDef::Literal { value } => Ok(Segment::Literal(value.as_bytes().to_vec())),
SegmentDef::Variable { charset, min, max } => {
let charset_name = CharsetName::from_name(charset).ok_or_else(|| {
SegmentDefError::UnknownCharset {
index: 0,
name: charset.clone(),
}
})?;
Ok(Segment::Variable {
charset: charset_name,
min: *min,
max: *max,
})
}
SegmentDef::Opaque { value, charset } => {
let value_bytes = value.as_bytes().to_vec();
let charset_name = match charset {
Some(name) => CharsetName::from_name(name).ok_or_else(|| {
SegmentDefError::UnknownCharset {
index: 0,
name: name.clone(),
}
})?,
None => detect_charset_name(value_bytes.as_slice()),
};
Ok(Segment::Opaque {
value: value_bytes,
charset: charset_name,
})
}
}
}
pub(crate) fn try_to_def(&self) -> Result<SegmentDef, SegmentDefError> {
match self {
Segment::Literal(bytes) => Ok(SegmentDef::Literal {
value: std::str::from_utf8(bytes)
.map_err(|_| SegmentDefError::NonUtf8Bytes)?
.to_owned(),
}),
Segment::Variable { charset, min, max } => Ok(SegmentDef::Variable {
charset: charset.as_str().to_string(),
min: *min,
max: *max,
}),
Segment::Opaque { value, charset } => Ok(SegmentDef::Opaque {
value: std::str::from_utf8(value)
.map_err(|_| SegmentDefError::NonUtf8Bytes)?
.to_owned(),
charset: Some(charset.as_str().to_owned()),
}),
}
}
pub(crate) fn to_def(&self) -> SegmentDef {
self.try_to_def()
.expect("built-in segment bytes are always ASCII")
}
}
pub(crate) struct MatchCapture {
pub(crate) end: usize,
pub(crate) variable_lengths: Vec<usize>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub(crate) enum CharsetName {
Alphanumeric,
UrlSafeBase64,
UppercaseAlphanumeric,
Digits,
HexLower,
Wide,
}
impl CharsetName {
pub(crate) fn resolve(&self) -> &'static crate::fake::Charset {
match self {
CharsetName::Alphanumeric => crate::fake::alphanumeric_ref(),
CharsetName::UrlSafeBase64 => crate::fake::url_safe_base64_ref(),
CharsetName::UppercaseAlphanumeric => crate::fake::uppercase_alphanumeric_ref(),
CharsetName::Digits => crate::fake::digits_ref(),
CharsetName::HexLower => crate::fake::hex_lower_ref(),
CharsetName::Wide => crate::fake::wide_ref(),
}
}
pub(crate) fn from_name(name: &str) -> Option<Self> {
match name {
"alphanumeric" => Some(CharsetName::Alphanumeric),
"url_safe_base64" => Some(CharsetName::UrlSafeBase64),
"uppercase_alphanumeric" => Some(CharsetName::UppercaseAlphanumeric),
"digits" => Some(CharsetName::Digits),
"hex_lower" => Some(CharsetName::HexLower),
"wide" => Some(CharsetName::Wide),
_ => None,
}
}
pub(crate) fn as_str(&self) -> &'static str {
match self {
CharsetName::Alphanumeric => "alphanumeric",
CharsetName::UrlSafeBase64 => "url_safe_base64",
CharsetName::UppercaseAlphanumeric => "uppercase_alphanumeric",
CharsetName::Digits => "digits",
CharsetName::HexLower => "hex_lower",
CharsetName::Wide => "wide",
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum SegmentDef {
Literal {
value: String,
},
Variable {
charset: String,
min: usize,
max: usize,
},
Opaque {
value: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
charset: Option<String>,
},
}
pub(crate) fn validate_segment_defs(defs: &[SegmentDef]) -> Result<(), SegmentDefError> {
let mut has_variable = false;
match defs.first() {
Some(SegmentDef::Literal { value }) | Some(SegmentDef::Opaque { value, .. }) => {
if value.len() < 2 {
return Err(SegmentDefError::FirstSegmentTooShort { len: value.len() });
}
}
Some(SegmentDef::Variable { .. }) => {
return Err(SegmentDefError::FirstSegmentVariable);
}
None => {
return Err(SegmentDefError::EmptySegmentList);
}
}
for (i, def) in defs.iter().enumerate() {
match def {
SegmentDef::Variable { charset, min, max } => {
has_variable = true;
if CharsetName::from_name(charset).is_none() {
return Err(SegmentDefError::UnknownCharset {
index: i,
name: charset.clone(),
});
}
if min > max {
return Err(SegmentDefError::MinExceedsMax {
index: i,
min: *min,
max: *max,
});
}
if *min < 1 {
return Err(SegmentDefError::MinTooSmall { index: i });
}
}
SegmentDef::Opaque {
charset: Some(name),
..
} if CharsetName::from_name(name).is_none() => {
return Err(SegmentDefError::UnknownCharset {
index: i,
name: name.clone(),
});
}
_ => {}
}
}
if !has_variable {
return Err(SegmentDefError::NoVariableSegment);
}
Ok(())
}
pub(crate) fn detect_charset_name(bytes: &[u8]) -> CharsetName {
if bytes.iter().all(|&b| b.is_ascii_digit()) {
CharsetName::Digits
} else if bytes
.iter()
.all(|&b| matches!(b, b'0'..=b'9' | b'a'..=b'f'))
{
CharsetName::HexLower
} else if bytes
.iter()
.all(|&b| b.is_ascii_uppercase() || b.is_ascii_digit())
{
CharsetName::UppercaseAlphanumeric
} else if bytes.iter().all(|&b| b.is_ascii_alphanumeric()) {
CharsetName::Alphanumeric
} else if bytes
.iter()
.all(|&b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_')
{
CharsetName::UrlSafeBase64
} else {
CharsetName::Wide
}
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum SegmentDefError {
#[error("segment list must contain at least one variable segment")]
NoVariableSegment,
#[error(
"unknown charset \"{name}\" in segment {index}; valid: alphanumeric, url_safe_base64, uppercase_alphanumeric, digits, hex_lower, wide (variable and opaque segments)"
)]
UnknownCharset {
index: usize,
name: String,
},
#[error("segment {index}: min ({min}) must not exceed max ({max})")]
MinExceedsMax {
index: usize,
min: usize,
max: usize,
},
#[error("segment {index}: min must be at least 1")]
MinTooSmall {
index: usize,
},
#[error("first segment must be literal or opaque, not variable")]
FirstSegmentVariable,
#[error("segment list must not be empty")]
EmptySegmentList,
#[error(
"first segment value is {len} byte(s); minimum 2 bytes to avoid excessive false-positive AC hits"
)]
FirstSegmentTooShort {
len: usize,
},
#[error("segment value contains non-UTF-8 bytes")]
NonUtf8Bytes,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn charset_name_serde_round_trip() {
let names = [
(CharsetName::Alphanumeric, "\"alphanumeric\""),
(CharsetName::UrlSafeBase64, "\"url_safe_base64\""),
(
CharsetName::UppercaseAlphanumeric,
"\"uppercase_alphanumeric\"",
),
(CharsetName::Digits, "\"digits\""),
(CharsetName::HexLower, "\"hex_lower\""),
];
for (variant, expected_json) in &names {
let json = serde_json::to_string(variant).unwrap();
assert_eq!(&json, expected_json);
let parsed: CharsetName = serde_json::from_str(&json).unwrap();
assert_eq!(&parsed, variant);
}
}
#[test]
fn segment_def_tagged_serde() {
let lit = SegmentDef::Literal {
value: "sk-".into(),
};
let var = SegmentDef::Variable {
charset: "alphanumeric".into(),
min: 32,
max: 48,
};
let json_lit = serde_json::to_string(&lit).unwrap();
assert!(json_lit.contains("\"type\":\"literal\""));
assert!(json_lit.contains("\"value\":\"sk-\""));
let json_var = serde_json::to_string(&var).unwrap();
assert!(json_var.contains("\"type\":\"variable\""));
assert!(json_var.contains("\"charset\":\"alphanumeric\""));
let round: SegmentDef = serde_json::from_str(&json_lit).unwrap();
assert_eq!(round, lit);
let round: SegmentDef = serde_json::from_str(&json_var).unwrap();
assert_eq!(round, var);
}
#[test]
fn validate_segment_defs_rejects_no_variable() {
let defs = vec![SegmentDef::Literal {
value: "prefix".into(),
}];
let err = validate_segment_defs(&defs).unwrap_err();
assert!(err.to_string().contains("at least one variable segment"));
}
#[test]
fn validate_segment_defs_rejects_unknown_charset() {
let defs = vec![
SegmentDef::Literal {
value: "prefix_".into(),
},
SegmentDef::Variable {
charset: "bogus".into(),
min: 1,
max: 10,
},
];
let err = validate_segment_defs(&defs).unwrap_err();
assert!(err.to_string().contains("unknown charset \"bogus\""));
}
#[test]
fn builtin_to_owned_conversion() {
let builtin_lit = BuiltinSegment::Literal(b"sk-ant-");
let owned = Segment::from(&builtin_lit);
match owned {
Segment::Literal(v) => assert_eq!(v, b"sk-ant-"),
_ => panic!("expected Literal"),
}
let builtin_var = BuiltinSegment::Variable {
charset: CharsetName::Alphanumeric,
min: 10,
max: 20,
};
let owned = Segment::from(&builtin_var);
match owned {
Segment::Variable { charset, min, max } => {
assert_eq!(charset, CharsetName::Alphanumeric);
assert_eq!(min, 10);
assert_eq!(max, 20);
}
_ => panic!("expected Variable"),
}
}
#[test]
fn validate_segment_defs_rejects_empty_first_literal() {
let defs = vec![
SegmentDef::Literal {
value: String::new(),
},
SegmentDef::Variable {
charset: "alphanumeric".into(),
min: 8,
max: 8,
},
];
let err = validate_segment_defs(&defs).unwrap_err();
assert!(matches!(
err,
SegmentDefError::FirstSegmentTooShort { len: 0 }
));
assert!(err.to_string().contains("0 byte"));
}
#[test]
fn validate_segment_defs_rejects_one_byte_first_literal() {
let defs = vec![
SegmentDef::Literal { value: "A".into() },
SegmentDef::Variable {
charset: "alphanumeric".into(),
min: 8,
max: 8,
},
];
let err = validate_segment_defs(&defs).unwrap_err();
assert!(matches!(
err,
SegmentDefError::FirstSegmentTooShort { len: 1 }
));
}
#[test]
fn validate_segment_defs_accepts_two_byte_first_literal() {
let defs = vec![
SegmentDef::Literal { value: "sk".into() },
SegmentDef::Variable {
charset: "alphanumeric".into(),
min: 8,
max: 8,
},
];
assert!(validate_segment_defs(&defs).is_ok());
}
}