1use std::fmt;
6
7use crate::ProtoError;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct Dn {
11 pub rdns: Vec<Rdn>,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct Rdn {
16 pub components: Vec<(String, String)>,
17}
18
19impl Dn {
20 pub fn parse(input: &str) -> Result<Self, ProtoError> {
21 let input = input.trim();
22 if input.is_empty() {
23 return Ok(Dn { rdns: Vec::new() });
24 }
25
26 let mut rdns = Vec::new();
27 let mut remaining = input;
28 loop {
29 let (rdn, rest) = parse_rdn(remaining)?;
30 rdns.push(rdn);
31 if rest.is_empty() {
32 break;
33 }
34 if let Some(r) = rest.strip_prefix(',') {
35 remaining = r;
36 } else {
37 return Err(ProtoError::Protocol(format!(
38 "expected ',' or end of DN, got {:?}",
39 &rest[..rest.len().min(10)]
40 )));
41 }
42 }
43 Ok(Dn { rdns })
44 }
45
46 pub fn is_empty(&self) -> bool {
47 self.rdns.is_empty()
48 }
49}
50
51fn parse_rdn(input: &str) -> Result<(Rdn, &str), ProtoError> {
52 let mut components = Vec::new();
53 let mut remaining = input;
54 loop {
55 let (attr, value, rest) = parse_ava(remaining)?;
56 components.push((attr, value));
57 if let Some(r) = rest.strip_prefix('+') {
58 remaining = r;
59 } else {
60 return Ok((Rdn { components }, rest));
61 }
62 }
63}
64
65fn parse_ava(input: &str) -> Result<(String, String, &str), ProtoError> {
66 let limit = find_unescaped_separator(input);
68 let eq_pos = input[..limit]
69 .find('=')
70 .ok_or_else(|| ProtoError::Protocol("expected '=' in attribute value assertion".into()))?;
71 let attr = input[..eq_pos].trim().to_string();
72 if attr.is_empty() {
73 return Err(ProtoError::Protocol("empty attribute type".into()));
74 }
75 let rest = &input[eq_pos + 1..];
76
77 if let Some(hex_rest) = rest.strip_prefix('#') {
78 let end = hex_rest.find([',', '+']).unwrap_or(hex_rest.len());
80 let hex = &hex_rest[..end];
81 if hex.is_empty() || hex.len() % 2 != 0 || !hex.bytes().all(|b| b.is_ascii_hexdigit()) {
82 return Err(ProtoError::Protocol(
83 "invalid hex-string in DN value: expected even number of hex digits after '#'"
84 .into(),
85 ));
86 }
87 let value = format!("#{hex}");
88 Ok((attr, value, &hex_rest[end..]))
89 } else if let Some(after_quote) = rest.strip_prefix('"') {
90 let end = after_quote
92 .find('"')
93 .ok_or_else(|| ProtoError::Protocol("unterminated quoted string in DN".into()))?;
94 let value = after_quote[..end].to_string();
95 Ok((attr, value, &after_quote[end + 1..]))
96 } else {
97 let (value, rest) = parse_dn_value(rest)?;
98 Ok((attr, value, rest))
99 }
100}
101
102fn find_unescaped_separator(input: &str) -> usize {
103 let bytes = input.as_bytes();
104 let mut i = 0;
105 while i < bytes.len() {
106 match bytes[i] {
107 b',' | b'+' => return i,
108 b'\\' => {
109 i += 1;
110 if i + 1 < bytes.len()
111 && bytes[i].is_ascii_hexdigit()
112 && bytes[i + 1].is_ascii_hexdigit()
113 {
114 i += 2;
115 } else if i < bytes.len() {
116 let ch = input[i..].chars().next().unwrap();
118 i += ch.len_utf8();
119 }
120 }
121 _ => i += 1,
122 }
123 }
124 bytes.len()
125}
126
127fn parse_dn_value(input: &str) -> Result<(String, &str), ProtoError> {
128 let mut out = String::new();
129 let bytes = input.as_bytes();
130 let mut i = 0;
131 let mut last_non_trimmable = 0;
134
135 while i < bytes.len() {
136 match bytes[i] {
137 b',' | b'+' => break,
138 b'\\' => {
139 i += 1;
140 if i >= bytes.len() {
141 break;
142 }
143 if i + 1 < bytes.len()
145 && bytes[i].is_ascii_hexdigit()
146 && bytes[i + 1].is_ascii_hexdigit()
147 && let Ok(byte) =
148 u8::from_str_radix(std::str::from_utf8(&bytes[i..i + 2]).unwrap_or(""), 16)
149 {
150 let mut raw = vec![byte];
152 i += 2;
153 while i + 2 < bytes.len()
154 && bytes[i] == b'\\'
155 && bytes[i + 1].is_ascii_hexdigit()
156 && bytes[i + 2].is_ascii_hexdigit()
157 {
158 if let Ok(b) = u8::from_str_radix(
159 std::str::from_utf8(&bytes[i + 1..i + 3]).unwrap_or(""),
160 16,
161 ) {
162 if b & 0xC0 != 0x80 {
164 break;
165 }
166 raw.push(b);
167 i += 3;
168 } else {
169 break;
170 }
171 }
172 let decoded = String::from_utf8(raw).map_err(|e| {
173 ProtoError::Protocol(format!("invalid UTF-8 in DN value: {e}"))
174 })?;
175 out.push_str(&decoded);
176 last_non_trimmable = out.len();
177 continue;
178 }
179 out.push(bytes[i] as char);
181 last_non_trimmable = out.len();
182 i += 1;
183 }
184 _ => {
185 let ch = input[i..].chars().next().unwrap();
187 out.push(ch);
188 if ch != ' ' {
189 last_non_trimmable = out.len();
190 }
191 i += ch.len_utf8();
192 }
193 }
194 }
195
196 out.truncate(last_non_trimmable);
198 Ok((out, &input[i..]))
199}
200
201pub fn escape_dn_value(value: &str) -> String {
203 let mut out = String::with_capacity(value.len());
204 let mut chars = value.chars().peekable();
205 let mut first = true;
206
207 while let Some(ch) = chars.next() {
208 let is_last = chars.peek().is_none();
209 let needs_escape = match ch {
210 '"' | '+' | ',' | ';' | '<' | '>' | '\\' => true,
211 '#' if first => true,
212 ' ' if first || is_last => true,
213 '\0' => true,
214 _ => false,
215 };
216 if needs_escape {
217 if ch == '\0' {
218 out.push_str("\\00");
219 } else {
220 out.push('\\');
221 out.push(ch);
222 }
223 } else {
224 out.push(ch);
225 }
226 first = false;
227 }
228 out
229}
230
231impl fmt::Display for Dn {
232 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233 for (i, rdn) in self.rdns.iter().enumerate() {
234 if i > 0 {
235 f.write_str(",")?;
236 }
237 write!(f, "{rdn}")?;
238 }
239 Ok(())
240 }
241}
242
243impl fmt::Display for Rdn {
244 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245 for (i, (attr, value)) in self.components.iter().enumerate() {
246 if i > 0 {
247 f.write_str("+")?;
248 }
249 write!(f, "{}={}", attr, escape_dn_value(value))?;
250 }
251 Ok(())
252 }
253}
254
255impl std::str::FromStr for Dn {
256 type Err = ProtoError;
257 fn from_str(s: &str) -> Result<Self, Self::Err> {
258 Self::parse(s)
259 }
260}