Skip to main content

ldap_client_proto/
url.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! RFC 4516 LDAP URL parser.
4
5use std::fmt;
6
7use crate::ProtoError;
8use crate::message::SearchScope;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct LdapUrl {
12    pub scheme: LdapScheme,
13    pub host: String,
14    pub port: Option<u16>,
15    pub base_dn: Option<String>,
16    pub attributes: Vec<String>,
17    pub scope: Option<SearchScope>,
18    pub filter: Option<String>,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum LdapScheme {
23    Ldap,
24    Ldaps,
25}
26
27impl LdapUrl {
28    pub fn parse(input: &str) -> Result<Self, ProtoError> {
29        let (scheme, rest) = if let Some(r) = input.strip_prefix("ldaps://") {
30            (LdapScheme::Ldaps, r)
31        } else if let Some(r) = input.strip_prefix("ldap://") {
32            (LdapScheme::Ldap, r)
33        } else {
34            return Err(ProtoError::Protocol(
35                "LDAP URL must start with ldap:// or ldaps://".into(),
36            ));
37        };
38
39        // Split host[:port] from the path at the first '/'
40        let (hostport, path) = match rest.find('/') {
41            Some(pos) => (&rest[..pos], &rest[pos + 1..]),
42            None => (rest, ""),
43        };
44
45        // Parse host and port
46        let (host, port) = if hostport.starts_with('[') {
47            // IPv6
48            let bracket_end = hostport
49                .find(']')
50                .ok_or_else(|| ProtoError::Protocol("unterminated IPv6 address".into()))?;
51            let h = &hostport[1..bracket_end];
52            let after = &hostport[bracket_end + 1..];
53            let p = if let Some(port_str) = after.strip_prefix(':') {
54                Some(
55                    port_str
56                        .parse::<u16>()
57                        .map_err(|e| ProtoError::Protocol(format!("invalid port: {e}")))?,
58                )
59            } else {
60                None
61            };
62            (h.to_string(), p)
63        } else {
64            match hostport.rsplit_once(':') {
65                Some((h, p)) => {
66                    let port = p
67                        .parse::<u16>()
68                        .map_err(|e| ProtoError::Protocol(format!("invalid port: {e}")))?;
69                    (h.to_string(), Some(port))
70                }
71                None => (hostport.to_string(), None),
72            }
73        };
74
75        if host.is_empty() {
76            return Err(ProtoError::Protocol("missing host in LDAP URL".into()));
77        }
78
79        // Parse path components: base_dn?attributes?scope?filter
80        let parts: Vec<&str> = path.splitn(4, '?').collect();
81
82        let base_dn = parts
83            .first()
84            .filter(|s| !s.is_empty())
85            .map(|s| percent_decode(s));
86
87        let attributes = parts
88            .get(1)
89            .filter(|s| !s.is_empty())
90            .map(|s| {
91                s.split(',')
92                    .filter(|a| !a.is_empty())
93                    .map(percent_decode)
94                    .collect()
95            })
96            .unwrap_or_default();
97
98        let scope = parts
99            .get(2)
100            .filter(|s| !s.is_empty())
101            .map(|s| parse_scope(s))
102            .transpose()?;
103
104        let filter = parts
105            .get(3)
106            .filter(|s| !s.is_empty())
107            .map(|s| percent_decode(s));
108
109        Ok(LdapUrl {
110            scheme,
111            host,
112            port,
113            base_dn,
114            attributes,
115            scope,
116            filter,
117        })
118    }
119
120    pub fn effective_port(&self) -> u16 {
121        self.port.unwrap_or(match self.scheme {
122            LdapScheme::Ldap => 389,
123            LdapScheme::Ldaps => 636,
124        })
125    }
126}
127
128fn parse_scope(s: &str) -> Result<SearchScope, ProtoError> {
129    match s.to_ascii_lowercase().as_str() {
130        "base" => Ok(SearchScope::BaseObject),
131        "one" => Ok(SearchScope::SingleLevel),
132        "sub" => Ok(SearchScope::WholeSubtree),
133        _ => Err(ProtoError::Protocol(format!("unknown scope: {s}"))),
134    }
135}
136
137fn percent_decode(s: &str) -> String {
138    let bytes = s.as_bytes();
139    let mut out = Vec::with_capacity(bytes.len());
140    let mut i = 0;
141    while i < bytes.len() {
142        if bytes[i] == b'%'
143            && i + 2 < bytes.len()
144            && let Ok(byte) =
145                u8::from_str_radix(std::str::from_utf8(&bytes[i + 1..i + 3]).unwrap_or(""), 16)
146        {
147            out.push(byte);
148            i += 3;
149            continue;
150        }
151        out.push(bytes[i]);
152        i += 1;
153    }
154    String::from_utf8_lossy(&out).into_owned()
155}
156
157fn percent_encode(s: &str) -> String {
158    use std::fmt::Write;
159    let mut out = String::with_capacity(s.len());
160    for &b in s.as_bytes() {
161        if b.is_ascii_alphanumeric() || matches!(b, b'-' | b'.' | b'_' | b'~' | b'=' | b',') {
162            out.push(b as char);
163        } else {
164            let _ = write!(out, "%{b:02X}");
165        }
166    }
167    out
168}
169
170impl fmt::Display for LdapUrl {
171    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172        let scheme = match self.scheme {
173            LdapScheme::Ldap => "ldap",
174            LdapScheme::Ldaps => "ldaps",
175        };
176        write!(f, "{scheme}://")?;
177
178        if self.host.contains(':') {
179            write!(f, "[{}]", self.host)?;
180        } else {
181            write!(f, "{}", self.host)?;
182        }
183
184        if let Some(port) = self.port {
185            write!(f, ":{port}")?;
186        }
187
188        write!(f, "/")?;
189
190        if let Some(dn) = &self.base_dn {
191            write!(f, "{}", percent_encode(dn))?;
192        }
193
194        // Only print subsequent fields if there's something to show
195        let has_attrs = !self.attributes.is_empty();
196        let has_scope = self.scope.is_some();
197        let has_filter = self.filter.is_some();
198
199        if has_attrs || has_scope || has_filter {
200            write!(f, "?")?;
201            if has_attrs {
202                let attrs: Vec<String> =
203                    self.attributes.iter().map(|a| percent_encode(a)).collect();
204                write!(f, "{}", attrs.join(","))?;
205            }
206        }
207
208        if has_scope || has_filter {
209            write!(f, "?")?;
210            if let Some(scope) = &self.scope {
211                let scope_str = match scope {
212                    SearchScope::BaseObject => "base",
213                    SearchScope::SingleLevel => "one",
214                    SearchScope::WholeSubtree => "sub",
215                };
216                write!(f, "{scope_str}")?;
217            }
218        }
219
220        if has_filter {
221            write!(f, "?")?;
222            if let Some(filter) = &self.filter {
223                write!(f, "{}", percent_encode(filter))?;
224            }
225        }
226
227        Ok(())
228    }
229}
230
231impl std::str::FromStr for LdapUrl {
232    type Err = ProtoError;
233    fn from_str(s: &str) -> Result<Self, Self::Err> {
234        Self::parse(s)
235    }
236}