ldap_rs/
channel.rs

1//! Low-level LDAP channel operations
2
3use std::{io, net::ToSocketAddrs, time::Duration};
4
5use futures::{
6    channel::mpsc::{self, Receiver, Sender},
7    future,
8    sink::SinkExt,
9    StreamExt, TryStreamExt,
10};
11use log::{debug, error};
12use rasn_ldap::LdapMessage;
13use tokio::{
14    io::{AsyncRead, AsyncWrite},
15    net::TcpStream,
16};
17
18use crate::{
19    codec::LdapCodec,
20    error::Error,
21    options::{TlsKind, TlsOptions},
22    TlsBackend,
23};
24
25const CHANNEL_SIZE: usize = 1024;
26const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
27
28pub type LdapMessageSender = Sender<LdapMessage>;
29pub type LdapMessageReceiver = Receiver<LdapMessage>;
30
31trait TlsStream: AsyncRead + AsyncWrite + Unpin + Send {}
32
33#[cfg(feature = "tls-native-tls")]
34impl<T: AsyncRead + AsyncWrite + Unpin + Send> TlsStream for tokio_native_tls::TlsStream<T> {}
35
36#[cfg(feature = "tls-rustls")]
37impl<T: AsyncRead + AsyncWrite + Unpin + Send> TlsStream for tokio_rustls::client::TlsStream<T> {}
38
39fn io_error<E>(e: E) -> io::Error
40where
41    E: Into<Box<dyn std::error::Error + Send + Sync>>,
42{
43    io::Error::new(io::ErrorKind::InvalidData, e)
44}
45
46fn make_channel<S>(stream: S) -> (LdapMessageSender, LdapMessageReceiver)
47where
48    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
49{
50    // construct framed instance based on LdapCodec
51    let framed = tokio_util::codec::Framed::new(stream, LdapCodec);
52
53    // The 'in' channel:
54    // Messages received from the socket will be forwarded to tx_in
55    // and received by the external client via rx_in endpoint
56    let (tx_in, rx_in) = mpsc::channel(CHANNEL_SIZE);
57
58    // The 'out' channel:
59    // Messages sent to tx_out by external clients will be picked up on rx_out endpoint
60    // and forwarded to socket
61    let (tx_out, rx_out) = mpsc::channel(CHANNEL_SIZE);
62
63    let channel = async move {
64        // sink is the sending part, stream is the receiving part
65        let (mut sink, stream) = framed.split();
66
67        // we receive LdapMessage messages from the clients and convert to stream chunks
68        let mut rx = rx_out.map(Ok::<_, Error>);
69
70        // app -> socket
71        let to_wire = sink.send_all(&mut rx);
72
73        // convert incoming channel errors into io::Error
74        let mut tx = tx_in.sink_map_err(io_error);
75
76        // app <- socket
77        let from_wire = stream.map_err(io_error).forward(&mut tx);
78
79        // await for either of futures: terminating one side will drop the other
80        future::select(to_wire, from_wire).await;
81    };
82
83    // spawn in the background
84    tokio::spawn(channel);
85
86    // we return (tx_out, rx_in) pair so that the consumer can send and receive messages
87    (tx_out, rx_in)
88}
89
90/// LDAP channel errors
91#[derive(Debug, thiserror::Error)]
92pub enum ChannelError {
93    #[error(transparent)]
94    IoError(#[from] io::Error),
95
96    #[error(transparent)]
97    ConnectTimeout(#[from] tokio::time::error::Elapsed),
98
99    #[error("STARTTLS failed")]
100    StartTlsFailed,
101
102    #[cfg(feature = "tls-native-tls")]
103    #[error(transparent)]
104    NativeTls(#[from] native_tls::Error),
105
106    #[cfg(feature = "tls-rustls")]
107    #[error(transparent)]
108    Rustls(#[from] rustls::Error),
109
110    #[cfg(feature = "tls-rustls")]
111    #[error(transparent)]
112    DnsName(#[from] rustls_pki_types::InvalidDnsNameError),
113}
114
115pub type ChannelResult<T> = Result<T, ChannelError>;
116
117/// LDAP TCP channel connector
118pub struct LdapChannel {
119    address: String,
120    port: u16,
121}
122
123impl LdapChannel {
124    /// Create a client-side channel with a given server address and port
125    pub fn for_client<S>(address: S, port: u16) -> Self
126    where
127        S: AsRef<str>,
128    {
129        LdapChannel {
130            address: address.as_ref().to_owned(),
131            port,
132        }
133    }
134
135    /// Connect to a server
136    /// Returns a pair of (sender, receiver) endpoints
137    pub async fn connect(self, tls_options: TlsOptions) -> ChannelResult<(LdapMessageSender, LdapMessageReceiver)> {
138        let mut addrs = (self.address.as_ref(), self.port).to_socket_addrs()?;
139        let address = addrs.next().ok_or_else(|| io_error("Address resolution error"))?;
140
141        // TCP connect with a timeout
142        let stream = tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect(&address)).await??;
143
144        debug!("Connection established to {}", address);
145
146        let channel = match tls_options.kind {
147            TlsKind::Plain => make_channel(stream),
148            #[cfg(tls)]
149            TlsKind::Tls => make_channel(self.tls_connect(tls_options, stream).await?),
150            #[cfg(tls)]
151            TlsKind::StartTls => make_channel(self.starttls_connect(tls_options, stream).await?),
152        };
153        Ok(channel)
154    }
155
156    async fn tls_connect<S>(&self, tls_options: TlsOptions, stream: S) -> ChannelResult<Box<dyn TlsStream>>
157    where
158        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
159    {
160        match tls_options.backend.unwrap_or_default() {
161            #[cfg(feature = "tls-native-tls")]
162            TlsBackend::Native(connector) => Ok(Box::new(
163                self.tls_connect_native_tls(tls_options.domain_name, connector, stream)
164                    .await?,
165            )),
166            #[cfg(feature = "tls-rustls")]
167            TlsBackend::Rustls(client_config) => Ok(Box::new(
168                self.tls_connect_rustls(tls_options.domain_name, client_config, stream)
169                    .await?,
170            )),
171        }
172    }
173
174    #[cfg(tls)]
175    async fn starttls_connect<S>(
176        &self,
177        tls_options: TlsOptions,
178        mut stream: S,
179    ) -> ChannelResult<impl AsyncRead + AsyncWrite + Unpin + Send>
180    where
181        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
182    {
183        use log::warn;
184        use rasn_ldap::{ExtendedRequest, ProtocolOp, ResultCode};
185
186        const STARTTLS_TIMEOUT: Duration = Duration::from_secs(30);
187
188        debug!("Begin STARTTLS negotiation");
189        let mut framed = tokio_util::codec::Framed::new(&mut stream, LdapCodec);
190        let req = ExtendedRequest {
191            request_name: crate::oid::STARTTLS_OID.into(),
192            request_value: None,
193        };
194        framed
195            .send(LdapMessage::new(1, ProtocolOp::ExtendedReq(req)))
196            .await
197            .map_err(|_| ChannelError::StartTlsFailed)?;
198        match tokio::time::timeout(STARTTLS_TIMEOUT, framed.next()).await {
199            Ok(Some(Ok(item))) => match item.protocol_op {
200                ProtocolOp::ExtendedResp(resp) if resp.result_code == ResultCode::Success && item.message_id == 1 => {
201                    debug!("End STARTTLS negotiation, switching protocols");
202                    return self.tls_connect(tls_options, stream).await;
203                }
204                _ => {
205                    warn!("STARTTLS negotiation failed");
206                }
207            },
208            Err(_) => {
209                warn!("Timeout occurred while waiting for STARTTLS reply");
210            }
211            _ => {
212                warn!("Unexpected response while waiting for STARTTLS reply");
213            }
214        }
215        Err(ChannelError::StartTlsFailed)
216    }
217
218    #[cfg(feature = "tls-native-tls")]
219    async fn tls_connect_native_tls<S>(
220        &self,
221        domain_name: Option<String>,
222        tls_connector: native_tls::TlsConnector,
223        stream: S,
224    ) -> ChannelResult<tokio_native_tls::TlsStream<S>>
225    where
226        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
227    {
228        let domain = domain_name.as_deref().unwrap_or(&self.address);
229
230        debug!("Performing TLS handshake using native-tls, SNI: {}", domain);
231
232        let tokio_connector = tokio_native_tls::TlsConnector::from(tls_connector);
233
234        let stream = tokio_connector
235            .connect(domain, stream)
236            .await
237            .map_err(ChannelError::NativeTls)?;
238
239        debug!("TLS handshake succeeded!");
240
241        Ok(stream)
242    }
243
244    #[cfg(feature = "tls-rustls")]
245    async fn tls_connect_rustls<S>(
246        &self,
247        domain_name: Option<String>,
248        client_config: rustls::ClientConfig,
249        stream: S,
250    ) -> ChannelResult<tokio_rustls::client::TlsStream<S>>
251    where
252        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
253    {
254        use rustls_pki_types::ServerName;
255        use std::sync::Arc;
256
257        let domain = ServerName::try_from(domain_name.as_deref().unwrap_or(&self.address).to_owned())?;
258
259        debug!("Performing TLS handshake using rustls, SNI: {:?}", domain);
260
261        let tokio_connector = tokio_rustls::TlsConnector::from(Arc::new(client_config));
262        let stream = tokio_connector.connect(domain, stream).await?;
263
264        debug!("TLS handshake succeeded!");
265
266        Ok(stream)
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use std::{
273        net::{SocketAddr, ToSocketAddrs},
274        sync::{
275            atomic::{AtomicUsize, Ordering},
276            Arc,
277        },
278    };
279
280    use rasn_ldap::{ProtocolOp, UnbindRequest};
281    use tokio::net::TcpListener;
282    use tokio_util::codec::Framed;
283
284    use super::*;
285
286    fn new_msg() -> LdapMessage {
287        LdapMessage::new(1, ProtocolOp::UnbindRequest(UnbindRequest))
288    }
289
290    async fn start_server(address: &SocketAddr, num_msgs: usize) {
291        let tcp = TcpListener::bind(&address).await.unwrap();
292
293        tokio::spawn(async move {
294            if let Ok((stream, _)) = tcp.accept().await {
295                let framed = Framed::new(stream, LdapCodec);
296                let (mut sink, stream) = framed.split();
297                sink.send_all(&mut stream.take(num_msgs)).await.unwrap();
298            }
299        });
300    }
301
302    #[tokio::test]
303    async fn test_connection_success() {
304        let address = ("127.0.0.1", 22561);
305
306        let socket_address = address.to_socket_addrs().unwrap().next().unwrap();
307
308        let counter = Arc::new(AtomicUsize::new(0));
309        let flag = counter.clone();
310
311        let res = {
312            start_server(&socket_address, 2).await;
313
314            let (mut sender, mut receiver) = LdapChannel::for_client(address.0, address.1)
315                .connect(TlsOptions::default())
316                .await
317                .unwrap();
318            let msg = new_msg();
319
320            sender.send(msg.clone()).await.unwrap();
321            sender.send(msg.clone()).await.unwrap();
322
323            while let Some(m) = receiver.next().await {
324                assert_eq!(msg, m);
325                flag.fetch_add(1, Ordering::SeqCst);
326            }
327            Ok::<(), ()>(())
328        };
329        assert!(res.is_ok());
330        assert_eq!(counter.load(Ordering::SeqCst), 2);
331    }
332
333    #[tokio::test]
334    async fn test_connection_fail() {
335        let res = LdapChannel::for_client("127.0.0.1", 32222)
336            .connect(TlsOptions::default())
337            .await;
338
339        assert!(res.is_err());
340    }
341}