1use ldap_client_ber::tag::Tag;
4use ldap_client_ber::{BerReader, BerWriter};
5
6use crate::ProtoError;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum Filter {
11 And(Vec<Filter>),
12 Or(Vec<Filter>),
13 Not(Box<Filter>),
14 Eq(String, String),
15 Approx(String, String),
16 Gte(String, String),
17 Lte(String, String),
18 Present(String),
19 Substring {
20 attr: String,
21 initial: Option<String>,
22 any: Vec<String>,
23 r#final: Option<String>,
24 },
25 ExtensibleMatch {
26 matching_rule: Option<String>,
27 attr: Option<String>,
28 value: String,
29 dn_attributes: bool,
30 },
31}
32
33impl Filter {
34 pub fn eq(attr: impl Into<String>, value: impl Into<String>) -> Self {
35 Self::Eq(attr.into(), value.into())
36 }
37
38 pub fn present(attr: impl Into<String>) -> Self {
39 Self::Present(attr.into())
40 }
41
42 pub fn and(filters: Vec<Filter>) -> Self {
43 Self::And(filters)
44 }
45
46 pub fn or(filters: Vec<Filter>) -> Self {
47 Self::Or(filters)
48 }
49
50 #[allow(clippy::should_implement_trait)]
51 pub fn not(filter: Filter) -> Self {
52 Self::Not(Box::new(filter))
53 }
54
55 pub fn approx(attr: impl Into<String>, value: impl Into<String>) -> Self {
56 Self::Approx(attr.into(), value.into())
57 }
58
59 pub fn gte(attr: impl Into<String>, value: impl Into<String>) -> Self {
60 Self::Gte(attr.into(), value.into())
61 }
62
63 pub fn lte(attr: impl Into<String>, value: impl Into<String>) -> Self {
64 Self::Lte(attr.into(), value.into())
65 }
66
67 pub fn substring(
68 attr: impl Into<String>,
69 initial: Option<String>,
70 any: Vec<String>,
71 r#final: Option<String>,
72 ) -> Self {
73 Self::Substring {
74 attr: attr.into(),
75 initial,
76 any,
77 r#final,
78 }
79 }
80
81 pub fn extensible_match(
82 rule: Option<impl Into<String>>,
83 attr: Option<impl Into<String>>,
84 value: impl Into<String>,
85 dn_attributes: bool,
86 ) -> Self {
87 Self::ExtensibleMatch {
88 matching_rule: rule.map(Into::into),
89 attr: attr.map(Into::into),
90 value: value.into(),
91 dn_attributes,
92 }
93 }
94
95 pub fn escape_value(input: &str) -> String {
97 use std::fmt::Write;
98 let mut out = String::with_capacity(input.len());
99 for ch in input.chars() {
100 match ch {
101 '*' | '(' | ')' | '\\' | '\0' => {
102 let _ = write!(out, "\\{:02x}", ch as u32);
103 }
104 _ => out.push(ch),
105 }
106 }
107 out
108 }
109
110 pub fn to_filter_string(&self) -> String {
112 match self {
113 Self::And(filters) => {
114 let inner: String = filters.iter().map(|f| f.to_filter_string()).collect();
115 format!("(&{inner})")
116 }
117 Self::Or(filters) => {
118 let inner: String = filters.iter().map(|f| f.to_filter_string()).collect();
119 format!("(|{inner})")
120 }
121 Self::Not(f) => format!("(!{})", f.to_filter_string()),
122 Self::Eq(a, v) => format!("({}={})", a, Self::escape_value(v)),
123 Self::Approx(a, v) => format!("({}~={})", a, Self::escape_value(v)),
124 Self::Gte(a, v) => format!("({}>={})", a, Self::escape_value(v)),
125 Self::Lte(a, v) => format!("({}<={})", a, Self::escape_value(v)),
126 Self::Present(a) => format!("({a}=*)"),
127 Self::Substring {
128 attr,
129 initial,
130 any,
131 r#final,
132 } => {
133 let mut val = String::new();
134 if let Some(init) = initial {
135 val.push_str(&Self::escape_value(init));
136 }
137 val.push('*');
138 for a in any {
139 val.push_str(&Self::escape_value(a));
140 val.push('*');
141 }
142 if let Some(fin) = r#final {
143 val.push_str(&Self::escape_value(fin));
144 }
145 format!("({attr}={val})")
146 }
147 Self::ExtensibleMatch {
148 matching_rule,
149 attr,
150 value,
151 dn_attributes,
152 } => {
153 let mut s = String::from("(");
154 if let Some(a) = attr {
155 s.push_str(a);
156 }
157 if *dn_attributes {
158 s.push_str(":dn");
159 }
160 if let Some(r) = matching_rule {
161 s.push(':');
162 s.push_str(r);
163 }
164 s.push_str(":=");
165 s.push_str(&Self::escape_value(value));
166 s.push(')');
167 s
168 }
169 }
170 }
171
172 pub fn parse(input: &str) -> Result<Self, ProtoError> {
174 let input = input.trim();
175 if input.is_empty() {
176 return Err(ProtoError::FilterParse("empty filter".into()));
177 }
178 let (filter, rest) = parse_filter(input, 0)?;
179 if !rest.is_empty() {
180 return Err(ProtoError::FilterParse(format!("trailing data: {rest:?}")));
181 }
182 Ok(filter)
183 }
184
185 pub fn encode(&self, w: &mut BerWriter) {
187 match self {
188 Self::And(filters) => {
189 w.write_sequence(Tag::context_constructed(0), |inner| {
190 for f in filters {
191 f.encode(inner);
192 }
193 });
194 }
195 Self::Or(filters) => {
196 w.write_sequence(Tag::context_constructed(1), |inner| {
197 for f in filters {
198 f.encode(inner);
199 }
200 });
201 }
202 Self::Not(f) => {
203 w.write_sequence(Tag::context_constructed(2), |inner| {
204 f.encode(inner);
205 });
206 }
207 Self::Eq(attr, value) => {
208 encode_ava(w, 3, attr, value);
209 }
210 Self::Approx(attr, value) => {
211 encode_ava(w, 8, attr, value);
212 }
213 Self::Gte(attr, value) => {
214 encode_ava(w, 5, attr, value);
215 }
216 Self::Lte(attr, value) => {
217 encode_ava(w, 6, attr, value);
218 }
219 Self::Present(attr) => {
220 w.write_octet_string(Tag::context(7), attr.as_bytes());
221 }
222 Self::Substring {
223 attr,
224 initial,
225 any,
226 r#final,
227 } => {
228 w.write_sequence(Tag::context_constructed(4), |inner| {
229 inner.write_bytes(attr.as_bytes());
230 inner.write_sequence(Tag::sequence(), |subseq| {
231 if let Some(init) = initial {
232 subseq.write_octet_string(Tag::context(0), init.as_bytes());
233 }
234 for a in any {
235 subseq.write_octet_string(Tag::context(1), a.as_bytes());
236 }
237 if let Some(fin) = r#final {
238 subseq.write_octet_string(Tag::context(2), fin.as_bytes());
239 }
240 });
241 });
242 }
243 Self::ExtensibleMatch {
244 matching_rule,
245 attr,
246 value,
247 dn_attributes,
248 } => {
249 w.write_sequence(Tag::context_constructed(9), |inner| {
250 if let Some(rule) = matching_rule {
251 inner.write_octet_string(Tag::context(1), rule.as_bytes());
252 }
253 if let Some(a) = attr {
254 inner.write_octet_string(Tag::context(2), a.as_bytes());
255 }
256 inner.write_octet_string(Tag::context(3), value.as_bytes());
257 if *dn_attributes {
258 inner.write_octet_string(Tag::context(4), &[0xFF]);
259 }
260 });
261 }
262 }
263 }
264
265 pub fn decode(r: &mut BerReader<'_>) -> Result<Self, ldap_client_ber::BerError> {
267 let tag = r.peek_tag()?;
268 if tag.class != ldap_client_ber::Class::Context {
269 return Err(ldap_client_ber::BerError::UnexpectedTag {
270 expected: Tag::context(0),
271 actual: tag,
272 });
273 }
274
275 match tag.number {
276 0 => {
277 let mut filters = Vec::new();
278 r.read_sequence_lax(Tag::context_constructed(0), |inner| {
279 while !inner.is_empty() {
280 filters.push(Filter::decode(inner)?);
281 }
282 Ok(())
283 })?;
284 Ok(Self::And(filters))
285 }
286 1 => {
287 let mut filters = Vec::new();
288 r.read_sequence_lax(Tag::context_constructed(1), |inner| {
289 while !inner.is_empty() {
290 filters.push(Filter::decode(inner)?);
291 }
292 Ok(())
293 })?;
294 Ok(Self::Or(filters))
295 }
296 2 => {
297 let f = r.read_sequence(Tag::context_constructed(2), Filter::decode)?;
298 Ok(Self::Not(Box::new(f)))
299 }
300 3 => decode_ava_ber(r, 3).map(|(a, v)| Self::Eq(a, v)),
301 5 => decode_ava_ber(r, 5).map(|(a, v)| Self::Gte(a, v)),
302 6 => decode_ava_ber(r, 6).map(|(a, v)| Self::Lte(a, v)),
303 7 => {
304 let value = r.read_tagged_implicit_octet_string(7)?;
305 Ok(Self::Present(String::from_utf8_lossy(value).into_owned()))
306 }
307 8 => decode_ava_ber(r, 8).map(|(a, v)| Self::Approx(a, v)),
308 4 => r.read_sequence(Tag::context_constructed(4), |inner| {
309 let attr = String::from_utf8_lossy(inner.read_octet_string()?).into_owned();
310 let mut initial = None;
311 let mut any = Vec::new();
312 let mut r#final = None;
313
314 inner.read_sequence(Tag::sequence(), |subseq| {
315 while !subseq.is_empty() {
316 let (tag, value) = subseq.read_element()?;
317 let s = String::from_utf8_lossy(value).into_owned();
318 match tag.number {
319 0 => initial = Some(s),
320 1 => any.push(s),
321 2 => r#final = Some(s),
322 _ => {}
323 }
324 }
325 Ok(())
326 })?;
327
328 Ok(Self::Substring {
329 attr,
330 initial,
331 any,
332 r#final,
333 })
334 }),
335 9 => r.read_sequence(Tag::context_constructed(9), |inner| {
336 let mut matching_rule = None;
337 let mut attr = None;
338 let mut value = String::new();
339 let mut dn_attributes = false;
340
341 while !inner.is_empty() {
342 let tag = inner.peek_tag()?;
343 match (tag.class, tag.number) {
344 (ldap_client_ber::Class::Context, 1) => {
345 let v = inner.read_tagged_implicit_octet_string(1)?;
346 matching_rule = Some(String::from_utf8_lossy(v).into_owned());
347 }
348 (ldap_client_ber::Class::Context, 2) => {
349 let v = inner.read_tagged_implicit_octet_string(2)?;
350 attr = Some(String::from_utf8_lossy(v).into_owned());
351 }
352 (ldap_client_ber::Class::Context, 3) => {
353 let v = inner.read_tagged_implicit_octet_string(3)?;
354 value = String::from_utf8_lossy(v).into_owned();
355 }
356 (ldap_client_ber::Class::Context, 4) => {
357 let v = inner.read_tagged_implicit_octet_string(4)?;
358 dn_attributes = v.first().is_some_and(|&b| b != 0);
359 }
360 _ => {
361 inner.read_element()?;
362 }
363 }
364 }
365
366 Ok(Self::ExtensibleMatch {
367 matching_rule,
368 attr,
369 value,
370 dn_attributes,
371 })
372 }),
373 _ => Err(ldap_client_ber::BerError::UnexpectedTag {
374 expected: Tag::context(0),
375 actual: tag,
376 }),
377 }
378 }
379}
380
381impl std::fmt::Display for Filter {
382 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
383 f.write_str(&self.to_filter_string())
384 }
385}
386
387impl std::str::FromStr for Filter {
388 type Err = ProtoError;
389 fn from_str(s: &str) -> Result<Self, Self::Err> {
390 Self::parse(s)
391 }
392}
393
394fn encode_ava(w: &mut BerWriter, tag_num: u32, attr: &str, value: &str) {
395 w.write_sequence(Tag::context_constructed(tag_num), |inner| {
396 inner.write_bytes(attr.as_bytes());
397 inner.write_bytes(value.as_bytes());
398 });
399}
400
401fn decode_ava_ber(
402 r: &mut BerReader<'_>,
403 tag_num: u32,
404) -> Result<(String, String), ldap_client_ber::BerError> {
405 r.read_sequence(Tag::context_constructed(tag_num), |inner| {
406 let attr = String::from_utf8_lossy(inner.read_octet_string()?).into_owned();
407 let value = String::from_utf8_lossy(inner.read_octet_string()?).into_owned();
408 Ok((attr, value))
409 })
410}
411
412const MAX_FILTER_DEPTH: usize = 128;
415
416fn parse_filter(input: &str, depth: usize) -> Result<(Filter, &str), ProtoError> {
417 if depth >= MAX_FILTER_DEPTH {
418 return Err(ProtoError::FilterParse("filter nesting too deep".into()));
419 }
420
421 let input = input
422 .strip_prefix('(')
423 .ok_or_else(|| ProtoError::FilterParse("expected '('".into()))?;
424
425 let (filter, rest) = parse_filter_comp(input, depth)?;
426
427 let rest = rest
428 .strip_prefix(')')
429 .ok_or_else(|| ProtoError::FilterParse("expected ')'".into()))?;
430
431 Ok((filter, rest))
432}
433
434fn parse_filter_comp(input: &str, depth: usize) -> Result<(Filter, &str), ProtoError> {
435 match input.chars().next() {
436 Some('&') => parse_filter_list(&input[1..], Filter::And, depth),
437 Some('|') => parse_filter_list(&input[1..], Filter::Or, depth),
438 Some('!') => {
439 let (f, rest) = parse_filter(&input[1..], depth + 1)?;
440 Ok((Filter::Not(Box::new(f)), rest))
441 }
442 _ => parse_item(input),
443 }
444}
445
446fn parse_filter_list(
447 mut input: &str,
448 ctor: fn(Vec<Filter>) -> Filter,
449 depth: usize,
450) -> Result<(Filter, &str), ProtoError> {
451 let mut filters = Vec::new();
452 while input.starts_with('(') {
453 let (f, rest) = parse_filter(input, depth + 1)?;
454 filters.push(f);
455 input = rest;
456 }
457 if filters.is_empty() {
458 return Err(ProtoError::FilterParse("empty filter list".into()));
459 }
460 Ok((ctor(filters), input))
461}
462
463fn parse_item(input: &str) -> Result<(Filter, &str), ProtoError> {
464 let mut i = 0;
466 let bytes = input.as_bytes();
467 while i < bytes.len() && !matches!(bytes[i], b'=' | b'>' | b'<' | b'~' | b')') {
468 i += 1;
469 }
470
471 if i >= bytes.len() || bytes[i] == b')' {
472 return Err(ProtoError::FilterParse("missing operator".into()));
473 }
474
475 let attr = &input[..i];
476
477 if attr.contains(':') {
479 return parse_extensible_match(input);
480 }
481
482 let (op_len, filter_type) = match (bytes.get(i), bytes.get(i + 1)) {
483 (Some(b'>'), Some(b'=')) => (2, ">="),
484 (Some(b'<'), Some(b'=')) => (2, "<="),
485 (Some(b'~'), Some(b'=')) => (2, "~="),
486 (Some(b'='), _) => (1, "="),
487 _ => return Err(ProtoError::FilterParse("unknown operator".into())),
488 };
489
490 let value_start = i + op_len;
491 let value_end = find_value_end(&input[value_start..]);
492 let raw_value = &input[value_start..value_start + value_end];
493 let rest = &input[value_start + value_end..];
494
495 match filter_type {
496 "=" => {
497 if raw_value == "*" {
498 Ok((Filter::Present(attr.to_string()), rest))
499 } else if raw_value.contains('*') {
500 Ok((parse_substring(attr, raw_value)?, rest))
501 } else {
502 Ok((
503 Filter::Eq(attr.to_string(), unescape_value(raw_value)?),
504 rest,
505 ))
506 }
507 }
508 ">=" => Ok((
509 Filter::Gte(attr.to_string(), unescape_value(raw_value)?),
510 rest,
511 )),
512 "<=" => Ok((
513 Filter::Lte(attr.to_string(), unescape_value(raw_value)?),
514 rest,
515 )),
516 "~=" => Ok((
517 Filter::Approx(attr.to_string(), unescape_value(raw_value)?),
518 rest,
519 )),
520 _ => unreachable!(),
521 }
522}
523
524fn parse_extensible_match(input: &str) -> Result<(Filter, &str), ProtoError> {
525 let eq_pos = input
527 .find(":=")
528 .ok_or_else(|| ProtoError::FilterParse("extensible match missing ':='".into()))?;
529
530 let prefix = &input[..eq_pos];
531 let value_start = eq_pos + 2;
532 let value_end = find_value_end(&input[value_start..]);
533 let raw_value = &input[value_start..value_start + value_end];
534 let rest = &input[value_start + value_end..];
535
536 let mut attr = None;
537 let mut matching_rule = None;
538 let mut dn_attributes = false;
539
540 let parts: Vec<&str> = prefix.split(':').collect();
541 match parts.len() {
542 1 => {
543 if !parts[0].is_empty() {
544 attr = Some(parts[0].to_string());
545 }
546 }
547 2 => {
548 if !parts[0].is_empty() {
549 attr = Some(parts[0].to_string());
550 }
551 if parts[1] == "dn" {
552 dn_attributes = true;
553 } else if !parts[1].is_empty() {
554 matching_rule = Some(parts[1].to_string());
555 }
556 }
557 3 => {
558 if !parts[0].is_empty() {
559 attr = Some(parts[0].to_string());
560 }
561 if parts[1] == "dn" {
562 dn_attributes = true;
563 }
564 if !parts[2].is_empty() {
565 matching_rule = Some(parts[2].to_string());
566 }
567 }
568 _ => {
569 return Err(ProtoError::FilterParse(
570 "too many colon-separated parts in extensible match".into(),
571 ));
572 }
573 }
574
575 if attr.is_none() && matching_rule.is_none() && !dn_attributes {
577 return Err(ProtoError::FilterParse(
578 "extensible match requires at least one of attr, matching rule, or :dn:".into(),
579 ));
580 }
581
582 Ok((
583 Filter::ExtensibleMatch {
584 matching_rule,
585 attr,
586 value: unescape_value(raw_value)?,
587 dn_attributes,
588 },
589 rest,
590 ))
591}
592
593const MAX_SUBSTRING_PARTS: usize = 64;
594
595fn parse_substring(attr: &str, raw_value: &str) -> Result<Filter, ProtoError> {
596 let parts: Vec<&str> = raw_value.split('*').collect();
597 if parts.len() > MAX_SUBSTRING_PARTS {
598 return Err(ProtoError::FilterParse(
599 "substring filter has too many wildcard parts".into(),
600 ));
601 }
602 let initial = if !parts[0].is_empty() {
603 Some(unescape_value(parts[0])?)
604 } else {
605 None
606 };
607 let r#final = match parts.last().filter(|s| !s.is_empty()) {
608 Some(s) => Some(unescape_value(s)?),
609 None => None,
610 };
611 let any: Vec<String> = parts[1..parts.len() - 1]
612 .iter()
613 .filter(|s| !s.is_empty())
614 .map(|s| unescape_value(s))
615 .collect::<Result<_, _>>()?;
616
617 if initial.is_none() && any.is_empty() && r#final.is_none() {
618 return Err(ProtoError::FilterParse(
619 "substring filter has no assertions".into(),
620 ));
621 }
622
623 Ok(Filter::Substring {
624 attr: attr.to_string(),
625 initial,
626 any,
627 r#final,
628 })
629}
630
631fn find_value_end(input: &str) -> usize {
632 let bytes = input.as_bytes();
633 let mut i = 0;
634 while i < bytes.len() && bytes[i] != b')' {
635 if bytes[i] == b'\\'
636 && i + 2 < bytes.len()
637 && bytes[i + 1].is_ascii_hexdigit()
638 && bytes[i + 2].is_ascii_hexdigit()
639 {
640 i += 3;
641 } else {
642 i += 1;
643 }
644 }
645 i
646}
647
648fn unescape_value(input: &str) -> Result<String, ProtoError> {
649 let mut out = Vec::with_capacity(input.len());
650 let bytes = input.as_bytes();
651 let mut i = 0;
652 while i < bytes.len() {
653 if bytes[i] == b'\\'
654 && i + 2 < bytes.len()
655 && let Ok(byte) =
656 u8::from_str_radix(std::str::from_utf8(&bytes[i + 1..i + 3]).unwrap_or(""), 16)
657 {
658 out.push(byte);
659 i += 3;
660 continue;
661 }
662 out.push(bytes[i]);
663 i += 1;
664 }
665 String::from_utf8(out)
666 .map_err(|e| ProtoError::FilterParse(format!("invalid UTF-8 in filter value: {e}")))
667}