1use std::fmt;
6
7use crate::ProtoError;
8use crate::message::SearchScope;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub struct LdapUrl {
12 pub scheme: LdapScheme,
13 pub host: String,
14 pub port: Option<u16>,
15 pub base_dn: Option<String>,
16 pub attributes: Vec<String>,
17 pub scope: Option<SearchScope>,
18 pub filter: Option<String>,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum LdapScheme {
23 Ldap,
24 Ldaps,
25}
26
27impl LdapUrl {
28 pub fn parse(input: &str) -> Result<Self, ProtoError> {
29 let (scheme, rest) = if let Some(r) = input.strip_prefix("ldaps://") {
30 (LdapScheme::Ldaps, r)
31 } else if let Some(r) = input.strip_prefix("ldap://") {
32 (LdapScheme::Ldap, r)
33 } else {
34 return Err(ProtoError::Protocol(
35 "LDAP URL must start with ldap:// or ldaps://".into(),
36 ));
37 };
38
39 let (hostport, path) = match rest.find('/') {
41 Some(pos) => (&rest[..pos], &rest[pos + 1..]),
42 None => (rest, ""),
43 };
44
45 let (host, port) = if hostport.starts_with('[') {
47 let bracket_end = hostport
49 .find(']')
50 .ok_or_else(|| ProtoError::Protocol("unterminated IPv6 address".into()))?;
51 let h = &hostport[1..bracket_end];
52 let after = &hostport[bracket_end + 1..];
53 let p = if let Some(port_str) = after.strip_prefix(':') {
54 Some(
55 port_str
56 .parse::<u16>()
57 .map_err(|e| ProtoError::Protocol(format!("invalid port: {e}")))?,
58 )
59 } else {
60 None
61 };
62 (h.to_string(), p)
63 } else {
64 match hostport.rsplit_once(':') {
65 Some((h, p)) => {
66 let port = p
67 .parse::<u16>()
68 .map_err(|e| ProtoError::Protocol(format!("invalid port: {e}")))?;
69 (h.to_string(), Some(port))
70 }
71 None => (hostport.to_string(), None),
72 }
73 };
74
75 if host.is_empty() {
76 return Err(ProtoError::Protocol("missing host in LDAP URL".into()));
77 }
78
79 let parts: Vec<&str> = path.splitn(4, '?').collect();
81
82 let base_dn = parts
83 .first()
84 .filter(|s| !s.is_empty())
85 .map(|s| percent_decode(s));
86
87 let attributes = parts
88 .get(1)
89 .filter(|s| !s.is_empty())
90 .map(|s| {
91 s.split(',')
92 .filter(|a| !a.is_empty())
93 .map(percent_decode)
94 .collect()
95 })
96 .unwrap_or_default();
97
98 let scope = parts
99 .get(2)
100 .filter(|s| !s.is_empty())
101 .map(|s| parse_scope(s))
102 .transpose()?;
103
104 let filter = parts
105 .get(3)
106 .filter(|s| !s.is_empty())
107 .map(|s| percent_decode(s));
108
109 Ok(LdapUrl {
110 scheme,
111 host,
112 port,
113 base_dn,
114 attributes,
115 scope,
116 filter,
117 })
118 }
119
120 pub fn effective_port(&self) -> u16 {
121 self.port.unwrap_or(match self.scheme {
122 LdapScheme::Ldap => 389,
123 LdapScheme::Ldaps => 636,
124 })
125 }
126}
127
128fn parse_scope(s: &str) -> Result<SearchScope, ProtoError> {
129 match s.to_ascii_lowercase().as_str() {
130 "base" => Ok(SearchScope::BaseObject),
131 "one" => Ok(SearchScope::SingleLevel),
132 "sub" => Ok(SearchScope::WholeSubtree),
133 _ => Err(ProtoError::Protocol(format!("unknown scope: {s}"))),
134 }
135}
136
137fn percent_decode(s: &str) -> String {
138 let bytes = s.as_bytes();
139 let mut out = Vec::with_capacity(bytes.len());
140 let mut i = 0;
141 while i < bytes.len() {
142 if bytes[i] == b'%'
143 && i + 2 < bytes.len()
144 && let Ok(byte) =
145 u8::from_str_radix(std::str::from_utf8(&bytes[i + 1..i + 3]).unwrap_or(""), 16)
146 {
147 out.push(byte);
148 i += 3;
149 continue;
150 }
151 out.push(bytes[i]);
152 i += 1;
153 }
154 String::from_utf8_lossy(&out).into_owned()
155}
156
157fn percent_encode(s: &str) -> String {
158 use std::fmt::Write;
159 let mut out = String::with_capacity(s.len());
160 for &b in s.as_bytes() {
161 if b.is_ascii_alphanumeric() || matches!(b, b'-' | b'.' | b'_' | b'~' | b'=' | b',') {
162 out.push(b as char);
163 } else {
164 let _ = write!(out, "%{b:02X}");
165 }
166 }
167 out
168}
169
170impl fmt::Display for LdapUrl {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 let scheme = match self.scheme {
173 LdapScheme::Ldap => "ldap",
174 LdapScheme::Ldaps => "ldaps",
175 };
176 write!(f, "{scheme}://")?;
177
178 if self.host.contains(':') {
179 write!(f, "[{}]", self.host)?;
180 } else {
181 write!(f, "{}", self.host)?;
182 }
183
184 if let Some(port) = self.port {
185 write!(f, ":{port}")?;
186 }
187
188 write!(f, "/")?;
189
190 if let Some(dn) = &self.base_dn {
191 write!(f, "{}", percent_encode(dn))?;
192 }
193
194 let has_attrs = !self.attributes.is_empty();
196 let has_scope = self.scope.is_some();
197 let has_filter = self.filter.is_some();
198
199 if has_attrs || has_scope || has_filter {
200 write!(f, "?")?;
201 if has_attrs {
202 let attrs: Vec<String> =
203 self.attributes.iter().map(|a| percent_encode(a)).collect();
204 write!(f, "{}", attrs.join(","))?;
205 }
206 }
207
208 if has_scope || has_filter {
209 write!(f, "?")?;
210 if let Some(scope) = &self.scope {
211 let scope_str = match scope {
212 SearchScope::BaseObject => "base",
213 SearchScope::SingleLevel => "one",
214 SearchScope::WholeSubtree => "sub",
215 };
216 write!(f, "{scope_str}")?;
217 }
218 }
219
220 if has_filter {
221 write!(f, "?")?;
222 if let Some(filter) = &self.filter {
223 write!(f, "{}", percent_encode(filter))?;
224 }
225 }
226
227 Ok(())
228 }
229}
230
231impl std::str::FromStr for LdapUrl {
232 type Err = ProtoError;
233 fn from_str(s: &str) -> Result<Self, Self::Err> {
234 Self::parse(s)
235 }
236}