1use std::{io, net::ToSocketAddrs, time::Duration};
4
5use futures::{
6 channel::mpsc::{self, Receiver, Sender},
7 future,
8 sink::SinkExt,
9 StreamExt, TryStreamExt,
10};
11use log::{debug, error};
12use rasn_ldap::LdapMessage;
13use tokio::{
14 io::{AsyncRead, AsyncWrite},
15 net::TcpStream,
16};
17
18use crate::{
19 codec::LdapCodec,
20 error::Error,
21 options::{TlsKind, TlsOptions},
22 TlsBackend,
23};
24
25const CHANNEL_SIZE: usize = 1024;
26const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
27
28pub type LdapMessageSender = Sender<LdapMessage>;
29pub type LdapMessageReceiver = Receiver<LdapMessage>;
30
31trait TlsStream: AsyncRead + AsyncWrite + Unpin + Send {}
32
33#[cfg(feature = "tls-native-tls")]
34impl<T: AsyncRead + AsyncWrite + Unpin + Send> TlsStream for tokio_native_tls::TlsStream<T> {}
35
36#[cfg(feature = "tls-rustls")]
37impl<T: AsyncRead + AsyncWrite + Unpin + Send> TlsStream for tokio_rustls::client::TlsStream<T> {}
38
39fn io_error<E>(e: E) -> io::Error
40where
41 E: Into<Box<dyn std::error::Error + Send + Sync>>,
42{
43 io::Error::new(io::ErrorKind::InvalidData, e)
44}
45
46fn make_channel<S>(stream: S) -> (LdapMessageSender, LdapMessageReceiver)
47where
48 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
49{
50 let framed = tokio_util::codec::Framed::new(stream, LdapCodec);
52
53 let (tx_in, rx_in) = mpsc::channel(CHANNEL_SIZE);
57
58 let (tx_out, rx_out) = mpsc::channel(CHANNEL_SIZE);
62
63 let channel = async move {
64 let (mut sink, stream) = framed.split();
66
67 let mut rx = rx_out.map(Ok::<_, Error>);
69
70 let to_wire = sink.send_all(&mut rx);
72
73 let mut tx = tx_in.sink_map_err(io_error);
75
76 let from_wire = stream.map_err(io_error).forward(&mut tx);
78
79 future::select(to_wire, from_wire).await;
81 };
82
83 tokio::spawn(channel);
85
86 (tx_out, rx_in)
88}
89
90#[derive(Debug, thiserror::Error)]
92pub enum ChannelError {
93 #[error(transparent)]
94 IoError(#[from] io::Error),
95
96 #[error(transparent)]
97 ConnectTimeout(#[from] tokio::time::error::Elapsed),
98
99 #[error("STARTTLS failed")]
100 StartTlsFailed,
101
102 #[cfg(feature = "tls-native-tls")]
103 #[error(transparent)]
104 NativeTls(#[from] native_tls::Error),
105
106 #[cfg(feature = "tls-rustls")]
107 #[error(transparent)]
108 Rustls(#[from] rustls::Error),
109
110 #[cfg(feature = "tls-rustls")]
111 #[error(transparent)]
112 DnsName(#[from] rustls_pki_types::InvalidDnsNameError),
113}
114
115pub type ChannelResult<T> = Result<T, ChannelError>;
116
117pub struct LdapChannel {
119 address: String,
120 port: u16,
121}
122
123impl LdapChannel {
124 pub fn for_client<S>(address: S, port: u16) -> Self
126 where
127 S: AsRef<str>,
128 {
129 LdapChannel {
130 address: address.as_ref().to_owned(),
131 port,
132 }
133 }
134
135 pub async fn connect(self, tls_options: TlsOptions) -> ChannelResult<(LdapMessageSender, LdapMessageReceiver)> {
138 let mut addrs = (self.address.as_ref(), self.port).to_socket_addrs()?;
139 let address = addrs.next().ok_or_else(|| io_error("Address resolution error"))?;
140
141 let stream = tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect(&address)).await??;
143
144 debug!("Connection established to {}", address);
145
146 let channel = match tls_options.kind {
147 TlsKind::Plain => make_channel(stream),
148 #[cfg(tls)]
149 TlsKind::Tls => make_channel(self.tls_connect(tls_options, stream).await?),
150 #[cfg(tls)]
151 TlsKind::StartTls => make_channel(self.starttls_connect(tls_options, stream).await?),
152 };
153 Ok(channel)
154 }
155
156 async fn tls_connect<S>(&self, tls_options: TlsOptions, stream: S) -> ChannelResult<Box<dyn TlsStream>>
157 where
158 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
159 {
160 match tls_options.backend.unwrap_or_default() {
161 #[cfg(feature = "tls-native-tls")]
162 TlsBackend::Native(connector) => Ok(Box::new(
163 self.tls_connect_native_tls(tls_options.domain_name, connector, stream)
164 .await?,
165 )),
166 #[cfg(feature = "tls-rustls")]
167 TlsBackend::Rustls(client_config) => Ok(Box::new(
168 self.tls_connect_rustls(tls_options.domain_name, client_config, stream)
169 .await?,
170 )),
171 }
172 }
173
174 #[cfg(tls)]
175 async fn starttls_connect<S>(
176 &self,
177 tls_options: TlsOptions,
178 mut stream: S,
179 ) -> ChannelResult<impl AsyncRead + AsyncWrite + Unpin + Send>
180 where
181 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
182 {
183 use log::warn;
184 use rasn_ldap::{ExtendedRequest, ProtocolOp, ResultCode};
185
186 const STARTTLS_TIMEOUT: Duration = Duration::from_secs(30);
187
188 debug!("Begin STARTTLS negotiation");
189 let mut framed = tokio_util::codec::Framed::new(&mut stream, LdapCodec);
190 let req = ExtendedRequest {
191 request_name: crate::oid::STARTTLS_OID.into(),
192 request_value: None,
193 };
194 framed
195 .send(LdapMessage::new(1, ProtocolOp::ExtendedReq(req)))
196 .await
197 .map_err(|_| ChannelError::StartTlsFailed)?;
198 match tokio::time::timeout(STARTTLS_TIMEOUT, framed.next()).await {
199 Ok(Some(Ok(item))) => match item.protocol_op {
200 ProtocolOp::ExtendedResp(resp) if resp.result_code == ResultCode::Success && item.message_id == 1 => {
201 debug!("End STARTTLS negotiation, switching protocols");
202 return self.tls_connect(tls_options, stream).await;
203 }
204 _ => {
205 warn!("STARTTLS negotiation failed");
206 }
207 },
208 Err(_) => {
209 warn!("Timeout occurred while waiting for STARTTLS reply");
210 }
211 _ => {
212 warn!("Unexpected response while waiting for STARTTLS reply");
213 }
214 }
215 Err(ChannelError::StartTlsFailed)
216 }
217
218 #[cfg(feature = "tls-native-tls")]
219 async fn tls_connect_native_tls<S>(
220 &self,
221 domain_name: Option<String>,
222 tls_connector: native_tls::TlsConnector,
223 stream: S,
224 ) -> ChannelResult<tokio_native_tls::TlsStream<S>>
225 where
226 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
227 {
228 let domain = domain_name.as_deref().unwrap_or(&self.address);
229
230 debug!("Performing TLS handshake using native-tls, SNI: {}", domain);
231
232 let tokio_connector = tokio_native_tls::TlsConnector::from(tls_connector);
233
234 let stream = tokio_connector
235 .connect(domain, stream)
236 .await
237 .map_err(ChannelError::NativeTls)?;
238
239 debug!("TLS handshake succeeded!");
240
241 Ok(stream)
242 }
243
244 #[cfg(feature = "tls-rustls")]
245 async fn tls_connect_rustls<S>(
246 &self,
247 domain_name: Option<String>,
248 client_config: rustls::ClientConfig,
249 stream: S,
250 ) -> ChannelResult<tokio_rustls::client::TlsStream<S>>
251 where
252 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
253 {
254 use rustls_pki_types::ServerName;
255 use std::sync::Arc;
256
257 let domain = ServerName::try_from(domain_name.as_deref().unwrap_or(&self.address).to_owned())?;
258
259 debug!("Performing TLS handshake using rustls, SNI: {:?}", domain);
260
261 let tokio_connector = tokio_rustls::TlsConnector::from(Arc::new(client_config));
262 let stream = tokio_connector.connect(domain, stream).await?;
263
264 debug!("TLS handshake succeeded!");
265
266 Ok(stream)
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use std::{
273 net::{SocketAddr, ToSocketAddrs},
274 sync::{
275 atomic::{AtomicUsize, Ordering},
276 Arc,
277 },
278 };
279
280 use rasn_ldap::{ProtocolOp, UnbindRequest};
281 use tokio::net::TcpListener;
282 use tokio_util::codec::Framed;
283
284 use super::*;
285
286 fn new_msg() -> LdapMessage {
287 LdapMessage::new(1, ProtocolOp::UnbindRequest(UnbindRequest))
288 }
289
290 async fn start_server(address: &SocketAddr, num_msgs: usize) {
291 let tcp = TcpListener::bind(&address).await.unwrap();
292
293 tokio::spawn(async move {
294 if let Ok((stream, _)) = tcp.accept().await {
295 let framed = Framed::new(stream, LdapCodec);
296 let (mut sink, stream) = framed.split();
297 sink.send_all(&mut stream.take(num_msgs)).await.unwrap();
298 }
299 });
300 }
301
302 #[tokio::test]
303 async fn test_connection_success() {
304 let address = ("127.0.0.1", 22561);
305
306 let socket_address = address.to_socket_addrs().unwrap().next().unwrap();
307
308 let counter = Arc::new(AtomicUsize::new(0));
309 let flag = counter.clone();
310
311 let res = {
312 start_server(&socket_address, 2).await;
313
314 let (mut sender, mut receiver) = LdapChannel::for_client(address.0, address.1)
315 .connect(TlsOptions::default())
316 .await
317 .unwrap();
318 let msg = new_msg();
319
320 sender.send(msg.clone()).await.unwrap();
321 sender.send(msg.clone()).await.unwrap();
322
323 while let Some(m) = receiver.next().await {
324 assert_eq!(msg, m);
325 flag.fetch_add(1, Ordering::SeqCst);
326 }
327 Ok::<(), ()>(())
328 };
329 assert!(res.is_ok());
330 assert_eq!(counter.load(Ordering::SeqCst), 2);
331 }
332
333 #[tokio::test]
334 async fn test_connection_fail() {
335 let res = LdapChannel::for_client("127.0.0.1", 32222)
336 .connect(TlsOptions::default())
337 .await;
338
339 assert!(res.is_err());
340 }
341}