1use std::num::NonZeroU64;
2
3use dbn::{
4    decode::{DbnMetadata, DecodeRecordRef},
5    RType, Record, RecordRef, Schema,
6};
7
8#[derive(Debug)]
9pub struct SchemaFilter<D> {
10    decoder: D,
11    rtype: Option<RType>,
12}
13
14impl<D> SchemaFilter<D>
15where
16    D: DbnMetadata,
17{
18    pub fn new(mut decoder: D, schema: Option<Schema>) -> Self {
19        if let Some(schema) = schema {
20            decoder.metadata_mut().schema = Some(schema);
21        }
22        Self::new_no_metadata(decoder, schema)
23    }
24}
25
26impl<D> SchemaFilter<D> {
27    pub fn new_no_metadata(decoder: D, schema: Option<Schema>) -> Self {
28        Self {
29            decoder,
30            rtype: schema.map(RType::from),
31        }
32    }
33}
34
35impl<D: DbnMetadata> DbnMetadata for SchemaFilter<D> {
36    fn metadata(&self) -> &dbn::Metadata {
37        self.decoder.metadata()
38    }
39
40    fn metadata_mut(&mut self) -> &mut dbn::Metadata {
41        self.decoder.metadata_mut()
42    }
43}
44
45impl<D: DecodeRecordRef> DecodeRecordRef for SchemaFilter<D> {
46    fn decode_record_ref(&mut self) -> dbn::Result<Option<dbn::RecordRef<'_>>> {
47        while let Some(record) = self.decoder.decode_record_ref()? {
48            if self
49                .rtype
50                .map(|rtype| rtype as u8 == record.header().rtype)
51                .unwrap_or(true)
52            {
53                return Ok(Some(unsafe {
56                    RecordRef::unchecked_from_header(record.header())
57                }));
58            }
59        }
60        Ok(None)
61    }
62}
63
64#[derive(Debug)]
65pub struct LimitFilter<D> {
66    decoder: D,
67    limit: Option<NonZeroU64>,
68    record_count: u64,
69}
70
71impl<D> LimitFilter<D>
72where
73    D: DbnMetadata,
74{
75    pub fn new(mut decoder: D, limit: Option<NonZeroU64>) -> Self {
76        if let Some(limit) = limit {
77            let metadata_limit = &mut decoder.metadata_mut().limit;
78            if let Some(metadata_limit) = metadata_limit {
79                *metadata_limit = (*metadata_limit).min(limit);
80            } else {
81                *metadata_limit = Some(limit);
82            }
83        }
84        Self::new_no_metadata(decoder, limit)
85    }
86}
87
88impl<D> LimitFilter<D> {
89    pub fn new_no_metadata(decoder: D, limit: Option<NonZeroU64>) -> Self {
90        Self {
91            decoder,
92            limit,
93            record_count: 0,
94        }
95    }
96}
97
98impl<D: DbnMetadata> DbnMetadata for LimitFilter<D> {
99    fn metadata(&self) -> &dbn::Metadata {
100        self.decoder.metadata()
101    }
102
103    fn metadata_mut(&mut self) -> &mut dbn::Metadata {
104        self.decoder.metadata_mut()
105    }
106}
107
108impl<D: DecodeRecordRef> DecodeRecordRef for LimitFilter<D> {
109    fn decode_record_ref(&mut self) -> dbn::Result<Option<RecordRef<'_>>> {
110        if self
111            .limit
112            .map(|limit| self.record_count >= limit.get())
113            .unwrap_or(false)
114        {
115            return Ok(None);
116        }
117        Ok(self.decoder.decode_record_ref()?.inspect(|_| {
118            self.record_count += 1;
119        }))
120    }
121}