use super::bidi::{bidi_class, BidiClass};
use super::generated::idna as gen;
use super::generated::properties::JoiningType;
use super::normalize::{canonical_combining_class, nfc};
use super::predicates::{is_mark, joining_type};
use alloc::string::String;
use alloc::vec::Vec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum Error {
Disallowed,
InvalidLabel,
Punycode,
Hyphen,
Bidi,
ContextJ,
LeadingMark,
}
const BASE: u32 = 36;
const TMIN: u32 = 1;
const TMAX: u32 = 26;
const SKEW: u32 = 38;
const DAMP: u32 = 700;
const INITIAL_BIAS: u32 = 72;
const INITIAL_N: u32 = 128;
fn adapt(mut delta: u32, num_points: u32, first_time: bool) -> u32 {
delta /= if first_time { DAMP } else { 2 };
delta += delta / num_points;
let mut k = 0;
while delta > ((BASE - TMIN) * TMAX) / 2 {
delta /= BASE - TMIN;
k += BASE;
}
k + (BASE - TMIN + 1) * delta / (delta + SKEW)
}
fn digit_to_basic(d: u32) -> char {
if d < 26 {
(b'a' + d as u8) as char
} else {
(b'0' + (d - 26) as u8) as char
}
}
fn basic_to_digit(c: char) -> Option<u32> {
match c {
'a'..='z' => Some(c as u32 - 'a' as u32),
'A'..='Z' => Some(c as u32 - 'A' as u32),
'0'..='9' => Some(c as u32 - '0' as u32 + 26),
_ => None,
}
}
fn punycode_encode(input: &[char]) -> Option<String> {
let mut output = String::new();
let mut n = INITIAL_N;
let mut delta: u32 = 0;
let mut bias = INITIAL_BIAS;
for &c in input {
if (c as u32) < 0x80 {
output.push(c);
}
}
let basic = output.len() as u32;
if basic > 0 {
output.push('-');
}
let mut handled = basic;
let total = input.len() as u32;
while handled < total {
let m = input.iter().map(|&c| c as u32).filter(|&c| c >= n).min()?;
delta = delta.checked_add((m - n).checked_mul(handled + 1)?)?;
n = m;
for &c in input {
let c = c as u32;
if c < n {
delta = delta.checked_add(1)?;
}
if c == n {
let mut q = delta;
let mut k = BASE;
loop {
let t = k.saturating_sub(bias).clamp(TMIN, TMAX);
if q < t {
break;
}
output.push(digit_to_basic(t + (q - t) % (BASE - t)));
q = (q - t) / (BASE - t);
k += BASE;
}
output.push(digit_to_basic(q));
bias = adapt(delta, handled + 1, handled == basic);
delta = 0;
handled += 1;
}
}
delta += 1;
n += 1;
}
Some(output)
}
const MAX_LABEL_CODE_POINTS: usize = 63;
fn punycode_decode(input: &str) -> Option<Vec<char>> {
let mut output: Vec<u32> = Vec::new();
let mut n = INITIAL_N;
let mut i: u32 = 0;
let mut bias = INITIAL_BIAS;
let bytes: Vec<char> = input.chars().collect();
let (basic_end, has_basic) = match bytes.iter().rposition(|&c| c == '-') {
Some(p) => (p, true),
None => (0, false),
};
if has_basic {
for &c in &bytes[..basic_end] {
if (c as u32) >= 0x80 {
return None;
}
if output.len() >= MAX_LABEL_CODE_POINTS {
return None;
}
output.push(c as u32);
}
}
let mut pos = if has_basic { basic_end + 1 } else { 0 };
while pos < bytes.len() {
let old_i = i;
let mut w = 1;
let mut k = BASE;
loop {
let c = *bytes.get(pos)?;
pos += 1;
let digit = basic_to_digit(c)?;
i = i.checked_add(digit.checked_mul(w)?)?;
let t = k.saturating_sub(bias).clamp(TMIN, TMAX);
if digit < t {
break;
}
w = w.checked_mul(BASE - t)?;
k += BASE;
}
if output.len() >= MAX_LABEL_CODE_POINTS {
return None;
}
let len = output.len() as u32 + 1;
bias = adapt(i - old_i, len, old_i == 0);
n = n.checked_add(i / len)?;
i %= len;
output.insert(i as usize, n);
i += 1;
}
output.into_iter().map(char::from_u32).collect()
}
fn map_and_normalize(domain: &str) -> Result<String, Error> {
let mut mapped: Vec<char> = Vec::new();
for c in domain.chars() {
match gen::idna_status(c as u32) {
0 => mapped.push(c), 1 => mapped.extend_from_slice(idna_map(c)), 2 => {} _ => return Err(Error::Disallowed), }
}
Ok(nfc(mapped.into_iter()).collect())
}
fn idna_map(c: char) -> &'static [char] {
gen::idna_mapped(c as u32).unwrap_or(&[])
}
fn validate_label(label: &[char], is_bidi_domain: bool) -> Result<(), Error> {
if label.is_empty() {
return Err(Error::InvalidLabel);
}
for &c in label {
if gen::idna_status(c as u32) != 0 {
return Err(Error::InvalidLabel);
}
}
if nfc(label.iter().copied()).collect::<Vec<char>>() != label {
return Err(Error::InvalidLabel);
}
if label.len() >= 4 && label[2] == '-' && label[3] == '-' {
return Err(Error::Hyphen);
}
if label.first() == Some(&'-') || label.last() == Some(&'-') {
return Err(Error::Hyphen);
}
if is_mark(label[0]) {
return Err(Error::LeadingMark);
}
check_joiners(label)?;
if is_bidi_domain {
check_bidi(label)?;
}
Ok(())
}
fn check_joiners(label: &[char]) -> Result<(), Error> {
for (i, &c) in label.iter().enumerate() {
match c {
'\u{200C}' => {
if i > 0 && canonical_combining_class(label[i - 1]) == 9 {
continue;
}
if has_joining_context(label, i) {
continue;
}
return Err(Error::ContextJ);
}
'\u{200D}' => {
if i > 0 && canonical_combining_class(label[i - 1]) == 9 {
continue;
}
return Err(Error::ContextJ);
}
_ => {}
}
}
Ok(())
}
fn has_joining_context(label: &[char], pos: usize) -> bool {
let mut left = None;
for j in (0..pos).rev() {
match joining_type(label[j]) {
JoiningType::Transparent => continue,
jt => {
left = Some(jt);
break;
}
}
}
let before = matches!(
left,
Some(JoiningType::LeftJoining | JoiningType::DualJoining)
);
if !before {
return false;
}
let mut right = None;
for &c in &label[pos + 1..] {
match joining_type(c) {
JoiningType::Transparent => continue,
jt => {
right = Some(jt);
break;
}
}
}
matches!(
right,
Some(JoiningType::RightJoining | JoiningType::DualJoining)
)
}
fn check_bidi(label: &[char]) -> Result<(), Error> {
let first = bidi_class(label[0]);
if !matches!(first, BidiClass::L | BidiClass::R | BidiClass::AL) {
return Err(Error::Bidi);
}
if first.is_rtl() {
for &c in label {
match bidi_class(c) {
BidiClass::R
| BidiClass::AL
| BidiClass::AN
| BidiClass::EN
| BidiClass::ES
| BidiClass::CS
| BidiClass::ET
| BidiClass::ON
| BidiClass::BN
| BidiClass::NSM => {}
_ => return Err(Error::Bidi),
}
}
match last_non_nsm(label) {
Some(BidiClass::R | BidiClass::AL | BidiClass::EN | BidiClass::AN) => {}
_ => return Err(Error::Bidi),
}
let has_en = label.iter().any(|&c| bidi_class(c) == BidiClass::EN);
let has_an = label.iter().any(|&c| bidi_class(c) == BidiClass::AN);
if has_en && has_an {
return Err(Error::Bidi);
}
} else {
for &c in label {
match bidi_class(c) {
BidiClass::L
| BidiClass::EN
| BidiClass::ES
| BidiClass::CS
| BidiClass::ET
| BidiClass::ON
| BidiClass::BN
| BidiClass::NSM => {}
_ => return Err(Error::Bidi),
}
}
match last_non_nsm(label) {
Some(BidiClass::L | BidiClass::EN) => {}
_ => return Err(Error::Bidi),
}
}
Ok(())
}
fn last_non_nsm(label: &[char]) -> Option<BidiClass> {
label
.iter()
.rev()
.map(|&c| bidi_class(c))
.find(|&bc| bc != BidiClass::NSM)
}
pub fn to_ascii(domain: &str) -> Result<String, Error> {
let processed = map_and_normalize(domain)?;
let labels: Vec<&str> = processed.split('.').collect();
struct LabelInfo<'a> {
original: &'a str,
is_a_label: bool,
ulabel: Vec<char>,
}
let mut infos: Vec<LabelInfo> = Vec::with_capacity(labels.len());
for label in &labels {
if label.is_empty() {
return Err(Error::InvalidLabel); }
let (is_a_label, ulabel): (bool, Vec<char>) = match label.strip_prefix("xn--") {
Some(rest) => (true, punycode_decode(rest).ok_or(Error::Punycode)?),
None => (false, label.chars().collect()),
};
if ulabel.len() > MAX_LABEL_CODE_POINTS {
return Err(Error::InvalidLabel); }
infos.push(LabelInfo {
original: label,
is_a_label,
ulabel,
});
}
let is_bidi_domain = infos.iter().any(|info| {
info.ulabel
.iter()
.any(|&c| matches!(bidi_class(c), BidiClass::R | BidiClass::AL | BidiClass::AN))
});
let mut out: Vec<String> = Vec::with_capacity(infos.len());
for info in &infos {
validate_label(&info.ulabel, is_bidi_domain)?;
let ascii = if info.is_a_label {
let encoded = punycode_encode(&info.ulabel).ok_or(Error::Punycode)?;
let mut canonical = String::from("xn--");
canonical.push_str(&encoded);
if !canonical.eq_ignore_ascii_case(info.original) {
return Err(Error::Punycode);
}
String::from(info.original)
} else if info.ulabel.iter().all(char::is_ascii) {
String::from(info.original)
} else {
let encoded = punycode_encode(&info.ulabel).ok_or(Error::Punycode)?;
let mut l = String::from("xn--");
l.push_str(&encoded);
l
};
if ascii.is_empty() || ascii.len() > 63 {
return Err(Error::InvalidLabel);
}
out.push(ascii);
}
let result = out.join(".");
if result.is_empty() {
return Err(Error::InvalidLabel);
}
if result.len() > 253 {
return Err(Error::InvalidLabel);
}
Ok(result)
}
pub fn to_unicode(domain: &str) -> Result<String, Error> {
let processed = map_and_normalize(domain)?;
let mut out: Vec<String> = Vec::new();
for label in processed.split('.') {
if let Some(rest) = label.strip_prefix("xn--") {
let decoded = punycode_decode(rest).ok_or(Error::Punycode)?;
out.push(decoded.into_iter().collect());
} else {
out.push(String::from(label));
}
}
Ok(out.join("."))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn deviation_chars_remain_valid() {
assert_eq!(to_ascii("faß.de").unwrap(), "xn--fa-hia.de");
assert!(to_ascii("ςoς.example").is_ok());
}
#[test]
fn rtl_bidi_violation_rejected() {
assert_eq!(to_ascii("\u{05D0}a"), Err(Error::Bidi));
}
#[test]
fn bare_zwnj_rejected() {
assert_eq!(to_ascii("a\u{200C}b.example"), Err(Error::ContextJ));
}
#[test]
fn virama_zwnj_accepted() {
assert!(to_ascii("\u{0915}\u{094D}\u{200C}.example").is_ok());
}
#[test]
fn non_canonical_xn_label_rejected() {
let canonical = "bcher-kva";
let decoded = punycode_decode(canonical).unwrap();
assert_eq!(decoded, alloc::vec!['b', 'ü', 'c', 'h', 'e', 'r']);
assert_eq!(punycode_encode(&decoded).unwrap(), canonical);
let other = punycode_encode(&alloc::vec!['b', 'ü', 'c', 'h', 'e', 'r', 'ÿ']).unwrap();
assert_ne!(other, canonical);
let reencoded = punycode_encode(&decoded).unwrap();
assert_ne!(other, reencoded);
assert_eq!(to_ascii("xn--bcher-kva0.example"), Err(Error::Punycode));
assert_eq!(
to_ascii("xn--bcher-kva.example").unwrap(),
"xn--bcher-kva.example"
);
}
#[test]
fn leading_combining_mark_rejected() {
assert_eq!(to_ascii("\u{0300}a.example"), Err(Error::LeadingMark));
}
#[test]
fn double_hyphen_positions_rejected() {
assert_eq!(to_ascii("ab--c.example"), Err(Error::Hyphen));
assert_eq!(to_ascii("-abc.example"), Err(Error::Hyphen));
assert_eq!(to_ascii("abc-.example"), Err(Error::Hyphen));
}
#[test]
fn verify_dns_length_rejects_empty_labels() {
assert_eq!(to_ascii("example.com."), Err(Error::InvalidLabel));
assert_eq!(to_ascii("鱊。"), Err(Error::InvalidLabel));
assert_eq!(to_ascii("a..b"), Err(Error::InvalidLabel));
assert_eq!(to_ascii("example.com").unwrap(), "example.com");
}
#[test]
fn to_unicode_surfaces_decode_failure() {
assert!(to_unicode("xn--ll.example").is_err());
assert_eq!(to_unicode("xn--bcher-kva.de").unwrap(), "bücher.de");
}
#[test]
fn punycode_decode_rejects_over_long_output() {
let bomb = alloc::format!("td{}", "a".repeat(64));
assert!(punycode_decode(&bomb).is_none());
assert_eq!(
punycode_decode("bcher-kva"),
Some(alloc::vec!['b', 'ü', 'c', 'h', 'e', 'r'])
);
}
#[test]
fn punycode_decode_caps_output_length() {
for payload in ["wgv71a119e", "bcher-kva", "fa-hia"] {
if let Some(decoded) = punycode_decode(payload) {
assert!(decoded.len() <= MAX_LABEL_CODE_POINTS);
}
}
}
#[test]
fn to_ascii_rejects_over_long_domain() {
let label = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; let mut domain = String::new();
for i in 0..10 {
if i > 0 {
domain.push('.');
}
domain.push_str(label);
}
assert!(domain.len() > 253);
assert_eq!(to_ascii(&domain), Err(Error::InvalidLabel));
}
#[test]
fn to_ascii_accepts_domain_at_limit() {
let label = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; let mut domain = String::new();
for i in 0..8 {
if i > 0 {
domain.push('.');
}
domain.push_str(label);
}
assert!(domain.len() <= 253);
assert!(to_ascii(&domain).is_ok());
}
}