Skip to main content

ldap_client_proto/
dn.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! RFC 4514 Distinguished Name parser.
4
5use std::fmt;
6
7use crate::ProtoError;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct Dn {
11    pub rdns: Vec<Rdn>,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct Rdn {
16    pub components: Vec<(String, String)>,
17}
18
19impl Dn {
20    pub fn parse(input: &str) -> Result<Self, ProtoError> {
21        let input = input.trim();
22        if input.is_empty() {
23            return Ok(Dn { rdns: Vec::new() });
24        }
25
26        let mut rdns = Vec::new();
27        let mut remaining = input;
28        loop {
29            let (rdn, rest) = parse_rdn(remaining)?;
30            rdns.push(rdn);
31            if rest.is_empty() {
32                break;
33            }
34            if let Some(r) = rest.strip_prefix(',') {
35                remaining = r;
36            } else {
37                return Err(ProtoError::Protocol(format!(
38                    "expected ',' or end of DN, got {:?}",
39                    &rest[..rest.len().min(10)]
40                )));
41            }
42        }
43        Ok(Dn { rdns })
44    }
45
46    pub fn is_empty(&self) -> bool {
47        self.rdns.is_empty()
48    }
49}
50
51fn parse_rdn(input: &str) -> Result<(Rdn, &str), ProtoError> {
52    let mut components = Vec::new();
53    let mut remaining = input;
54    loop {
55        let (attr, value, rest) = parse_ava(remaining)?;
56        components.push((attr, value));
57        if let Some(r) = rest.strip_prefix('+') {
58            remaining = r;
59        } else {
60            return Ok((Rdn { components }, rest));
61        }
62    }
63}
64
65fn parse_ava(input: &str) -> Result<(String, String, &str), ProtoError> {
66    // Only search for '=' before any unescaped ',' or '+' (those are RDN/AVA separators).
67    let limit = find_unescaped_separator(input);
68    let eq_pos = input[..limit]
69        .find('=')
70        .ok_or_else(|| ProtoError::Protocol("expected '=' in attribute value assertion".into()))?;
71    let attr = input[..eq_pos].trim().to_string();
72    if attr.is_empty() {
73        return Err(ProtoError::Protocol("empty attribute type".into()));
74    }
75    let rest = &input[eq_pos + 1..];
76
77    if let Some(hex_rest) = rest.strip_prefix('#') {
78        // Hex-encoded BER value (RFC 4514 §2.4)
79        let end = hex_rest.find([',', '+']).unwrap_or(hex_rest.len());
80        let hex = &hex_rest[..end];
81        if hex.is_empty() || hex.len() % 2 != 0 || !hex.bytes().all(|b| b.is_ascii_hexdigit()) {
82            return Err(ProtoError::Protocol(
83                "invalid hex-string in DN value: expected even number of hex digits after '#'"
84                    .into(),
85            ));
86        }
87        let value = format!("#{hex}");
88        Ok((attr, value, &hex_rest[end..]))
89    } else if let Some(after_quote) = rest.strip_prefix('"') {
90        // Quoted string (legacy, but we should parse it)
91        let end = after_quote
92            .find('"')
93            .ok_or_else(|| ProtoError::Protocol("unterminated quoted string in DN".into()))?;
94        let value = after_quote[..end].to_string();
95        Ok((attr, value, &after_quote[end + 1..]))
96    } else {
97        let (value, rest) = parse_dn_value(rest)?;
98        Ok((attr, value, rest))
99    }
100}
101
102fn find_unescaped_separator(input: &str) -> usize {
103    let bytes = input.as_bytes();
104    let mut i = 0;
105    while i < bytes.len() {
106        match bytes[i] {
107            b',' | b'+' => return i,
108            b'\\' => {
109                i += 1;
110                if i + 1 < bytes.len()
111                    && bytes[i].is_ascii_hexdigit()
112                    && bytes[i + 1].is_ascii_hexdigit()
113                {
114                    i += 2;
115                } else if i < bytes.len() {
116                    // Skip the full escaped character (multi-byte safe).
117                    let ch = input[i..].chars().next().unwrap();
118                    i += ch.len_utf8();
119                }
120            }
121            _ => i += 1,
122        }
123    }
124    bytes.len()
125}
126
127fn parse_dn_value(input: &str) -> Result<(String, &str), ProtoError> {
128    let mut out = String::new();
129    let bytes = input.as_bytes();
130    let mut i = 0;
131    // Track position of last non-space or escaped character for trailing-space trim.
132    // RFC 4514 §3: trailing spaces are trimmed only when unescaped.
133    let mut last_non_trimmable = 0;
134
135    while i < bytes.len() {
136        match bytes[i] {
137            b',' | b'+' => break,
138            b'\\' => {
139                i += 1;
140                if i >= bytes.len() {
141                    break;
142                }
143                // Hex pair?
144                if i + 1 < bytes.len()
145                    && bytes[i].is_ascii_hexdigit()
146                    && bytes[i + 1].is_ascii_hexdigit()
147                    && let Ok(byte) =
148                        u8::from_str_radix(std::str::from_utf8(&bytes[i..i + 2]).unwrap_or(""), 16)
149                {
150                    // Accumulate raw bytes for multi-byte UTF-8
151                    let mut raw = vec![byte];
152                    i += 2;
153                    while i + 2 < bytes.len()
154                        && bytes[i] == b'\\'
155                        && bytes[i + 1].is_ascii_hexdigit()
156                        && bytes[i + 2].is_ascii_hexdigit()
157                    {
158                        if let Ok(b) = u8::from_str_radix(
159                            std::str::from_utf8(&bytes[i + 1..i + 3]).unwrap_or(""),
160                            16,
161                        ) {
162                            // Only continue if this looks like a continuation byte
163                            if b & 0xC0 != 0x80 {
164                                break;
165                            }
166                            raw.push(b);
167                            i += 3;
168                        } else {
169                            break;
170                        }
171                    }
172                    let decoded = String::from_utf8(raw).map_err(|e| {
173                        ProtoError::Protocol(format!("invalid UTF-8 in DN value: {e}"))
174                    })?;
175                    out.push_str(&decoded);
176                    last_non_trimmable = out.len();
177                    continue;
178                }
179                // Escaped special character (always ASCII per RFC 4514)
180                out.push(bytes[i] as char);
181                last_non_trimmable = out.len();
182                i += 1;
183            }
184            _ => {
185                // Decode full UTF-8 character to handle multi-byte sequences.
186                let ch = input[i..].chars().next().unwrap();
187                out.push(ch);
188                if ch != ' ' {
189                    last_non_trimmable = out.len();
190                }
191                i += ch.len_utf8();
192            }
193        }
194    }
195
196    // Trim unescaped trailing spaces (per RFC 4514 §3).
197    out.truncate(last_non_trimmable);
198    Ok((out, &input[i..]))
199}
200
201/// Escape a DN value per RFC 4514 §2.4.
202pub fn escape_dn_value(value: &str) -> String {
203    let mut out = String::with_capacity(value.len());
204    let mut chars = value.chars().peekable();
205    let mut first = true;
206
207    while let Some(ch) = chars.next() {
208        let is_last = chars.peek().is_none();
209        let needs_escape = match ch {
210            '"' | '+' | ',' | ';' | '<' | '>' | '\\' => true,
211            '#' if first => true,
212            ' ' if first || is_last => true,
213            '\0' => true,
214            _ => false,
215        };
216        if needs_escape {
217            if ch == '\0' {
218                out.push_str("\\00");
219            } else {
220                out.push('\\');
221                out.push(ch);
222            }
223        } else {
224            out.push(ch);
225        }
226        first = false;
227    }
228    out
229}
230
231impl fmt::Display for Dn {
232    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233        for (i, rdn) in self.rdns.iter().enumerate() {
234            if i > 0 {
235                f.write_str(",")?;
236            }
237            write!(f, "{rdn}")?;
238        }
239        Ok(())
240    }
241}
242
243impl fmt::Display for Rdn {
244    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245        for (i, (attr, value)) in self.components.iter().enumerate() {
246            if i > 0 {
247                f.write_str("+")?;
248            }
249            write!(f, "{}={}", attr, escape_dn_value(value))?;
250        }
251        Ok(())
252    }
253}
254
255impl std::str::FromStr for Dn {
256    type Err = ProtoError;
257    fn from_str(s: &str) -> Result<Self, Self::Err> {
258        Self::parse(s)
259    }
260}