Skip to main content

ldap_client_proto/
filter.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3use ldap_client_ber::tag::Tag;
4use ldap_client_ber::{BerReader, BerWriter};
5
6use crate::ProtoError;
7
8/// LDAP search filter (RFC 4515).
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum Filter {
11    And(Vec<Filter>),
12    Or(Vec<Filter>),
13    Not(Box<Filter>),
14    Eq(String, String),
15    Approx(String, String),
16    Gte(String, String),
17    Lte(String, String),
18    Present(String),
19    Substring {
20        attr: String,
21        initial: Option<String>,
22        any: Vec<String>,
23        r#final: Option<String>,
24    },
25    ExtensibleMatch {
26        matching_rule: Option<String>,
27        attr: Option<String>,
28        value: String,
29        dn_attributes: bool,
30    },
31}
32
33impl Filter {
34    pub fn eq(attr: impl Into<String>, value: impl Into<String>) -> Self {
35        Self::Eq(attr.into(), value.into())
36    }
37
38    pub fn present(attr: impl Into<String>) -> Self {
39        Self::Present(attr.into())
40    }
41
42    pub fn and(filters: Vec<Filter>) -> Self {
43        Self::And(filters)
44    }
45
46    pub fn or(filters: Vec<Filter>) -> Self {
47        Self::Or(filters)
48    }
49
50    #[allow(clippy::should_implement_trait)]
51    pub fn not(filter: Filter) -> Self {
52        Self::Not(Box::new(filter))
53    }
54
55    pub fn approx(attr: impl Into<String>, value: impl Into<String>) -> Self {
56        Self::Approx(attr.into(), value.into())
57    }
58
59    pub fn gte(attr: impl Into<String>, value: impl Into<String>) -> Self {
60        Self::Gte(attr.into(), value.into())
61    }
62
63    pub fn lte(attr: impl Into<String>, value: impl Into<String>) -> Self {
64        Self::Lte(attr.into(), value.into())
65    }
66
67    pub fn substring(
68        attr: impl Into<String>,
69        initial: Option<String>,
70        any: Vec<String>,
71        r#final: Option<String>,
72    ) -> Self {
73        Self::Substring {
74            attr: attr.into(),
75            initial,
76            any,
77            r#final,
78        }
79    }
80
81    pub fn extensible_match(
82        rule: Option<impl Into<String>>,
83        attr: Option<impl Into<String>>,
84        value: impl Into<String>,
85        dn_attributes: bool,
86    ) -> Self {
87        Self::ExtensibleMatch {
88            matching_rule: rule.map(Into::into),
89            attr: attr.map(Into::into),
90            value: value.into(),
91            dn_attributes,
92        }
93    }
94
95    /// Escape a value for RFC 4515 filter strings.
96    pub fn escape_value(input: &str) -> String {
97        use std::fmt::Write;
98        let mut out = String::with_capacity(input.len());
99        for ch in input.chars() {
100            match ch {
101                '*' | '(' | ')' | '\\' | '\0' => {
102                    let _ = write!(out, "\\{:02x}", ch as u32);
103                }
104                _ => out.push(ch),
105            }
106        }
107        out
108    }
109
110    /// Serialize to RFC 4515 string.
111    pub fn to_filter_string(&self) -> String {
112        match self {
113            Self::And(filters) => {
114                let inner: String = filters.iter().map(|f| f.to_filter_string()).collect();
115                format!("(&{inner})")
116            }
117            Self::Or(filters) => {
118                let inner: String = filters.iter().map(|f| f.to_filter_string()).collect();
119                format!("(|{inner})")
120            }
121            Self::Not(f) => format!("(!{})", f.to_filter_string()),
122            Self::Eq(a, v) => format!("({}={})", a, Self::escape_value(v)),
123            Self::Approx(a, v) => format!("({}~={})", a, Self::escape_value(v)),
124            Self::Gte(a, v) => format!("({}>={})", a, Self::escape_value(v)),
125            Self::Lte(a, v) => format!("({}<={})", a, Self::escape_value(v)),
126            Self::Present(a) => format!("({a}=*)"),
127            Self::Substring {
128                attr,
129                initial,
130                any,
131                r#final,
132            } => {
133                let mut val = String::new();
134                if let Some(init) = initial {
135                    val.push_str(&Self::escape_value(init));
136                }
137                val.push('*');
138                for a in any {
139                    val.push_str(&Self::escape_value(a));
140                    val.push('*');
141                }
142                if let Some(fin) = r#final {
143                    val.push_str(&Self::escape_value(fin));
144                }
145                format!("({attr}={val})")
146            }
147            Self::ExtensibleMatch {
148                matching_rule,
149                attr,
150                value,
151                dn_attributes,
152            } => {
153                let mut s = String::from("(");
154                if let Some(a) = attr {
155                    s.push_str(a);
156                }
157                if *dn_attributes {
158                    s.push_str(":dn");
159                }
160                if let Some(r) = matching_rule {
161                    s.push(':');
162                    s.push_str(r);
163                }
164                s.push_str(":=");
165                s.push_str(&Self::escape_value(value));
166                s.push(')');
167                s
168            }
169        }
170    }
171
172    /// Parse an RFC 4515 filter string.
173    pub fn parse(input: &str) -> Result<Self, ProtoError> {
174        let input = input.trim();
175        if input.is_empty() {
176            return Err(ProtoError::FilterParse("empty filter".into()));
177        }
178        let (filter, rest) = parse_filter(input, 0)?;
179        if !rest.is_empty() {
180            return Err(ProtoError::FilterParse(format!("trailing data: {rest:?}")));
181        }
182        Ok(filter)
183    }
184
185    /// Encode to BER bytes.
186    pub fn encode(&self, w: &mut BerWriter) {
187        match self {
188            Self::And(filters) => {
189                w.write_sequence(Tag::context_constructed(0), |inner| {
190                    for f in filters {
191                        f.encode(inner);
192                    }
193                });
194            }
195            Self::Or(filters) => {
196                w.write_sequence(Tag::context_constructed(1), |inner| {
197                    for f in filters {
198                        f.encode(inner);
199                    }
200                });
201            }
202            Self::Not(f) => {
203                w.write_sequence(Tag::context_constructed(2), |inner| {
204                    f.encode(inner);
205                });
206            }
207            Self::Eq(attr, value) => {
208                encode_ava(w, 3, attr, value);
209            }
210            Self::Approx(attr, value) => {
211                encode_ava(w, 8, attr, value);
212            }
213            Self::Gte(attr, value) => {
214                encode_ava(w, 5, attr, value);
215            }
216            Self::Lte(attr, value) => {
217                encode_ava(w, 6, attr, value);
218            }
219            Self::Present(attr) => {
220                w.write_octet_string(Tag::context(7), attr.as_bytes());
221            }
222            Self::Substring {
223                attr,
224                initial,
225                any,
226                r#final,
227            } => {
228                w.write_sequence(Tag::context_constructed(4), |inner| {
229                    inner.write_bytes(attr.as_bytes());
230                    inner.write_sequence(Tag::sequence(), |subseq| {
231                        if let Some(init) = initial {
232                            subseq.write_octet_string(Tag::context(0), init.as_bytes());
233                        }
234                        for a in any {
235                            subseq.write_octet_string(Tag::context(1), a.as_bytes());
236                        }
237                        if let Some(fin) = r#final {
238                            subseq.write_octet_string(Tag::context(2), fin.as_bytes());
239                        }
240                    });
241                });
242            }
243            Self::ExtensibleMatch {
244                matching_rule,
245                attr,
246                value,
247                dn_attributes,
248            } => {
249                w.write_sequence(Tag::context_constructed(9), |inner| {
250                    if let Some(rule) = matching_rule {
251                        inner.write_octet_string(Tag::context(1), rule.as_bytes());
252                    }
253                    if let Some(a) = attr {
254                        inner.write_octet_string(Tag::context(2), a.as_bytes());
255                    }
256                    inner.write_octet_string(Tag::context(3), value.as_bytes());
257                    if *dn_attributes {
258                        inner.write_octet_string(Tag::context(4), &[0xFF]);
259                    }
260                });
261            }
262        }
263    }
264
265    /// Decode from BER.
266    pub fn decode(r: &mut BerReader<'_>) -> Result<Self, ldap_client_ber::BerError> {
267        let tag = r.peek_tag()?;
268        if tag.class != ldap_client_ber::Class::Context {
269            return Err(ldap_client_ber::BerError::UnexpectedTag {
270                expected: Tag::context(0),
271                actual: tag,
272            });
273        }
274
275        match tag.number {
276            0 => {
277                let mut filters = Vec::new();
278                r.read_sequence_lax(Tag::context_constructed(0), |inner| {
279                    while !inner.is_empty() {
280                        filters.push(Filter::decode(inner)?);
281                    }
282                    Ok(())
283                })?;
284                Ok(Self::And(filters))
285            }
286            1 => {
287                let mut filters = Vec::new();
288                r.read_sequence_lax(Tag::context_constructed(1), |inner| {
289                    while !inner.is_empty() {
290                        filters.push(Filter::decode(inner)?);
291                    }
292                    Ok(())
293                })?;
294                Ok(Self::Or(filters))
295            }
296            2 => {
297                let f = r.read_sequence(Tag::context_constructed(2), Filter::decode)?;
298                Ok(Self::Not(Box::new(f)))
299            }
300            3 => decode_ava_ber(r, 3).map(|(a, v)| Self::Eq(a, v)),
301            5 => decode_ava_ber(r, 5).map(|(a, v)| Self::Gte(a, v)),
302            6 => decode_ava_ber(r, 6).map(|(a, v)| Self::Lte(a, v)),
303            7 => {
304                let value = r.read_tagged_implicit_octet_string(7)?;
305                Ok(Self::Present(String::from_utf8_lossy(value).into_owned()))
306            }
307            8 => decode_ava_ber(r, 8).map(|(a, v)| Self::Approx(a, v)),
308            4 => r.read_sequence(Tag::context_constructed(4), |inner| {
309                let attr = String::from_utf8_lossy(inner.read_octet_string()?).into_owned();
310                let mut initial = None;
311                let mut any = Vec::new();
312                let mut r#final = None;
313
314                inner.read_sequence(Tag::sequence(), |subseq| {
315                    while !subseq.is_empty() {
316                        let (tag, value) = subseq.read_element()?;
317                        let s = String::from_utf8_lossy(value).into_owned();
318                        match tag.number {
319                            0 => initial = Some(s),
320                            1 => any.push(s),
321                            2 => r#final = Some(s),
322                            _ => {}
323                        }
324                    }
325                    Ok(())
326                })?;
327
328                Ok(Self::Substring {
329                    attr,
330                    initial,
331                    any,
332                    r#final,
333                })
334            }),
335            9 => r.read_sequence(Tag::context_constructed(9), |inner| {
336                let mut matching_rule = None;
337                let mut attr = None;
338                let mut value = String::new();
339                let mut dn_attributes = false;
340
341                while !inner.is_empty() {
342                    let tag = inner.peek_tag()?;
343                    match (tag.class, tag.number) {
344                        (ldap_client_ber::Class::Context, 1) => {
345                            let v = inner.read_tagged_implicit_octet_string(1)?;
346                            matching_rule = Some(String::from_utf8_lossy(v).into_owned());
347                        }
348                        (ldap_client_ber::Class::Context, 2) => {
349                            let v = inner.read_tagged_implicit_octet_string(2)?;
350                            attr = Some(String::from_utf8_lossy(v).into_owned());
351                        }
352                        (ldap_client_ber::Class::Context, 3) => {
353                            let v = inner.read_tagged_implicit_octet_string(3)?;
354                            value = String::from_utf8_lossy(v).into_owned();
355                        }
356                        (ldap_client_ber::Class::Context, 4) => {
357                            let v = inner.read_tagged_implicit_octet_string(4)?;
358                            dn_attributes = v.first().is_some_and(|&b| b != 0);
359                        }
360                        _ => {
361                            inner.read_element()?;
362                        }
363                    }
364                }
365
366                Ok(Self::ExtensibleMatch {
367                    matching_rule,
368                    attr,
369                    value,
370                    dn_attributes,
371                })
372            }),
373            _ => Err(ldap_client_ber::BerError::UnexpectedTag {
374                expected: Tag::context(0),
375                actual: tag,
376            }),
377        }
378    }
379}
380
381impl std::fmt::Display for Filter {
382    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383        f.write_str(&self.to_filter_string())
384    }
385}
386
387impl std::str::FromStr for Filter {
388    type Err = ProtoError;
389    fn from_str(s: &str) -> Result<Self, Self::Err> {
390        Self::parse(s)
391    }
392}
393
394fn encode_ava(w: &mut BerWriter, tag_num: u32, attr: &str, value: &str) {
395    w.write_sequence(Tag::context_constructed(tag_num), |inner| {
396        inner.write_bytes(attr.as_bytes());
397        inner.write_bytes(value.as_bytes());
398    });
399}
400
401fn decode_ava_ber(
402    r: &mut BerReader<'_>,
403    tag_num: u32,
404) -> Result<(String, String), ldap_client_ber::BerError> {
405    r.read_sequence(Tag::context_constructed(tag_num), |inner| {
406        let attr = String::from_utf8_lossy(inner.read_octet_string()?).into_owned();
407        let value = String::from_utf8_lossy(inner.read_octet_string()?).into_owned();
408        Ok((attr, value))
409    })
410}
411
412// ---------- RFC 4515 filter string parser ----------
413
414const MAX_FILTER_DEPTH: usize = 128;
415
416fn parse_filter(input: &str, depth: usize) -> Result<(Filter, &str), ProtoError> {
417    if depth >= MAX_FILTER_DEPTH {
418        return Err(ProtoError::FilterParse("filter nesting too deep".into()));
419    }
420
421    let input = input
422        .strip_prefix('(')
423        .ok_or_else(|| ProtoError::FilterParse("expected '('".into()))?;
424
425    let (filter, rest) = parse_filter_comp(input, depth)?;
426
427    let rest = rest
428        .strip_prefix(')')
429        .ok_or_else(|| ProtoError::FilterParse("expected ')'".into()))?;
430
431    Ok((filter, rest))
432}
433
434fn parse_filter_comp(input: &str, depth: usize) -> Result<(Filter, &str), ProtoError> {
435    match input.chars().next() {
436        Some('&') => parse_filter_list(&input[1..], Filter::And, depth),
437        Some('|') => parse_filter_list(&input[1..], Filter::Or, depth),
438        Some('!') => {
439            let (f, rest) = parse_filter(&input[1..], depth + 1)?;
440            Ok((Filter::Not(Box::new(f)), rest))
441        }
442        _ => parse_item(input),
443    }
444}
445
446fn parse_filter_list(
447    mut input: &str,
448    ctor: fn(Vec<Filter>) -> Filter,
449    depth: usize,
450) -> Result<(Filter, &str), ProtoError> {
451    let mut filters = Vec::new();
452    while input.starts_with('(') {
453        let (f, rest) = parse_filter(input, depth + 1)?;
454        filters.push(f);
455        input = rest;
456    }
457    if filters.is_empty() {
458        return Err(ProtoError::FilterParse("empty filter list".into()));
459    }
460    Ok((ctor(filters), input))
461}
462
463fn parse_item(input: &str) -> Result<(Filter, &str), ProtoError> {
464    // Find the operator position.
465    let mut i = 0;
466    let bytes = input.as_bytes();
467    while i < bytes.len() && !matches!(bytes[i], b'=' | b'>' | b'<' | b'~' | b')') {
468        i += 1;
469    }
470
471    if i >= bytes.len() || bytes[i] == b')' {
472        return Err(ProtoError::FilterParse("missing operator".into()));
473    }
474
475    let attr = &input[..i];
476
477    // Check for extensible match: attr:dn:rule:= or just :rule:= patterns
478    if attr.contains(':') {
479        return parse_extensible_match(input);
480    }
481
482    let (op_len, filter_type) = match (bytes.get(i), bytes.get(i + 1)) {
483        (Some(b'>'), Some(b'=')) => (2, ">="),
484        (Some(b'<'), Some(b'=')) => (2, "<="),
485        (Some(b'~'), Some(b'=')) => (2, "~="),
486        (Some(b'='), _) => (1, "="),
487        _ => return Err(ProtoError::FilterParse("unknown operator".into())),
488    };
489
490    let value_start = i + op_len;
491    let value_end = find_value_end(&input[value_start..]);
492    let raw_value = &input[value_start..value_start + value_end];
493    let rest = &input[value_start + value_end..];
494
495    match filter_type {
496        "=" => {
497            if raw_value == "*" {
498                Ok((Filter::Present(attr.to_string()), rest))
499            } else if raw_value.contains('*') {
500                Ok((parse_substring(attr, raw_value)?, rest))
501            } else {
502                Ok((
503                    Filter::Eq(attr.to_string(), unescape_value(raw_value)?),
504                    rest,
505                ))
506            }
507        }
508        ">=" => Ok((
509            Filter::Gte(attr.to_string(), unescape_value(raw_value)?),
510            rest,
511        )),
512        "<=" => Ok((
513            Filter::Lte(attr.to_string(), unescape_value(raw_value)?),
514            rest,
515        )),
516        "~=" => Ok((
517            Filter::Approx(attr.to_string(), unescape_value(raw_value)?),
518            rest,
519        )),
520        _ => unreachable!(),
521    }
522}
523
524fn parse_extensible_match(input: &str) -> Result<(Filter, &str), ProtoError> {
525    // Format: [attr][:dn][:rule]:=value
526    let eq_pos = input
527        .find(":=")
528        .ok_or_else(|| ProtoError::FilterParse("extensible match missing ':='".into()))?;
529
530    let prefix = &input[..eq_pos];
531    let value_start = eq_pos + 2;
532    let value_end = find_value_end(&input[value_start..]);
533    let raw_value = &input[value_start..value_start + value_end];
534    let rest = &input[value_start + value_end..];
535
536    let mut attr = None;
537    let mut matching_rule = None;
538    let mut dn_attributes = false;
539
540    let parts: Vec<&str> = prefix.split(':').collect();
541    match parts.len() {
542        1 => {
543            if !parts[0].is_empty() {
544                attr = Some(parts[0].to_string());
545            }
546        }
547        2 => {
548            if !parts[0].is_empty() {
549                attr = Some(parts[0].to_string());
550            }
551            if parts[1] == "dn" {
552                dn_attributes = true;
553            } else if !parts[1].is_empty() {
554                matching_rule = Some(parts[1].to_string());
555            }
556        }
557        3 => {
558            if !parts[0].is_empty() {
559                attr = Some(parts[0].to_string());
560            }
561            if parts[1] == "dn" {
562                dn_attributes = true;
563            }
564            if !parts[2].is_empty() {
565                matching_rule = Some(parts[2].to_string());
566            }
567        }
568        _ => {
569            return Err(ProtoError::FilterParse(
570                "too many colon-separated parts in extensible match".into(),
571            ));
572        }
573    }
574
575    // RFC 4515 ยง3: at least one of attr, matching_rule, or dn_attributes must be present.
576    if attr.is_none() && matching_rule.is_none() && !dn_attributes {
577        return Err(ProtoError::FilterParse(
578            "extensible match requires at least one of attr, matching rule, or :dn:".into(),
579        ));
580    }
581
582    Ok((
583        Filter::ExtensibleMatch {
584            matching_rule,
585            attr,
586            value: unescape_value(raw_value)?,
587            dn_attributes,
588        },
589        rest,
590    ))
591}
592
593const MAX_SUBSTRING_PARTS: usize = 64;
594
595fn parse_substring(attr: &str, raw_value: &str) -> Result<Filter, ProtoError> {
596    let parts: Vec<&str> = raw_value.split('*').collect();
597    if parts.len() > MAX_SUBSTRING_PARTS {
598        return Err(ProtoError::FilterParse(
599            "substring filter has too many wildcard parts".into(),
600        ));
601    }
602    let initial = if !parts[0].is_empty() {
603        Some(unescape_value(parts[0])?)
604    } else {
605        None
606    };
607    let r#final = match parts.last().filter(|s| !s.is_empty()) {
608        Some(s) => Some(unescape_value(s)?),
609        None => None,
610    };
611    let any: Vec<String> = parts[1..parts.len() - 1]
612        .iter()
613        .filter(|s| !s.is_empty())
614        .map(|s| unescape_value(s))
615        .collect::<Result<_, _>>()?;
616
617    if initial.is_none() && any.is_empty() && r#final.is_none() {
618        return Err(ProtoError::FilterParse(
619            "substring filter has no assertions".into(),
620        ));
621    }
622
623    Ok(Filter::Substring {
624        attr: attr.to_string(),
625        initial,
626        any,
627        r#final,
628    })
629}
630
631fn find_value_end(input: &str) -> usize {
632    let bytes = input.as_bytes();
633    let mut i = 0;
634    while i < bytes.len() && bytes[i] != b')' {
635        if bytes[i] == b'\\'
636            && i + 2 < bytes.len()
637            && bytes[i + 1].is_ascii_hexdigit()
638            && bytes[i + 2].is_ascii_hexdigit()
639        {
640            i += 3;
641        } else {
642            i += 1;
643        }
644    }
645    i
646}
647
648fn unescape_value(input: &str) -> Result<String, ProtoError> {
649    let mut out = Vec::with_capacity(input.len());
650    let bytes = input.as_bytes();
651    let mut i = 0;
652    while i < bytes.len() {
653        if bytes[i] == b'\\'
654            && i + 2 < bytes.len()
655            && let Ok(byte) =
656                u8::from_str_radix(std::str::from_utf8(&bytes[i + 1..i + 3]).unwrap_or(""), 16)
657        {
658            out.push(byte);
659            i += 3;
660            continue;
661        }
662        out.push(bytes[i]);
663        i += 1;
664    }
665    String::from_utf8(out)
666        .map_err(|e| ProtoError::FilterParse(format!("invalid UTF-8 in filter value: {e}")))
667}