1use std::num::NonZeroU64;
2
3use dbn::{
4 decode::{DbnMetadata, DecodeRecordRef},
5 RType, Record, RecordRef, Schema,
6};
7
8#[derive(Debug)]
9pub struct SchemaFilter<D> {
10 decoder: D,
11 rtype: Option<RType>,
12}
13
14impl<D> SchemaFilter<D>
15where
16 D: DbnMetadata,
17{
18 pub fn new(mut decoder: D, schema: Option<Schema>) -> Self {
19 if let Some(schema) = schema {
20 decoder.metadata_mut().schema = Some(schema);
21 }
22 Self::new_no_metadata(decoder, schema)
23 }
24}
25
26impl<D> SchemaFilter<D> {
27 pub fn new_no_metadata(decoder: D, schema: Option<Schema>) -> Self {
28 Self {
29 decoder,
30 rtype: schema.map(RType::from),
31 }
32 }
33}
34
35impl<D: DbnMetadata> DbnMetadata for SchemaFilter<D> {
36 fn metadata(&self) -> &dbn::Metadata {
37 self.decoder.metadata()
38 }
39
40 fn metadata_mut(&mut self) -> &mut dbn::Metadata {
41 self.decoder.metadata_mut()
42 }
43}
44
45impl<D: DecodeRecordRef> DecodeRecordRef for SchemaFilter<D> {
46 fn decode_record_ref(&mut self) -> dbn::Result<Option<dbn::RecordRef<'_>>> {
47 while let Some(record) = self.decoder.decode_record_ref()? {
48 if self
49 .rtype
50 .map(|rtype| rtype as u8 == record.header().rtype)
51 .unwrap_or(true)
52 {
53 return Ok(Some(unsafe {
56 RecordRef::unchecked_from_header(record.header())
57 }));
58 }
59 }
60 Ok(None)
61 }
62}
63
64#[derive(Debug)]
65pub struct LimitFilter<D> {
66 decoder: D,
67 limit: Option<NonZeroU64>,
68 record_count: u64,
69}
70
71impl<D> LimitFilter<D>
72where
73 D: DbnMetadata,
74{
75 pub fn new(mut decoder: D, limit: Option<NonZeroU64>) -> Self {
76 if let Some(limit) = limit {
77 let metadata_limit = &mut decoder.metadata_mut().limit;
78 if let Some(metadata_limit) = metadata_limit {
79 *metadata_limit = (*metadata_limit).min(limit);
80 } else {
81 *metadata_limit = Some(limit);
82 }
83 }
84 Self::new_no_metadata(decoder, limit)
85 }
86}
87
88impl<D> LimitFilter<D> {
89 pub fn new_no_metadata(decoder: D, limit: Option<NonZeroU64>) -> Self {
90 Self {
91 decoder,
92 limit,
93 record_count: 0,
94 }
95 }
96}
97
98impl<D: DbnMetadata> DbnMetadata for LimitFilter<D> {
99 fn metadata(&self) -> &dbn::Metadata {
100 self.decoder.metadata()
101 }
102
103 fn metadata_mut(&mut self) -> &mut dbn::Metadata {
104 self.decoder.metadata_mut()
105 }
106}
107
108impl<D: DecodeRecordRef> DecodeRecordRef for LimitFilter<D> {
109 fn decode_record_ref(&mut self) -> dbn::Result<Option<RecordRef<'_>>> {
110 if self
111 .limit
112 .map(|limit| self.record_count >= limit.get())
113 .unwrap_or(false)
114 {
115 return Ok(None);
116 }
117 Ok(self.decoder.decode_record_ref()?.inspect(|_| {
118 self.record_count += 1;
119 }))
120 }
121}