1use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
5use std::time::Duration;
6
7use futures_util::{SinkExt, StreamExt};
8use rustls::ClientConfig;
9use rustls_pki_types::ServerName;
10use secrecy::{ExposeSecret, SecretString};
11use tokio::net::TcpStream;
12use tokio::sync::Mutex;
13use tokio_util::codec::Framed;
14use tracing::debug;
15
16use ldap_client_ber::LdapCodec;
17use ldap_client_proto::{
18 AddRequest, BindAuthentication, BindRequest, CompareRequest, Control, DerefAliases,
19 ExtendedRequest, ExtendedResponse, Filter, LdapMessage, LdapOperation,
20 LdapResult as ProtoLdapResult, LdapScheme, LdapUrl, MessageId, ModifyDnRequest, ModifyRequest,
21 PAGED_RESULTS_OID, PagedResultsControl, ResultCode, SearchRequest, SearchResultEntry,
22 SearchScope,
23};
24
25use crate::Error;
26use crate::conn::{self, LdapStream};
27
28const STARTTLS_OID: &str = "1.3.6.1.4.1.1466.20037";
29const NOTICE_OF_DISCONNECTION_OID: &str = "1.3.6.1.4.1.1466.20036";
30const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
31const DEFAULT_MAX_MESSAGE_SIZE: u32 = 10 * 1024 * 1024;
32const MAX_SEARCH_ENTRIES: usize = 500_000;
33
34pub type UnsolicitedHandler = Arc<dyn Fn(&ExtendedResponse) + Send + Sync>;
36
37fn default_unsolicited_handler() -> UnsolicitedHandler {
38 Arc::new(|resp| {
39 tracing::debug!(
40 oid = resp.oid.as_deref().unwrap_or("<none>"),
41 "received unsolicited notification from server"
42 );
43 })
44}
45
46#[derive(Debug, Clone)]
47pub struct SearchResult {
48 pub entries: Vec<SearchResultEntry>,
49 pub referrals: Vec<String>,
50 pub controls: Vec<Control>,
51}
52
53#[derive(Clone, Copy, Debug, PartialEq, Eq)]
54pub enum Transport {
55 Plain,
56 Tls,
57 StartTls,
58}
59
60#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
62pub enum ReferralPolicy {
63 #[default]
65 Ignore,
66 Return,
68 Follow { hop_limit: u8 },
70}
71
72impl ReferralPolicy {
73 pub fn follow() -> Self {
75 Self::Follow { hop_limit: 10 }
76 }
77}
78
79pub enum BindCredentials<'a> {
81 Simple {
83 dn: &'a str,
84 password: &'a SecretString,
85 },
86 ServiceAccount,
88 SaslExternal,
90}
91
92pub struct ClientBuilder {
93 host: String,
94 port: u16,
95 transport: Transport,
96 tls_config: Option<Arc<ClientConfig>>,
97 connect_timeout: Duration,
98 request_timeout: Duration,
99 max_message_size: u32,
100 base_dn: Option<String>,
101 service_account_dn: Option<String>,
102 service_account_password: Option<SecretString>,
103 referral_policy: ReferralPolicy,
104 unsolicited_handler: UnsolicitedHandler,
105}
106
107impl ClientBuilder {
108 pub fn new(host: impl Into<String>, port: u16) -> Self {
109 Self {
110 host: host.into(),
111 port,
112 transport: Transport::Plain,
113 tls_config: None,
114 connect_timeout: DEFAULT_TIMEOUT,
115 request_timeout: DEFAULT_TIMEOUT,
116 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
117 base_dn: None,
118 service_account_dn: None,
119 service_account_password: None,
120 referral_policy: ReferralPolicy::default(),
121 unsolicited_handler: default_unsolicited_handler(),
122 }
123 }
124
125 pub fn from_url(url: &str) -> Result<Self, Error> {
126 let parsed = LdapUrl::parse(url).map_err(|e| Error::InvalidUrl(format!("{e}")))?;
127
128 let transport = match parsed.scheme {
129 LdapScheme::Ldap => Transport::Plain,
130 LdapScheme::Ldaps => Transport::Tls,
131 };
132
133 let port = parsed.effective_port();
134 Ok(Self {
135 host: parsed.host,
136 port,
137 transport,
138 tls_config: None,
139 connect_timeout: DEFAULT_TIMEOUT,
140 request_timeout: DEFAULT_TIMEOUT,
141 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
142 base_dn: parsed.base_dn,
143 service_account_dn: None,
144 service_account_password: None,
145 referral_policy: ReferralPolicy::default(),
146 unsolicited_handler: default_unsolicited_handler(),
147 })
148 }
149
150 pub fn transport(mut self, transport: Transport) -> Self {
151 self.transport = transport;
152 self
153 }
154
155 pub fn tls_config(mut self, config: Arc<ClientConfig>) -> Self {
159 self.tls_config = Some(config);
160 self
161 }
162
163 pub fn tls(mut self, config: crate::tls_config::TlsConfig) -> Result<Self, Error> {
167 self.tls_config = Some(config.build()?);
168 Ok(self)
169 }
170
171 pub fn timeout(mut self, timeout: Duration) -> Self {
173 self.connect_timeout = timeout;
174 self.request_timeout = timeout;
175 self
176 }
177
178 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
179 self.connect_timeout = timeout;
180 self
181 }
182
183 pub fn request_timeout(mut self, timeout: Duration) -> Self {
184 self.request_timeout = timeout;
185 self
186 }
187
188 pub fn base_dn(mut self, base_dn: impl Into<String>) -> Self {
189 self.base_dn = Some(base_dn.into());
190 self
191 }
192
193 pub fn service_account(mut self, dn: impl Into<String>, password: SecretString) -> Self {
194 self.service_account_dn = Some(dn.into());
195 self.service_account_password = Some(password);
196 self
197 }
198
199 pub fn referral_policy(mut self, policy: ReferralPolicy) -> Self {
200 self.referral_policy = policy;
201 self
202 }
203
204 pub fn max_message_size(mut self, max: u32) -> Self {
206 self.max_message_size = max;
207 self
208 }
209
210 pub fn on_unsolicited_notification(
214 mut self,
215 handler: impl Fn(&ExtendedResponse) + Send + Sync + 'static,
216 ) -> Self {
217 self.unsolicited_handler = Arc::new(handler);
218 self
219 }
220
221 pub async fn connect(self) -> Result<Client, Error> {
222 let addr = format_addr(&self.host, self.port);
223 debug!(addr = %addr, transport = ?self.transport, "connecting");
224
225 let tcp = match tokio::time::timeout(self.connect_timeout, TcpStream::connect(&addr)).await
226 {
227 Ok(Ok(tcp)) => tcp,
228 Ok(Err(e)) => return Err(Error::Io(e)),
229 Err(_) => return Err(Error::Timeout),
230 };
231 tcp.set_nodelay(true)?;
232
233 let tls_config = self
234 .tls_config
235 .clone()
236 .unwrap_or_else(|| Arc::new(conn::default_tls_config()));
237
238 let stream = match self.transport {
239 Transport::Plain => LdapStream::Plain(tcp),
240 Transport::Tls | Transport::StartTls => {
241 let server_name = ServerName::try_from(self.host.clone())
242 .map_err(|e| Error::InvalidUrl(format!("invalid server name: {e}")))?;
243
244 if self.transport == Transport::Tls {
245 conn::upgrade_to_tls(tcp, server_name, tls_config.clone(), self.connect_timeout)
246 .await?
247 } else {
248 perform_start_tls(
249 tcp,
250 server_name,
251 tls_config.clone(),
252 self.request_timeout,
253 self.max_message_size,
254 self.connect_timeout,
255 )
256 .await?
257 }
258 }
259 };
260
261 let start_id = if self.transport == Transport::StartTls {
262 2
263 } else {
264 1
265 };
266 let codec = LdapCodec::new().with_max_message_size(self.max_message_size);
267 Ok(Client {
268 framed: Mutex::new(Framed::new(stream, codec)),
269 next_id: AtomicI32::new(start_id),
270 connected: AtomicBool::new(true),
271 request_timeout: self.request_timeout,
272 max_message_size: self.max_message_size,
273 base_dn: self.base_dn,
274 referral_policy: self.referral_policy,
275 last_reconnect: Mutex::new(None),
276 unsolicited_handler: self.unsolicited_handler,
277 host: self.host,
278 port: self.port,
279 transport: self.transport,
280 tls_config,
281 connect_timeout: self.connect_timeout,
282 service_account_dn: self.service_account_dn,
283 service_account_password: self.service_account_password,
284 })
285 }
286}
287
288async fn perform_start_tls(
289 tcp: TcpStream,
290 server_name: ServerName<'static>,
291 tls_config: Arc<ClientConfig>,
292 timeout: Duration,
293 max_message_size: u32,
294 tls_timeout: Duration,
295) -> Result<LdapStream, Error> {
296 let mut framed = Framed::new(
297 tcp,
298 LdapCodec::new().with_max_message_size(max_message_size),
299 );
300
301 let msg = LdapMessage {
302 message_id: MessageId(1),
303 operation: LdapOperation::ExtendedRequest(ExtendedRequest {
304 oid: STARTTLS_OID.to_string(),
305 value: None,
306 }),
307 controls: vec![],
308 };
309 framed.send(msg.encode()).await.map_err(ber_to_io)?;
310
311 let response = match tokio::time::timeout(timeout, framed.next()).await {
312 Ok(Some(Ok(frame))) => LdapMessage::decode(&frame).map_err(Error::Proto)?,
313 Ok(Some(Err(e))) => return Err(ber_to_io(e)),
314 Ok(None) => return Err(Error::ConnectionClosed),
315 Err(_) => return Err(Error::Timeout),
316 };
317
318 match response.operation {
319 LdapOperation::ExtendedResponse(resp) if resp.result.code.is_success() => {}
320 LdapOperation::ExtendedResponse(resp) => {
321 return Err(Error::StartTls(resp.result.diagnostic_message));
322 }
323 _ => return Err(Error::StartTls("unexpected response".into())),
324 }
325
326 let parts = framed.into_parts();
327 if !parts.read_buf.is_empty() || !parts.write_buf.is_empty() {
328 return Err(Error::StartTls(
329 "unexpected buffered data before TLS handshake".into(),
330 ));
331 }
332 let tcp = parts.io;
333 conn::upgrade_to_tls(tcp, server_name, tls_config, tls_timeout).await
334}
335
336type FramedLdap = Framed<LdapStream, LdapCodec>;
337
338const MIN_RECONNECT_INTERVAL: Duration = Duration::from_secs(1);
339
340pub struct Client {
341 framed: Mutex<FramedLdap>,
342 next_id: AtomicI32,
343 connected: AtomicBool,
344 request_timeout: Duration,
345 max_message_size: u32,
346 base_dn: Option<String>,
347 referral_policy: ReferralPolicy,
348 last_reconnect: Mutex<Option<tokio::time::Instant>>,
349 unsolicited_handler: UnsolicitedHandler,
350 host: String,
352 port: u16,
353 transport: Transport,
354 tls_config: Arc<ClientConfig>,
355 connect_timeout: Duration,
356 service_account_dn: Option<String>,
357 service_account_password: Option<SecretString>,
358}
359
360fn format_addr(host: &str, port: u16) -> String {
361 if host.contains(':') {
362 format!("[{host}]:{port}")
363 } else {
364 format!("{host}:{port}")
365 }
366}
367
368impl Client {
369 fn next_message_id(&self) -> MessageId {
370 let id = self
371 .next_id
372 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |n| {
373 if n > 0 && n < i32::MAX - 1 {
374 Some(n + 1)
375 } else {
376 Some(2) }
378 });
379 MessageId(id.unwrap())
381 }
382
383 fn resolve_base_dn(&self, base_dn: String) -> String {
384 if base_dn.is_empty()
385 && let Some(default) = &self.base_dn
386 {
387 return default.clone();
388 }
389 base_dn
390 }
391
392 pub fn is_connected(&self) -> bool {
393 self.connected.load(Ordering::Relaxed)
394 }
395
396 pub async fn reconnect(&self) -> Result<(), Error> {
397 {
398 let mut last = self.last_reconnect.lock().await;
399 if let Some(prev) = *last {
400 let elapsed = prev.elapsed();
401 if elapsed < MIN_RECONNECT_INTERVAL {
402 tokio::time::sleep(MIN_RECONNECT_INTERVAL - elapsed).await;
403 }
404 }
405 *last = Some(tokio::time::Instant::now());
406 }
407
408 let addr = format_addr(&self.host, self.port);
409 debug!(addr = %addr, transport = ?self.transport, "reconnecting");
410
411 let tcp = match tokio::time::timeout(self.connect_timeout, TcpStream::connect(&addr)).await
412 {
413 Ok(Ok(tcp)) => tcp,
414 Ok(Err(e)) => return Err(Error::Io(e)),
415 Err(_) => return Err(Error::Timeout),
416 };
417 tcp.set_nodelay(true)?;
418
419 let stream = match self.transport {
420 Transport::Plain => LdapStream::Plain(tcp),
421 Transport::Tls | Transport::StartTls => {
422 let server_name = ServerName::try_from(self.host.clone())
423 .map_err(|e| Error::InvalidUrl(format!("invalid server name: {e}")))?;
424
425 if self.transport == Transport::Tls {
426 conn::upgrade_to_tls(
427 tcp,
428 server_name,
429 self.tls_config.clone(),
430 self.connect_timeout,
431 )
432 .await?
433 } else {
434 perform_start_tls(
435 tcp,
436 server_name,
437 self.tls_config.clone(),
438 self.request_timeout,
439 self.max_message_size,
440 self.connect_timeout,
441 )
442 .await?
443 }
444 }
445 };
446
447 let start_id = if self.transport == Transport::StartTls {
448 2
449 } else {
450 1
451 };
452
453 let mut framed = self.framed.lock().await;
454 *framed = Framed::new(
455 stream,
456 LdapCodec::new().with_max_message_size(self.max_message_size),
457 );
458 self.next_id.store(start_id, Ordering::Relaxed);
459 self.connected.store(true, Ordering::Relaxed);
460 drop(framed);
461
462 if self.service_account_dn.is_some()
463 && let Err(e) = self.rebind_service_account().await
464 {
465 self.connected.store(false, Ordering::Relaxed);
466 return Err(e);
467 }
468 Ok(())
469 }
470
471 pub async fn rebind_service_account(&self) -> Result<(), Error> {
472 let dn = self.service_account_dn.as_deref().ok_or_else(|| {
473 Error::Proto(ldap_client_proto::ProtoError::Protocol(
474 "no service account configured".into(),
475 ))
476 })?;
477 let password = self.service_account_password.as_ref().ok_or_else(|| {
478 Error::Proto(ldap_client_proto::ProtoError::Protocol(
479 "no service account password configured".into(),
480 ))
481 })?;
482 self.simple_bind(dn, password).await
483 }
484
485 async fn request(&self, operation: LdapOperation) -> Result<LdapMessage, Error> {
486 self.request_with_controls(operation, vec![]).await
487 }
488
489 async fn request_with_controls(
490 &self,
491 operation: LdapOperation,
492 controls: Vec<Control>,
493 ) -> Result<LdapMessage, Error> {
494 let message_id = self.next_message_id();
495 let msg = LdapMessage {
496 message_id,
497 operation,
498 controls,
499 };
500 let data = msg.encode();
501
502 let mut framed = self.framed.lock().await;
503 send_msg(&mut framed, data, &self.connected).await?;
504 recv_msg(
505 &mut framed,
506 self.request_timeout,
507 &self.connected,
508 &self.unsolicited_handler,
509 )
510 .await
511 }
512
513 pub async fn simple_bind(
514 &self,
515 dn: impl Into<String>,
516 password: &SecretString,
517 ) -> Result<(), Error> {
518 if self.transport == Transport::Plain {
519 tracing::warn!(
520 "simple bind over plain (unencrypted) connection; credentials are sent in cleartext"
521 );
522 }
523 let op = LdapOperation::BindRequest(BindRequest {
524 version: 3,
525 name: dn.into(),
526 authentication: BindAuthentication::Simple(zeroize::Zeroizing::new(
527 password.expose_secret().as_bytes().to_vec(),
528 )),
529 });
530
531 let response = self.request(op).await?;
532 match response.operation {
533 LdapOperation::BindResponse(resp) => check_result(&resp.result, self.referral_policy),
534 _ => Err(unexpected_response("BindResponse")),
535 }
536 }
537
538 pub async fn search(
539 &self,
540 base_dn: impl Into<String>,
541 scope: SearchScope,
542 filter: Filter,
543 attrs: Vec<String>,
544 ) -> Result<Vec<SearchResultEntry>, Error> {
545 let (entries, _controls) = self
546 .search_with_controls(base_dn, scope, filter, attrs, vec![])
547 .await?;
548 Ok(entries)
549 }
550
551 pub async fn search_with_controls(
552 &self,
553 base_dn: impl Into<String>,
554 scope: SearchScope,
555 filter: Filter,
556 attrs: Vec<String>,
557 controls: Vec<Control>,
558 ) -> Result<(Vec<SearchResultEntry>, Vec<Control>), Error> {
559 let message_id = self.next_message_id();
560 let msg = LdapMessage {
561 message_id,
562 operation: LdapOperation::SearchRequest(SearchRequest {
563 base_dn: self.resolve_base_dn(base_dn.into()),
564 scope,
565 deref_aliases: DerefAliases::NeverDerefAliases,
566 size_limit: 0,
567 time_limit: 0,
568 types_only: false,
569 filter,
570 attributes: attrs,
571 }),
572 controls,
573 };
574 let data = msg.encode();
575
576 let mut framed = self.framed.lock().await;
577 send_msg(&mut framed, data, &self.connected).await?;
578
579 let mut entries = Vec::new();
580 let collected = collect_search_results(
581 &mut framed,
582 self.request_timeout,
583 &self.connected,
584 &mut entries,
585 self.referral_policy,
586 &self.unsolicited_handler,
587 )
588 .await?;
589
590 Ok((entries, collected.controls))
591 }
592
593 pub async fn search_full(
594 &self,
595 base_dn: impl Into<String>,
596 scope: SearchScope,
597 filter: Filter,
598 attrs: Vec<String>,
599 controls: Vec<Control>,
600 ) -> Result<SearchResult, Error> {
601 let message_id = self.next_message_id();
602 let msg = LdapMessage {
603 message_id,
604 operation: LdapOperation::SearchRequest(SearchRequest {
605 base_dn: self.resolve_base_dn(base_dn.into()),
606 scope,
607 deref_aliases: DerefAliases::NeverDerefAliases,
608 size_limit: 0,
609 time_limit: 0,
610 types_only: false,
611 filter,
612 attributes: attrs,
613 }),
614 controls,
615 };
616 let data = msg.encode();
617
618 let mut framed = self.framed.lock().await;
619 send_msg(&mut framed, data, &self.connected).await?;
620
621 let mut entries = Vec::new();
622 let collected = collect_search_results(
623 &mut framed,
624 self.request_timeout,
625 &self.connected,
626 &mut entries,
627 self.referral_policy,
628 &self.unsolicited_handler,
629 )
630 .await?;
631
632 Ok(SearchResult {
633 entries,
634 referrals: collected.referral_urls,
635 controls: collected.controls,
636 })
637 }
638
639 pub async fn search_paged(
640 &self,
641 base_dn: &str,
642 scope: SearchScope,
643 filter: Filter,
644 attrs: Vec<String>,
645 page_size: i32,
646 ) -> Result<Vec<SearchResultEntry>, Error> {
647 const MAX_PAGED_ROUNDS: usize = 100_000;
648 let resolved_base = self.resolve_base_dn(base_dn.to_string());
649 let mut all_entries = Vec::new();
650 let mut cookie = Vec::new();
651 let mut prev_cookie = Vec::new();
652
653 for _ in 0..MAX_PAGED_ROUNDS {
654 let paged =
655 PagedResultsControl::new(page_size).with_cookie(std::mem::take(&mut cookie));
656 let controls = vec![paged.to_control()];
657
658 let message_id = self.next_message_id();
659 let msg = LdapMessage {
660 message_id,
661 operation: LdapOperation::SearchRequest(SearchRequest {
662 base_dn: resolved_base.clone(),
663 scope,
664 deref_aliases: DerefAliases::NeverDerefAliases,
665 size_limit: 0,
666 time_limit: 0,
667 types_only: false,
668 filter: filter.clone(),
669 attributes: attrs.clone(),
670 }),
671 controls,
672 };
673 let data = msg.encode();
674
675 let mut framed = self.framed.lock().await;
676 send_msg(&mut framed, data, &self.connected).await?;
677
678 let collected = collect_search_results(
679 &mut framed,
680 self.request_timeout,
681 &self.connected,
682 &mut all_entries,
683 self.referral_policy,
684 &self.unsolicited_handler,
685 )
686 .await?;
687
688 let new_cookie = collected
689 .controls
690 .iter()
691 .find(|c| c.oid == PAGED_RESULTS_OID)
692 .and_then(|c| PagedResultsControl::from_control(c).ok())
693 .map(|p| p.cookie);
694
695 match new_cookie {
696 Some(c) if !c.is_empty() => {
697 if c == prev_cookie {
698 break;
700 }
701 prev_cookie = c.clone();
702 cookie = c;
703 }
704 _ => break,
705 }
706 }
707
708 Ok(all_entries)
709 }
710
711 pub fn search_paged_stream(
712 &self,
713 base_dn: &str,
714 scope: SearchScope,
715 filter: Filter,
716 attrs: Vec<String>,
717 page_size: i32,
718 ) -> PagedSearch<'_> {
719 PagedSearch {
720 client: self,
721 base_dn: self.resolve_base_dn(base_dn.to_string()),
722 scope,
723 filter,
724 attrs,
725 page_size,
726 cookie: Vec::new(),
727 done: false,
728 }
729 }
730
731 pub async fn add(
732 &self,
733 dn: impl Into<String>,
734 attrs: Vec<ldap_client_proto::PartialAttribute>,
735 ) -> Result<(), Error> {
736 let dn = dn.into();
737 let mut chased: Option<Client> = None;
738 loop {
739 let client = chased.as_ref().unwrap_or(self);
740 let op = LdapOperation::AddRequest(AddRequest {
741 dn: dn.clone(),
742 attributes: attrs.clone(),
743 });
744 let response = client.request(op).await?;
745 match response.operation {
746 LdapOperation::AddResponse(result) => match try_chase(client, &result).await {
747 Chase::Ok => return Ok(()),
748 Chase::Follow(c) => {
749 chased = Some(*c);
750 continue;
751 }
752 Chase::Err(e) => return Err(e),
753 },
754 _ => return Err(unexpected_response("AddResponse")),
755 }
756 }
757 }
758
759 pub async fn modify(
760 &self,
761 dn: impl Into<String>,
762 changes: Vec<ldap_client_proto::Modification>,
763 ) -> Result<(), Error> {
764 let dn = dn.into();
765 let mut chased: Option<Client> = None;
766 loop {
767 let client = chased.as_ref().unwrap_or(self);
768 let op = LdapOperation::ModifyRequest(ModifyRequest {
769 dn: dn.clone(),
770 changes: changes.clone(),
771 });
772 let response = client.request(op).await?;
773 match response.operation {
774 LdapOperation::ModifyResponse(result) => match try_chase(client, &result).await {
775 Chase::Ok => return Ok(()),
776 Chase::Follow(c) => {
777 chased = Some(*c);
778 continue;
779 }
780 Chase::Err(e) => return Err(e),
781 },
782 _ => return Err(unexpected_response("ModifyResponse")),
783 }
784 }
785 }
786
787 pub async fn delete(&self, dn: impl Into<String>) -> Result<(), Error> {
788 let dn = dn.into();
789 let mut chased: Option<Client> = None;
790 loop {
791 let client = chased.as_ref().unwrap_or(self);
792 let op = LdapOperation::DeleteRequest(dn.clone());
793 let response = client.request(op).await?;
794 match response.operation {
795 LdapOperation::DeleteResponse(result) => match try_chase(client, &result).await {
796 Chase::Ok => return Ok(()),
797 Chase::Follow(c) => {
798 chased = Some(*c);
799 continue;
800 }
801 Chase::Err(e) => return Err(e),
802 },
803 _ => return Err(unexpected_response("DeleteResponse")),
804 }
805 }
806 }
807
808 pub async fn compare(
809 &self,
810 dn: impl Into<String>,
811 attr: impl Into<String>,
812 value: impl AsRef<[u8]>,
813 ) -> Result<bool, Error> {
814 let dn = dn.into();
815 let attr = attr.into();
816 let value = value.as_ref().to_vec();
817 let mut chased: Option<Client> = None;
818 loop {
819 let client = chased.as_ref().unwrap_or(self);
820 let op = LdapOperation::CompareRequest(CompareRequest {
821 dn: dn.clone(),
822 attr: attr.clone(),
823 value: value.clone(),
824 });
825 let response = client.request(op).await?;
826 match response.operation {
827 LdapOperation::CompareResponse(result) => {
828 use ldap_client_proto::ResultCode;
829 match result.code {
830 ResultCode::CompareTrue => return Ok(true),
831 ResultCode::CompareFalse => return Ok(false),
832 _ => match try_chase(client, &result).await {
833 Chase::Ok => return Err(Error::ldap(&result)),
834 Chase::Follow(c) => {
835 chased = Some(*c);
836 continue;
837 }
838 Chase::Err(e) => return Err(e),
839 },
840 }
841 }
842 _ => return Err(unexpected_response("CompareResponse")),
843 }
844 }
845 }
846
847 pub async fn modify_dn(
848 &self,
849 dn: impl Into<String>,
850 new_rdn: impl Into<String>,
851 delete_old_rdn: bool,
852 new_superior: Option<String>,
853 ) -> Result<(), Error> {
854 let dn = dn.into();
855 let new_rdn = new_rdn.into();
856 let mut chased: Option<Client> = None;
857 loop {
858 let client = chased.as_ref().unwrap_or(self);
859 let op = LdapOperation::ModifyDnRequest(ModifyDnRequest {
860 dn: dn.clone(),
861 new_rdn: new_rdn.clone(),
862 delete_old_rdn,
863 new_superior: new_superior.clone(),
864 });
865 let response = client.request(op).await?;
866 match response.operation {
867 LdapOperation::ModifyDnResponse(result) => match try_chase(client, &result).await {
868 Chase::Ok => return Ok(()),
869 Chase::Follow(c) => {
870 chased = Some(*c);
871 continue;
872 }
873 Chase::Err(e) => return Err(e),
874 },
875 _ => return Err(unexpected_response("ModifyDnResponse")),
876 }
877 }
878 }
879
880 pub async fn extended(
881 &self,
882 oid: impl Into<String>,
883 value: Option<Vec<u8>>,
884 ) -> Result<ldap_client_proto::ExtendedResponse, Error> {
885 let oid = oid.into();
886 let mut chased: Option<Client> = None;
887 loop {
888 let client = chased.as_ref().unwrap_or(self);
889 let op = LdapOperation::ExtendedRequest(ExtendedRequest {
890 oid: oid.clone(),
891 value: value.clone(),
892 });
893 let response = client.request(op).await?;
894 match response.operation {
895 LdapOperation::ExtendedResponse(resp) => {
896 match try_chase(client, &resp.result).await {
897 Chase::Ok => return Ok(resp),
898 Chase::Follow(c) => {
899 chased = Some(*c);
900 continue;
901 }
902 Chase::Err(e) => return Err(e),
903 }
904 }
905 _ => return Err(unexpected_response("ExtendedResponse")),
906 }
907 }
908 }
909
910 pub async fn who_am_i(&self) -> Result<Option<String>, Error> {
911 let resp = self.extended("1.3.6.1.4.1.4203.1.11.3", None).await?;
912 Ok(resp.value.map(|v| String::from_utf8_lossy(&v).into_owned()))
913 }
914
915 pub async fn search_one(
916 &self,
917 base_dn: impl Into<String>,
918 scope: SearchScope,
919 filter: Filter,
920 attrs: Vec<String>,
921 ) -> Result<Option<SearchResultEntry>, Error> {
922 let message_id = self.next_message_id();
923 let msg = LdapMessage {
924 message_id,
925 operation: LdapOperation::SearchRequest(SearchRequest {
926 base_dn: self.resolve_base_dn(base_dn.into()),
927 scope,
928 deref_aliases: DerefAliases::NeverDerefAliases,
929 size_limit: 2,
930 time_limit: self.request_timeout.as_secs() as i32,
931 types_only: false,
932 filter,
933 attributes: attrs,
934 }),
935 controls: vec![],
936 };
937 let data = msg.encode();
938
939 let mut framed = self.framed.lock().await;
940 send_msg(&mut framed, data, &self.connected).await?;
941
942 let mut entries = Vec::new();
943 let result = collect_search_results(
944 &mut framed,
945 self.request_timeout,
946 &self.connected,
947 &mut entries,
948 self.referral_policy,
949 &self.unsolicited_handler,
950 )
951 .await;
952
953 if let Err(Error::Ldap { code, .. }) = &result
956 && *code == ResultCode::SizeLimitExceeded
957 {
958 return Err(Error::MultipleResults);
959 }
960 result?;
961
962 match entries.len() {
963 0 => Ok(None),
964 1 => Ok(Some(entries.into_iter().next().unwrap())),
965 _ => Err(Error::MultipleResults),
966 }
967 }
968
969 pub async fn root_dse(&self) -> Result<SearchResultEntry, Error> {
970 let entries = self
971 .search(
972 "",
973 SearchScope::BaseObject,
974 Filter::present("objectClass"),
975 vec!["*".into(), "+".into()],
976 )
977 .await?;
978 entries.into_iter().next().ok_or_else(|| {
979 Error::Proto(ldap_client_proto::ProtoError::Protocol(
980 "root DSE not found".into(),
981 ))
982 })
983 }
984
985 pub async fn sasl_external_bind(&self) -> Result<(), Error> {
986 let op = LdapOperation::BindRequest(BindRequest {
987 version: 3,
988 name: String::new(),
989 authentication: BindAuthentication::Sasl {
990 mechanism: "EXTERNAL".into(),
991 credentials: None,
992 },
993 });
994 let response = self.request(op).await?;
995 match response.operation {
996 LdapOperation::BindResponse(resp) => check_result(&resp.result, self.referral_policy),
997 _ => Err(unexpected_response("BindResponse")),
998 }
999 }
1000
1001 pub async fn search_range(
1008 &self,
1009 base_dn: &str,
1010 filter: Filter,
1011 attr: &str,
1012 ) -> Result<Vec<Vec<u8>>, Error> {
1013 const MAX_RANGE_ROUNDS: usize = 100_000;
1014 let resolved_base = self.resolve_base_dn(base_dn.to_string());
1015 let mut all_values: Vec<Vec<u8>> = Vec::new();
1016 let mut range_start: u32 = 0;
1017
1018 for _ in 0..MAX_RANGE_ROUNDS {
1019 let range_attr = format!("{attr};range={range_start}-*");
1020 let entries = self
1021 .search(
1022 resolved_base.clone(),
1023 SearchScope::BaseObject,
1024 filter.clone(),
1025 vec![range_attr],
1026 )
1027 .await?;
1028
1029 let entry = match entries.into_iter().next() {
1030 Some(e) => e,
1031 None => break,
1032 };
1033
1034 let mut found = false;
1036 for pa in &entry.attributes {
1037 if let Some((base, _start, end)) = parse_range_option(&pa.name)
1038 && base.eq_ignore_ascii_case(attr)
1039 {
1040 all_values.extend(pa.values.iter().cloned());
1041 found = true;
1042 match end {
1043 None => {
1044 return Ok(all_values);
1046 }
1047 Some(e) => {
1048 let next = e.saturating_add(1);
1049 if next <= range_start {
1050 return Ok(all_values);
1052 }
1053 range_start = next;
1054 }
1055 }
1056 }
1057 }
1058 if !found {
1059 for pa in entry.attributes {
1062 let base = pa.name.split(';').next().unwrap_or(&pa.name);
1063 if base.eq_ignore_ascii_case(attr) {
1064 all_values.extend(pa.values);
1065 }
1066 }
1067 break;
1068 }
1069 }
1070
1071 Ok(all_values)
1072 }
1073
1074 async fn connect_referral(&self, urls: &[String], hop_limit: u8) -> Result<Client, Error> {
1079 if hop_limit == 0 {
1080 return Err(Error::ReferralHopLimitExceeded);
1081 }
1082
1083 let mut last_err = None;
1084 for raw_url in urls {
1085 let referral_url = match LdapUrl::parse(raw_url) {
1086 Ok(u) => u,
1087 Err(_) => continue,
1088 };
1089 let transport = match referral_url.scheme {
1090 LdapScheme::Ldap => Transport::Plain,
1091 LdapScheme::Ldaps => Transport::Tls,
1092 };
1093 if self.transport != Transport::Plain && transport == Transport::Plain {
1094 debug!(url = %raw_url, "skipping referral that would downgrade from TLS to plain");
1095 continue;
1096 }
1097 let mut builder =
1098 ClientBuilder::new(referral_url.host.clone(), referral_url.effective_port())
1099 .transport(transport)
1100 .tls_config(self.tls_config.clone())
1101 .connect_timeout(self.connect_timeout)
1102 .request_timeout(self.request_timeout)
1103 .max_message_size(self.max_message_size)
1104 .referral_policy(ReferralPolicy::Follow {
1105 hop_limit: hop_limit - 1,
1106 });
1107 builder.unsolicited_handler = self.unsolicited_handler.clone();
1108 if let Some(dn) = &self.service_account_dn
1109 && let Some(pw) = &self.service_account_password
1110 {
1111 builder = builder.service_account(dn.clone(), pw.clone());
1112 }
1113 match builder.connect().await {
1114 Ok(client) => {
1115 if client.service_account_dn.is_some()
1117 && let Err(e) = client.rebind_service_account().await
1118 {
1119 last_err = Some(e);
1120 continue;
1121 }
1122 return Ok(client);
1123 }
1124 Err(e) => {
1125 last_err = Some(e);
1126 }
1127 }
1128 }
1129
1130 Err(last_err.unwrap_or(Error::InvalidUrl("no valid referral URLs".into())))
1131 }
1132
1133 pub async fn bind(&self, credentials: BindCredentials<'_>) -> Result<(), Error> {
1135 match credentials {
1136 BindCredentials::Simple { dn, password } => self.simple_bind(dn, password).await,
1137 BindCredentials::ServiceAccount => self.rebind_service_account().await,
1138 BindCredentials::SaslExternal => self.sasl_external_bind().await,
1139 }
1140 }
1141
1142 pub async fn unbind(&self) -> Result<(), Error> {
1143 let message_id = self.next_message_id();
1144 let msg = LdapMessage {
1145 message_id,
1146 operation: LdapOperation::UnbindRequest,
1147 controls: vec![],
1148 };
1149 let data = msg.encode();
1150 let mut framed = self.framed.lock().await;
1151 send_msg(&mut framed, data, &self.connected).await
1152 }
1153}
1154
1155pub struct PagedSearch<'a> {
1161 client: &'a Client,
1162 base_dn: String,
1163 scope: SearchScope,
1164 filter: Filter,
1165 attrs: Vec<String>,
1166 page_size: i32,
1167 cookie: Vec<u8>,
1168 done: bool,
1169}
1170
1171impl<'a> PagedSearch<'a> {
1172 pub async fn next_page(&mut self) -> Result<Option<Vec<SearchResultEntry>>, Error> {
1177 if self.done {
1178 return Ok(None);
1179 }
1180
1181 let paged =
1182 PagedResultsControl::new(self.page_size).with_cookie(std::mem::take(&mut self.cookie));
1183 let controls = vec![paged.to_control()];
1184
1185 let message_id = self.client.next_message_id();
1186 let msg = LdapMessage {
1187 message_id,
1188 operation: LdapOperation::SearchRequest(SearchRequest {
1189 base_dn: self.base_dn.clone(),
1190 scope: self.scope,
1191 deref_aliases: DerefAliases::NeverDerefAliases,
1192 size_limit: 0,
1193 time_limit: 0,
1194 types_only: false,
1195 filter: self.filter.clone(),
1196 attributes: self.attrs.clone(),
1197 }),
1198 controls,
1199 };
1200 let data = msg.encode();
1201
1202 let mut framed = self.client.framed.lock().await;
1203 send_msg(&mut framed, data, &self.client.connected).await?;
1204
1205 let mut entries = Vec::new();
1206 let collected = collect_search_results(
1207 &mut framed,
1208 self.client.request_timeout,
1209 &self.client.connected,
1210 &mut entries,
1211 self.client.referral_policy,
1212 &self.client.unsolicited_handler,
1213 )
1214 .await?;
1215
1216 let new_cookie = collected
1217 .controls
1218 .iter()
1219 .find(|c| c.oid == PAGED_RESULTS_OID)
1220 .and_then(|c| PagedResultsControl::from_control(c).ok())
1221 .map(|p| p.cookie);
1222
1223 match new_cookie {
1224 Some(c) if !c.is_empty() => self.cookie = c,
1225 _ => self.done = true,
1226 }
1227
1228 Ok(Some(entries))
1229 }
1230
1231 pub async fn cancel(&mut self) -> Result<(), Error> {
1236 if self.done {
1237 return Ok(());
1238 }
1239 self.done = true;
1240
1241 let paged = PagedResultsControl::new(0).with_cookie(std::mem::take(&mut self.cookie));
1242 let controls = vec![paged.to_control()];
1243
1244 let message_id = self.client.next_message_id();
1245 let msg = LdapMessage {
1246 message_id,
1247 operation: LdapOperation::SearchRequest(SearchRequest {
1248 base_dn: self.base_dn.clone(),
1249 scope: self.scope,
1250 deref_aliases: DerefAliases::NeverDerefAliases,
1251 size_limit: 0,
1252 time_limit: 0,
1253 types_only: false,
1254 filter: self.filter.clone(),
1255 attributes: self.attrs.clone(),
1256 }),
1257 controls,
1258 };
1259 let data = msg.encode();
1260
1261 let mut framed = self.client.framed.lock().await;
1262 send_msg(&mut framed, data, &self.client.connected).await?;
1263
1264 let mut entries = Vec::new();
1265 collect_search_results(
1266 &mut framed,
1267 self.client.request_timeout,
1268 &self.client.connected,
1269 &mut entries,
1270 self.client.referral_policy,
1271 &self.client.unsolicited_handler,
1272 )
1273 .await?;
1274
1275 Ok(())
1276 }
1277
1278 pub fn is_done(&self) -> bool {
1280 self.done
1281 }
1282}
1283
1284struct CollectedSearch {
1286 controls: Vec<Control>,
1287 referral_urls: Vec<String>,
1288}
1289
1290async fn collect_search_results(
1291 framed: &mut FramedLdap,
1292 timeout: Duration,
1293 connected: &AtomicBool,
1294 entries: &mut Vec<SearchResultEntry>,
1295 referral_policy: ReferralPolicy,
1296 unsolicited_handler: &UnsolicitedHandler,
1297) -> Result<CollectedSearch, Error> {
1298 let mut referral_urls = Vec::new();
1299 loop {
1300 let response = recv_msg(framed, timeout, connected, unsolicited_handler).await?;
1301 match response.operation {
1302 LdapOperation::SearchResultEntry(entry) => {
1303 if entries.len() >= MAX_SEARCH_ENTRIES {
1304 return Err(Error::SearchEntryLimitExceeded(MAX_SEARCH_ENTRIES));
1305 }
1306 entries.push(entry);
1307 }
1308 LdapOperation::SearchResultDone(result) => {
1309 if result.code.is_referral() {
1311 referral_urls.extend(result.referral.iter().cloned());
1312 }
1313 check_result(&result, referral_policy)?;
1314 return Ok(CollectedSearch {
1315 controls: response.controls,
1316 referral_urls,
1317 });
1318 }
1319 LdapOperation::SearchResultReference(urls) => {
1320 referral_urls.extend(urls);
1321 }
1322 _ => return Err(unexpected_response("SearchResult*")),
1323 }
1324 }
1325}
1326
1327async fn send_msg(
1328 framed: &mut FramedLdap,
1329 data: Vec<u8>,
1330 connected: &AtomicBool,
1331) -> Result<(), Error> {
1332 framed.send(data).await.map_err(|e| {
1333 connected.store(false, Ordering::Relaxed);
1334 ber_to_io(e)
1335 })
1336}
1337
1338async fn recv_msg(
1339 framed: &mut FramedLdap,
1340 timeout: Duration,
1341 connected: &AtomicBool,
1342 unsolicited_handler: &UnsolicitedHandler,
1343) -> Result<LdapMessage, Error> {
1344 loop {
1345 let msg = match tokio::time::timeout(timeout, framed.next()).await {
1346 Ok(Some(Ok(frame))) => LdapMessage::decode(&frame).map_err(Error::Proto)?,
1347 Ok(Some(Err(e))) => {
1348 connected.store(false, Ordering::Relaxed);
1349 return Err(ber_to_io(e));
1350 }
1351 Ok(None) => {
1352 connected.store(false, Ordering::Relaxed);
1353 return Err(Error::ConnectionClosed);
1354 }
1355 Err(_) => {
1356 connected.store(false, Ordering::Relaxed);
1359 return Err(Error::Timeout);
1360 }
1361 };
1362
1363 if msg.message_id == MessageId(0) {
1365 if let LdapOperation::ExtendedResponse(ref resp) = msg.operation {
1366 if resp.oid.as_deref() == Some(NOTICE_OF_DISCONNECTION_OID) {
1367 connected.store(false, Ordering::Relaxed);
1368 return Err(Error::ConnectionClosed);
1369 }
1370 unsolicited_handler(resp);
1371 }
1372 continue;
1373 }
1374
1375 return Ok(msg);
1376 }
1377}
1378
1379fn ber_to_io(e: ldap_client_ber::BerError) -> Error {
1380 match e {
1381 ldap_client_ber::BerError::Io(io) => Error::Io(io),
1382 other => Error::Ber(other),
1383 }
1384}
1385
1386fn check_result(result: &ProtoLdapResult, referral_policy: ReferralPolicy) -> Result<(), Error> {
1387 if result.code.is_success() {
1388 return Ok(());
1389 }
1390 if result.code.is_referral() {
1391 return match referral_policy {
1392 ReferralPolicy::Ignore => Ok(()),
1393 ReferralPolicy::Return | ReferralPolicy::Follow { .. } => Err(Error::Referral {
1394 urls: result.referral.clone(),
1395 result: result.clone(),
1396 }),
1397 };
1398 }
1399 Err(Error::ldap(result))
1400}
1401
1402enum Chase {
1404 Ok,
1405 Follow(Box<Client>),
1406 Err(Error),
1407}
1408
1409async fn try_chase(client: &Client, result: &ProtoLdapResult) -> Chase {
1413 if result.code.is_success() {
1414 return Chase::Ok;
1415 }
1416 if result.code.is_referral() {
1417 return match client.referral_policy {
1418 ReferralPolicy::Ignore => Chase::Ok,
1419 ReferralPolicy::Return => Chase::Err(Error::Referral {
1420 urls: result.referral.clone(),
1421 result: result.clone(),
1422 }),
1423 ReferralPolicy::Follow { hop_limit } => {
1424 match client.connect_referral(&result.referral, hop_limit).await {
1425 Ok(c) => Chase::Follow(Box::new(c)),
1426 Err(e) => Chase::Err(e),
1427 }
1428 }
1429 };
1430 }
1431 Chase::Err(Error::ldap(result))
1432}
1433
1434fn unexpected_response(expected: &str) -> Error {
1435 Error::Proto(ldap_client_proto::ProtoError::Protocol(format!(
1436 "unexpected response, expected {expected}"
1437 )))
1438}
1439
1440pub fn parse_range_option(attr_name: &str) -> Option<(&str, u32, Option<u32>)> {
1445 let base = attr_name.split(';').next().unwrap_or(attr_name);
1446 let range_part = attr_name
1447 .split(';')
1448 .find_map(|part| part.strip_prefix("range="))?;
1449 let (start_s, end_s) = range_part.split_once('-')?;
1450 let start: u32 = start_s.parse().ok()?;
1451 let end = if end_s == "*" {
1452 None
1453 } else {
1454 Some(end_s.parse().ok()?)
1455 };
1456 Some((base, start, end))
1457}