use rasn::error::DecodeErrorKind;
use rasn_ldap::{LdapMessage, ProtocolOp};
use std::{
fmt::Display,
io::{ErrorKind, Read, Write},
net::{TcpStream, ToSocketAddrs},
time::Duration,
};
use crate::bind::Unbound;
pub const LDAP_PORT: u16 = 389;
pub const LDAPS_PORT: u16 = 636;
#[cfg(feature = "native-tls")]
pub mod native_tls;
#[cfg(feature = "rustls")]
pub mod rustls;
pub mod bind;
pub mod search;
mod unbind;
pub struct LdapConnection<Stream, BindState = Unbound>
where
Stream: Read + Write,
{
stream: Stream,
next_message_id: u32,
state: BindState,
}
impl LdapConnection<TcpStream, Unbound> {
pub fn connect(addr: impl ToSocketAddrs) -> Result<LdapConnection<TcpStream, Unbound>, std::io::Error> {
let stream = TcpStream::connect(addr)?;
stream.set_read_timeout(Some(Duration::from_secs(10)))?;
Ok(LdapConnection::new_unbound(stream))
}
}
impl<Stream: Read + Write> LdapConnection<Stream, Unbound> {
pub fn new_unbound(stream: Stream) -> LdapConnection<Stream, Unbound> {
LdapConnection {
stream,
next_message_id: 1,
state: Unbound { _priv: () },
}
}
}
impl<Stream: Read + Write, T> LdapConnection<Stream, T> {
fn get_and_increase_message_id(&mut self) -> u32 {
let next = self.next_message_id;
self.next_message_id += 1;
next
}
fn send_single_message(
&mut self,
protocol_op: ProtocolOp,
_controls: Option<()>,
) -> Result<ProtocolOp, MessageError> {
let message_id = self.get_and_increase_message_id();
let message = LdapMessage::new(message_id, protocol_op);
let encoded = rasn::ber::encode(&message).expect("Failed to encode BER message");
self.stream.write_all(&encoded).map_err(MessageError::Io)?;
let mut buf = Vec::new();
let mut temp_buffer = [0u8; 2048];
loop {
match self.stream.read(&mut temp_buffer).map_err(MessageError::Io)? {
0 => {
return Err(MessageError::Io(std::io::Error::new(
ErrorKind::ConnectionReset,
"connection closed",
)));
}
n => {
buf.extend_from_slice(&temp_buffer[..n]);
match rasn::ber::decode::<LdapMessage>(&buf) {
Ok(res) => {
if res.message_id != message_id {
return Err(MessageError::UnsolicitedResponse);
}
return Ok(res.protocol_op);
}
Err(e) if matches!(e.kind.as_ref(), DecodeErrorKind::Incomplete { .. }) => {
continue;
}
Err(e) => return Err(MessageError::Message(e)),
}
}
};
}
}
}
#[derive(Debug)]
enum MessageError {
Io(std::io::Error),
Message(rasn::ber::de::DecodeError),
UnsolicitedResponse,
}
impl std::error::Error for MessageError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(io) => Some(io),
Self::Message(m) => Some(m),
Self::UnsolicitedResponse => None,
}
}
}
impl Display for MessageError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(io) => write!(f, "io: {io}"),
Self::Message(m) => write!(f, "message: {m}"),
Self::UnsolicitedResponse => write!(f, "Message IDs don't align"),
}
}
}