ldap_rs/
client.rs

1//! LDAP client module
2
3use std::{
4    collections::VecDeque,
5    convert::{TryFrom, TryInto},
6    pin::Pin,
7    sync::{
8        atomic::{AtomicBool, AtomicU32, Ordering},
9        Arc,
10    },
11    task::{Context, Poll},
12};
13
14use futures::{future::BoxFuture, Future, Stream, TryStreamExt};
15use parking_lot::RwLock;
16use rasn_ldap::{
17    AuthenticationChoice, BindRequest, BindResponse, Controls, ExtendedRequest, LdapMessage, LdapResult, ProtocolOp,
18    ResultCode, SaslCredentials, SearchResultDone, UnbindRequest,
19};
20
21use crate::{
22    conn::{LdapConnection, MessageStream},
23    controls::SimplePagedResultsControl,
24    error::Error,
25    oid,
26    options::TlsOptions,
27    request::SearchRequest,
28    Attribute, ModifyRequest, SearchEntry,
29};
30
31pub type Result<T> = std::result::Result<T, Error>;
32
33fn check_result(result: LdapResult) -> Result<()> {
34    if result.result_code == ResultCode::Success || result.result_code == ResultCode::SaslBindInProgress {
35        Ok(())
36    } else {
37        Err(Error::OperationFailed(result.into()))
38    }
39}
40
41/// LDAP client builder
42pub struct LdapClientBuilder {
43    address: String,
44    port: u16,
45    tls_options: TlsOptions,
46}
47
48impl LdapClientBuilder {
49    /// Set port number, default is 389
50    pub fn port(mut self, port: u16) -> Self {
51        self.port = port;
52        self
53    }
54
55    /// Set TLS options, default is plain connection
56    pub fn tls_options(mut self, options: TlsOptions) -> Self {
57        self.tls_options = options;
58        self
59    }
60
61    /// Build client and connect
62    pub async fn connect(self) -> Result<LdapClient> {
63        LdapClient::connect(self.address, self.port, self.tls_options).await
64    }
65}
66
67/// LDAP client
68#[derive(Clone)]
69pub struct LdapClient {
70    connection: LdapConnection,
71    id_counter: Arc<AtomicU32>,
72}
73
74impl LdapClient {
75    /// Create client builder
76    pub fn builder<A: AsRef<str>>(address: A) -> LdapClientBuilder {
77        LdapClientBuilder {
78            address: address.as_ref().to_owned(),
79            port: 389,
80            tls_options: TlsOptions::default(),
81        }
82    }
83
84    pub(crate) async fn connect<A>(address: A, port: u16, tls_options: TlsOptions) -> Result<Self>
85    where
86        A: AsRef<str>,
87    {
88        let connection = LdapConnection::connect(address, port, tls_options).await?;
89        Ok(Self {
90            connection,
91            id_counter: Arc::new(AtomicU32::new(2)), // 1 is used by STARTTLS
92        })
93    }
94
95    fn new_id(&self) -> u32 {
96        self.id_counter.fetch_add(1, Ordering::SeqCst)
97    }
98
99    async fn do_bind(&mut self, req: BindRequest) -> Result<BindResponse> {
100        let id = self.new_id();
101        let msg = LdapMessage::new(id, ProtocolOp::BindRequest(req));
102
103        let item = self.connection.send_recv(msg).await?;
104
105        match item.protocol_op {
106            ProtocolOp::BindResponse(resp) => {
107                let result = resp.clone();
108                check_result(LdapResult::new(
109                    resp.result_code,
110                    resp.matched_dn,
111                    resp.diagnostic_message,
112                ))?;
113                Ok(result)
114            }
115            _ => Err(Error::InvalidResponse),
116        }
117    }
118
119    fn new_sasl_bind_req(&self, mech: &str, creds: Option<&[u8]>) -> BindRequest {
120        let auth_choice =
121            AuthenticationChoice::Sasl(SaslCredentials::new(mech.into(), creds.map(|c| c.to_vec().into())));
122        BindRequest::new(3, String::new().into(), auth_choice)
123    }
124
125    /// Perform simple bind operation with username and password
126    pub async fn simple_bind<U, P>(&mut self, username: U, password: P) -> Result<()>
127    where
128        U: AsRef<str>,
129        P: AsRef<str>,
130    {
131        let auth_choice = AuthenticationChoice::Simple(password.as_ref().to_owned().into());
132        let req = BindRequest::new(3, username.as_ref().to_owned().into(), auth_choice);
133        self.do_bind(req).await?;
134        Ok(())
135    }
136
137    /// Perform SASL EXTERNAL bind
138    pub async fn sasl_external_bind(&mut self) -> Result<()> {
139        let req = self.new_sasl_bind_req("EXTERNAL", None);
140        self.do_bind(req).await?;
141        Ok(())
142    }
143
144    #[cfg(feature = "gssapi")]
145    /// Perform SASL GSSAPI bind for a given server realm.
146    /// The following features are NOT implemented:
147    ///  * SASL protection over plain connection (use TLS instead)
148    ///  * Channel binding
149    pub async fn sasl_gssapi_bind<S: AsRef<str>>(&mut self, realm: S) -> Result<()> {
150        // GSSAPI code credits: https://github.com/inejge/ldap3
151        use cross_krb5::{ClientCtx, InitiateFlags, K5Ctx, Step};
152
153        const SASL_RECV_MAX_SIZE: u32 = 0x0200_0000;
154
155        let spn = format!("ldap/{}", realm.as_ref());
156
157        let (client_ctx, token) =
158            ClientCtx::new(InitiateFlags::empty(), None, &spn, None).map_err(|e| Error::GssApiError(e.to_string()))?;
159
160        let req = self.new_sasl_bind_req("GSSAPI", Some(token.as_ref()));
161        let response = self.do_bind(req).await?;
162
163        let token = match response.server_sasl_creds {
164            Some(token) => token,
165            _ => return Err(Error::NoSaslCredentials),
166        };
167
168        let step = client_ctx
169            .step(&token)
170            .map_err(|e| Error::GssApiError(format!("{}", e)))?;
171
172        let mut client_ctx = match step {
173            Step::Finished((ctx, None)) => ctx,
174            _ => {
175                return Err(Error::GssApiError(
176                    "GSSAPI exchange not finished or has an additional token".to_owned(),
177                ))
178            }
179        };
180
181        let req = self.new_sasl_bind_req("GSSAPI", None);
182        let response = self.do_bind(req).await?;
183
184        if response.server_sasl_creds.is_none() {
185            return Err(Error::NoSaslCredentials);
186        }
187
188        let recv_max_size = SASL_RECV_MAX_SIZE.to_be_bytes();
189        let size_msg = client_ctx
190            .wrap(true, &recv_max_size)
191            .map_err(|e| Error::GssApiError(format!("{}", e)))?;
192
193        let req = self.new_sasl_bind_req("GSSAPI", Some(size_msg.as_ref()));
194        self.do_bind(req).await?;
195
196        Ok(())
197    }
198
199    /// Perform unbind operation. This will instruct LDAP server to terminate the connection
200    pub async fn unbind(&mut self) -> Result<()> {
201        let id = self.new_id();
202
203        let msg = LdapMessage::new(id, ProtocolOp::UnbindRequest(UnbindRequest));
204        self.connection.send(msg).await?;
205
206        Ok(())
207    }
208
209    /// Send 'whoami' extended request (RFC4532)
210    pub async fn whoami(&mut self) -> Result<Option<String>> {
211        let id = self.new_id();
212
213        let msg = LdapMessage::new(
214            id,
215            ProtocolOp::ExtendedReq(ExtendedRequest {
216                request_name: oid::WHOAMI_OID.into(),
217                request_value: None,
218            }),
219        );
220
221        let resp = self.connection.send_recv(msg).await?;
222
223        match resp.protocol_op {
224            ProtocolOp::ExtendedResp(resp) => {
225                check_result(LdapResult::new(
226                    resp.result_code,
227                    resp.matched_dn,
228                    resp.diagnostic_message,
229                ))?;
230                Ok(resp.response_value.map(|v| String::from_utf8_lossy(&v).into_owned()))
231            }
232            _ => Err(Error::InvalidResponse),
233        }
234    }
235
236    /// Perform search operation without paging. Returns a stream of search entries
237    pub async fn search(&mut self, request: SearchRequest) -> Result<SearchEntries> {
238        let id = self.new_id();
239
240        let msg = LdapMessage::new(id, ProtocolOp::SearchRequest(request.into()));
241        let stream = self.connection.send_recv_stream(msg).await?;
242
243        Ok(SearchEntries {
244            inner: stream,
245            page_control: None,
246            page_finished: Arc::new(AtomicBool::new(false)),
247        })
248    }
249
250    /// Perform search operation without paging and return one result
251    pub async fn search_one(&mut self, request: SearchRequest) -> Result<Option<SearchEntry>> {
252        let entries = self.search(request).await?;
253        let mut attrs = entries.try_collect::<VecDeque<_>>().await?;
254        Ok(attrs.pop_front())
255    }
256
257    /// Perform search operation with paging. Returns a stream of pages
258    pub fn search_paged(&mut self, request: SearchRequest, page_size: u32) -> Pages {
259        Pages {
260            page_control: Arc::new(RwLock::new(SimplePagedResultsControl::new(page_size))),
261            page_finished: Arc::new(AtomicBool::new(true)),
262            client: self.clone(),
263            request,
264            page_size,
265            inner: None,
266        }
267    }
268
269    /// Perform modify operation
270    pub async fn modify(&mut self, request: ModifyRequest) -> Result<()> {
271        let id = self.new_id();
272
273        let msg = LdapMessage::new(id, ProtocolOp::ModifyRequest(request.into()));
274        let resp = self.connection.send_recv(msg).await?;
275
276        match resp.protocol_op {
277            ProtocolOp::ModifyResponse(resp) => {
278                check_result(LdapResult::new(
279                    resp.0.result_code,
280                    resp.0.matched_dn,
281                    resp.0.diagnostic_message,
282                ))?;
283                Ok(())
284            }
285            _ => Err(Error::InvalidResponse),
286        }
287    }
288
289    /// Perform add operation
290    pub async fn add<S, I>(&mut self, dn: S, attributes: I) -> Result<()>
291    where
292        S: AsRef<str>,
293        I: IntoIterator<Item = Attribute>,
294    {
295        let id = self.new_id();
296
297        let msg = LdapMessage::new(
298            id,
299            ProtocolOp::AddRequest(rasn_ldap::AddRequest {
300                entry: dn.as_ref().to_owned().into(),
301                attributes: attributes.into_iter().map(Into::into).collect(),
302            }),
303        );
304        let resp = self.connection.send_recv(msg).await?;
305
306        match resp.protocol_op {
307            ProtocolOp::AddResponse(resp) => {
308                check_result(LdapResult::new(
309                    resp.0.result_code,
310                    resp.0.matched_dn,
311                    resp.0.diagnostic_message,
312                ))?;
313                Ok(())
314            }
315            _ => Err(Error::InvalidResponse),
316        }
317    }
318
319    /// Perform delete operation
320    pub async fn delete<S: AsRef<str>>(&mut self, dn: S) -> Result<()> {
321        let id = self.new_id();
322
323        let msg = LdapMessage::new(
324            id,
325            ProtocolOp::DelRequest(rasn_ldap::DelRequest(dn.as_ref().to_owned().into())),
326        );
327        let resp = self.connection.send_recv(msg).await?;
328
329        match resp.protocol_op {
330            ProtocolOp::DelResponse(resp) => {
331                check_result(LdapResult::new(
332                    resp.0.result_code,
333                    resp.0.matched_dn,
334                    resp.0.diagnostic_message,
335                ))?;
336                Ok(())
337            }
338            _ => Err(Error::InvalidResponse),
339        }
340    }
341}
342
343/// Pages represents a stream of paged search results
344pub struct Pages {
345    page_control: Arc<RwLock<SimplePagedResultsControl>>,
346    page_finished: Arc<AtomicBool>,
347    client: LdapClient,
348    request: SearchRequest,
349    page_size: u32,
350    inner: Option<BoxFuture<'static, Result<SearchEntries>>>,
351}
352
353impl Pages {
354    fn is_page_finished(&self) -> bool {
355        self.page_finished.load(Ordering::SeqCst)
356    }
357}
358
359impl Stream for Pages {
360    type Item = Result<SearchEntries>;
361
362    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
363        if !self.page_control.read().has_entries() {
364            return Poll::Ready(None);
365        }
366
367        if self.inner.is_none() {
368            if !self.is_page_finished() {
369                return Poll::Ready(None);
370            }
371
372            let mut client = self.client.clone();
373            let request = self.request.clone();
374            let control_ref = self.page_control.clone();
375            let page_size = self.page_size;
376            let page_finished = self.page_finished.clone();
377
378            self.page_finished.store(false, Ordering::SeqCst);
379
380            let fut = async move {
381                let id = client.new_id();
382
383                let mut msg = LdapMessage::new(id, ProtocolOp::SearchRequest(request.into()));
384                msg.controls = Some(vec![control_ref.read().clone().with_size(page_size).try_into()?]);
385
386                let stream = client.connection.send_recv_stream(msg).await?;
387                Ok(SearchEntries {
388                    inner: stream,
389                    page_control: Some(control_ref),
390                    page_finished,
391                })
392            };
393            self.inner = Some(Box::pin(fut));
394        }
395
396        match Pin::new(self.inner.as_mut().unwrap()).poll(cx) {
397            Poll::Pending => Poll::Pending,
398            Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
399            Poll::Ready(Ok(entries)) => {
400                self.inner = None;
401                Poll::Ready(Some(Ok(entries)))
402            }
403        }
404    }
405}
406
407/// Search entries represents a stream of search results
408pub struct SearchEntries {
409    inner: MessageStream,
410    page_control: Option<Arc<RwLock<SimplePagedResultsControl>>>,
411    page_finished: Arc<AtomicBool>,
412}
413
414impl SearchEntries {
415    fn search_done(
416        self: Pin<&mut Self>,
417        controls: Option<Controls>,
418        done: SearchResultDone,
419    ) -> Poll<Option<Result<SearchEntry>>> {
420        self.page_finished.store(true, Ordering::SeqCst);
421
422        if done.0.result_code == ResultCode::Success {
423            if let Some(ref control_ref) = self.page_control {
424                let page_control = controls.and_then(|controls| {
425                    controls
426                        .into_iter()
427                        .find(|c| c.control_type == SimplePagedResultsControl::OID)
428                        .and_then(|c| SimplePagedResultsControl::try_from(c).ok())
429                });
430
431                if let Some(page_control) = page_control {
432                    *control_ref.write() = page_control;
433                    Poll::Ready(None)
434                } else {
435                    Poll::Ready(Some(Err(Error::InvalidResponse)))
436                }
437            } else {
438                Poll::Ready(None)
439            }
440        } else {
441            Poll::Ready(Some(Err(Error::OperationFailed(done.0.into()))))
442        }
443    }
444}
445
446impl Stream for SearchEntries {
447    type Item = Result<SearchEntry>;
448
449    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
450        loop {
451            let rc = match Pin::new(&mut self.inner).poll_next(cx) {
452                Poll::Pending => Poll::Pending,
453                Poll::Ready(None) => Poll::Ready(Some(Err(Error::ConnectionClosed))),
454                Poll::Ready(Some(msg)) => match msg.protocol_op {
455                    ProtocolOp::SearchResEntry(item) => Poll::Ready(Some(Ok(item.into()))),
456                    ProtocolOp::SearchResRef(_) => continue,
457                    ProtocolOp::SearchResDone(done) => self.search_done(msg.controls, done),
458                    _ => Poll::Ready(Some(Err(Error::InvalidResponse))),
459                },
460            };
461            return rc;
462        }
463    }
464}