ldap3_client/
lib.rs

1#![deny(warnings)]
2#![warn(unused_extern_crates)]
3#![deny(clippy::todo)]
4#![deny(clippy::unimplemented)]
5#![deny(clippy::unwrap_used)]
6#![deny(clippy::panic)]
7#![deny(clippy::await_holding_lock)]
8#![deny(clippy::needless_pass_by_value)]
9#![deny(clippy::trivially_copy_pass_by_ref)]
10// We allow expect since it forces good error messages at the least.
11#![allow(clippy::expect_used)]
12
13use base64::{engine::general_purpose, Engine as _};
14use futures_util::sink::SinkExt;
15use futures_util::stream::StreamExt;
16use ldap3_proto::proto::*;
17use ldap3_proto::LdapCodec;
18use rustls_platform_verifier::ConfigVerifierExt;
19use serde::{Deserialize, Serialize};
20use std::collections::{BTreeMap, BTreeSet};
21use std::fmt;
22use std::fs::File;
23use std::io::Read;
24use std::path::Path;
25use std::sync::Arc;
26use tokio::io::{ReadHalf, WriteHalf};
27use tokio::net::TcpStream;
28use tokio::time;
29use tokio_rustls::{
30    client::TlsStream,
31    rustls::client::danger::*,
32    rustls::client::ClientConfig,
33    rustls::pki_types::{pem::PemObject, CertificateDer, ServerName, UnixTime},
34    rustls::Error as RustlsError,
35    rustls::RootCertStore,
36    rustls::{DigitallySignedStruct, SignatureScheme},
37    TlsConnector,
38};
39
40use tokio_util::codec::{FramedRead, FramedWrite};
41use tracing::{error, info, trace, warn};
42use url::{Host, Url};
43use uuid::Uuid;
44
45pub use ldap3_proto::filter;
46pub use ldap3_proto::proto;
47pub use search::LdapSearchResult;
48pub use syncrepl::{LdapSyncRepl, LdapSyncReplEntry, LdapSyncStateValue};
49pub use tokio::time::Duration;
50
51mod addirsync;
52mod search;
53mod syncrepl;
54
55#[non_exhaustive]
56#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
57#[repr(i32)]
58pub enum LdapError {
59    InvalidUrl = -1,
60    LdapiNotSupported = -2,
61    UseCldapTool = -3,
62    ResolverError = -4,
63    ConnectError = -5,
64    TlsError = -6,
65    PasswordNotFound = -7,
66    AnonymousInvalidState = -8,
67    TransportWriteError = -9,
68    TransportReadError = -10,
69    InvalidProtocolState = -11,
70    FileIOError = -12,
71
72    UnavailableCriticalExtension = 12,
73    InvalidCredentials = 49,
74    InsufficentAccessRights = 50,
75    UnwillingToPerform = 53,
76    EsyncRefreshRequired = 4096,
77    NotImplemented = 9999,
78}
79
80impl From<LdapResultCode> for LdapError {
81    fn from(code: LdapResultCode) -> Self {
82        match code {
83            LdapResultCode::InvalidCredentials => LdapError::InvalidCredentials,
84            LdapResultCode::InsufficentAccessRights => LdapError::InsufficentAccessRights,
85            LdapResultCode::EsyncRefreshRequired => LdapError::EsyncRefreshRequired,
86            LdapResultCode::UnavailableCriticalExtension => LdapError::UnavailableCriticalExtension,
87            LdapResultCode::UnwillingToPerform => LdapError::UnwillingToPerform,
88            err => {
89                error!("{:?} not implemented yet!!", err);
90                LdapError::NotImplemented
91            }
92        }
93    }
94}
95
96impl fmt::Display for LdapError {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        match self {
99            LdapError::InvalidUrl => write!(f, "Invalid URL"),
100            LdapError::LdapiNotSupported => write!(f, "Ldapi Not Supported"),
101            LdapError::UseCldapTool => write!(f, "Use cldap tool for cldap:// urls"),
102            LdapError::ResolverError => write!(f, "Failed to resolve hostname or invalid ip"),
103            LdapError::ConnectError => write!(f, "Failed to connect to host"),
104            LdapError::TlsError => write!(f, "Failed to establish TLS"),
105            LdapError::PasswordNotFound => write!(f, "No password available for bind"),
106            LdapError::AnonymousInvalidState => write!(f, "Invalid Anonymous bind state"),
107            LdapError::InvalidProtocolState => {
108                write!(f, "The LDAP server sent a response we did not expect")
109            }
110            LdapError::FileIOError => {
111                write!(f, "An error occurred while accessing a file")
112            }
113            LdapError::TransportReadError => {
114                write!(f, "An error occurred reading from the transport")
115            }
116            LdapError::TransportWriteError => {
117                write!(f, "An error occurred writing to the transport")
118            }
119            LdapError::UnavailableCriticalExtension => write!(f, "An extension marked as critical was not available"),
120            LdapError::InvalidCredentials => write!(f, "Invalid DN or Password"),
121            LdapError::InsufficentAccessRights => write!(f, "Insufficient Access"),
122            LdapError::UnwillingToPerform => write!(f, "Too many failures, server is unwilling to perform the operation."),
123            LdapError::EsyncRefreshRequired => write!(f, "An initial content sync is required. The current cookie should be considered invalid."),
124            LdapError::NotImplemented => write!(f, "An error occurred, but we haven't implemented code to handle this error yet.")
125        }
126    }
127}
128
129pub type LdapResult<T> = Result<T, LdapError>;
130
131enum LdapReadTransport {
132    Plain(FramedRead<ReadHalf<TcpStream>, LdapCodec>),
133    Tls(FramedRead<ReadHalf<TlsStream<TcpStream>>, LdapCodec>),
134}
135
136impl fmt::Debug for LdapReadTransport {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        match self {
139            LdapReadTransport::Plain(_) => f
140                .debug_struct("LdapReadTransport")
141                .field("type", &"plain")
142                .finish(),
143            LdapReadTransport::Tls(_) => f
144                .debug_struct("LdapReadTransport")
145                .field("type", &"tls")
146                .finish(),
147        }
148    }
149}
150
151enum LdapWriteTransport {
152    Plain(FramedWrite<WriteHalf<TcpStream>, LdapCodec>),
153    Tls(FramedWrite<WriteHalf<TlsStream<TcpStream>>, LdapCodec>),
154}
155
156impl fmt::Debug for LdapWriteTransport {
157    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158        match self {
159            LdapWriteTransport::Plain(_) => f
160                .debug_struct("LdapWriteTransport")
161                .field("type", &"plain")
162                .finish(),
163            LdapWriteTransport::Tls(_) => f
164                .debug_struct("LdapWriteTransport")
165                .field("type", &"tls")
166                .finish(),
167        }
168    }
169}
170
171impl LdapWriteTransport {
172    async fn send(&mut self, msg: LdapMsg) -> LdapResult<()> {
173        match self {
174            LdapWriteTransport::Plain(f) => f.send(msg).await.map_err(|e| {
175                info!(?e, "transport error");
176                LdapError::TransportWriteError
177            }),
178            LdapWriteTransport::Tls(f) => f.send(msg).await.map_err(|e| {
179                info!(?e, "transport error");
180                LdapError::TransportWriteError
181            }),
182        }
183    }
184}
185
186impl LdapReadTransport {
187    async fn next(&mut self) -> LdapResult<LdapMsg> {
188        match self {
189            LdapReadTransport::Plain(f) => f.next().await.transpose().map_err(|e| {
190                info!(?e, "transport error");
191                LdapError::TransportReadError
192            })?,
193            LdapReadTransport::Tls(f) => f.next().await.transpose().map_err(|e| {
194                info!(?e, "transport error");
195                LdapError::TransportReadError
196            })?,
197        }
198        .ok_or_else(|| {
199            info!("connection closed");
200            LdapError::TransportReadError
201        })
202    }
203}
204
205#[derive(Debug, Deserialize, Serialize)]
206pub struct LdapEntry {
207    pub dn: String,
208    pub attrs: BTreeMap<String, BTreeSet<String>>,
209}
210
211impl LdapEntry {
212    pub fn get_ava_single(&self, attr: &str) -> Option<&str> {
213        if let Some(ava) = self.attrs.get(attr) {
214            if ava.len() == 1 {
215                ava.iter().next().map(String::as_ref)
216            } else {
217                None
218            }
219        } else {
220            None
221        }
222    }
223
224    pub fn remove_ava_single(&mut self, attr: &str) -> Option<String> {
225        if let Some(ava) = self.attrs.remove(attr) {
226            if ava.len() == 1 {
227                ava.into_iter().next()
228            } else {
229                None
230            }
231        } else {
232            None
233        }
234    }
235
236    pub fn remove_ava(&mut self, attr: &str) -> Option<BTreeSet<String>> {
237        self.attrs.remove(attr)
238    }
239}
240
241impl From<LdapSearchResultEntry> for LdapEntry {
242    fn from(ent: LdapSearchResultEntry) -> Self {
243        let LdapSearchResultEntry { dn, attributes } = ent;
244
245        let attrs = attributes
246            .into_iter()
247            .map(|LdapPartialAttribute { atype, vals }| {
248                let atype = atype.to_lowercase();
249
250                let lower = atype == "objectclass";
251
252                let va = vals
253                    .into_iter()
254                    .map(|bin| {
255                        std::str::from_utf8(&bin)
256                            .map(|s| {
257                                if lower {
258                                    s.to_lowercase()
259                                } else {
260                                    s.to_string()
261                                }
262                            })
263                            .unwrap_or_else(|_| general_purpose::URL_SAFE.encode(&bin))
264                    })
265                    .collect();
266                (atype, va)
267            })
268            .collect();
269
270        LdapEntry { dn, attrs }
271    }
272}
273
274pub struct LdapClientBuilder<'a> {
275    url: &'a Url,
276    timeout: Duration,
277    cas: Vec<&'a Path>,
278    verify: bool,
279    /// The maximum LDAP packet size parsed during decoding.
280    max_ber_size: Option<usize>,
281}
282
283impl<'a> LdapClientBuilder<'a> {
284    pub fn new(url: &'a Url) -> Self {
285        LdapClientBuilder {
286            url,
287            timeout: Duration::from_secs(30),
288            cas: Vec::new(),
289            verify: true,
290            max_ber_size: None,
291        }
292    }
293
294    pub fn set_timeout(self, timeout: Duration) -> Self {
295        Self { timeout, ..self }
296    }
297
298    pub fn add_tls_ca<T>(mut self, ca: &'a T) -> Self
299    where
300        T: AsRef<Path>,
301    {
302        self.cas.push(ca.as_ref());
303        self
304    }
305
306    pub fn danger_accept_invalid_certs(self, accept_invalid_certs: bool) -> Self {
307        Self {
308            verify: !accept_invalid_certs,
309            ..self
310        }
311    }
312
313    /// Set the maximum size of a decoded message
314    pub fn max_ber_size(self, max_ber_size: Option<usize>) -> Self {
315        Self {
316            max_ber_size,
317            ..self
318        }
319    }
320
321    #[tracing::instrument(level = "debug", skip_all)]
322    pub async fn build(self) -> LdapResult<LdapClient> {
323        let LdapClientBuilder {
324            url,
325            timeout,
326            cas,
327            verify,
328            max_ber_size,
329        } = self;
330
331        info!(%url);
332        info!(?timeout);
333
334        // Check the scheme is ldap or ldaps
335        // for now, no ldapi support.
336        let need_tls = match url.scheme() {
337            "ldapi" => return Err(LdapError::LdapiNotSupported),
338            "cldap" => return Err(LdapError::UseCldapTool),
339            "ldap" => false,
340            "ldaps" => true,
341            _ => return Err(LdapError::InvalidUrl),
342        };
343
344        info!(%need_tls);
345        // get domain + port
346
347        // Do we have query params? Can we use them?
348        // https://ldap.com/ldap-urls/
349
350        // resolve to a set of socket addrs.
351        let addrs = url
352            .socket_addrs(|| Some(if need_tls { 636 } else { 389 }))
353            .map_err(|e| {
354                info!(?e, "resolver error");
355                LdapError::ResolverError
356            })?;
357
358        if addrs.is_empty() {
359            return Err(LdapError::ResolverError);
360        }
361
362        addrs.iter().for_each(|address| info!(?address));
363
364        let mut aiter = addrs.into_iter();
365
366        // Try for each to open, with a timeout.
367        let tcpstream = loop {
368            if let Some(addr) = aiter.next() {
369                let sleep = time::sleep(timeout);
370                tokio::pin!(sleep);
371                tokio::select! {
372                    maybe_stream = TcpStream::connect(addr) => {
373                        match maybe_stream {
374                            Ok(t) => {
375                                info!(?addr, "connection established");
376                                break t;
377                            }
378                            Err(e) => {
379                                info!(?addr, ?e, "error");
380                                continue;
381                            }
382                        }
383                    }
384                    _ = &mut sleep => {
385                        info!(?addr, "timeout");
386                        continue;
387                    }
388                }
389            } else {
390                return Err(LdapError::ConnectError);
391            }
392        };
393
394        // If they didn't set it in the builder then set it to the default
395        let max_ber_size = max_ber_size.unwrap_or(ldap3_proto::DEFAULT_MAX_BER_SIZE);
396
397        // If ldaps - start rustls
398        let (write_transport, read_transport) = if need_tls {
399            // What about the `verify` flag?
400            let tls_client_config = if !cas.is_empty() {
401                let mut cert_store = RootCertStore::empty();
402                for ca in cas.iter() {
403                    let mut file = File::open(ca).map_err(|e| {
404                        error!(?e, "Unable to open {:?}", ca);
405                        LdapError::FileIOError
406                    })?;
407
408                    let mut pem = Vec::new();
409                    file.read_to_end(&mut pem).map_err(|e| {
410                        error!(?e, "Unable to read {:?}", ca);
411                        LdapError::FileIOError
412                    })?;
413
414                    let ca_cert = CertificateDer::from_pem_slice(pem.as_slice()).map_err(|e| {
415                        error!(?e, "rustls");
416                        LdapError::TlsError
417                    })?;
418
419                    cert_store
420                        .add(ca_cert)
421                        .map(|()| {
422                            info!("Added {:?} to cert store", ca);
423                        })
424                        .map_err(|e| {
425                            error!(?e, "rustls");
426                            LdapError::TlsError
427                        })?;
428                }
429
430                ClientConfig::builder()
431                    .with_root_certificates(cert_store)
432                    .with_no_client_auth()
433            } else if !verify {
434                warn!("⚠️ CERTIFICATE VERIFICATION IS DISABLED. THIS IS DANGEROUS!!!!");
435                let yolo_cert_validator = Arc::new(YoloCertValidator);
436
437                ClientConfig::builder()
438                    .dangerous()
439                    .with_custom_certificate_verifier(yolo_cert_validator)
440                    .with_no_client_auth()
441            } else {
442                // Just use the system CA roots.
443                ClientConfig::with_platform_verifier()
444            };
445
446            let tls_connector = TlsConnector::from(Arc::new(tls_client_config));
447
448            let server_name = match url.host() {
449                Some(Host::Domain(name)) => {
450                    ServerName::try_from(name.to_owned()).map_err(|err| {
451                        error!(?err, "server name invalid");
452                        LdapError::TlsError
453                    })?
454                }
455                Some(Host::Ipv4(addr)) => ServerName::from(addr),
456                Some(Host::Ipv6(addr)) => ServerName::from(addr),
457                None => {
458                    error!("url invalid");
459                    return Err(LdapError::TlsError);
460                }
461            };
462
463            let tlsstream = tls_connector
464                .connect(
465                    server_name,
466                    // Pin::new(&mut tcpstream)
467                    tcpstream,
468                )
469                .await
470                .map_err(|e| {
471                    error!(?e, "rustls");
472                    LdapError::TlsError
473                })?;
474
475            info!("tls configured");
476
477            let (r, w) = tokio::io::split(tlsstream);
478            (
479                LdapWriteTransport::Tls(FramedWrite::new(w, LdapCodec::default())),
480                LdapReadTransport::Tls(FramedRead::new(r, LdapCodec::new(Some(max_ber_size)))),
481            )
482        } else {
483            let (r, w) = tokio::io::split(tcpstream);
484            (
485                LdapWriteTransport::Plain(FramedWrite::new(w, LdapCodec::default())),
486                LdapReadTransport::Plain(FramedRead::new(r, LdapCodec::new(Some(max_ber_size)))),
487            )
488        };
489
490        let msg_counter = 1;
491
492        // Good to go - return ok!
493        Ok(LdapClient {
494            read_transport,
495            write_transport,
496            msg_counter,
497        })
498    }
499}
500
501#[derive(Debug)]
502pub struct LdapClient {
503    read_transport: LdapReadTransport,
504    write_transport: LdapWriteTransport,
505    msg_counter: i32,
506}
507
508impl LdapClient {
509    fn get_next_msgid(&mut self) -> i32 {
510        let msgid = self.msg_counter;
511        self.msg_counter += 1;
512        msgid
513    }
514
515    #[tracing::instrument(level = "debug", skip_all)]
516    pub async fn bind<S: Into<String>>(&mut self, dn: S, pw: S) -> LdapResult<()> {
517        let dn = dn.into();
518        info!(%dn);
519        let msgid = self.get_next_msgid();
520
521        let msg = LdapMsg {
522            msgid,
523            op: LdapOp::BindRequest(LdapBindRequest {
524                dn,
525                cred: LdapBindCred::Simple(pw.into()),
526            }),
527            ctrl: vec![],
528        };
529
530        self.write_transport.send(msg).await?;
531
532        // Get the response
533        self.read_transport
534            .next()
535            .await
536            .and_then(|msg| match msg.op {
537                LdapOp::BindResponse(res) => {
538                    if res.res.code == LdapResultCode::Success {
539                        info!("bind success");
540                        Ok(())
541                    } else {
542                        info!(?res.res.code);
543                        Err(LdapError::from(res.res.code))
544                    }
545                }
546                op => {
547                    trace!(?op);
548                    Err(LdapError::InvalidProtocolState)
549                }
550            })
551    }
552
553    #[tracing::instrument(level = "debug", skip_all)]
554    pub async fn whoami(&mut self) -> LdapResult<Option<String>> {
555        let msgid = self.get_next_msgid();
556
557        let msg = LdapMsg {
558            msgid,
559            op: LdapOp::ExtendedRequest(Into::into(LdapWhoamiRequest {})),
560            ctrl: vec![],
561        };
562
563        self.write_transport.send(msg).await?;
564
565        self.read_transport
566            .next()
567            .await
568            .and_then(|msg| match msg.op {
569                LdapOp::ExtendedResponse(ler) => LdapWhoamiResponse::try_from(&ler)
570                    .map_err(|_| LdapError::InvalidProtocolState)
571                    .map(|res| res.dn),
572                op => {
573                    trace!(?op);
574                    Err(LdapError::InvalidProtocolState)
575                }
576            })
577    }
578}
579
580#[derive(Debug)]
581/// This should never be used for anything but testing, as it does no verification!
582struct YoloCertValidator;
583
584impl ServerCertVerifier for YoloCertValidator {
585    fn verify_server_cert(
586        &self,
587        _end_entity: &CertificateDer<'_>,
588        _intermediates: &[CertificateDer<'_>],
589        _server_name: &ServerName<'_>,
590        _ocsp_response: &[u8],
591        _now: UnixTime,
592    ) -> Result<ServerCertVerified, RustlsError> {
593        // Yolo.
594        Ok(ServerCertVerified::assertion())
595    }
596
597    fn verify_tls12_signature(
598        &self,
599        _message: &[u8],
600        _cert: &CertificateDer<'_>,
601        _dss: &DigitallySignedStruct,
602    ) -> Result<HandshakeSignatureValid, RustlsError> {
603        Ok(HandshakeSignatureValid::assertion())
604    }
605
606    fn verify_tls13_signature(
607        &self,
608        _message: &[u8],
609        _cert: &CertificateDer<'_>,
610        _dss: &DigitallySignedStruct,
611    ) -> Result<HandshakeSignatureValid, RustlsError> {
612        Ok(HandshakeSignatureValid::assertion())
613    }
614
615    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
616        vec![
617            SignatureScheme::RSA_PKCS1_SHA384,
618            SignatureScheme::ECDSA_NISTP256_SHA256,
619            SignatureScheme::RSA_PKCS1_SHA256,
620            SignatureScheme::ECDSA_NISTP384_SHA384,
621            SignatureScheme::RSA_PKCS1_SHA512,
622            SignatureScheme::ECDSA_NISTP521_SHA512,
623        ]
624    }
625}
626
627/// Doesn't test the actual *build* step because that requires a live LDAP server.
628#[test]
629fn test_ldapclient_builder() {
630    let url = Url::parse("ldap://ldap.example.com:389").unwrap();
631    let client = LdapClientBuilder::new(&url).max_ber_size(Some(1234567));
632    assert_eq!(client.timeout, Duration::from_secs(30));
633    let client = client.set_timeout(Duration::from_secs(60));
634    assert_eq!(client.timeout, Duration::from_secs(60));
635    assert_eq!(client.cas.len(), 0);
636    assert_eq!(client.max_ber_size, Some(1234567));
637    assert_eq!(client.verify, true);
638
639    let ca_path = "test.pem".to_string();
640    let client = client.add_tls_ca(&ca_path);
641    assert_eq!(client.cas.len(), 1);
642
643    let badssl_client = client.danger_accept_invalid_certs(true);
644    assert_eq!(badssl_client.verify, false);
645}