ldap3/
protocol.rs

1use std::io;
2#[cfg(feature = "gssapi")]
3use std::sync::RwLock;
4#[cfg(feature = "gssapi")]
5use std::sync::{Arc, Mutex};
6
7use crate::RequestId;
8use crate::controls::{Control, RawControl};
9use crate::controls_impl::{build_tag, parse_controls};
10use crate::search::SearchItem;
11
12use lber::common::TagClass;
13use lber::parse::parse_uint;
14use lber::structure::{PL, StructureTag};
15use lber::structures::{ASNTag, Integer, Sequence, Tag};
16use lber::universal::Types;
17use lber::write;
18
19use bytes::{Buf, BytesMut};
20#[cfg(feature = "gssapi")]
21use cross_krb5::{ClientCtx, K5Ctx};
22use tokio::sync::{mpsc, oneshot};
23use tokio_util::codec::{Decoder, Encoder};
24
25pub(crate) struct LdapCodec {
26    #[cfg(feature = "gssapi")]
27    pub(crate) has_decoded_data: bool,
28    #[cfg(feature = "gssapi")]
29    pub(crate) sasl_param: Arc<RwLock<(bool, u32)>>, // sasl_wrap, sasl_max_send
30    #[cfg(feature = "gssapi")]
31    pub(crate) client_ctx: Arc<Mutex<Option<ClientCtx>>>,
32}
33
34pub(crate) type MaybeControls = Option<Vec<RawControl>>;
35pub(crate) type ItemSender = mpsc::UnboundedSender<(SearchItem, Vec<Control>)>;
36pub(crate) type ResultSender = oneshot::Sender<(Tag, Vec<Control>)>;
37
38#[derive(Debug)]
39pub enum MiscSender {
40    #[cfg(any(feature = "tls-native", feature = "tls-rustls"))]
41    Cert(oneshot::Sender<Option<Vec<u8>>>),
42}
43
44#[derive(Debug)]
45pub enum LdapOp {
46    Single,
47    Search(ItemSender),
48    Abandon(RequestId),
49    Unbind,
50}
51
52#[allow(clippy::type_complexity)]
53fn decode_inner(buf: &mut BytesMut) -> Result<Option<(RequestId, (Tag, Vec<Control>))>, io::Error> {
54    let decoding_error = io::Error::new(io::ErrorKind::Other, "decoding error");
55    let mut parser = lber::Parser::new();
56    let binding = parser.parse(buf);
57    let (i, tag) = match binding {
58        Err(e) if e.is_incomplete() => return Ok(None),
59        Err(_e) => return Err(decoding_error),
60        Ok((i, ref tag)) => (i, tag),
61    };
62    buf.advance(buf.len() - i.len());
63    let tag = tag.clone();
64    let mut tags = match tag
65        .match_id(Types::Sequence as u64)
66        .and_then(|t| t.expect_constructed())
67    {
68        Some(tags) => tags,
69        None => return Err(decoding_error),
70    };
71    let mut maybe_controls = tags.pop().expect("element");
72    let has_controls = match maybe_controls {
73        StructureTag {
74            id,
75            class,
76            ref payload,
77        } if class == TagClass::Context && id == 0 => match *payload {
78            PL::C(_) => true,
79            PL::P(_) => return Err(decoding_error),
80        },
81        StructureTag { id, class, .. } if class == TagClass::Context && id == 10 => {
82            // Active Directory bug workaround
83            //
84            // AD incorrectly encodes Notice of Disconnection messages. The OID of the
85            // Unsolicited Notification should be part of the ExtendedResponse sequence
86            // but AD puts it outside, where the optional controls belong. This confuses
87            // our parser, which doesn't expect the extra sequence element at the end
88            // and crashes. This match arm thus ignores the element.
89            maybe_controls = tags.pop().expect("element");
90            false
91        }
92        _ => false,
93    };
94    let (protoop, controls) = if has_controls {
95        (tags.pop().expect("element"), Some(maybe_controls))
96    } else {
97        (maybe_controls, None)
98    };
99    let controls = match controls {
100        Some(controls) => parse_controls(controls),
101        None => vec![],
102    };
103    let msgid = match parse_uint(
104        tags.pop()
105            .expect("element")
106            .match_class(TagClass::Universal)
107            .and_then(|t| t.match_id(Types::Integer as u64))
108            .and_then(|t| t.expect_primitive())
109            .expect("message id")
110            .as_slice(),
111    ) {
112        Ok((_, id)) => id as i32,
113        _ => return Err(decoding_error),
114    };
115    Ok(Some((msgid, (Tag::StructureTag(protoop), controls))))
116}
117
118impl Decoder for LdapCodec {
119    type Item = (RequestId, (Tag, Vec<Control>));
120    type Error = io::Error;
121
122    #[cfg(not(feature = "gssapi"))]
123    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
124        decode_inner(buf)
125    }
126
127    #[cfg(feature = "gssapi")]
128    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
129        const U32_SIZE: usize = std::mem::size_of::<u32>();
130
131        let sasl_wrap = { self.sasl_param.read().expect("sasl param").0 };
132        if !sasl_wrap || buf.is_empty() {
133            return decode_inner(buf);
134        }
135        if self.has_decoded_data {
136            let res = decode_inner(buf);
137            if res.is_ok() && buf.is_empty() {
138                self.has_decoded_data = false;
139            }
140            return res;
141        }
142        if buf.len() < U32_SIZE {
143            return Err(io::Error::new(io::ErrorKind::Other, "invalid SASL buffer"));
144        }
145        let sasl_len = u32::from_be_bytes(buf[0..U32_SIZE].try_into().unwrap());
146        if buf.len() - U32_SIZE < sasl_len as usize {
147            return Ok(None);
148        }
149        buf.advance(U32_SIZE);
150        let client_opt = &mut *self.client_ctx.lock().expect("client ctx lock");
151        let client_ctx = client_opt.as_mut().expect("client Option mut ref");
152        let mut decoded = client_ctx.unwrap_iov(sasl_len as usize, buf).map_err(|e| {
153            io::Error::new(io::ErrorKind::Other, format!("gss_unwrap error: {:#}", e))
154        })?;
155        let res = decode_inner(&mut decoded);
156        if res.is_ok() && !decoded.is_empty() && buf.is_empty() {
157            buf.extend(decoded);
158            self.has_decoded_data = true;
159        }
160        res
161    }
162}
163
164#[cfg(not(feature = "gssapi"))]
165#[inline]
166fn maybe_wrap(
167    _codec: &mut LdapCodec,
168    outstruct: StructureTag,
169    into: &mut BytesMut,
170) -> io::Result<()> {
171    write::encode_into(into, outstruct)?;
172    Ok(())
173}
174
175#[cfg(feature = "gssapi")]
176fn maybe_wrap(
177    codec: &mut LdapCodec,
178    outstruct: StructureTag,
179    into: &mut BytesMut,
180) -> io::Result<()> {
181    let mut out_buf = BytesMut::new();
182    write::encode_into(&mut out_buf, outstruct)?;
183    let (sasl_wrap, sasl_send_max) = {
184        let sasl_param = codec.sasl_param.read().expect("sasl param");
185        (sasl_param.0, sasl_param.1)
186    };
187    if sasl_wrap {
188        let client_opt = &mut *codec.client_ctx.lock().expect("client_ctx lock");
189        let client_ctx = client_opt.as_mut().expect("client Option mut ref");
190        if sasl_send_max > 0 && out_buf.len() > sasl_send_max as usize {
191            return Err(io::Error::new(
192                io::ErrorKind::Other,
193                format!(
194                    "buffer too large for GSSAPI: {} > {}",
195                    out_buf.len(),
196                    sasl_send_max
197                ),
198            ));
199        }
200        let sasl_buf = client_ctx.wrap(true, &out_buf).map_err(|e| {
201            io::Error::new(io::ErrorKind::Other, format!("gss_wrap error: {:#}", e))
202        })?;
203        let sasl_len = (sasl_buf.len() as u32).to_be_bytes();
204        into.extend(&sasl_len);
205        into.extend(&*sasl_buf);
206    } else {
207        into.extend(&out_buf);
208    }
209    Ok(())
210}
211
212impl Encoder<(RequestId, Tag, MaybeControls)> for LdapCodec {
213    type Error = io::Error;
214
215    fn encode(
216        &mut self,
217        msg: (RequestId, Tag, MaybeControls),
218        into: &mut BytesMut,
219    ) -> io::Result<()> {
220        let (id, tag, controls) = msg;
221        let outstruct = {
222            let mut msg = vec![
223                Tag::Integer(Integer {
224                    inner: id as i64,
225                    ..Default::default()
226                }),
227                tag,
228            ];
229            if let Some(controls) = controls {
230                msg.push(Tag::StructureTag(StructureTag {
231                    id: 0,
232                    class: TagClass::Context,
233                    payload: PL::C(controls.into_iter().map(build_tag).collect()),
234                }));
235            }
236            Tag::Sequence(Sequence {
237                inner: msg,
238                ..Default::default()
239            })
240            .into_structure()
241        };
242        maybe_wrap(self, outstruct, into)?;
243        Ok(())
244    }
245}