use super::parser::{NgramRecord, NgramRecordRef};
#[derive(Clone, Debug, Default)]
pub struct AggregatedNgram {
pub ngram: String,
pub total_count: u64,
pub max_volume_count: u32,
pub year_span: u16,
pub first_year: u16,
pub last_year: u16,
}
pub struct YearAggregator {
year_range: Option<(u16, u16)>,
current: Option<AggregatedNgram>,
}
impl YearAggregator {
pub fn new(year_range: Option<(u16, u16)>) -> Self {
Self {
year_range,
current: None,
}
}
pub fn push(&mut self, record: NgramRecord) -> Option<AggregatedNgram> {
if let Some((start, end)) = self.year_range {
if record.year < start || record.year > end {
return None;
}
}
match &mut self.current {
Some(current) if current.ngram == record.ngram => {
current.total_count += record.match_count;
current.max_volume_count = current.max_volume_count.max(record.volume_count);
current.year_span += 1;
current.first_year = current.first_year.min(record.year);
current.last_year = current.last_year.max(record.year);
None
}
Some(_) => {
let previous = self.current.take();
self.current = Some(AggregatedNgram {
ngram: record.ngram,
total_count: record.match_count,
max_volume_count: record.volume_count,
year_span: 1,
first_year: record.year,
last_year: record.year,
});
previous
}
None => {
self.current = Some(AggregatedNgram {
ngram: record.ngram,
total_count: record.match_count,
max_volume_count: record.volume_count,
year_span: 1,
first_year: record.year,
last_year: record.year,
});
None
}
}
}
pub fn push_ref(&mut self, record: &NgramRecordRef<'_>) -> Option<AggregatedNgram> {
if let Some((start, end)) = self.year_range {
if record.year < start || record.year > end {
return None;
}
}
match &mut self.current {
Some(current) if current.ngram == record.ngram => {
current.total_count += record.match_count;
current.max_volume_count = current.max_volume_count.max(record.volume_count);
current.year_span += 1;
current.first_year = current.first_year.min(record.year);
current.last_year = current.last_year.max(record.year);
None
}
Some(_) => {
let previous = self.current.take();
self.current = Some(AggregatedNgram {
ngram: record.ngram.to_string(),
total_count: record.match_count,
max_volume_count: record.volume_count,
year_span: 1,
first_year: record.year,
last_year: record.year,
});
previous
}
None => {
self.current = Some(AggregatedNgram {
ngram: record.ngram.to_string(),
total_count: record.match_count,
max_volume_count: record.volume_count,
year_span: 1,
first_year: record.year,
last_year: record.year,
});
None
}
}
}
pub fn flush(&mut self) -> Option<AggregatedNgram> {
self.current.take()
}
pub fn reset(&mut self) {
self.current = None;
}
}
pub struct AggregatingIterator<I> {
inner: I,
aggregator: YearAggregator,
flushed: bool,
}
impl<I> AggregatingIterator<I>
where
I: Iterator<Item = NgramRecord>,
{
pub fn new(inner: I, year_range: Option<(u16, u16)>) -> Self {
Self {
inner,
aggregator: YearAggregator::new(year_range),
flushed: false,
}
}
}
impl<I> Iterator for AggregatingIterator<I>
where
I: Iterator<Item = NgramRecord>,
{
type Item = AggregatedNgram;
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.inner.next() {
Some(record) => {
if let Some(aggregated) = self.aggregator.push(record) {
return Some(aggregated);
}
}
None => {
if !self.flushed {
self.flushed = true;
return self.aggregator.flush();
}
return None;
}
}
}
}
}
pub trait AggregateExt: Iterator<Item = NgramRecord> + Sized {
fn aggregate_years(self, year_range: Option<(u16, u16)>) -> AggregatingIterator<Self> {
AggregatingIterator::new(self, year_range)
}
}
impl<I: Iterator<Item = NgramRecord>> AggregateExt for I {}
#[cfg(test)]
mod tests {
use super::*;
fn make_record(ngram: &str, year: u16, count: u64) -> NgramRecord {
NgramRecord {
ngram: ngram.to_string(),
year,
match_count: count,
volume_count: 100,
}
}
#[test]
fn test_single_ngram_multiple_years() {
let mut agg = YearAggregator::new(None);
assert!(agg.push(make_record("the", 2000, 100)).is_none());
assert!(agg.push(make_record("the", 2001, 200)).is_none());
assert!(agg.push(make_record("the", 2002, 150)).is_none());
let result = agg.flush().unwrap();
assert_eq!(result.ngram, "the");
assert_eq!(result.total_count, 450);
assert_eq!(result.year_span, 3);
assert_eq!(result.first_year, 2000);
assert_eq!(result.last_year, 2002);
}
#[test]
fn test_multiple_ngrams() {
let mut agg = YearAggregator::new(None);
assert!(agg.push(make_record("a", 2000, 100)).is_none());
assert!(agg.push(make_record("a", 2001, 200)).is_none());
let result = agg.push(make_record("b", 2000, 50)).unwrap();
assert_eq!(result.ngram, "a");
assert_eq!(result.total_count, 300);
assert!(agg.push(make_record("b", 2001, 60)).is_none());
let result = agg.flush().unwrap();
assert_eq!(result.ngram, "b");
assert_eq!(result.total_count, 110);
}
#[test]
fn test_year_filter() {
let mut agg = YearAggregator::new(Some((2000, 2010)));
assert!(agg.push(make_record("the", 1999, 100)).is_none());
assert!(agg.push(make_record("the", 2011, 100)).is_none());
assert!(agg.push(make_record("the", 2000, 100)).is_none());
assert!(agg.push(make_record("the", 2005, 200)).is_none());
let result = agg.flush().unwrap();
assert_eq!(result.total_count, 300);
assert_eq!(result.year_span, 2);
}
#[test]
fn test_iterator_adapter() {
let records = vec![
make_record("a", 2000, 100),
make_record("a", 2001, 200),
make_record("b", 2000, 50),
make_record("b", 2001, 60),
];
let aggregated: Vec<_> = records.into_iter().aggregate_years(None).collect();
assert_eq!(aggregated.len(), 2);
assert_eq!(aggregated[0].ngram, "a");
assert_eq!(aggregated[0].total_count, 300);
assert_eq!(aggregated[1].ngram, "b");
assert_eq!(aggregated[1].total_count, 110);
}
fn make_record_ref<'a>(ngram: &'a str, year: u16, count: u64) -> NgramRecordRef<'a> {
NgramRecordRef {
ngram,
year,
match_count: count,
volume_count: 100,
}
}
#[test]
fn test_push_ref_no_alloc_on_same_ngram() {
let mut agg = YearAggregator::new(None);
assert!(agg.push_ref(&make_record_ref("the", 2000, 100)).is_none());
assert!(agg.push_ref(&make_record_ref("the", 2001, 100)).is_none());
assert!(agg.push_ref(&make_record_ref("the", 2002, 100)).is_none());
assert!(agg.push_ref(&make_record_ref("the", 2003, 100)).is_none());
assert!(agg.push_ref(&make_record_ref("the", 2004, 100)).is_none());
let previous = agg
.push_ref(&make_record_ref("a", 2000, 50))
.expect("new ngram should yield previous");
assert_eq!(previous.ngram, "the");
assert_eq!(previous.total_count, 500);
assert_eq!(previous.year_span, 5);
assert_eq!(previous.first_year, 2000);
assert_eq!(previous.last_year, 2004);
}
#[test]
fn test_push_ref_vs_push_equivalent() {
let inputs = [
("the", 2000u16, 100u64),
("the", 2001, 200),
("the", 2002, 150),
("a", 2000, 50),
("a", 2001, 60),
("a", 2010, 70),
];
let mut owned_agg = YearAggregator::new(None);
let mut owned_results = Vec::new();
for &(n, y, c) in &inputs {
if let Some(r) = owned_agg.push(make_record(n, y, c)) {
owned_results.push(r);
}
}
if let Some(r) = owned_agg.flush() {
owned_results.push(r);
}
let mut ref_agg = YearAggregator::new(None);
let mut ref_results = Vec::new();
for &(n, y, c) in &inputs {
if let Some(r) = ref_agg.push_ref(&make_record_ref(n, y, c)) {
ref_results.push(r);
}
}
if let Some(r) = ref_agg.flush() {
ref_results.push(r);
}
assert_eq!(owned_results.len(), ref_results.len());
for (o, r) in owned_results.iter().zip(ref_results.iter()) {
assert_eq!(o.ngram, r.ngram);
assert_eq!(o.total_count, r.total_count);
assert_eq!(o.year_span, r.year_span);
assert_eq!(o.first_year, r.first_year);
assert_eq!(o.last_year, r.last_year);
}
}
}