Skip to main content

ldap3_client/
lib.rs

1#![deny(warnings)]
2#![warn(unused_extern_crates)]
3#![deny(clippy::todo)]
4#![deny(clippy::unimplemented)]
5#![deny(clippy::unwrap_used)]
6#![deny(clippy::panic)]
7#![deny(clippy::await_holding_lock)]
8#![deny(clippy::needless_pass_by_value)]
9#![deny(clippy::trivially_copy_pass_by_ref)]
10// We allow expect since it forces good error messages at the least.
11#![allow(clippy::expect_used)]
12
13use base64::{engine::general_purpose, Engine as _};
14use futures_util::sink::SinkExt;
15use futures_util::stream::StreamExt;
16use ldap3_proto::proto::*;
17use ldap3_proto::LdapCodec;
18use rustls_platform_verifier::ConfigVerifierExt;
19use serde::{Deserialize, Serialize};
20use std::borrow::Borrow;
21use std::collections::{BTreeMap, BTreeSet};
22use std::fmt;
23use std::sync::Arc;
24use tokio::io::{ReadHalf, WriteHalf};
25use tokio::net::TcpStream;
26use tokio::time;
27
28use tokio_rustls::{
29    client::TlsStream,
30    rustls::client::{danger::*, ClientConfig},
31    rustls::pki_types::{CertificateDer, ServerName, UnixTime},
32    rustls::Error as RustlsError,
33    rustls::{DigitallySignedStruct, SignatureScheme},
34    TlsConnector,
35};
36
37use tokio_util::codec::{FramedRead, FramedWrite};
38use tracing::{error, info, trace, warn};
39use url::{Host, Url};
40use uuid::Uuid;
41
42pub use ldap3_proto::filter;
43pub use ldap3_proto::proto;
44pub use search::LdapSearchResult;
45pub use syncrepl::{LdapSyncRepl, LdapSyncReplEntry, LdapSyncStateValue};
46pub use tokio::time::Duration;
47
48mod addirsync;
49mod search;
50mod syncrepl;
51
52#[non_exhaustive]
53#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
54#[repr(i32)]
55pub enum LdapError {
56    InvalidUrl = -1,
57    LdapiNotSupported = -2,
58    UseCldapTool = -3,
59    ResolverError = -4,
60    ConnectError = -5,
61    TlsError = -6,
62    PasswordNotFound = -7,
63    AnonymousInvalidState = -8,
64    TransportWriteError = -9,
65    TransportReadError = -10,
66    InvalidProtocolState = -11,
67    FileIOError = -12,
68
69    UnavailableCriticalExtension = 12,
70    InvalidCredentials = 49,
71    InsufficentAccessRights = 50,
72    UnwillingToPerform = 53,
73    EsyncRefreshRequired = 4096,
74    NotImplemented = 9999,
75}
76
77impl From<LdapResultCode> for LdapError {
78    fn from(code: LdapResultCode) -> Self {
79        match code {
80            LdapResultCode::InvalidCredentials => LdapError::InvalidCredentials,
81            LdapResultCode::InsufficentAccessRights => LdapError::InsufficentAccessRights,
82            LdapResultCode::EsyncRefreshRequired => LdapError::EsyncRefreshRequired,
83            LdapResultCode::UnavailableCriticalExtension => LdapError::UnavailableCriticalExtension,
84            LdapResultCode::UnwillingToPerform => LdapError::UnwillingToPerform,
85            err => {
86                error!("{:?} not implemented yet!!", err);
87                LdapError::NotImplemented
88            }
89        }
90    }
91}
92
93impl fmt::Display for LdapError {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        match self {
96            LdapError::InvalidUrl => write!(f, "Invalid URL"),
97            LdapError::LdapiNotSupported => write!(f, "Ldapi Not Supported"),
98            LdapError::UseCldapTool => write!(f, "Use cldap tool for cldap:// urls"),
99            LdapError::ResolverError => write!(f, "Failed to resolve hostname or invalid ip"),
100            LdapError::ConnectError => write!(f, "Failed to connect to host"),
101            LdapError::TlsError => write!(f, "Failed to establish TLS"),
102            LdapError::PasswordNotFound => write!(f, "No password available for bind"),
103            LdapError::AnonymousInvalidState => write!(f, "Invalid Anonymous bind state"),
104            LdapError::InvalidProtocolState => {
105                write!(f, "The LDAP server sent a response we did not expect")
106            }
107            LdapError::FileIOError => {
108                write!(f, "An error occurred while accessing a file")
109            }
110            LdapError::TransportReadError => {
111                write!(f, "An error occurred reading from the transport")
112            }
113            LdapError::TransportWriteError => {
114                write!(f, "An error occurred writing to the transport")
115            }
116            LdapError::UnavailableCriticalExtension => write!(f, "An extension marked as critical was not available"),
117            LdapError::InvalidCredentials => write!(f, "Invalid DN or Password"),
118            LdapError::InsufficentAccessRights => write!(f, "Insufficient Access"),
119            LdapError::UnwillingToPerform => write!(f, "Too many failures, server is unwilling to perform the operation."),
120            LdapError::EsyncRefreshRequired => write!(f, "An initial content sync is required. The current cookie should be considered invalid."),
121            LdapError::NotImplemented => write!(f, "An error occurred, but we haven't implemented code to handle this error yet.")
122        }
123    }
124}
125
126pub type LdapResult<T> = Result<T, LdapError>;
127
128enum LdapReadTransport {
129    Plain(FramedRead<ReadHalf<TcpStream>, LdapCodec>),
130    Tls(FramedRead<ReadHalf<TlsStream<TcpStream>>, LdapCodec>),
131}
132
133impl fmt::Debug for LdapReadTransport {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        match self {
136            LdapReadTransport::Plain(_) => f
137                .debug_struct("LdapReadTransport")
138                .field("type", &"plain")
139                .finish(),
140            LdapReadTransport::Tls(_) => f
141                .debug_struct("LdapReadTransport")
142                .field("type", &"tls")
143                .finish(),
144        }
145    }
146}
147
148enum LdapWriteTransport {
149    Plain(FramedWrite<WriteHalf<TcpStream>, LdapCodec>),
150    Tls(FramedWrite<WriteHalf<TlsStream<TcpStream>>, LdapCodec>),
151}
152
153impl fmt::Debug for LdapWriteTransport {
154    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
155        match self {
156            LdapWriteTransport::Plain(_) => f
157                .debug_struct("LdapWriteTransport")
158                .field("type", &"plain")
159                .finish(),
160            LdapWriteTransport::Tls(_) => f
161                .debug_struct("LdapWriteTransport")
162                .field("type", &"tls")
163                .finish(),
164        }
165    }
166}
167
168impl LdapWriteTransport {
169    async fn send(&mut self, msg: LdapMsg) -> LdapResult<()> {
170        match self {
171            LdapWriteTransport::Plain(f) => f.send(msg).await.map_err(|e| {
172                info!(?e, "transport error");
173                LdapError::TransportWriteError
174            }),
175            LdapWriteTransport::Tls(f) => f.send(msg).await.map_err(|e| {
176                info!(?e, "transport error");
177                LdapError::TransportWriteError
178            }),
179        }
180    }
181}
182
183impl LdapReadTransport {
184    async fn next(&mut self) -> LdapResult<LdapMsg> {
185        match self {
186            LdapReadTransport::Plain(f) => f.next().await.transpose().map_err(|e| {
187                info!(?e, "transport error");
188                LdapError::TransportReadError
189            })?,
190            LdapReadTransport::Tls(f) => f.next().await.transpose().map_err(|e| {
191                info!(?e, "transport error");
192                LdapError::TransportReadError
193            })?,
194        }
195        .ok_or_else(|| {
196            info!("connection closed");
197            LdapError::TransportReadError
198        })
199    }
200}
201
202#[derive(Debug, Deserialize, Serialize)]
203pub struct LdapEntry {
204    pub dn: String,
205    pub attrs: BTreeMap<String, BTreeSet<String>>,
206}
207
208impl LdapEntry {
209    pub fn get_ava_single(&self, attr: &str) -> Option<&str> {
210        if let Some(ava) = self.attrs.get(attr) {
211            if ava.len() == 1 {
212                ava.iter().next().map(String::as_ref)
213            } else {
214                None
215            }
216        } else {
217            None
218        }
219    }
220
221    pub fn remove_ava_single(&mut self, attr: &str) -> Option<String> {
222        if let Some(ava) = self.attrs.remove(attr) {
223            if ava.len() == 1 {
224                ava.into_iter().next()
225            } else {
226                None
227            }
228        } else {
229            None
230        }
231    }
232
233    pub fn remove_ava(&mut self, attr: &str) -> Option<BTreeSet<String>> {
234        self.attrs.remove(attr)
235    }
236}
237
238impl From<LdapSearchResultEntry> for LdapEntry {
239    fn from(ent: LdapSearchResultEntry) -> Self {
240        let LdapSearchResultEntry { dn, attributes } = ent;
241
242        let attrs = attributes
243            .into_iter()
244            .map(|LdapPartialAttribute { atype, vals }| {
245                let atype = atype.to_lowercase();
246
247                let lower = atype == "objectclass";
248
249                let va = vals
250                    .into_iter()
251                    .map(|bin| {
252                        std::str::from_utf8(&bin)
253                            .map(|s| {
254                                if lower {
255                                    s.to_lowercase()
256                                } else {
257                                    s.to_string()
258                                }
259                            })
260                            .unwrap_or_else(|_| general_purpose::URL_SAFE.encode(&bin))
261                    })
262                    .collect();
263                (atype, va)
264            })
265            .collect();
266
267        LdapEntry { dn, attrs }
268    }
269}
270
271#[derive(Debug, Clone)]
272pub struct LdapClientBuilder<U> {
273    url: U,
274    timeout: Duration,
275    /// The maximum LDAP packet size parsed during decoding.
276    max_ber_size: Option<usize>,
277    rustls_client: Option<Arc<ClientConfig>>,
278}
279
280impl<U: Borrow<Url>> LdapClientBuilder<U> {
281    pub fn new(url: U) -> Self {
282        Self {
283            url,
284            timeout: Duration::from_secs(30),
285            max_ber_size: None,
286            rustls_client: None,
287        }
288    }
289
290    pub fn set_tls_config(&mut self, config: Option<Arc<ClientConfig>>) {
291        self.rustls_client = config
292    }
293
294    /// set the rustls [`ClientConfig`] used to handle ldaps connections
295    ///
296    /// if not set uses the platform verifier
297    pub fn with_tls_config(mut self, config: ClientConfig) -> Self {
298        self.set_tls_config(Some(Arc::new(config)));
299        self
300    }
301
302    /// set rustls [`ClientConfig`] to one that does not verify the server certificate
303    pub fn danger_accept_invalid_certs(self) -> Self {
304        warn!("⚠️ CERTIFICATE VERIFICATION IS DISABLED. THIS IS DANGEROUS!!!!");
305        let yolo_cert_validator = Arc::new(YoloCertValidator);
306
307        let client_config = ClientConfig::builder()
308            .dangerous()
309            .with_custom_certificate_verifier(yolo_cert_validator)
310            .with_no_client_auth();
311        self.with_tls_config(client_config)
312    }
313
314    pub fn set_timeout(self, timeout: Duration) -> Self {
315        Self { timeout, ..self }
316    }
317
318    /// Set the maximum size of a decoded message
319    pub fn max_ber_size(self, max_ber_size: Option<usize>) -> Self {
320        Self {
321            max_ber_size,
322            ..self
323        }
324    }
325
326    #[tracing::instrument(level = "debug", skip_all)]
327    pub async fn build(self) -> LdapResult<LdapClient> {
328        let LdapClientBuilder {
329            url,
330            timeout,
331            max_ber_size,
332            rustls_client,
333        } = self;
334
335        let url = url.borrow();
336
337        info!(%url);
338        info!(?timeout);
339
340        // Check the scheme is ldap or ldaps
341        // for now, no ldapi support.
342        let need_tls = match url.scheme() {
343            "ldapi" => return Err(LdapError::LdapiNotSupported),
344            "cldap" => return Err(LdapError::UseCldapTool),
345            "ldap" => false,
346            "ldaps" => true,
347            _ => return Err(LdapError::InvalidUrl),
348        };
349
350        info!(%need_tls);
351        // get domain + port
352
353        // Do we have query params? Can we use them?
354        // https://ldap.com/ldap-urls/
355
356        // resolve to a set of socket addrs.
357        let addrs = url
358            .socket_addrs(|| Some(if need_tls { 636 } else { 389 }))
359            .map_err(|e| {
360                info!(?e, "resolver error");
361                LdapError::ResolverError
362            })?;
363
364        if addrs.is_empty() {
365            return Err(LdapError::ResolverError);
366        }
367
368        addrs.iter().for_each(|address| info!(?address));
369
370        let mut aiter = addrs.into_iter();
371
372        // Try for each to open, with a timeout.
373        let tcpstream = loop {
374            if let Some(addr) = aiter.next() {
375                let sleep = time::sleep(timeout);
376                tokio::pin!(sleep);
377                tokio::select! {
378                    maybe_stream = TcpStream::connect(addr) => {
379                        match maybe_stream {
380                            Ok(t) => {
381                                info!(?addr, "connection established");
382                                break t;
383                            }
384                            Err(e) => {
385                                info!(?addr, ?e, "error");
386                                continue;
387                            }
388                        }
389                    }
390                    _ = &mut sleep => {
391                        info!(?addr, "timeout");
392                        continue;
393                    }
394                }
395            } else {
396                return Err(LdapError::ConnectError);
397            }
398        };
399
400        // If they didn't set it in the builder then set it to the default
401        let max_ber_size = max_ber_size.unwrap_or(ldap3_proto::DEFAULT_MAX_BER_SIZE);
402
403        // If ldaps - start rustls
404        let (write_transport, read_transport) = if need_tls {
405            let tls_client_config = if let Some(client_config) = rustls_client {
406                client_config
407            } else {
408                Arc::new(ClientConfig::with_platform_verifier().map_err(|e| {
409                    error!(?e, "rustls");
410                    LdapError::TlsError
411                })?)
412            };
413
414            let tls_connector = TlsConnector::from(tls_client_config);
415
416            let server_name = match url.host() {
417                Some(Host::Domain(name)) => {
418                    ServerName::try_from(name.to_owned()).map_err(|err| {
419                        error!(?err, "server name invalid");
420                        LdapError::TlsError
421                    })?
422                }
423                Some(Host::Ipv4(addr)) => ServerName::from(addr),
424                Some(Host::Ipv6(addr)) => ServerName::from(addr),
425                None => {
426                    error!("url invalid");
427                    return Err(LdapError::TlsError);
428                }
429            };
430
431            let tlsstream = tls_connector
432                .connect(
433                    server_name,
434                    // Pin::new(&mut tcpstream)
435                    tcpstream,
436                )
437                .await
438                .map_err(|e| {
439                    error!(?e, "rustls");
440                    LdapError::TlsError
441                })?;
442
443            info!("tls configured");
444
445            let (r, w) = tokio::io::split(tlsstream);
446            (
447                LdapWriteTransport::Tls(FramedWrite::new(w, LdapCodec::default())),
448                LdapReadTransport::Tls(FramedRead::new(r, LdapCodec::new(Some(max_ber_size)))),
449            )
450        } else {
451            let (r, w) = tokio::io::split(tcpstream);
452            (
453                LdapWriteTransport::Plain(FramedWrite::new(w, LdapCodec::default())),
454                LdapReadTransport::Plain(FramedRead::new(r, LdapCodec::new(Some(max_ber_size)))),
455            )
456        };
457
458        let msg_counter = 1;
459
460        // Good to go - return ok!
461        Ok(LdapClient {
462            read_transport,
463            write_transport,
464            msg_counter,
465        })
466    }
467}
468
469#[derive(Debug)]
470pub struct LdapClient {
471    read_transport: LdapReadTransport,
472    write_transport: LdapWriteTransport,
473    msg_counter: i32,
474}
475
476impl LdapClient {
477    fn get_next_msgid(&mut self) -> i32 {
478        let msgid = self.msg_counter;
479        self.msg_counter += 1;
480        msgid
481    }
482
483    #[tracing::instrument(level = "debug", skip_all)]
484    pub async fn bind<S: Into<String>>(&mut self, dn: S, pw: S) -> LdapResult<()> {
485        let dn = dn.into();
486        info!(%dn);
487        let msgid = self.get_next_msgid();
488
489        let msg = LdapMsg {
490            msgid,
491            op: LdapOp::BindRequest(LdapBindRequest {
492                dn,
493                cred: LdapBindCred::Simple(pw.into()),
494            }),
495            ctrl: vec![],
496        };
497
498        self.write_transport.send(msg).await?;
499
500        // Get the response
501        self.read_transport
502            .next()
503            .await
504            .and_then(|msg| match msg.op {
505                LdapOp::BindResponse(res) => {
506                    if res.res.code == LdapResultCode::Success {
507                        info!("bind success");
508                        Ok(())
509                    } else {
510                        info!(?res.res.code);
511                        Err(LdapError::from(res.res.code))
512                    }
513                }
514                op => {
515                    trace!(?op);
516                    Err(LdapError::InvalidProtocolState)
517                }
518            })
519    }
520
521    #[tracing::instrument(level = "debug", skip_all)]
522    pub async fn whoami(&mut self) -> LdapResult<Option<String>> {
523        let msgid = self.get_next_msgid();
524
525        let msg = LdapMsg {
526            msgid,
527            op: LdapOp::ExtendedRequest(Into::into(LdapWhoamiRequest {})),
528            ctrl: vec![],
529        };
530
531        self.write_transport.send(msg).await?;
532
533        self.read_transport
534            .next()
535            .await
536            .and_then(|msg| match msg.op {
537                LdapOp::ExtendedResponse(ler) => LdapWhoamiResponse::try_from(&ler)
538                    .map_err(|_| LdapError::InvalidProtocolState)
539                    .map(|res| res.dn),
540                op => {
541                    trace!(?op);
542                    Err(LdapError::InvalidProtocolState)
543                }
544            })
545    }
546}
547
548#[derive(Debug)]
549/// This should never be used for anything but testing, as it does no verification!
550struct YoloCertValidator;
551
552impl ServerCertVerifier for YoloCertValidator {
553    fn verify_server_cert(
554        &self,
555        _end_entity: &CertificateDer<'_>,
556        _intermediates: &[CertificateDer<'_>],
557        _server_name: &ServerName<'_>,
558        _ocsp_response: &[u8],
559        _now: UnixTime,
560    ) -> Result<ServerCertVerified, RustlsError> {
561        // Yolo.
562        Ok(ServerCertVerified::assertion())
563    }
564
565    fn verify_tls12_signature(
566        &self,
567        _message: &[u8],
568        _cert: &CertificateDer<'_>,
569        _dss: &DigitallySignedStruct,
570    ) -> Result<HandshakeSignatureValid, RustlsError> {
571        Ok(HandshakeSignatureValid::assertion())
572    }
573
574    fn verify_tls13_signature(
575        &self,
576        _message: &[u8],
577        _cert: &CertificateDer<'_>,
578        _dss: &DigitallySignedStruct,
579    ) -> Result<HandshakeSignatureValid, RustlsError> {
580        Ok(HandshakeSignatureValid::assertion())
581    }
582
583    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
584        vec![
585            SignatureScheme::RSA_PKCS1_SHA384,
586            SignatureScheme::ECDSA_NISTP256_SHA256,
587            SignatureScheme::RSA_PKCS1_SHA256,
588            SignatureScheme::ECDSA_NISTP384_SHA384,
589            SignatureScheme::RSA_PKCS1_SHA512,
590            SignatureScheme::ECDSA_NISTP521_SHA512,
591        ]
592    }
593}
594
595/// Doesn't test the actual *build* step because that requires a live LDAP server.
596#[test]
597fn test_ldapclient_builder() {
598    let url = Url::parse("ldap://ldap.example.com:389").unwrap();
599    let client = LdapClientBuilder::new(&url).max_ber_size(Some(1234567));
600    assert_eq!(client.timeout, Duration::from_secs(30));
601    let client = client.set_timeout(Duration::from_secs(60));
602    assert_eq!(client.timeout, Duration::from_secs(60));
603    assert_eq!(client.max_ber_size, Some(1234567));
604    assert!(client.rustls_client.is_none());
605}