use std::fmt;
use super::Aggregate;
use crate::{Headers, Row, error};
#[derive(Debug)]
pub struct Closure<F, C> {
source: String,
colname: String,
current: Option<C>,
closure: F,
}
impl <F, C> Closure<F, C>
where
F: FnMut(C, &str) -> error::Result<C>,
C: fmt::Display,
{
pub fn new(colname: String, source: String, closure: F, init: C) -> Closure<F, C> {
Closure {
source,
colname,
closure,
current: Some(init),
}
}
}
impl <F, C> Aggregate for Closure<F, C>
where
F: FnMut(C, &str) -> error::Result<C>,
C: fmt::Display,
{
fn update(&mut self, headers: &Headers, row: &Row) -> error::Result<()> {
match headers.get_field(row, &self.source) {
Some(data) => {
match (self.closure)(self.current.take().unwrap(), data) {
Ok(s) => {
self.current = Some(s);
Ok(())
}
Err(e) => Err(e),
}
}
None => Err(error::Error::ColumnNotFound(self.source.to_string())),
}
}
fn value(&self) -> String {
self.current.as_ref().unwrap().to_string()
}
fn colname(&self) -> &str {
&self.colname
}
}
#[cfg(test)]
mod tests {
use super::{Aggregate, Closure};
use crate::Row;
#[test]
fn test_sum() {
let mut c = Closure::new("col".into(), "source".into(), |acc, cur| {
Ok(acc + cur.parse::<f64>().unwrap())
}, 0.0);
let h = Row::from(vec!["source"]).into();
let r = Row::from(vec!["3.0"]);
c.update(&h, &r).unwrap();
let r = Row::from(vec!["2.0"]);
c.update(&h, &r).unwrap();
assert_eq!(c.value(), "5");
}
}