use std::collections::HashSet;
use std::hash::Hash;
#[cfg(feature = "gssapi")]
use std::sync::RwLock;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::adapters::{EntriesOnly, IntoAdapterVec};
use crate::controls_impl::IntoRawControlVec;
use crate::exop::Exop;
use crate::exop_impl::construct_exop;
use crate::protocol::{LdapOp, MaybeControls, MiscSender, ResultSender};
use crate::result::{
CompareResult, ExopResult, LdapError, LdapResult, LdapResultExt, Result, SearchResult,
};
use crate::search::{Scope, SearchOptions, SearchStream};
use crate::RequestId;
use lber::common::TagClass;
use lber::structures::{Boolean, Enumerated, Integer, Null, OctetString, Sequence, Set, Tag};
#[cfg(feature = "gssapi")]
use cross_krb5::{ClientCtx, InitiateFlags, K5Ctx, Step};
use tokio::sync::{mpsc, oneshot};
use tokio::time;
#[derive(Clone, Debug)]
pub(crate) struct SaslCreds(pub Option<Vec<u8>>);
#[derive(Clone, Debug, PartialEq)]
pub enum Mod<S: AsRef<[u8]> + Eq + Hash> {
Add(S, HashSet<S>),
Delete(S, HashSet<S>),
Replace(S, HashSet<S>),
Increment(S, S),
}
#[derive(Debug)]
pub struct Ldap {
pub(crate) msgmap: Arc<Mutex<(RequestId, HashSet<RequestId>)>>,
pub(crate) tx: mpsc::UnboundedSender<(RequestId, LdapOp, Tag, MaybeControls, ResultSender)>,
pub(crate) id_scrub_tx: mpsc::UnboundedSender<RequestId>,
pub(crate) misc_tx: mpsc::UnboundedSender<MiscSender>,
pub(crate) last_id: RequestId,
#[cfg(feature = "gssapi")]
pub(crate) sasl_param: Arc<RwLock<(bool, u32)>>, #[cfg(feature = "gssapi")]
pub(crate) client_ctx: Arc<Mutex<Option<ClientCtx>>>,
#[cfg(feature = "gssapi")]
pub(crate) tls_endpoint_token: Arc<Option<Vec<u8>>>,
pub(crate) has_tls: bool,
pub timeout: Option<Duration>,
pub controls: MaybeControls,
pub search_opts: Option<SearchOptions>,
}
impl Clone for Ldap {
fn clone(&self) -> Self {
Ldap {
msgmap: self.msgmap.clone(),
tx: self.tx.clone(),
id_scrub_tx: self.id_scrub_tx.clone(),
misc_tx: self.misc_tx.clone(),
#[cfg(feature = "gssapi")]
sasl_param: self.sasl_param.clone(),
#[cfg(feature = "gssapi")]
client_ctx: self.client_ctx.clone(),
#[cfg(feature = "gssapi")]
tls_endpoint_token: self.tls_endpoint_token.clone(),
has_tls: self.has_tls,
last_id: 0,
timeout: None,
controls: None,
search_opts: None,
}
}
}
fn sasl_bind_req(mech: &str, creds: Option<&[u8]>) -> Tag {
let mut inner_vec = vec![Tag::OctetString(OctetString {
inner: Vec::from(mech),
..Default::default()
})];
if let Some(creds) = creds {
inner_vec.push(Tag::OctetString(OctetString {
inner: creds.to_vec(),
..Default::default()
}));
}
Tag::Sequence(Sequence {
id: 0,
class: TagClass::Application,
inner: vec![
Tag::Integer(Integer {
inner: 3,
..Default::default()
}),
Tag::OctetString(OctetString {
inner: Vec::new(),
..Default::default()
}),
Tag::Sequence(Sequence {
id: 3,
class: TagClass::Context,
inner: inner_vec,
}),
],
})
}
impl Ldap {
fn next_msgid(&mut self) -> i32 {
let mut msgmap = self.msgmap.lock().expect("msgmap mutex (inc id)");
let last_ldap_id = msgmap.0;
let mut next_ldap_id = last_ldap_id;
loop {
if next_ldap_id == std::i32::MAX {
next_ldap_id = 1;
} else {
next_ldap_id += 1;
}
if !msgmap.1.contains(&next_ldap_id) {
break;
}
assert_ne!(
next_ldap_id, last_ldap_id,
"LDAP message id wraparound with no free slots"
);
}
msgmap.0 = next_ldap_id;
msgmap.1.insert(next_ldap_id);
next_ldap_id
}
pub(crate) async fn op_call(
&mut self,
op: LdapOp,
req: Tag,
) -> Result<(LdapResult, Exop, SaslCreds)> {
let id = self.next_msgid();
self.last_id = id;
let (tx, rx) = oneshot::channel();
self.tx.send((id, op, req, self.controls.take(), tx))?;
let response = if let Some(timeout) = self.timeout.take() {
let res = time::timeout(timeout, rx).await;
if res.is_err() {
self.id_scrub_tx.send(self.last_id)?;
}
res?
} else {
rx.await
}?;
let (ldap_ext, controls) = (LdapResultExt::from(response.0), response.1);
let (mut result, exop, sasl_creds) = (ldap_ext.0, ldap_ext.1, ldap_ext.2);
result.ctrls = controls;
Ok((result, exop, sasl_creds))
}
pub fn with_search_options(&mut self, opts: SearchOptions) -> &mut Self {
self.search_opts = Some(opts);
self
}
pub fn with_controls<V: IntoRawControlVec>(&mut self, ctrls: V) -> &mut Self {
self.controls = Some(ctrls.into());
self
}
pub fn with_timeout(&mut self, duration: Duration) -> &mut Self {
self.timeout = Some(duration);
self
}
pub async fn simple_bind(&mut self, bind_dn: &str, bind_pw: &str) -> Result<LdapResult> {
let req = Tag::Sequence(Sequence {
id: 0,
class: TagClass::Application,
inner: vec![
Tag::Integer(Integer {
inner: 3,
..Default::default()
}),
Tag::OctetString(OctetString {
inner: Vec::from(bind_dn),
..Default::default()
}),
Tag::OctetString(OctetString {
id: 0,
class: TagClass::Context,
inner: Vec::from(bind_pw),
}),
],
});
Ok(self.op_call(LdapOp::Single, req).await?.0)
}
pub async fn sasl_external_bind(&mut self) -> Result<LdapResult> {
let req = sasl_bind_req("EXTERNAL", Some(b""));
Ok(self.op_call(LdapOp::Single, req).await?.0)
}
#[cfg_attr(docsrs, doc(cfg(feature = "gssapi")))]
#[cfg(feature = "gssapi")]
pub async fn sasl_gssapi_bind(&mut self, server_fqdn: &str) -> Result<LdapResult> {
const LDAP_RESULT_SASL_BIND_IN_PROGRESS: u32 = 14;
const GSSAUTH_P_NONE: u8 = 1;
const GSSAUTH_P_PRIVACY: u8 = 4;
let mut spn = String::from("ldap/");
spn.push_str(server_fqdn);
let cti = if self.has_tls {
let cbt = {
let mut cbt = Vec::from(&b"tls-server-end-point:"[..]);
if let Some(ref token) = self.tls_endpoint_token.as_ref() {
cbt.extend(token);
Some(cbt)
} else {
None
}
};
ClientCtx::new(InitiateFlags::empty(), None, &spn, cbt.as_deref())
} else {
ClientCtx::new(InitiateFlags::empty(), None, &spn, None)
};
let (client_ctx, token) =
cti.map_err(|e| LdapError::GssapiOperationError(format!("{:#}", e)))?;
let req = sasl_bind_req("GSSAPI", Some(&token));
let ans = self.op_call(LdapOp::Single, req).await?;
if (ans.0).rc != LDAP_RESULT_SASL_BIND_IN_PROGRESS {
return Ok(ans.0);
}
let token = match (ans.2).0 {
Some(token) => token,
_ => return Err(LdapError::NoGssapiToken),
};
let step = client_ctx
.step(&token)
.map_err(|e| LdapError::GssapiOperationError(format!("{:#}", e)))?;
let mut client_ctx = match step {
Step::Finished((ctx, None)) => ctx,
_ => {
return Err(LdapError::GssapiOperationError(String::from(
"GSSAPI exchange not finished or has an additional token",
)))
}
};
let req = sasl_bind_req("GSSAPI", None);
let ans = self.op_call(LdapOp::Single, req).await?;
if (ans.0).rc != LDAP_RESULT_SASL_BIND_IN_PROGRESS {
return Ok(ans.0);
}
let token = match (ans.2).0 {
Some(token) => token,
_ => return Err(LdapError::NoGssapiToken),
};
let mut buf = client_ctx
.unwrap(&token)
.map_err(|e| LdapError::GssapiOperationError(format!("{:#}", e)))?;
let needed_layer = if self.has_tls {
GSSAUTH_P_NONE
} else {
GSSAUTH_P_PRIVACY
};
if buf[0] | needed_layer == 0 {
return Err(LdapError::GssapiOperationError(format!(
"no appropriate security layer offered: needed {}, mask {}",
needed_layer, buf[0]
)));
}
let recv_max_size = (0x9FFFB8u32 | (needed_layer as u32) << 24).to_be_bytes();
let size_msg = client_ctx
.wrap(true, &recv_max_size)
.map_err(|e| LdapError::GssapiOperationError(format!("{:#}", e)))?;
let req = sasl_bind_req("GSSAPI", Some(&size_msg));
let res = self.op_call(LdapOp::Single, req).await?.0;
if res.rc == 0 {
if needed_layer == GSSAUTH_P_PRIVACY {
buf[0] = 0;
let send_max_size =
u32::from_be_bytes((&buf[..]).try_into().expect("send max size"));
let mut sasl_param = self.sasl_param.write().expect("sasl param");
sasl_param.0 = true;
sasl_param.1 = send_max_size;
}
let client_opt = &mut *self.client_ctx.lock().unwrap();
client_opt.replace(client_ctx);
}
Ok(res)
}
pub async fn search<'a, S: AsRef<str> + Send + Sync + 'a, A: AsRef<[S]> + Send + Sync + 'a>(
&mut self,
base: &str,
scope: Scope,
filter: &str,
attrs: A,
) -> Result<SearchResult> {
let mut stream = self
.streaming_search_with(EntriesOnly::new(), base, scope, filter, attrs)
.await?;
let mut re_vec = vec![];
while let Some(entry) = stream.next().await? {
re_vec.push(entry);
}
let res = stream.finish().await;
Ok(SearchResult(re_vec, res))
}
pub async fn streaming_search<
'a,
S: AsRef<str> + Send + Sync + 'a,
A: AsRef<[S]> + Send + Sync + 'a,
>(
&mut self,
base: &str,
scope: Scope,
filter: &str,
attrs: A,
) -> Result<SearchStream<'a, S, A>> {
self.streaming_search_with(vec![], base, scope, filter, attrs)
.await
}
pub async fn streaming_search_with<
'a,
V: IntoAdapterVec<'a, S, A>,
S: AsRef<str> + Send + Sync + 'a,
A: AsRef<[S]> + Send + Sync + 'a,
>(
&mut self,
adapters: V,
base: &str,
scope: Scope,
filter: &str,
attrs: A,
) -> Result<SearchStream<'a, S, A>> {
let mut ldap = self.clone();
ldap.controls = self.controls.take();
ldap.timeout = self.timeout.take();
ldap.search_opts = self.search_opts.take();
let mut stream = SearchStream::new(ldap, adapters.into());
stream.start(base, scope, filter, attrs).await?;
Ok(stream)
}
pub async fn add<S: AsRef<[u8]> + Eq + Hash>(
&mut self,
dn: &str,
attrs: Vec<(S, HashSet<S>)>,
) -> Result<LdapResult> {
let mut any_empty = false;
let req = Tag::Sequence(Sequence {
id: 8,
class: TagClass::Application,
inner: vec![
Tag::OctetString(OctetString {
inner: Vec::from(dn.as_bytes()),
..Default::default()
}),
Tag::Sequence(Sequence {
inner: attrs
.into_iter()
.map(|(name, vals)| {
if vals.is_empty() {
any_empty = true;
}
Tag::Sequence(Sequence {
inner: vec![
Tag::OctetString(OctetString {
inner: Vec::from(name.as_ref()),
..Default::default()
}),
Tag::Set(Set {
inner: vals
.into_iter()
.map(|v| {
Tag::OctetString(OctetString {
inner: Vec::from(v.as_ref()),
..Default::default()
})
})
.collect(),
..Default::default()
}),
],
..Default::default()
})
})
.collect(),
..Default::default()
}),
],
});
if any_empty {
return Err(LdapError::AddNoValues);
}
Ok(self.op_call(LdapOp::Single, req).await?.0)
}
pub async fn compare<B: AsRef<[u8]>>(
&mut self,
dn: &str,
attr: &str,
val: B,
) -> Result<CompareResult> {
let req = Tag::Sequence(Sequence {
id: 14,
class: TagClass::Application,
inner: vec![
Tag::OctetString(OctetString {
inner: Vec::from(dn.as_bytes()),
..Default::default()
}),
Tag::Sequence(Sequence {
inner: vec![
Tag::OctetString(OctetString {
inner: Vec::from(attr.as_bytes()),
..Default::default()
}),
Tag::OctetString(OctetString {
inner: Vec::from(val.as_ref()),
..Default::default()
}),
],
..Default::default()
}),
],
});
Ok(CompareResult(self.op_call(LdapOp::Single, req).await?.0))
}
pub async fn delete(&mut self, dn: &str) -> Result<LdapResult> {
let req = Tag::OctetString(OctetString {
id: 10,
class: TagClass::Application,
inner: Vec::from(dn.as_bytes()),
});
Ok(self.op_call(LdapOp::Single, req).await?.0)
}
pub async fn modify<S: AsRef<[u8]> + Eq + Hash>(
&mut self,
dn: &str,
mods: Vec<Mod<S>>,
) -> Result<LdapResult> {
let mut any_add_empty = false;
let req = Tag::Sequence(Sequence {
id: 6,
class: TagClass::Application,
inner: vec![
Tag::OctetString(OctetString {
inner: Vec::from(dn.as_bytes()),
..Default::default()
}),
Tag::Sequence(Sequence {
inner: mods
.into_iter()
.map(|m| {
let mut is_add = false;
let (num, attr, set) = match m {
Mod::Add(attr, set) => {
is_add = true;
(0, attr, set)
}
Mod::Delete(attr, set) => (1, attr, set),
Mod::Replace(attr, set) => (2, attr, set),
Mod::Increment(attr, val) => (3, attr, HashSet::from([val])),
};
if set.is_empty() && is_add {
any_add_empty = true;
}
let op = Tag::Enumerated(Enumerated {
inner: num,
..Default::default()
});
let part_attr = Tag::Sequence(Sequence {
inner: vec![
Tag::OctetString(OctetString {
inner: Vec::from(attr.as_ref()),
..Default::default()
}),
Tag::Set(Set {
inner: set
.into_iter()
.map(|val| {
Tag::OctetString(OctetString {
inner: Vec::from(val.as_ref()),
..Default::default()
})
})
.collect(),
..Default::default()
}),
],
..Default::default()
});
Tag::Sequence(Sequence {
inner: vec![op, part_attr],
..Default::default()
})
})
.collect(),
..Default::default()
}),
],
});
if any_add_empty {
return Err(LdapError::AddNoValues);
}
Ok(self.op_call(LdapOp::Single, req).await?.0)
}
pub async fn modifydn(
&mut self,
dn: &str,
rdn: &str,
delete_old: bool,
new_sup: Option<&str>,
) -> Result<LdapResult> {
let mut params = vec![
Tag::OctetString(OctetString {
inner: Vec::from(dn.as_bytes()),
..Default::default()
}),
Tag::OctetString(OctetString {
inner: Vec::from(rdn.as_bytes()),
..Default::default()
}),
Tag::Boolean(Boolean {
inner: delete_old,
..Default::default()
}),
];
if let Some(new_sup) = new_sup {
params.push(Tag::OctetString(OctetString {
id: 0,
class: TagClass::Context,
inner: Vec::from(new_sup.as_bytes()),
}));
}
let req = Tag::Sequence(Sequence {
id: 12,
class: TagClass::Application,
inner: params,
});
Ok(self.op_call(LdapOp::Single, req).await?.0)
}
pub async fn extended<E>(&mut self, exop: E) -> Result<ExopResult>
where
E: Into<Exop>,
{
let req = Tag::Sequence(Sequence {
id: 23,
class: TagClass::Application,
inner: construct_exop(exop.into()),
});
self.op_call(LdapOp::Single, req)
.await
.map(|et| ExopResult(et.1, et.0))
}
pub async fn unbind(&mut self) -> Result<()> {
let req = Tag::Null(Null {
id: 2,
class: TagClass::Application,
inner: (),
});
Ok(self.op_call(LdapOp::Unbind, req).await.map(|_| ())?)
}
pub fn last_id(&mut self) -> RequestId {
self.last_id
}
pub async fn abandon(&mut self, msgid: RequestId) -> Result<()> {
let req = Tag::Integer(Integer {
id: 16,
class: TagClass::Application,
inner: msgid as i64,
});
Ok(self
.op_call(LdapOp::Abandon(msgid), req)
.await
.map(|_| ())?)
}
pub fn is_closed(&mut self) -> bool {
self.tx.is_closed()
}
pub async fn get_peer_certificate(&mut self) -> Result<Option<Vec<u8>>> {
#[cfg(any(feature = "tls-native", feature = "tls-rustls"))]
{
let (tx, rx) = oneshot::channel();
self.misc_tx.send(MiscSender::Cert(tx))?;
Ok(rx.await?)
}
#[cfg(not(any(feature = "tls-native", feature = "tls-rustls")))]
{
Ok(None)
}
}
}