use std::borrow::Cow;
use std::collections::HashSet;
use std::hash::{Hash, Hasher};
use crate::filter::Unescaper;
use crate::result::{LdapError, Result};
use crate::search::Scope;
use percent_encoding::percent_decode_str;
use url::Url;
pub fn ldap_escape<'a, S: Into<Cow<'a, str>>>(lit: S) -> Cow<'a, str> {
#[inline]
fn needs_escape(c: u8) -> bool {
c == b'\\' || c == b'*' || c == b'(' || c == b')' || c == 0
}
#[inline]
fn xdigit(c: u8) -> u8 {
c + if c < 10 { b'0' } else { b'a' - 10 }
}
let lit = lit.into();
let mut output = None;
for (i, &c) in lit.as_bytes().iter().enumerate() {
if needs_escape(c) {
if output.is_none() {
output = Some(Vec::with_capacity(lit.len() + 12));
output.as_mut().unwrap().extend(lit[..i].as_bytes());
}
let output = output.as_mut().unwrap();
output.push(b'\\');
output.push(xdigit(c >> 4));
output.push(xdigit(c & 0xF));
} else if let Some(ref mut output) = output {
output.push(c);
}
}
if let Some(output) = output {
Cow::Owned(String::from_utf8(output).expect("ldap escaped"))
} else {
lit
}
}
pub fn dn_escape<'a, S: Into<Cow<'a, str>>>(val: S) -> Cow<'a, str> {
#[inline]
fn always_escape(c: u8) -> bool {
c == b'"'
|| c == b'+'
|| c == b','
|| c == b';'
|| c == b'<'
|| c == b'='
|| c == b'>'
|| c == b'\\'
|| c == 0
}
#[inline]
fn escape_leading(c: u8) -> bool {
c == b' ' || c == b'#'
}
#[inline]
fn escape_trailing(c: u8) -> bool {
c == b' '
}
#[inline]
fn xdigit(c: u8) -> u8 {
c + if c < 10 { b'0' } else { b'a' - 10 }
}
let val = val.into();
let mut output = None;
for (i, &c) in val.as_bytes().iter().enumerate() {
if always_escape(c)
|| i == 0 && escape_leading(c)
|| i + 1 == val.len() && escape_trailing(c)
{
if output.is_none() {
output = Some(Vec::with_capacity(val.len() + 12));
output.as_mut().unwrap().extend(val[..i].as_bytes());
}
let output = output.as_mut().unwrap();
output.push(b'\\');
output.push(xdigit(c >> 4));
output.push(xdigit(c & 0xF));
} else if let Some(ref mut output) = output {
output.push(c);
}
}
if let Some(output) = output {
Cow::Owned(String::from_utf8(output).expect("dn escaped"))
} else {
val
}
}
#[derive(Clone, Debug)]
pub enum LdapUrlExt<'a> {
Bindname(Cow<'a, str>),
XBindpw(Cow<'a, str>),
Credentials(Cow<'a, str>),
SaslMech(Cow<'a, str>),
StartTLS,
Unknown(Cow<'a, str>),
}
impl<'a> PartialEq for LdapUrlExt<'a> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(LdapUrlExt::Bindname(_), LdapUrlExt::Bindname(_)) => true,
(LdapUrlExt::XBindpw(_), LdapUrlExt::XBindpw(_)) => true,
(LdapUrlExt::Credentials(_), LdapUrlExt::Credentials(_)) => true,
(LdapUrlExt::SaslMech(_), LdapUrlExt::SaslMech(_)) => true,
(LdapUrlExt::StartTLS, LdapUrlExt::StartTLS) => true,
(LdapUrlExt::Unknown(_), LdapUrlExt::Unknown(_)) => true,
_ => false,
}
}
}
impl<'a> Eq for LdapUrlExt<'a> {}
impl<'a> Hash for LdapUrlExt<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
LdapUrlExt::Bindname(_) => "Bindname".hash(state),
LdapUrlExt::XBindpw(_) => "XBindpw".hash(state),
LdapUrlExt::Credentials(_) => "Credentials".hash(state),
LdapUrlExt::SaslMech(_) => "SaslMech".hash(state),
LdapUrlExt::StartTLS => "StartTLS".hash(state),
LdapUrlExt::Unknown(_) => "Unknown".hash(state),
}
}
}
#[derive(Clone, Debug)]
pub struct LdapUrlParams<'a> {
pub base: Cow<'a, str>,
pub attrs: Vec<&'a str>,
pub scope: Scope,
pub filter: Cow<'a, str>,
pub extensions: HashSet<LdapUrlExt<'a>>,
}
#[inline]
fn ascii_lc_equal(s: &str, t: &str) -> bool {
if s.len() != t.len() {
return false;
}
s.as_bytes()
.iter()
.zip(t.as_bytes().iter().map(u8::to_ascii_lowercase))
.all(|(&s, t)| s == t)
}
pub fn get_url_params(url: &Url) -> Result<LdapUrlParams<'_>> {
let mut base = url.path();
if base.chars().next().unwrap_or('\0') == '/' {
base = &base[1..];
}
let base = percent_decode_str(base)
.decode_utf8()
.map_err(|_| LdapError::DecodingUTF8)?;
let mut query = url.query().unwrap_or("").splitn(4, '?');
let attrs = match query.next() {
Some("") | None => vec!["*"],
Some(alist) => alist.split(',').collect(),
};
let scope = match query.next() {
Some("") | None => Scope::Subtree,
Some(scope_str) => match scope_str {
"base" => Scope::Base,
"one" => Scope::OneLevel,
"sub" => Scope::Subtree,
any => return Err(LdapError::InvalidScopeString(any.into())),
},
};
let filter = match query.next() {
Some("") | None => "(objectClass=*)",
Some(filter) => filter,
};
let filter = percent_decode_str(filter)
.decode_utf8()
.map_err(|_| LdapError::DecodingUTF8)?;
let extensions = match query.next() {
Some("") | None => HashSet::new(),
Some(exts) => {
let mut ext_set = HashSet::new();
for ext in exts.split(',') {
let (crit, id, val) = {
let mut crit = false;
let mut idv = ext.splitn(2, '=');
let mut id = idv.next().unwrap_or("");
if id != "" && &id[..1] == "!" {
id = &id[1..];
crit = true;
}
let val = idv.next();
(
crit,
id,
percent_decode_str(val.unwrap_or(""))
.decode_utf8()
.map_err(|_| LdapError::DecodingUTF8)?,
)
};
let ext = match id {
"1.3.6.1.4.1.10094.1.5.1" => LdapUrlExt::Credentials(val),
"1.3.6.1.4.1.10094.1.5.2" => LdapUrlExt::SaslMech(val),
"1.3.6.1.4.1.1466.20037" => LdapUrlExt::StartTLS,
ext => {
if ascii_lc_equal("bindname", ext) {
LdapUrlExt::Bindname(val)
} else if ascii_lc_equal("x-bindpw", ext) {
LdapUrlExt::XBindpw(val)
} else if crit {
return Err(LdapError::UnrecognizedCriticalExtension(format!(
"{:?}",
LdapUrlExt::Unknown(ext.into())
)));
} else {
LdapUrlExt::Unknown("".into())
}
}
};
if ext != LdapUrlExt::Unknown("".into()) {
ext_set.insert(ext);
}
}
ext_set
}
};
Ok(LdapUrlParams {
base,
attrs,
scope,
filter,
extensions,
})
}
pub fn ldap_str_unescape<'a, S: Into<Cow<'a, str>>>(val: S) -> Result<Cow<'a, str>> {
let val = val.into();
let mut output = None;
let mut esc = Unescaper::Value(0);
for (i, &c) in val.as_bytes().iter().enumerate() {
esc = esc.feed(c);
match esc {
Unescaper::WantFirst => {
if output.is_none() {
output = Some(Vec::with_capacity(val.len() + 12));
output.as_mut().unwrap().extend(val[..i].as_bytes());
}
}
Unescaper::Value(c) => {
if output.is_some() {
output.as_mut().unwrap().push(c);
}
}
_ => (),
}
}
if output.is_some() {
if let Unescaper::Value(_) = esc {
Ok(Cow::Owned(
String::from_utf8(output.unwrap()).map_err(|_| LdapError::DecodingUTF8)?,
))
} else {
Err(LdapError::DecodingUTF8)
}
} else {
Ok(val)
}
}
#[cfg(test)]
mod test {
use super::dn_escape;
#[test]
fn dn_esc_leading_space() {
assert_eq!(dn_escape(" foo"), "\\20foo");
}
#[test]
fn dn_esc_trailing_space() {
assert_eq!(dn_escape("foo "), "foo\\20");
}
#[test]
fn dn_esc_inner_space() {
assert_eq!(dn_escape("f o o"), "f o o");
}
#[test]
fn dn_esc_single_space() {
assert_eq!(dn_escape(" "), "\\20");
}
#[test]
fn dn_esc_two_spaces() {
assert_eq!(dn_escape(" "), "\\20\\20");
}
#[test]
fn dn_esc_three_spaces() {
assert_eq!(dn_escape(" "), "\\20 \\20");
}
#[test]
fn dn_esc_leading_hash() {
assert_eq!(dn_escape("#rust"), "\\23rust");
}
}