use std::{
collections::HashMap,
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::channel::mpsc;
use futures::{SinkExt, Stream, StreamExt};
use log::debug;
use parking_lot::RwLock;
use crate::{
channel::{LdapChannel, LdapMessageReceiver, LdapMessageSender},
error::Error,
oid,
rasn_ldap::{LdapMessage, ProtocolOp},
TlsOptions,
};
const CHANNEL_SIZE: usize = 1024;
type RequestMap = Arc<RwLock<HashMap<u32, LdapMessageSender>>>;
#[derive(Clone)]
pub(crate) struct LdapConnection {
requests: RequestMap,
channel_sender: LdapMessageSender,
}
impl LdapConnection {
pub(crate) async fn connect<A>(address: A, port: u16, tls_options: TlsOptions) -> Result<Self, Error>
where
A: AsRef<str>,
{
let (channel_sender, mut channel_receiver) =
LdapChannel::for_client(address, port).connect(tls_options).await?;
let connection = Self {
requests: RequestMap::default(),
channel_sender,
};
let requests = connection.requests.clone();
tokio::spawn(async move {
while let Some(msg) = channel_receiver.next().await {
match msg.protocol_op {
ProtocolOp::ExtendedResp(resp)
if resp.response_name.as_deref() == Some(oid::NOTICE_OF_DISCONNECTION_OID) =>
{
debug!("Notice of disconnection received, exiting");
break;
}
_ => {
let sender = requests.read().get(&msg.message_id).cloned();
if let Some(mut sender) = sender {
let _ = sender.send(msg).await;
}
}
}
}
debug!("Connection terminated");
requests.write().clear();
});
Ok(connection)
}
pub(crate) async fn send_recv_stream(&mut self, msg: LdapMessage) -> Result<MessageStream, Error> {
let id = msg.message_id;
self.channel_sender.send(msg).await?;
let (tx, rx) = mpsc::channel(CHANNEL_SIZE);
self.requests.write().insert(id, tx);
Ok(MessageStream {
id,
requests: self.requests.clone(),
receiver: rx,
})
}
pub(crate) async fn send(&mut self, msg: LdapMessage) -> Result<(), Error> {
Ok(self.channel_sender.send(msg).await?)
}
pub(crate) async fn send_recv(&mut self, msg: LdapMessage) -> Result<LdapMessage, Error> {
Ok(self
.send_recv_stream(msg)
.await?
.next()
.await
.ok_or_else(|| io::Error::new(io::ErrorKind::ConnectionReset, "Connection closed"))?)
}
}
pub(crate) struct MessageStream {
id: u32,
requests: RequestMap,
receiver: LdapMessageReceiver,
}
impl Stream for MessageStream {
type Item = LdapMessage;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.receiver).poll_next(cx)
}
}
impl Drop for MessageStream {
fn drop(&mut self) {
self.requests.write().remove(&self.id);
}
}