1#![deny(warnings)]
2#![warn(unused_extern_crates)]
3#![deny(clippy::todo)]
4#![deny(clippy::unimplemented)]
5#![deny(clippy::unwrap_used)]
6#![deny(clippy::panic)]
7#![deny(clippy::await_holding_lock)]
8#![deny(clippy::needless_pass_by_value)]
9#![deny(clippy::trivially_copy_pass_by_ref)]
10#![allow(clippy::expect_used)]
12
13use base64::{engine::general_purpose, Engine as _};
14use futures_util::sink::SinkExt;
15use futures_util::stream::StreamExt;
16use ldap3_proto::proto::*;
17use ldap3_proto::LdapCodec;
18use rustls_platform_verifier::ConfigVerifierExt;
19use serde::{Deserialize, Serialize};
20use std::collections::{BTreeMap, BTreeSet};
21use std::fmt;
22use std::fs::File;
23use std::io::Read;
24use std::path::Path;
25use std::sync::Arc;
26use tokio::io::{ReadHalf, WriteHalf};
27use tokio::net::TcpStream;
28use tokio::time;
29use tokio_rustls::{
30 client::TlsStream,
31 rustls::client::danger::*,
32 rustls::client::ClientConfig,
33 rustls::pki_types::{pem::PemObject, CertificateDer, ServerName, UnixTime},
34 rustls::Error as RustlsError,
35 rustls::RootCertStore,
36 rustls::{DigitallySignedStruct, SignatureScheme},
37 TlsConnector,
38};
39
40use tokio_util::codec::{FramedRead, FramedWrite};
41use tracing::{error, info, trace, warn};
42use url::{Host, Url};
43use uuid::Uuid;
44
45pub use ldap3_proto::filter;
46pub use ldap3_proto::proto;
47pub use search::LdapSearchResult;
48pub use syncrepl::{LdapSyncRepl, LdapSyncReplEntry, LdapSyncStateValue};
49pub use tokio::time::Duration;
50
51mod addirsync;
52mod search;
53mod syncrepl;
54
55#[non_exhaustive]
56#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
57#[repr(i32)]
58pub enum LdapError {
59 InvalidUrl = -1,
60 LdapiNotSupported = -2,
61 UseCldapTool = -3,
62 ResolverError = -4,
63 ConnectError = -5,
64 TlsError = -6,
65 PasswordNotFound = -7,
66 AnonymousInvalidState = -8,
67 TransportWriteError = -9,
68 TransportReadError = -10,
69 InvalidProtocolState = -11,
70 FileIOError = -12,
71
72 UnavailableCriticalExtension = 12,
73 InvalidCredentials = 49,
74 InsufficentAccessRights = 50,
75 UnwillingToPerform = 53,
76 EsyncRefreshRequired = 4096,
77 NotImplemented = 9999,
78}
79
80impl From<LdapResultCode> for LdapError {
81 fn from(code: LdapResultCode) -> Self {
82 match code {
83 LdapResultCode::InvalidCredentials => LdapError::InvalidCredentials,
84 LdapResultCode::InsufficentAccessRights => LdapError::InsufficentAccessRights,
85 LdapResultCode::EsyncRefreshRequired => LdapError::EsyncRefreshRequired,
86 LdapResultCode::UnavailableCriticalExtension => LdapError::UnavailableCriticalExtension,
87 LdapResultCode::UnwillingToPerform => LdapError::UnwillingToPerform,
88 err => {
89 error!("{:?} not implemented yet!!", err);
90 LdapError::NotImplemented
91 }
92 }
93 }
94}
95
96impl fmt::Display for LdapError {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 match self {
99 LdapError::InvalidUrl => write!(f, "Invalid URL"),
100 LdapError::LdapiNotSupported => write!(f, "Ldapi Not Supported"),
101 LdapError::UseCldapTool => write!(f, "Use cldap tool for cldap:// urls"),
102 LdapError::ResolverError => write!(f, "Failed to resolve hostname or invalid ip"),
103 LdapError::ConnectError => write!(f, "Failed to connect to host"),
104 LdapError::TlsError => write!(f, "Failed to establish TLS"),
105 LdapError::PasswordNotFound => write!(f, "No password available for bind"),
106 LdapError::AnonymousInvalidState => write!(f, "Invalid Anonymous bind state"),
107 LdapError::InvalidProtocolState => {
108 write!(f, "The LDAP server sent a response we did not expect")
109 }
110 LdapError::FileIOError => {
111 write!(f, "An error occurred while accessing a file")
112 }
113 LdapError::TransportReadError => {
114 write!(f, "An error occurred reading from the transport")
115 }
116 LdapError::TransportWriteError => {
117 write!(f, "An error occurred writing to the transport")
118 }
119 LdapError::UnavailableCriticalExtension => write!(f, "An extension marked as critical was not available"),
120 LdapError::InvalidCredentials => write!(f, "Invalid DN or Password"),
121 LdapError::InsufficentAccessRights => write!(f, "Insufficient Access"),
122 LdapError::UnwillingToPerform => write!(f, "Too many failures, server is unwilling to perform the operation."),
123 LdapError::EsyncRefreshRequired => write!(f, "An initial content sync is required. The current cookie should be considered invalid."),
124 LdapError::NotImplemented => write!(f, "An error occurred, but we haven't implemented code to handle this error yet.")
125 }
126 }
127}
128
129pub type LdapResult<T> = Result<T, LdapError>;
130
131enum LdapReadTransport {
132 Plain(FramedRead<ReadHalf<TcpStream>, LdapCodec>),
133 Tls(FramedRead<ReadHalf<TlsStream<TcpStream>>, LdapCodec>),
134}
135
136impl fmt::Debug for LdapReadTransport {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 match self {
139 LdapReadTransport::Plain(_) => f
140 .debug_struct("LdapReadTransport")
141 .field("type", &"plain")
142 .finish(),
143 LdapReadTransport::Tls(_) => f
144 .debug_struct("LdapReadTransport")
145 .field("type", &"tls")
146 .finish(),
147 }
148 }
149}
150
151enum LdapWriteTransport {
152 Plain(FramedWrite<WriteHalf<TcpStream>, LdapCodec>),
153 Tls(FramedWrite<WriteHalf<TlsStream<TcpStream>>, LdapCodec>),
154}
155
156impl fmt::Debug for LdapWriteTransport {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 match self {
159 LdapWriteTransport::Plain(_) => f
160 .debug_struct("LdapWriteTransport")
161 .field("type", &"plain")
162 .finish(),
163 LdapWriteTransport::Tls(_) => f
164 .debug_struct("LdapWriteTransport")
165 .field("type", &"tls")
166 .finish(),
167 }
168 }
169}
170
171impl LdapWriteTransport {
172 async fn send(&mut self, msg: LdapMsg) -> LdapResult<()> {
173 match self {
174 LdapWriteTransport::Plain(f) => f.send(msg).await.map_err(|e| {
175 info!(?e, "transport error");
176 LdapError::TransportWriteError
177 }),
178 LdapWriteTransport::Tls(f) => f.send(msg).await.map_err(|e| {
179 info!(?e, "transport error");
180 LdapError::TransportWriteError
181 }),
182 }
183 }
184}
185
186impl LdapReadTransport {
187 async fn next(&mut self) -> LdapResult<LdapMsg> {
188 match self {
189 LdapReadTransport::Plain(f) => f.next().await.transpose().map_err(|e| {
190 info!(?e, "transport error");
191 LdapError::TransportReadError
192 })?,
193 LdapReadTransport::Tls(f) => f.next().await.transpose().map_err(|e| {
194 info!(?e, "transport error");
195 LdapError::TransportReadError
196 })?,
197 }
198 .ok_or_else(|| {
199 info!("connection closed");
200 LdapError::TransportReadError
201 })
202 }
203}
204
205#[derive(Debug, Deserialize, Serialize)]
206pub struct LdapEntry {
207 pub dn: String,
208 pub attrs: BTreeMap<String, BTreeSet<String>>,
209}
210
211impl LdapEntry {
212 pub fn get_ava_single(&self, attr: &str) -> Option<&str> {
213 if let Some(ava) = self.attrs.get(attr) {
214 if ava.len() == 1 {
215 ava.iter().next().map(String::as_ref)
216 } else {
217 None
218 }
219 } else {
220 None
221 }
222 }
223
224 pub fn remove_ava_single(&mut self, attr: &str) -> Option<String> {
225 if let Some(ava) = self.attrs.remove(attr) {
226 if ava.len() == 1 {
227 ava.into_iter().next()
228 } else {
229 None
230 }
231 } else {
232 None
233 }
234 }
235
236 pub fn remove_ava(&mut self, attr: &str) -> Option<BTreeSet<String>> {
237 self.attrs.remove(attr)
238 }
239}
240
241impl From<LdapSearchResultEntry> for LdapEntry {
242 fn from(ent: LdapSearchResultEntry) -> Self {
243 let LdapSearchResultEntry { dn, attributes } = ent;
244
245 let attrs = attributes
246 .into_iter()
247 .map(|LdapPartialAttribute { atype, vals }| {
248 let atype = atype.to_lowercase();
249
250 let lower = atype == "objectclass";
251
252 let va = vals
253 .into_iter()
254 .map(|bin| {
255 std::str::from_utf8(&bin)
256 .map(|s| {
257 if lower {
258 s.to_lowercase()
259 } else {
260 s.to_string()
261 }
262 })
263 .unwrap_or_else(|_| general_purpose::URL_SAFE.encode(&bin))
264 })
265 .collect();
266 (atype, va)
267 })
268 .collect();
269
270 LdapEntry { dn, attrs }
271 }
272}
273
274pub struct LdapClientBuilder<'a> {
275 url: &'a Url,
276 timeout: Duration,
277 cas: Vec<&'a Path>,
278 verify: bool,
279 max_ber_size: Option<usize>,
281}
282
283impl<'a> LdapClientBuilder<'a> {
284 pub fn new(url: &'a Url) -> Self {
285 LdapClientBuilder {
286 url,
287 timeout: Duration::from_secs(30),
288 cas: Vec::new(),
289 verify: true,
290 max_ber_size: None,
291 }
292 }
293
294 pub fn set_timeout(self, timeout: Duration) -> Self {
295 Self { timeout, ..self }
296 }
297
298 pub fn add_tls_ca<T>(mut self, ca: &'a T) -> Self
299 where
300 T: AsRef<Path>,
301 {
302 self.cas.push(ca.as_ref());
303 self
304 }
305
306 pub fn danger_accept_invalid_certs(self, accept_invalid_certs: bool) -> Self {
307 Self {
308 verify: !accept_invalid_certs,
309 ..self
310 }
311 }
312
313 pub fn max_ber_size(self, max_ber_size: Option<usize>) -> Self {
315 Self {
316 max_ber_size,
317 ..self
318 }
319 }
320
321 #[tracing::instrument(level = "debug", skip_all)]
322 pub async fn build(self) -> LdapResult<LdapClient> {
323 let LdapClientBuilder {
324 url,
325 timeout,
326 cas,
327 verify,
328 max_ber_size,
329 } = self;
330
331 info!(%url);
332 info!(?timeout);
333
334 let need_tls = match url.scheme() {
337 "ldapi" => return Err(LdapError::LdapiNotSupported),
338 "cldap" => return Err(LdapError::UseCldapTool),
339 "ldap" => false,
340 "ldaps" => true,
341 _ => return Err(LdapError::InvalidUrl),
342 };
343
344 info!(%need_tls);
345 let addrs = url
352 .socket_addrs(|| Some(if need_tls { 636 } else { 389 }))
353 .map_err(|e| {
354 info!(?e, "resolver error");
355 LdapError::ResolverError
356 })?;
357
358 if addrs.is_empty() {
359 return Err(LdapError::ResolverError);
360 }
361
362 addrs.iter().for_each(|address| info!(?address));
363
364 let mut aiter = addrs.into_iter();
365
366 let tcpstream = loop {
368 if let Some(addr) = aiter.next() {
369 let sleep = time::sleep(timeout);
370 tokio::pin!(sleep);
371 tokio::select! {
372 maybe_stream = TcpStream::connect(addr) => {
373 match maybe_stream {
374 Ok(t) => {
375 info!(?addr, "connection established");
376 break t;
377 }
378 Err(e) => {
379 info!(?addr, ?e, "error");
380 continue;
381 }
382 }
383 }
384 _ = &mut sleep => {
385 info!(?addr, "timeout");
386 continue;
387 }
388 }
389 } else {
390 return Err(LdapError::ConnectError);
391 }
392 };
393
394 let max_ber_size = max_ber_size.unwrap_or(ldap3_proto::DEFAULT_MAX_BER_SIZE);
396
397 let (write_transport, read_transport) = if need_tls {
399 let tls_client_config = if !cas.is_empty() {
401 let mut cert_store = RootCertStore::empty();
402 for ca in cas.iter() {
403 let mut file = File::open(ca).map_err(|e| {
404 error!(?e, "Unable to open {:?}", ca);
405 LdapError::FileIOError
406 })?;
407
408 let mut pem = Vec::new();
409 file.read_to_end(&mut pem).map_err(|e| {
410 error!(?e, "Unable to read {:?}", ca);
411 LdapError::FileIOError
412 })?;
413
414 let ca_cert = CertificateDer::from_pem_slice(pem.as_slice()).map_err(|e| {
415 error!(?e, "rustls");
416 LdapError::TlsError
417 })?;
418
419 cert_store
420 .add(ca_cert)
421 .map(|()| {
422 info!("Added {:?} to cert store", ca);
423 })
424 .map_err(|e| {
425 error!(?e, "rustls");
426 LdapError::TlsError
427 })?;
428 }
429
430 ClientConfig::builder()
431 .with_root_certificates(cert_store)
432 .with_no_client_auth()
433 } else if !verify {
434 warn!("⚠️ CERTIFICATE VERIFICATION IS DISABLED. THIS IS DANGEROUS!!!!");
435 let yolo_cert_validator = Arc::new(YoloCertValidator);
436
437 ClientConfig::builder()
438 .dangerous()
439 .with_custom_certificate_verifier(yolo_cert_validator)
440 .with_no_client_auth()
441 } else {
442 ClientConfig::with_platform_verifier()
444 };
445
446 let tls_connector = TlsConnector::from(Arc::new(tls_client_config));
447
448 let server_name = match url.host() {
449 Some(Host::Domain(name)) => {
450 ServerName::try_from(name.to_owned()).map_err(|err| {
451 error!(?err, "server name invalid");
452 LdapError::TlsError
453 })?
454 }
455 Some(Host::Ipv4(addr)) => ServerName::from(addr),
456 Some(Host::Ipv6(addr)) => ServerName::from(addr),
457 None => {
458 error!("url invalid");
459 return Err(LdapError::TlsError);
460 }
461 };
462
463 let tlsstream = tls_connector
464 .connect(
465 server_name,
466 tcpstream,
468 )
469 .await
470 .map_err(|e| {
471 error!(?e, "rustls");
472 LdapError::TlsError
473 })?;
474
475 info!("tls configured");
476
477 let (r, w) = tokio::io::split(tlsstream);
478 (
479 LdapWriteTransport::Tls(FramedWrite::new(w, LdapCodec::default())),
480 LdapReadTransport::Tls(FramedRead::new(r, LdapCodec::new(Some(max_ber_size)))),
481 )
482 } else {
483 let (r, w) = tokio::io::split(tcpstream);
484 (
485 LdapWriteTransport::Plain(FramedWrite::new(w, LdapCodec::default())),
486 LdapReadTransport::Plain(FramedRead::new(r, LdapCodec::new(Some(max_ber_size)))),
487 )
488 };
489
490 let msg_counter = 1;
491
492 Ok(LdapClient {
494 read_transport,
495 write_transport,
496 msg_counter,
497 })
498 }
499}
500
501#[derive(Debug)]
502pub struct LdapClient {
503 read_transport: LdapReadTransport,
504 write_transport: LdapWriteTransport,
505 msg_counter: i32,
506}
507
508impl LdapClient {
509 fn get_next_msgid(&mut self) -> i32 {
510 let msgid = self.msg_counter;
511 self.msg_counter += 1;
512 msgid
513 }
514
515 #[tracing::instrument(level = "debug", skip_all)]
516 pub async fn bind<S: Into<String>>(&mut self, dn: S, pw: S) -> LdapResult<()> {
517 let dn = dn.into();
518 info!(%dn);
519 let msgid = self.get_next_msgid();
520
521 let msg = LdapMsg {
522 msgid,
523 op: LdapOp::BindRequest(LdapBindRequest {
524 dn,
525 cred: LdapBindCred::Simple(pw.into()),
526 }),
527 ctrl: vec![],
528 };
529
530 self.write_transport.send(msg).await?;
531
532 self.read_transport
534 .next()
535 .await
536 .and_then(|msg| match msg.op {
537 LdapOp::BindResponse(res) => {
538 if res.res.code == LdapResultCode::Success {
539 info!("bind success");
540 Ok(())
541 } else {
542 info!(?res.res.code);
543 Err(LdapError::from(res.res.code))
544 }
545 }
546 op => {
547 trace!(?op);
548 Err(LdapError::InvalidProtocolState)
549 }
550 })
551 }
552
553 #[tracing::instrument(level = "debug", skip_all)]
554 pub async fn whoami(&mut self) -> LdapResult<Option<String>> {
555 let msgid = self.get_next_msgid();
556
557 let msg = LdapMsg {
558 msgid,
559 op: LdapOp::ExtendedRequest(Into::into(LdapWhoamiRequest {})),
560 ctrl: vec![],
561 };
562
563 self.write_transport.send(msg).await?;
564
565 self.read_transport
566 .next()
567 .await
568 .and_then(|msg| match msg.op {
569 LdapOp::ExtendedResponse(ler) => LdapWhoamiResponse::try_from(&ler)
570 .map_err(|_| LdapError::InvalidProtocolState)
571 .map(|res| res.dn),
572 op => {
573 trace!(?op);
574 Err(LdapError::InvalidProtocolState)
575 }
576 })
577 }
578}
579
580#[derive(Debug)]
581struct YoloCertValidator;
583
584impl ServerCertVerifier for YoloCertValidator {
585 fn verify_server_cert(
586 &self,
587 _end_entity: &CertificateDer<'_>,
588 _intermediates: &[CertificateDer<'_>],
589 _server_name: &ServerName<'_>,
590 _ocsp_response: &[u8],
591 _now: UnixTime,
592 ) -> Result<ServerCertVerified, RustlsError> {
593 Ok(ServerCertVerified::assertion())
595 }
596
597 fn verify_tls12_signature(
598 &self,
599 _message: &[u8],
600 _cert: &CertificateDer<'_>,
601 _dss: &DigitallySignedStruct,
602 ) -> Result<HandshakeSignatureValid, RustlsError> {
603 Ok(HandshakeSignatureValid::assertion())
604 }
605
606 fn verify_tls13_signature(
607 &self,
608 _message: &[u8],
609 _cert: &CertificateDer<'_>,
610 _dss: &DigitallySignedStruct,
611 ) -> Result<HandshakeSignatureValid, RustlsError> {
612 Ok(HandshakeSignatureValid::assertion())
613 }
614
615 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
616 vec![
617 SignatureScheme::RSA_PKCS1_SHA384,
618 SignatureScheme::ECDSA_NISTP256_SHA256,
619 SignatureScheme::RSA_PKCS1_SHA256,
620 SignatureScheme::ECDSA_NISTP384_SHA384,
621 SignatureScheme::RSA_PKCS1_SHA512,
622 SignatureScheme::ECDSA_NISTP521_SHA512,
623 ]
624 }
625}
626
627#[test]
629fn test_ldapclient_builder() {
630 let url = Url::parse("ldap://ldap.example.com:389").unwrap();
631 let client = LdapClientBuilder::new(&url).max_ber_size(Some(1234567));
632 assert_eq!(client.timeout, Duration::from_secs(30));
633 let client = client.set_timeout(Duration::from_secs(60));
634 assert_eq!(client.timeout, Duration::from_secs(60));
635 assert_eq!(client.cas.len(), 0);
636 assert_eq!(client.max_ber_size, Some(1234567));
637 assert_eq!(client.verify, true);
638
639 let ca_path = "test.pem".to_string();
640 let client = client.add_tls_ca(&ca_path);
641 assert_eq!(client.cas.len(), 1);
642
643 let badssl_client = client.danger_accept_invalid_certs(true);
644 assert_eq!(badssl_client.verify, false);
645}