1use std::{
4 collections::VecDeque,
5 convert::{TryFrom, TryInto},
6 pin::Pin,
7 sync::{
8 atomic::{AtomicBool, AtomicU32, Ordering},
9 Arc,
10 },
11 task::{Context, Poll},
12};
13
14use futures::{future::BoxFuture, Future, Stream, TryStreamExt};
15use parking_lot::RwLock;
16use rasn_ldap::{
17 AuthenticationChoice, BindRequest, BindResponse, Controls, ExtendedRequest, LdapMessage, LdapResult, ProtocolOp,
18 ResultCode, SaslCredentials, SearchResultDone, UnbindRequest,
19};
20
21use crate::{
22 conn::{LdapConnection, MessageStream},
23 controls::SimplePagedResultsControl,
24 error::Error,
25 oid,
26 options::TlsOptions,
27 request::SearchRequest,
28 Attribute, ModifyRequest, SearchEntry,
29};
30
31pub type Result<T> = std::result::Result<T, Error>;
32
33fn check_result(result: LdapResult) -> Result<()> {
34 if result.result_code == ResultCode::Success || result.result_code == ResultCode::SaslBindInProgress {
35 Ok(())
36 } else {
37 Err(Error::OperationFailed(result.into()))
38 }
39}
40
41pub struct LdapClientBuilder {
43 address: String,
44 port: u16,
45 tls_options: TlsOptions,
46}
47
48impl LdapClientBuilder {
49 pub fn port(mut self, port: u16) -> Self {
51 self.port = port;
52 self
53 }
54
55 pub fn tls_options(mut self, options: TlsOptions) -> Self {
57 self.tls_options = options;
58 self
59 }
60
61 pub async fn connect(self) -> Result<LdapClient> {
63 LdapClient::connect(self.address, self.port, self.tls_options).await
64 }
65}
66
67#[derive(Clone)]
69pub struct LdapClient {
70 connection: LdapConnection,
71 id_counter: Arc<AtomicU32>,
72}
73
74impl LdapClient {
75 pub fn builder<A: AsRef<str>>(address: A) -> LdapClientBuilder {
77 LdapClientBuilder {
78 address: address.as_ref().to_owned(),
79 port: 389,
80 tls_options: TlsOptions::default(),
81 }
82 }
83
84 pub(crate) async fn connect<A>(address: A, port: u16, tls_options: TlsOptions) -> Result<Self>
85 where
86 A: AsRef<str>,
87 {
88 let connection = LdapConnection::connect(address, port, tls_options).await?;
89 Ok(Self {
90 connection,
91 id_counter: Arc::new(AtomicU32::new(2)), })
93 }
94
95 fn new_id(&self) -> u32 {
96 self.id_counter.fetch_add(1, Ordering::SeqCst)
97 }
98
99 async fn do_bind(&mut self, req: BindRequest) -> Result<BindResponse> {
100 let id = self.new_id();
101 let msg = LdapMessage::new(id, ProtocolOp::BindRequest(req));
102
103 let item = self.connection.send_recv(msg).await?;
104
105 match item.protocol_op {
106 ProtocolOp::BindResponse(resp) => {
107 let result = resp.clone();
108 check_result(LdapResult::new(
109 resp.result_code,
110 resp.matched_dn,
111 resp.diagnostic_message,
112 ))?;
113 Ok(result)
114 }
115 _ => Err(Error::InvalidResponse),
116 }
117 }
118
119 fn new_sasl_bind_req(&self, mech: &str, creds: Option<&[u8]>) -> BindRequest {
120 let auth_choice =
121 AuthenticationChoice::Sasl(SaslCredentials::new(mech.into(), creds.map(|c| c.to_vec().into())));
122 BindRequest::new(3, String::new().into(), auth_choice)
123 }
124
125 pub async fn simple_bind<U, P>(&mut self, username: U, password: P) -> Result<()>
127 where
128 U: AsRef<str>,
129 P: AsRef<str>,
130 {
131 let auth_choice = AuthenticationChoice::Simple(password.as_ref().to_owned().into());
132 let req = BindRequest::new(3, username.as_ref().to_owned().into(), auth_choice);
133 self.do_bind(req).await?;
134 Ok(())
135 }
136
137 pub async fn sasl_external_bind(&mut self) -> Result<()> {
139 let req = self.new_sasl_bind_req("EXTERNAL", None);
140 self.do_bind(req).await?;
141 Ok(())
142 }
143
144 #[cfg(feature = "gssapi")]
145 pub async fn sasl_gssapi_bind<S: AsRef<str>>(&mut self, realm: S) -> Result<()> {
150 use cross_krb5::{ClientCtx, InitiateFlags, K5Ctx, Step};
152
153 const SASL_RECV_MAX_SIZE: u32 = 0x0200_0000;
154
155 let spn = format!("ldap/{}", realm.as_ref());
156
157 let (client_ctx, token) =
158 ClientCtx::new(InitiateFlags::empty(), None, &spn, None).map_err(|e| Error::GssApiError(e.to_string()))?;
159
160 let req = self.new_sasl_bind_req("GSSAPI", Some(token.as_ref()));
161 let response = self.do_bind(req).await?;
162
163 let token = match response.server_sasl_creds {
164 Some(token) => token,
165 _ => return Err(Error::NoSaslCredentials),
166 };
167
168 let step = client_ctx
169 .step(&token)
170 .map_err(|e| Error::GssApiError(format!("{}", e)))?;
171
172 let mut client_ctx = match step {
173 Step::Finished((ctx, None)) => ctx,
174 _ => {
175 return Err(Error::GssApiError(
176 "GSSAPI exchange not finished or has an additional token".to_owned(),
177 ))
178 }
179 };
180
181 let req = self.new_sasl_bind_req("GSSAPI", None);
182 let response = self.do_bind(req).await?;
183
184 if response.server_sasl_creds.is_none() {
185 return Err(Error::NoSaslCredentials);
186 }
187
188 let recv_max_size = SASL_RECV_MAX_SIZE.to_be_bytes();
189 let size_msg = client_ctx
190 .wrap(true, &recv_max_size)
191 .map_err(|e| Error::GssApiError(format!("{}", e)))?;
192
193 let req = self.new_sasl_bind_req("GSSAPI", Some(size_msg.as_ref()));
194 self.do_bind(req).await?;
195
196 Ok(())
197 }
198
199 pub async fn unbind(&mut self) -> Result<()> {
201 let id = self.new_id();
202
203 let msg = LdapMessage::new(id, ProtocolOp::UnbindRequest(UnbindRequest));
204 self.connection.send(msg).await?;
205
206 Ok(())
207 }
208
209 pub async fn whoami(&mut self) -> Result<Option<String>> {
211 let id = self.new_id();
212
213 let msg = LdapMessage::new(
214 id,
215 ProtocolOp::ExtendedReq(ExtendedRequest {
216 request_name: oid::WHOAMI_OID.into(),
217 request_value: None,
218 }),
219 );
220
221 let resp = self.connection.send_recv(msg).await?;
222
223 match resp.protocol_op {
224 ProtocolOp::ExtendedResp(resp) => {
225 check_result(LdapResult::new(
226 resp.result_code,
227 resp.matched_dn,
228 resp.diagnostic_message,
229 ))?;
230 Ok(resp.response_value.map(|v| String::from_utf8_lossy(&v).into_owned()))
231 }
232 _ => Err(Error::InvalidResponse),
233 }
234 }
235
236 pub async fn search(&mut self, request: SearchRequest) -> Result<SearchEntries> {
238 let id = self.new_id();
239
240 let msg = LdapMessage::new(id, ProtocolOp::SearchRequest(request.into()));
241 let stream = self.connection.send_recv_stream(msg).await?;
242
243 Ok(SearchEntries {
244 inner: stream,
245 page_control: None,
246 page_finished: Arc::new(AtomicBool::new(false)),
247 })
248 }
249
250 pub async fn search_one(&mut self, request: SearchRequest) -> Result<Option<SearchEntry>> {
252 let entries = self.search(request).await?;
253 let mut attrs = entries.try_collect::<VecDeque<_>>().await?;
254 Ok(attrs.pop_front())
255 }
256
257 pub fn search_paged(&mut self, request: SearchRequest, page_size: u32) -> Pages {
259 Pages {
260 page_control: Arc::new(RwLock::new(SimplePagedResultsControl::new(page_size))),
261 page_finished: Arc::new(AtomicBool::new(true)),
262 client: self.clone(),
263 request,
264 page_size,
265 inner: None,
266 }
267 }
268
269 pub async fn modify(&mut self, request: ModifyRequest) -> Result<()> {
271 let id = self.new_id();
272
273 let msg = LdapMessage::new(id, ProtocolOp::ModifyRequest(request.into()));
274 let resp = self.connection.send_recv(msg).await?;
275
276 match resp.protocol_op {
277 ProtocolOp::ModifyResponse(resp) => {
278 check_result(LdapResult::new(
279 resp.0.result_code,
280 resp.0.matched_dn,
281 resp.0.diagnostic_message,
282 ))?;
283 Ok(())
284 }
285 _ => Err(Error::InvalidResponse),
286 }
287 }
288
289 pub async fn add<S, I>(&mut self, dn: S, attributes: I) -> Result<()>
291 where
292 S: AsRef<str>,
293 I: IntoIterator<Item = Attribute>,
294 {
295 let id = self.new_id();
296
297 let msg = LdapMessage::new(
298 id,
299 ProtocolOp::AddRequest(rasn_ldap::AddRequest {
300 entry: dn.as_ref().to_owned().into(),
301 attributes: attributes.into_iter().map(Into::into).collect(),
302 }),
303 );
304 let resp = self.connection.send_recv(msg).await?;
305
306 match resp.protocol_op {
307 ProtocolOp::AddResponse(resp) => {
308 check_result(LdapResult::new(
309 resp.0.result_code,
310 resp.0.matched_dn,
311 resp.0.diagnostic_message,
312 ))?;
313 Ok(())
314 }
315 _ => Err(Error::InvalidResponse),
316 }
317 }
318
319 pub async fn delete<S: AsRef<str>>(&mut self, dn: S) -> Result<()> {
321 let id = self.new_id();
322
323 let msg = LdapMessage::new(
324 id,
325 ProtocolOp::DelRequest(rasn_ldap::DelRequest(dn.as_ref().to_owned().into())),
326 );
327 let resp = self.connection.send_recv(msg).await?;
328
329 match resp.protocol_op {
330 ProtocolOp::DelResponse(resp) => {
331 check_result(LdapResult::new(
332 resp.0.result_code,
333 resp.0.matched_dn,
334 resp.0.diagnostic_message,
335 ))?;
336 Ok(())
337 }
338 _ => Err(Error::InvalidResponse),
339 }
340 }
341}
342
343pub struct Pages {
345 page_control: Arc<RwLock<SimplePagedResultsControl>>,
346 page_finished: Arc<AtomicBool>,
347 client: LdapClient,
348 request: SearchRequest,
349 page_size: u32,
350 inner: Option<BoxFuture<'static, Result<SearchEntries>>>,
351}
352
353impl Pages {
354 fn is_page_finished(&self) -> bool {
355 self.page_finished.load(Ordering::SeqCst)
356 }
357}
358
359impl Stream for Pages {
360 type Item = Result<SearchEntries>;
361
362 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
363 if !self.page_control.read().has_entries() {
364 return Poll::Ready(None);
365 }
366
367 if self.inner.is_none() {
368 if !self.is_page_finished() {
369 return Poll::Ready(None);
370 }
371
372 let mut client = self.client.clone();
373 let request = self.request.clone();
374 let control_ref = self.page_control.clone();
375 let page_size = self.page_size;
376 let page_finished = self.page_finished.clone();
377
378 self.page_finished.store(false, Ordering::SeqCst);
379
380 let fut = async move {
381 let id = client.new_id();
382
383 let mut msg = LdapMessage::new(id, ProtocolOp::SearchRequest(request.into()));
384 msg.controls = Some(vec![control_ref.read().clone().with_size(page_size).try_into()?]);
385
386 let stream = client.connection.send_recv_stream(msg).await?;
387 Ok(SearchEntries {
388 inner: stream,
389 page_control: Some(control_ref),
390 page_finished,
391 })
392 };
393 self.inner = Some(Box::pin(fut));
394 }
395
396 match Pin::new(self.inner.as_mut().unwrap()).poll(cx) {
397 Poll::Pending => Poll::Pending,
398 Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
399 Poll::Ready(Ok(entries)) => {
400 self.inner = None;
401 Poll::Ready(Some(Ok(entries)))
402 }
403 }
404 }
405}
406
407pub struct SearchEntries {
409 inner: MessageStream,
410 page_control: Option<Arc<RwLock<SimplePagedResultsControl>>>,
411 page_finished: Arc<AtomicBool>,
412}
413
414impl SearchEntries {
415 fn search_done(
416 self: Pin<&mut Self>,
417 controls: Option<Controls>,
418 done: SearchResultDone,
419 ) -> Poll<Option<Result<SearchEntry>>> {
420 self.page_finished.store(true, Ordering::SeqCst);
421
422 if done.0.result_code == ResultCode::Success {
423 if let Some(ref control_ref) = self.page_control {
424 let page_control = controls.and_then(|controls| {
425 controls
426 .into_iter()
427 .find(|c| c.control_type == SimplePagedResultsControl::OID)
428 .and_then(|c| SimplePagedResultsControl::try_from(c).ok())
429 });
430
431 if let Some(page_control) = page_control {
432 *control_ref.write() = page_control;
433 Poll::Ready(None)
434 } else {
435 Poll::Ready(Some(Err(Error::InvalidResponse)))
436 }
437 } else {
438 Poll::Ready(None)
439 }
440 } else {
441 Poll::Ready(Some(Err(Error::OperationFailed(done.0.into()))))
442 }
443 }
444}
445
446impl Stream for SearchEntries {
447 type Item = Result<SearchEntry>;
448
449 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
450 loop {
451 let rc = match Pin::new(&mut self.inner).poll_next(cx) {
452 Poll::Pending => Poll::Pending,
453 Poll::Ready(None) => Poll::Ready(Some(Err(Error::ConnectionClosed))),
454 Poll::Ready(Some(msg)) => match msg.protocol_op {
455 ProtocolOp::SearchResEntry(item) => Poll::Ready(Some(Ok(item.into()))),
456 ProtocolOp::SearchResRef(_) => continue,
457 ProtocolOp::SearchResDone(done) => self.search_done(msg.controls, done),
458 _ => Poll::Ready(Some(Err(Error::InvalidResponse))),
459 },
460 };
461 return rc;
462 }
463 }
464}