Skip to main content

ldap_client/
client.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
5use std::time::Duration;
6
7use futures_util::{SinkExt, StreamExt};
8use rustls::ClientConfig;
9use rustls_pki_types::ServerName;
10use secrecy::{ExposeSecret, SecretString};
11use tokio::net::TcpStream;
12use tokio::sync::Mutex;
13use tokio_util::codec::Framed;
14use tracing::debug;
15
16use ldap_client_ber::LdapCodec;
17use ldap_client_proto::{
18    AddRequest, BindAuthentication, BindRequest, CompareRequest, Control, DerefAliases,
19    ExtendedRequest, ExtendedResponse, Filter, LdapMessage, LdapOperation,
20    LdapResult as ProtoLdapResult, LdapScheme, LdapUrl, MessageId, ModifyDnRequest, ModifyRequest,
21    PAGED_RESULTS_OID, PagedResultsControl, ResultCode, SearchRequest, SearchResultEntry,
22    SearchScope,
23};
24
25use crate::Error;
26use crate::conn::{self, LdapStream};
27
28const STARTTLS_OID: &str = "1.3.6.1.4.1.1466.20037";
29const NOTICE_OF_DISCONNECTION_OID: &str = "1.3.6.1.4.1.1466.20036";
30const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
31const DEFAULT_MAX_MESSAGE_SIZE: u32 = 10 * 1024 * 1024;
32const MAX_SEARCH_ENTRIES: usize = 500_000;
33
34/// Handler for unsolicited notifications other than Notice of Disconnection.
35pub type UnsolicitedHandler = Arc<dyn Fn(&ExtendedResponse) + Send + Sync>;
36
37fn default_unsolicited_handler() -> UnsolicitedHandler {
38    Arc::new(|resp| {
39        tracing::debug!(
40            oid = resp.oid.as_deref().unwrap_or("<none>"),
41            "received unsolicited notification from server"
42        );
43    })
44}
45
46#[derive(Debug, Clone)]
47pub struct SearchResult {
48    pub entries: Vec<SearchResultEntry>,
49    pub referrals: Vec<String>,
50    pub controls: Vec<Control>,
51}
52
53#[derive(Clone, Copy, Debug, PartialEq, Eq)]
54pub enum Transport {
55    Plain,
56    Tls,
57    StartTls,
58}
59
60/// Controls how the server's referral responses are handled.
61#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
62pub enum ReferralPolicy {
63    /// Silently ignore referrals (treat as success).
64    #[default]
65    Ignore,
66    /// Return referrals as `Error::Referral` to the caller.
67    Return,
68    /// Automatically chase referrals up to `hop_limit` hops.
69    Follow { hop_limit: u8 },
70}
71
72impl ReferralPolicy {
73    /// Create a `Follow` policy with the default hop limit of 10.
74    pub fn follow() -> Self {
75        Self::Follow { hop_limit: 10 }
76    }
77}
78
79/// Credentials for [`Client::bind`].
80pub enum BindCredentials<'a> {
81    /// Simple bind with a DN and password.
82    Simple {
83        dn: &'a str,
84        password: &'a SecretString,
85    },
86    /// Re-bind using the pre-configured service account.
87    ServiceAccount,
88    /// SASL EXTERNAL bind (client certificate authentication).
89    SaslExternal,
90}
91
92pub struct ClientBuilder {
93    host: String,
94    port: u16,
95    transport: Transport,
96    tls_config: Option<Arc<ClientConfig>>,
97    connect_timeout: Duration,
98    request_timeout: Duration,
99    max_message_size: u32,
100    base_dn: Option<String>,
101    service_account_dn: Option<String>,
102    service_account_password: Option<SecretString>,
103    referral_policy: ReferralPolicy,
104    unsolicited_handler: UnsolicitedHandler,
105}
106
107impl ClientBuilder {
108    pub fn new(host: impl Into<String>, port: u16) -> Self {
109        Self {
110            host: host.into(),
111            port,
112            transport: Transport::Plain,
113            tls_config: None,
114            connect_timeout: DEFAULT_TIMEOUT,
115            request_timeout: DEFAULT_TIMEOUT,
116            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
117            base_dn: None,
118            service_account_dn: None,
119            service_account_password: None,
120            referral_policy: ReferralPolicy::default(),
121            unsolicited_handler: default_unsolicited_handler(),
122        }
123    }
124
125    pub fn from_url(url: &str) -> Result<Self, Error> {
126        let parsed = LdapUrl::parse(url).map_err(|e| Error::InvalidUrl(format!("{e}")))?;
127
128        let transport = match parsed.scheme {
129            LdapScheme::Ldap => Transport::Plain,
130            LdapScheme::Ldaps => Transport::Tls,
131        };
132
133        let port = parsed.effective_port();
134        Ok(Self {
135            host: parsed.host,
136            port,
137            transport,
138            tls_config: None,
139            connect_timeout: DEFAULT_TIMEOUT,
140            request_timeout: DEFAULT_TIMEOUT,
141            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
142            base_dn: parsed.base_dn,
143            service_account_dn: None,
144            service_account_password: None,
145            referral_policy: ReferralPolicy::default(),
146            unsolicited_handler: default_unsolicited_handler(),
147        })
148    }
149
150    pub fn transport(mut self, transport: Transport) -> Self {
151        self.transport = transport;
152        self
153    }
154
155    /// Set the TLS configuration from a pre-built [`rustls::ClientConfig`].
156    ///
157    /// For a higher-level API, see [`tls`](Self::tls).
158    pub fn tls_config(mut self, config: Arc<ClientConfig>) -> Self {
159        self.tls_config = Some(config);
160        self
161    }
162
163    /// Build and set the TLS configuration from a [`TlsConfig`](crate::tls_config::TlsConfig).
164    ///
165    /// Returns `Err` if certificate loading fails.
166    pub fn tls(mut self, config: crate::tls_config::TlsConfig) -> Result<Self, Error> {
167        self.tls_config = Some(config.build()?);
168        Ok(self)
169    }
170
171    /// Set both connect and request timeouts.
172    pub fn timeout(mut self, timeout: Duration) -> Self {
173        self.connect_timeout = timeout;
174        self.request_timeout = timeout;
175        self
176    }
177
178    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
179        self.connect_timeout = timeout;
180        self
181    }
182
183    pub fn request_timeout(mut self, timeout: Duration) -> Self {
184        self.request_timeout = timeout;
185        self
186    }
187
188    pub fn base_dn(mut self, base_dn: impl Into<String>) -> Self {
189        self.base_dn = Some(base_dn.into());
190        self
191    }
192
193    pub fn service_account(mut self, dn: impl Into<String>, password: SecretString) -> Self {
194        self.service_account_dn = Some(dn.into());
195        self.service_account_password = Some(password);
196        self
197    }
198
199    pub fn referral_policy(mut self, policy: ReferralPolicy) -> Self {
200        self.referral_policy = policy;
201        self
202    }
203
204    /// Set the maximum accepted LDAP message size (10 MiB by default).
205    pub fn max_message_size(mut self, max: u32) -> Self {
206        self.max_message_size = max;
207        self
208    }
209
210    /// Register a callback for unsolicited notifications (message-id 0) other
211    /// than the Notice of Disconnection. The default handler logs the event at
212    /// `debug` level.
213    pub fn on_unsolicited_notification(
214        mut self,
215        handler: impl Fn(&ExtendedResponse) + Send + Sync + 'static,
216    ) -> Self {
217        self.unsolicited_handler = Arc::new(handler);
218        self
219    }
220
221    pub async fn connect(self) -> Result<Client, Error> {
222        let addr = format_addr(&self.host, self.port);
223        debug!(addr = %addr, transport = ?self.transport, "connecting");
224
225        let tcp = match tokio::time::timeout(self.connect_timeout, TcpStream::connect(&addr)).await
226        {
227            Ok(Ok(tcp)) => tcp,
228            Ok(Err(e)) => return Err(Error::Io(e)),
229            Err(_) => return Err(Error::Timeout),
230        };
231        tcp.set_nodelay(true)?;
232
233        let tls_config = self
234            .tls_config
235            .clone()
236            .unwrap_or_else(|| Arc::new(conn::default_tls_config()));
237
238        let stream = match self.transport {
239            Transport::Plain => LdapStream::Plain(tcp),
240            Transport::Tls | Transport::StartTls => {
241                let server_name = ServerName::try_from(self.host.clone())
242                    .map_err(|e| Error::InvalidUrl(format!("invalid server name: {e}")))?;
243
244                if self.transport == Transport::Tls {
245                    conn::upgrade_to_tls(tcp, server_name, tls_config.clone(), self.connect_timeout)
246                        .await?
247                } else {
248                    perform_start_tls(
249                        tcp,
250                        server_name,
251                        tls_config.clone(),
252                        self.request_timeout,
253                        self.max_message_size,
254                        self.connect_timeout,
255                    )
256                    .await?
257                }
258            }
259        };
260
261        let start_id = if self.transport == Transport::StartTls {
262            2
263        } else {
264            1
265        };
266        let codec = LdapCodec::new().with_max_message_size(self.max_message_size);
267        Ok(Client {
268            framed: Mutex::new(Framed::new(stream, codec)),
269            next_id: AtomicI32::new(start_id),
270            connected: AtomicBool::new(true),
271            request_timeout: self.request_timeout,
272            max_message_size: self.max_message_size,
273            base_dn: self.base_dn,
274            referral_policy: self.referral_policy,
275            last_reconnect: Mutex::new(None),
276            unsolicited_handler: self.unsolicited_handler,
277            host: self.host,
278            port: self.port,
279            transport: self.transport,
280            tls_config,
281            connect_timeout: self.connect_timeout,
282            service_account_dn: self.service_account_dn,
283            service_account_password: self.service_account_password,
284        })
285    }
286}
287
288async fn perform_start_tls(
289    tcp: TcpStream,
290    server_name: ServerName<'static>,
291    tls_config: Arc<ClientConfig>,
292    timeout: Duration,
293    max_message_size: u32,
294    tls_timeout: Duration,
295) -> Result<LdapStream, Error> {
296    let mut framed = Framed::new(
297        tcp,
298        LdapCodec::new().with_max_message_size(max_message_size),
299    );
300
301    let msg = LdapMessage {
302        message_id: MessageId(1),
303        operation: LdapOperation::ExtendedRequest(ExtendedRequest {
304            oid: STARTTLS_OID.to_string(),
305            value: None,
306        }),
307        controls: vec![],
308    };
309    framed.send(msg.encode()).await.map_err(ber_to_io)?;
310
311    let response = match tokio::time::timeout(timeout, framed.next()).await {
312        Ok(Some(Ok(frame))) => LdapMessage::decode(&frame).map_err(Error::Proto)?,
313        Ok(Some(Err(e))) => return Err(ber_to_io(e)),
314        Ok(None) => return Err(Error::ConnectionClosed),
315        Err(_) => return Err(Error::Timeout),
316    };
317
318    match response.operation {
319        LdapOperation::ExtendedResponse(resp) if resp.result.code.is_success() => {}
320        LdapOperation::ExtendedResponse(resp) => {
321            return Err(Error::StartTls(resp.result.diagnostic_message));
322        }
323        _ => return Err(Error::StartTls("unexpected response".into())),
324    }
325
326    let parts = framed.into_parts();
327    if !parts.read_buf.is_empty() || !parts.write_buf.is_empty() {
328        return Err(Error::StartTls(
329            "unexpected buffered data before TLS handshake".into(),
330        ));
331    }
332    let tcp = parts.io;
333    conn::upgrade_to_tls(tcp, server_name, tls_config, tls_timeout).await
334}
335
336type FramedLdap = Framed<LdapStream, LdapCodec>;
337
338const MIN_RECONNECT_INTERVAL: Duration = Duration::from_secs(1);
339
340pub struct Client {
341    framed: Mutex<FramedLdap>,
342    next_id: AtomicI32,
343    connected: AtomicBool,
344    request_timeout: Duration,
345    max_message_size: u32,
346    base_dn: Option<String>,
347    referral_policy: ReferralPolicy,
348    last_reconnect: Mutex<Option<tokio::time::Instant>>,
349    unsolicited_handler: UnsolicitedHandler,
350    // Fields stored for reconnect.
351    host: String,
352    port: u16,
353    transport: Transport,
354    tls_config: Arc<ClientConfig>,
355    connect_timeout: Duration,
356    service_account_dn: Option<String>,
357    service_account_password: Option<SecretString>,
358}
359
360fn format_addr(host: &str, port: u16) -> String {
361    if host.contains(':') {
362        format!("[{host}]:{port}")
363    } else {
364        format!("{host}:{port}")
365    }
366}
367
368impl Client {
369    fn next_message_id(&self) -> MessageId {
370        let id = self
371            .next_id
372            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |n| {
373                if n > 0 && n < i32::MAX - 1 {
374                    Some(n + 1)
375                } else {
376                    Some(2) // wrap-around: store 2, current caller keeps n
377                }
378            });
379        // fetch_update with an always-Some closure never fails.
380        MessageId(id.unwrap())
381    }
382
383    fn resolve_base_dn(&self, base_dn: String) -> String {
384        if base_dn.is_empty()
385            && let Some(default) = &self.base_dn
386        {
387            return default.clone();
388        }
389        base_dn
390    }
391
392    pub fn is_connected(&self) -> bool {
393        self.connected.load(Ordering::Relaxed)
394    }
395
396    pub async fn reconnect(&self) -> Result<(), Error> {
397        {
398            let mut last = self.last_reconnect.lock().await;
399            if let Some(prev) = *last {
400                let elapsed = prev.elapsed();
401                if elapsed < MIN_RECONNECT_INTERVAL {
402                    tokio::time::sleep(MIN_RECONNECT_INTERVAL - elapsed).await;
403                }
404            }
405            *last = Some(tokio::time::Instant::now());
406        }
407
408        let addr = format_addr(&self.host, self.port);
409        debug!(addr = %addr, transport = ?self.transport, "reconnecting");
410
411        let tcp = match tokio::time::timeout(self.connect_timeout, TcpStream::connect(&addr)).await
412        {
413            Ok(Ok(tcp)) => tcp,
414            Ok(Err(e)) => return Err(Error::Io(e)),
415            Err(_) => return Err(Error::Timeout),
416        };
417        tcp.set_nodelay(true)?;
418
419        let stream = match self.transport {
420            Transport::Plain => LdapStream::Plain(tcp),
421            Transport::Tls | Transport::StartTls => {
422                let server_name = ServerName::try_from(self.host.clone())
423                    .map_err(|e| Error::InvalidUrl(format!("invalid server name: {e}")))?;
424
425                if self.transport == Transport::Tls {
426                    conn::upgrade_to_tls(
427                        tcp,
428                        server_name,
429                        self.tls_config.clone(),
430                        self.connect_timeout,
431                    )
432                    .await?
433                } else {
434                    perform_start_tls(
435                        tcp,
436                        server_name,
437                        self.tls_config.clone(),
438                        self.request_timeout,
439                        self.max_message_size,
440                        self.connect_timeout,
441                    )
442                    .await?
443                }
444            }
445        };
446
447        let start_id = if self.transport == Transport::StartTls {
448            2
449        } else {
450            1
451        };
452
453        let mut framed = self.framed.lock().await;
454        *framed = Framed::new(
455            stream,
456            LdapCodec::new().with_max_message_size(self.max_message_size),
457        );
458        self.next_id.store(start_id, Ordering::Relaxed);
459        self.connected.store(true, Ordering::Relaxed);
460        drop(framed);
461
462        if self.service_account_dn.is_some()
463            && let Err(e) = self.rebind_service_account().await
464        {
465            self.connected.store(false, Ordering::Relaxed);
466            return Err(e);
467        }
468        Ok(())
469    }
470
471    pub async fn rebind_service_account(&self) -> Result<(), Error> {
472        let dn = self.service_account_dn.as_deref().ok_or_else(|| {
473            Error::Proto(ldap_client_proto::ProtoError::Protocol(
474                "no service account configured".into(),
475            ))
476        })?;
477        let password = self.service_account_password.as_ref().ok_or_else(|| {
478            Error::Proto(ldap_client_proto::ProtoError::Protocol(
479                "no service account password configured".into(),
480            ))
481        })?;
482        self.simple_bind(dn, password).await
483    }
484
485    async fn request(&self, operation: LdapOperation) -> Result<LdapMessage, Error> {
486        self.request_with_controls(operation, vec![]).await
487    }
488
489    async fn request_with_controls(
490        &self,
491        operation: LdapOperation,
492        controls: Vec<Control>,
493    ) -> Result<LdapMessage, Error> {
494        let message_id = self.next_message_id();
495        let msg = LdapMessage {
496            message_id,
497            operation,
498            controls,
499        };
500        let data = msg.encode();
501
502        let mut framed = self.framed.lock().await;
503        send_msg(&mut framed, data, &self.connected).await?;
504        recv_msg(
505            &mut framed,
506            self.request_timeout,
507            &self.connected,
508            &self.unsolicited_handler,
509        )
510        .await
511    }
512
513    pub async fn simple_bind(
514        &self,
515        dn: impl Into<String>,
516        password: &SecretString,
517    ) -> Result<(), Error> {
518        if self.transport == Transport::Plain {
519            tracing::warn!(
520                "simple bind over plain (unencrypted) connection; credentials are sent in cleartext"
521            );
522        }
523        let op = LdapOperation::BindRequest(BindRequest {
524            version: 3,
525            name: dn.into(),
526            authentication: BindAuthentication::Simple(zeroize::Zeroizing::new(
527                password.expose_secret().as_bytes().to_vec(),
528            )),
529        });
530
531        let response = self.request(op).await?;
532        match response.operation {
533            LdapOperation::BindResponse(resp) => check_result(&resp.result, self.referral_policy),
534            _ => Err(unexpected_response("BindResponse")),
535        }
536    }
537
538    pub async fn search(
539        &self,
540        base_dn: impl Into<String>,
541        scope: SearchScope,
542        filter: Filter,
543        attrs: Vec<String>,
544    ) -> Result<Vec<SearchResultEntry>, Error> {
545        let (entries, _controls) = self
546            .search_with_controls(base_dn, scope, filter, attrs, vec![])
547            .await?;
548        Ok(entries)
549    }
550
551    pub async fn search_with_controls(
552        &self,
553        base_dn: impl Into<String>,
554        scope: SearchScope,
555        filter: Filter,
556        attrs: Vec<String>,
557        controls: Vec<Control>,
558    ) -> Result<(Vec<SearchResultEntry>, Vec<Control>), Error> {
559        let message_id = self.next_message_id();
560        let msg = LdapMessage {
561            message_id,
562            operation: LdapOperation::SearchRequest(SearchRequest {
563                base_dn: self.resolve_base_dn(base_dn.into()),
564                scope,
565                deref_aliases: DerefAliases::NeverDerefAliases,
566                size_limit: 0,
567                time_limit: 0,
568                types_only: false,
569                filter,
570                attributes: attrs,
571            }),
572            controls,
573        };
574        let data = msg.encode();
575
576        let mut framed = self.framed.lock().await;
577        send_msg(&mut framed, data, &self.connected).await?;
578
579        let mut entries = Vec::new();
580        let collected = collect_search_results(
581            &mut framed,
582            self.request_timeout,
583            &self.connected,
584            &mut entries,
585            self.referral_policy,
586            &self.unsolicited_handler,
587        )
588        .await?;
589
590        Ok((entries, collected.controls))
591    }
592
593    pub async fn search_full(
594        &self,
595        base_dn: impl Into<String>,
596        scope: SearchScope,
597        filter: Filter,
598        attrs: Vec<String>,
599        controls: Vec<Control>,
600    ) -> Result<SearchResult, Error> {
601        let message_id = self.next_message_id();
602        let msg = LdapMessage {
603            message_id,
604            operation: LdapOperation::SearchRequest(SearchRequest {
605                base_dn: self.resolve_base_dn(base_dn.into()),
606                scope,
607                deref_aliases: DerefAliases::NeverDerefAliases,
608                size_limit: 0,
609                time_limit: 0,
610                types_only: false,
611                filter,
612                attributes: attrs,
613            }),
614            controls,
615        };
616        let data = msg.encode();
617
618        let mut framed = self.framed.lock().await;
619        send_msg(&mut framed, data, &self.connected).await?;
620
621        let mut entries = Vec::new();
622        let collected = collect_search_results(
623            &mut framed,
624            self.request_timeout,
625            &self.connected,
626            &mut entries,
627            self.referral_policy,
628            &self.unsolicited_handler,
629        )
630        .await?;
631
632        Ok(SearchResult {
633            entries,
634            referrals: collected.referral_urls,
635            controls: collected.controls,
636        })
637    }
638
639    pub async fn search_paged(
640        &self,
641        base_dn: &str,
642        scope: SearchScope,
643        filter: Filter,
644        attrs: Vec<String>,
645        page_size: i32,
646    ) -> Result<Vec<SearchResultEntry>, Error> {
647        const MAX_PAGED_ROUNDS: usize = 100_000;
648        let resolved_base = self.resolve_base_dn(base_dn.to_string());
649        let mut all_entries = Vec::new();
650        let mut cookie = Vec::new();
651        let mut prev_cookie = Vec::new();
652
653        for _ in 0..MAX_PAGED_ROUNDS {
654            let paged =
655                PagedResultsControl::new(page_size).with_cookie(std::mem::take(&mut cookie));
656            let controls = vec![paged.to_control()];
657
658            let message_id = self.next_message_id();
659            let msg = LdapMessage {
660                message_id,
661                operation: LdapOperation::SearchRequest(SearchRequest {
662                    base_dn: resolved_base.clone(),
663                    scope,
664                    deref_aliases: DerefAliases::NeverDerefAliases,
665                    size_limit: 0,
666                    time_limit: 0,
667                    types_only: false,
668                    filter: filter.clone(),
669                    attributes: attrs.clone(),
670                }),
671                controls,
672            };
673            let data = msg.encode();
674
675            let mut framed = self.framed.lock().await;
676            send_msg(&mut framed, data, &self.connected).await?;
677
678            let collected = collect_search_results(
679                &mut framed,
680                self.request_timeout,
681                &self.connected,
682                &mut all_entries,
683                self.referral_policy,
684                &self.unsolicited_handler,
685            )
686            .await?;
687
688            let new_cookie = collected
689                .controls
690                .iter()
691                .find(|c| c.oid == PAGED_RESULTS_OID)
692                .and_then(|c| PagedResultsControl::from_control(c).ok())
693                .map(|p| p.cookie);
694
695            match new_cookie {
696                Some(c) if !c.is_empty() => {
697                    if c == prev_cookie {
698                        // Server returned the same cookie twice. Abort to prevent an infinite loop.
699                        break;
700                    }
701                    prev_cookie = c.clone();
702                    cookie = c;
703                }
704                _ => break,
705            }
706        }
707
708        Ok(all_entries)
709    }
710
711    pub fn search_paged_stream(
712        &self,
713        base_dn: &str,
714        scope: SearchScope,
715        filter: Filter,
716        attrs: Vec<String>,
717        page_size: i32,
718    ) -> PagedSearch<'_> {
719        PagedSearch {
720            client: self,
721            base_dn: self.resolve_base_dn(base_dn.to_string()),
722            scope,
723            filter,
724            attrs,
725            page_size,
726            cookie: Vec::new(),
727            done: false,
728        }
729    }
730
731    pub async fn add(
732        &self,
733        dn: impl Into<String>,
734        attrs: Vec<ldap_client_proto::PartialAttribute>,
735    ) -> Result<(), Error> {
736        let dn = dn.into();
737        let mut chased: Option<Client> = None;
738        loop {
739            let client = chased.as_ref().unwrap_or(self);
740            let op = LdapOperation::AddRequest(AddRequest {
741                dn: dn.clone(),
742                attributes: attrs.clone(),
743            });
744            let response = client.request(op).await?;
745            match response.operation {
746                LdapOperation::AddResponse(result) => match try_chase(client, &result).await {
747                    Chase::Ok => return Ok(()),
748                    Chase::Follow(c) => {
749                        chased = Some(*c);
750                        continue;
751                    }
752                    Chase::Err(e) => return Err(e),
753                },
754                _ => return Err(unexpected_response("AddResponse")),
755            }
756        }
757    }
758
759    pub async fn modify(
760        &self,
761        dn: impl Into<String>,
762        changes: Vec<ldap_client_proto::Modification>,
763    ) -> Result<(), Error> {
764        let dn = dn.into();
765        let mut chased: Option<Client> = None;
766        loop {
767            let client = chased.as_ref().unwrap_or(self);
768            let op = LdapOperation::ModifyRequest(ModifyRequest {
769                dn: dn.clone(),
770                changes: changes.clone(),
771            });
772            let response = client.request(op).await?;
773            match response.operation {
774                LdapOperation::ModifyResponse(result) => match try_chase(client, &result).await {
775                    Chase::Ok => return Ok(()),
776                    Chase::Follow(c) => {
777                        chased = Some(*c);
778                        continue;
779                    }
780                    Chase::Err(e) => return Err(e),
781                },
782                _ => return Err(unexpected_response("ModifyResponse")),
783            }
784        }
785    }
786
787    pub async fn delete(&self, dn: impl Into<String>) -> Result<(), Error> {
788        let dn = dn.into();
789        let mut chased: Option<Client> = None;
790        loop {
791            let client = chased.as_ref().unwrap_or(self);
792            let op = LdapOperation::DeleteRequest(dn.clone());
793            let response = client.request(op).await?;
794            match response.operation {
795                LdapOperation::DeleteResponse(result) => match try_chase(client, &result).await {
796                    Chase::Ok => return Ok(()),
797                    Chase::Follow(c) => {
798                        chased = Some(*c);
799                        continue;
800                    }
801                    Chase::Err(e) => return Err(e),
802                },
803                _ => return Err(unexpected_response("DeleteResponse")),
804            }
805        }
806    }
807
808    pub async fn compare(
809        &self,
810        dn: impl Into<String>,
811        attr: impl Into<String>,
812        value: impl AsRef<[u8]>,
813    ) -> Result<bool, Error> {
814        let dn = dn.into();
815        let attr = attr.into();
816        let value = value.as_ref().to_vec();
817        let mut chased: Option<Client> = None;
818        loop {
819            let client = chased.as_ref().unwrap_or(self);
820            let op = LdapOperation::CompareRequest(CompareRequest {
821                dn: dn.clone(),
822                attr: attr.clone(),
823                value: value.clone(),
824            });
825            let response = client.request(op).await?;
826            match response.operation {
827                LdapOperation::CompareResponse(result) => {
828                    use ldap_client_proto::ResultCode;
829                    match result.code {
830                        ResultCode::CompareTrue => return Ok(true),
831                        ResultCode::CompareFalse => return Ok(false),
832                        _ => match try_chase(client, &result).await {
833                            Chase::Ok => return Err(Error::ldap(&result)),
834                            Chase::Follow(c) => {
835                                chased = Some(*c);
836                                continue;
837                            }
838                            Chase::Err(e) => return Err(e),
839                        },
840                    }
841                }
842                _ => return Err(unexpected_response("CompareResponse")),
843            }
844        }
845    }
846
847    pub async fn modify_dn(
848        &self,
849        dn: impl Into<String>,
850        new_rdn: impl Into<String>,
851        delete_old_rdn: bool,
852        new_superior: Option<String>,
853    ) -> Result<(), Error> {
854        let dn = dn.into();
855        let new_rdn = new_rdn.into();
856        let mut chased: Option<Client> = None;
857        loop {
858            let client = chased.as_ref().unwrap_or(self);
859            let op = LdapOperation::ModifyDnRequest(ModifyDnRequest {
860                dn: dn.clone(),
861                new_rdn: new_rdn.clone(),
862                delete_old_rdn,
863                new_superior: new_superior.clone(),
864            });
865            let response = client.request(op).await?;
866            match response.operation {
867                LdapOperation::ModifyDnResponse(result) => match try_chase(client, &result).await {
868                    Chase::Ok => return Ok(()),
869                    Chase::Follow(c) => {
870                        chased = Some(*c);
871                        continue;
872                    }
873                    Chase::Err(e) => return Err(e),
874                },
875                _ => return Err(unexpected_response("ModifyDnResponse")),
876            }
877        }
878    }
879
880    pub async fn extended(
881        &self,
882        oid: impl Into<String>,
883        value: Option<Vec<u8>>,
884    ) -> Result<ldap_client_proto::ExtendedResponse, Error> {
885        let oid = oid.into();
886        let mut chased: Option<Client> = None;
887        loop {
888            let client = chased.as_ref().unwrap_or(self);
889            let op = LdapOperation::ExtendedRequest(ExtendedRequest {
890                oid: oid.clone(),
891                value: value.clone(),
892            });
893            let response = client.request(op).await?;
894            match response.operation {
895                LdapOperation::ExtendedResponse(resp) => {
896                    match try_chase(client, &resp.result).await {
897                        Chase::Ok => return Ok(resp),
898                        Chase::Follow(c) => {
899                            chased = Some(*c);
900                            continue;
901                        }
902                        Chase::Err(e) => return Err(e),
903                    }
904                }
905                _ => return Err(unexpected_response("ExtendedResponse")),
906            }
907        }
908    }
909
910    pub async fn who_am_i(&self) -> Result<Option<String>, Error> {
911        let resp = self.extended("1.3.6.1.4.1.4203.1.11.3", None).await?;
912        Ok(resp.value.map(|v| String::from_utf8_lossy(&v).into_owned()))
913    }
914
915    pub async fn search_one(
916        &self,
917        base_dn: impl Into<String>,
918        scope: SearchScope,
919        filter: Filter,
920        attrs: Vec<String>,
921    ) -> Result<Option<SearchResultEntry>, Error> {
922        let message_id = self.next_message_id();
923        let msg = LdapMessage {
924            message_id,
925            operation: LdapOperation::SearchRequest(SearchRequest {
926                base_dn: self.resolve_base_dn(base_dn.into()),
927                scope,
928                deref_aliases: DerefAliases::NeverDerefAliases,
929                size_limit: 2,
930                time_limit: self.request_timeout.as_secs() as i32,
931                types_only: false,
932                filter,
933                attributes: attrs,
934            }),
935            controls: vec![],
936        };
937        let data = msg.encode();
938
939        let mut framed = self.framed.lock().await;
940        send_msg(&mut framed, data, &self.connected).await?;
941
942        let mut entries = Vec::new();
943        let result = collect_search_results(
944            &mut framed,
945            self.request_timeout,
946            &self.connected,
947            &mut entries,
948            self.referral_policy,
949            &self.unsolicited_handler,
950        )
951        .await;
952
953        // SizeLimitExceeded means the server stopped because of our size_limit=2,
954        // which indicates multiple results exist.
955        if let Err(Error::Ldap { code, .. }) = &result
956            && *code == ResultCode::SizeLimitExceeded
957        {
958            return Err(Error::MultipleResults);
959        }
960        result?;
961
962        match entries.len() {
963            0 => Ok(None),
964            1 => Ok(Some(entries.into_iter().next().unwrap())),
965            _ => Err(Error::MultipleResults),
966        }
967    }
968
969    pub async fn root_dse(&self) -> Result<SearchResultEntry, Error> {
970        let entries = self
971            .search(
972                "",
973                SearchScope::BaseObject,
974                Filter::present("objectClass"),
975                vec!["*".into(), "+".into()],
976            )
977            .await?;
978        entries.into_iter().next().ok_or_else(|| {
979            Error::Proto(ldap_client_proto::ProtoError::Protocol(
980                "root DSE not found".into(),
981            ))
982        })
983    }
984
985    pub async fn sasl_external_bind(&self) -> Result<(), Error> {
986        let op = LdapOperation::BindRequest(BindRequest {
987            version: 3,
988            name: String::new(),
989            authentication: BindAuthentication::Sasl {
990                mechanism: "EXTERNAL".into(),
991                credentials: None,
992            },
993        });
994        let response = self.request(op).await?;
995        match response.operation {
996            LdapOperation::BindResponse(resp) => check_result(&resp.result, self.referral_policy),
997            _ => Err(unexpected_response("BindResponse")),
998        }
999    }
1000
1001    /// Retrieve all values of a multi-valued attribute using Active Directory
1002    /// range retrieval.
1003    ///
1004    /// AD limits the number of values returned per request (typically 1500).
1005    /// This method loops, requesting `attr;range=N-*` with `BaseObject` scope
1006    /// until all values are collected.
1007    pub async fn search_range(
1008        &self,
1009        base_dn: &str,
1010        filter: Filter,
1011        attr: &str,
1012    ) -> Result<Vec<Vec<u8>>, Error> {
1013        const MAX_RANGE_ROUNDS: usize = 100_000;
1014        let resolved_base = self.resolve_base_dn(base_dn.to_string());
1015        let mut all_values: Vec<Vec<u8>> = Vec::new();
1016        let mut range_start: u32 = 0;
1017
1018        for _ in 0..MAX_RANGE_ROUNDS {
1019            let range_attr = format!("{attr};range={range_start}-*");
1020            let entries = self
1021                .search(
1022                    resolved_base.clone(),
1023                    SearchScope::BaseObject,
1024                    filter.clone(),
1025                    vec![range_attr],
1026                )
1027                .await?;
1028
1029            let entry = match entries.into_iter().next() {
1030                Some(e) => e,
1031                None => break,
1032            };
1033
1034            // Find the response attribute that has range info.
1035            let mut found = false;
1036            for pa in &entry.attributes {
1037                if let Some((base, _start, end)) = parse_range_option(&pa.name)
1038                    && base.eq_ignore_ascii_case(attr)
1039                {
1040                    all_values.extend(pa.values.iter().cloned());
1041                    found = true;
1042                    match end {
1043                        None => {
1044                            // `*` means this is the last chunk.
1045                            return Ok(all_values);
1046                        }
1047                        Some(e) => {
1048                            let next = e.saturating_add(1);
1049                            if next <= range_start {
1050                                // No progress — avoid infinite loop.
1051                                return Ok(all_values);
1052                            }
1053                            range_start = next;
1054                        }
1055                    }
1056                }
1057            }
1058            if !found {
1059                // No range option in response — the attribute may be small enough
1060                // to be returned entirely.
1061                for pa in entry.attributes {
1062                    let base = pa.name.split(';').next().unwrap_or(&pa.name);
1063                    if base.eq_ignore_ascii_case(attr) {
1064                        all_values.extend(pa.values);
1065                    }
1066                }
1067                break;
1068            }
1069        }
1070
1071        Ok(all_values)
1072    }
1073
1074    /// Connect to a referral server, optionally binding with the service account.
1075    ///
1076    /// Tries each URL in order, returning the first successful connection.
1077    /// The returned client has its hop limit decremented by one.
1078    async fn connect_referral(&self, urls: &[String], hop_limit: u8) -> Result<Client, Error> {
1079        if hop_limit == 0 {
1080            return Err(Error::ReferralHopLimitExceeded);
1081        }
1082
1083        let mut last_err = None;
1084        for raw_url in urls {
1085            let referral_url = match LdapUrl::parse(raw_url) {
1086                Ok(u) => u,
1087                Err(_) => continue,
1088            };
1089            let transport = match referral_url.scheme {
1090                LdapScheme::Ldap => Transport::Plain,
1091                LdapScheme::Ldaps => Transport::Tls,
1092            };
1093            if self.transport != Transport::Plain && transport == Transport::Plain {
1094                debug!(url = %raw_url, "skipping referral that would downgrade from TLS to plain");
1095                continue;
1096            }
1097            let mut builder =
1098                ClientBuilder::new(referral_url.host.clone(), referral_url.effective_port())
1099                    .transport(transport)
1100                    .tls_config(self.tls_config.clone())
1101                    .connect_timeout(self.connect_timeout)
1102                    .request_timeout(self.request_timeout)
1103                    .max_message_size(self.max_message_size)
1104                    .referral_policy(ReferralPolicy::Follow {
1105                        hop_limit: hop_limit - 1,
1106                    });
1107            builder.unsolicited_handler = self.unsolicited_handler.clone();
1108            if let Some(dn) = &self.service_account_dn
1109                && let Some(pw) = &self.service_account_password
1110            {
1111                builder = builder.service_account(dn.clone(), pw.clone());
1112            }
1113            match builder.connect().await {
1114                Ok(client) => {
1115                    // Bind with service account if configured.
1116                    if client.service_account_dn.is_some()
1117                        && let Err(e) = client.rebind_service_account().await
1118                    {
1119                        last_err = Some(e);
1120                        continue;
1121                    }
1122                    return Ok(client);
1123                }
1124                Err(e) => {
1125                    last_err = Some(e);
1126                }
1127            }
1128        }
1129
1130        Err(last_err.unwrap_or(Error::InvalidUrl("no valid referral URLs".into())))
1131    }
1132
1133    /// Bind using one of the supported credential types.
1134    pub async fn bind(&self, credentials: BindCredentials<'_>) -> Result<(), Error> {
1135        match credentials {
1136            BindCredentials::Simple { dn, password } => self.simple_bind(dn, password).await,
1137            BindCredentials::ServiceAccount => self.rebind_service_account().await,
1138            BindCredentials::SaslExternal => self.sasl_external_bind().await,
1139        }
1140    }
1141
1142    pub async fn unbind(&self) -> Result<(), Error> {
1143        let message_id = self.next_message_id();
1144        let msg = LdapMessage {
1145            message_id,
1146            operation: LdapOperation::UnbindRequest,
1147            controls: vec![],
1148        };
1149        let data = msg.encode();
1150        let mut framed = self.framed.lock().await;
1151        send_msg(&mut framed, data, &self.connected).await
1152    }
1153}
1154
1155/// Incremental paged search that yields one page of results at a time.
1156///
1157/// Created via [`Client::search_paged_stream`]. Call [`next_page`](PagedSearch::next_page)
1158/// repeatedly to fetch pages. If you stop before exhausting results, call
1159/// [`cancel`](PagedSearch::cancel) to release the server-side cookie.
1160pub struct PagedSearch<'a> {
1161    client: &'a Client,
1162    base_dn: String,
1163    scope: SearchScope,
1164    filter: Filter,
1165    attrs: Vec<String>,
1166    page_size: i32,
1167    cookie: Vec<u8>,
1168    done: bool,
1169}
1170
1171impl<'a> PagedSearch<'a> {
1172    /// Fetch the next page of results.
1173    ///
1174    /// Returns `Ok(Some(entries))` for each page, `Ok(None)` when all pages
1175    /// have been consumed.
1176    pub async fn next_page(&mut self) -> Result<Option<Vec<SearchResultEntry>>, Error> {
1177        if self.done {
1178            return Ok(None);
1179        }
1180
1181        let paged =
1182            PagedResultsControl::new(self.page_size).with_cookie(std::mem::take(&mut self.cookie));
1183        let controls = vec![paged.to_control()];
1184
1185        let message_id = self.client.next_message_id();
1186        let msg = LdapMessage {
1187            message_id,
1188            operation: LdapOperation::SearchRequest(SearchRequest {
1189                base_dn: self.base_dn.clone(),
1190                scope: self.scope,
1191                deref_aliases: DerefAliases::NeverDerefAliases,
1192                size_limit: 0,
1193                time_limit: 0,
1194                types_only: false,
1195                filter: self.filter.clone(),
1196                attributes: self.attrs.clone(),
1197            }),
1198            controls,
1199        };
1200        let data = msg.encode();
1201
1202        let mut framed = self.client.framed.lock().await;
1203        send_msg(&mut framed, data, &self.client.connected).await?;
1204
1205        let mut entries = Vec::new();
1206        let collected = collect_search_results(
1207            &mut framed,
1208            self.client.request_timeout,
1209            &self.client.connected,
1210            &mut entries,
1211            self.client.referral_policy,
1212            &self.client.unsolicited_handler,
1213        )
1214        .await?;
1215
1216        let new_cookie = collected
1217            .controls
1218            .iter()
1219            .find(|c| c.oid == PAGED_RESULTS_OID)
1220            .and_then(|c| PagedResultsControl::from_control(c).ok())
1221            .map(|p| p.cookie);
1222
1223        match new_cookie {
1224            Some(c) if !c.is_empty() => self.cookie = c,
1225            _ => self.done = true,
1226        }
1227
1228        Ok(Some(entries))
1229    }
1230
1231    /// Send an abandon request (page size 0) to release the server cookie.
1232    ///
1233    /// Call this if you stop iterating before all pages are consumed. If you
1234    /// don't, the server cookie will eventually time out on its own (~120 s).
1235    pub async fn cancel(&mut self) -> Result<(), Error> {
1236        if self.done {
1237            return Ok(());
1238        }
1239        self.done = true;
1240
1241        let paged = PagedResultsControl::new(0).with_cookie(std::mem::take(&mut self.cookie));
1242        let controls = vec![paged.to_control()];
1243
1244        let message_id = self.client.next_message_id();
1245        let msg = LdapMessage {
1246            message_id,
1247            operation: LdapOperation::SearchRequest(SearchRequest {
1248                base_dn: self.base_dn.clone(),
1249                scope: self.scope,
1250                deref_aliases: DerefAliases::NeverDerefAliases,
1251                size_limit: 0,
1252                time_limit: 0,
1253                types_only: false,
1254                filter: self.filter.clone(),
1255                attributes: self.attrs.clone(),
1256            }),
1257            controls,
1258        };
1259        let data = msg.encode();
1260
1261        let mut framed = self.client.framed.lock().await;
1262        send_msg(&mut framed, data, &self.client.connected).await?;
1263
1264        let mut entries = Vec::new();
1265        collect_search_results(
1266            &mut framed,
1267            self.client.request_timeout,
1268            &self.client.connected,
1269            &mut entries,
1270            self.client.referral_policy,
1271            &self.client.unsolicited_handler,
1272        )
1273        .await?;
1274
1275        Ok(())
1276    }
1277
1278    /// Returns `true` once all pages have been consumed or `cancel` was called.
1279    pub fn is_done(&self) -> bool {
1280        self.done
1281    }
1282}
1283
1284/// Result of collecting search response messages.
1285struct CollectedSearch {
1286    controls: Vec<Control>,
1287    referral_urls: Vec<String>,
1288}
1289
1290async fn collect_search_results(
1291    framed: &mut FramedLdap,
1292    timeout: Duration,
1293    connected: &AtomicBool,
1294    entries: &mut Vec<SearchResultEntry>,
1295    referral_policy: ReferralPolicy,
1296    unsolicited_handler: &UnsolicitedHandler,
1297) -> Result<CollectedSearch, Error> {
1298    let mut referral_urls = Vec::new();
1299    loop {
1300        let response = recv_msg(framed, timeout, connected, unsolicited_handler).await?;
1301        match response.operation {
1302            LdapOperation::SearchResultEntry(entry) => {
1303                if entries.len() >= MAX_SEARCH_ENTRIES {
1304                    return Err(Error::SearchEntryLimitExceeded(MAX_SEARCH_ENTRIES));
1305                }
1306                entries.push(entry);
1307            }
1308            LdapOperation::SearchResultDone(result) => {
1309                // Collect referral URLs from the Done result before checking.
1310                if result.code.is_referral() {
1311                    referral_urls.extend(result.referral.iter().cloned());
1312                }
1313                check_result(&result, referral_policy)?;
1314                return Ok(CollectedSearch {
1315                    controls: response.controls,
1316                    referral_urls,
1317                });
1318            }
1319            LdapOperation::SearchResultReference(urls) => {
1320                referral_urls.extend(urls);
1321            }
1322            _ => return Err(unexpected_response("SearchResult*")),
1323        }
1324    }
1325}
1326
1327async fn send_msg(
1328    framed: &mut FramedLdap,
1329    data: Vec<u8>,
1330    connected: &AtomicBool,
1331) -> Result<(), Error> {
1332    framed.send(data).await.map_err(|e| {
1333        connected.store(false, Ordering::Relaxed);
1334        ber_to_io(e)
1335    })
1336}
1337
1338async fn recv_msg(
1339    framed: &mut FramedLdap,
1340    timeout: Duration,
1341    connected: &AtomicBool,
1342    unsolicited_handler: &UnsolicitedHandler,
1343) -> Result<LdapMessage, Error> {
1344    loop {
1345        let msg = match tokio::time::timeout(timeout, framed.next()).await {
1346            Ok(Some(Ok(frame))) => LdapMessage::decode(&frame).map_err(Error::Proto)?,
1347            Ok(Some(Err(e))) => {
1348                connected.store(false, Ordering::Relaxed);
1349                return Err(ber_to_io(e));
1350            }
1351            Ok(None) => {
1352                connected.store(false, Ordering::Relaxed);
1353                return Err(Error::ConnectionClosed);
1354            }
1355            Err(_) => {
1356                // After a timeout the connection is desynchronized: the server
1357                // may still send the response later. Force a reconnect.
1358                connected.store(false, Ordering::Relaxed);
1359                return Err(Error::Timeout);
1360            }
1361        };
1362
1363        // Handle unsolicited notifications (message_id == 0).
1364        if msg.message_id == MessageId(0) {
1365            if let LdapOperation::ExtendedResponse(ref resp) = msg.operation {
1366                if resp.oid.as_deref() == Some(NOTICE_OF_DISCONNECTION_OID) {
1367                    connected.store(false, Ordering::Relaxed);
1368                    return Err(Error::ConnectionClosed);
1369                }
1370                unsolicited_handler(resp);
1371            }
1372            continue;
1373        }
1374
1375        return Ok(msg);
1376    }
1377}
1378
1379fn ber_to_io(e: ldap_client_ber::BerError) -> Error {
1380    match e {
1381        ldap_client_ber::BerError::Io(io) => Error::Io(io),
1382        other => Error::Ber(other),
1383    }
1384}
1385
1386fn check_result(result: &ProtoLdapResult, referral_policy: ReferralPolicy) -> Result<(), Error> {
1387    if result.code.is_success() {
1388        return Ok(());
1389    }
1390    if result.code.is_referral() {
1391        return match referral_policy {
1392            ReferralPolicy::Ignore => Ok(()),
1393            ReferralPolicy::Return | ReferralPolicy::Follow { .. } => Err(Error::Referral {
1394                urls: result.referral.clone(),
1395                result: result.clone(),
1396            }),
1397        };
1398    }
1399    Err(Error::ldap(result))
1400}
1401
1402/// Outcome of checking an LDAP result and potentially chasing a referral.
1403enum Chase {
1404    Ok,
1405    Follow(Box<Client>),
1406    Err(Error),
1407}
1408
1409/// Check the LDAP result: if success, return `Chase::Ok`. If a referral
1410/// and the client's policy is `Follow`, connect to the referral and return
1411/// `Chase::Follow`. Otherwise return `Chase::Err`.
1412async fn try_chase(client: &Client, result: &ProtoLdapResult) -> Chase {
1413    if result.code.is_success() {
1414        return Chase::Ok;
1415    }
1416    if result.code.is_referral() {
1417        return match client.referral_policy {
1418            ReferralPolicy::Ignore => Chase::Ok,
1419            ReferralPolicy::Return => Chase::Err(Error::Referral {
1420                urls: result.referral.clone(),
1421                result: result.clone(),
1422            }),
1423            ReferralPolicy::Follow { hop_limit } => {
1424                match client.connect_referral(&result.referral, hop_limit).await {
1425                    Ok(c) => Chase::Follow(Box::new(c)),
1426                    Err(e) => Chase::Err(e),
1427                }
1428            }
1429        };
1430    }
1431    Chase::Err(Error::ldap(result))
1432}
1433
1434fn unexpected_response(expected: &str) -> Error {
1435    Error::Proto(ldap_client_proto::ProtoError::Protocol(format!(
1436        "unexpected response, expected {expected}"
1437    )))
1438}
1439
1440/// Parse an AD range option from an attribute name.
1441///
1442/// `"member;range=0-1499"` → `Some(("member", 0, Some(1499)))`
1443/// `"member;range=1500-*"` → `Some(("member", 1500, None))`
1444pub fn parse_range_option(attr_name: &str) -> Option<(&str, u32, Option<u32>)> {
1445    let base = attr_name.split(';').next().unwrap_or(attr_name);
1446    let range_part = attr_name
1447        .split(';')
1448        .find_map(|part| part.strip_prefix("range="))?;
1449    let (start_s, end_s) = range_part.split_once('-')?;
1450    let start: u32 = start_s.parse().ok()?;
1451    let end = if end_s == "*" {
1452        None
1453    } else {
1454        Some(end_s.parse().ok()?)
1455    };
1456    Some((base, start, end))
1457}