use crate::error::EvaluatorError;
use crate::evaluator::*;
use crate::feature::Feature;
use crate::float_trait::Float;
use crate::time_series::TimeSeries;
use std::marker::PhantomData;
macro_const! {
const DOC: &str = r#"
Bulk feature extractor
- Depends on: as reuired by feature evaluators
- Minimum number of observations: as required by feature evaluators
- Number of features: total for all feature evaluators
"#;
}
#[doc = DOC!()]
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(
into = "FeatureExtractorParameters<F>",
from = "FeatureExtractorParameters<F>",
bound = "T: Float, F: FeatureEvaluator<T>"
)]
pub struct FeatureExtractor<T, F> {
features: Vec<F>,
info: Box<EvaluatorInfo>,
phantom: PhantomData<T>,
}
impl<T, F> FeatureExtractor<T, F>
where
T: Float,
F: FeatureEvaluator<T>,
{
pub fn new(features: Vec<F>) -> Self {
let info = EvaluatorInfo {
size: features.iter().map(|x| x.size_hint()).sum(),
min_ts_length: features
.iter()
.map(|x| x.min_ts_length())
.max()
.unwrap_or(0),
t_required: features.iter().any(|x| x.is_t_required()),
m_required: features.iter().any(|x| x.is_m_required()),
w_required: features.iter().any(|x| x.is_w_required()),
sorting_required: features.iter().any(|x| x.is_sorting_required()),
}
.into();
Self {
info,
features,
phantom: PhantomData,
}
}
pub fn get_features(&self) -> &Vec<F> {
&self.features
}
pub fn into_vec(self) -> Vec<F> {
self.features
}
pub fn add_feature(&mut self, feature: F) {
self.features.push(feature);
}
}
impl<T> FeatureExtractor<T, Feature<T>>
where
T: Float,
{
pub fn from_features(features: Vec<Feature<T>>) -> Self {
Self::new(features)
}
}
impl<T, F> FeatureExtractor<T, F> {
pub fn doc() -> &'static str {
DOC
}
}
impl<T, F> EvaluatorInfoTrait for FeatureExtractor<T, F>
where
T: Float,
F: FeatureEvaluator<T>,
{
fn get_info(&self) -> &EvaluatorInfo {
&self.info
}
}
impl<T, F> FeatureNamesDescriptionsTrait for FeatureExtractor<T, F>
where
T: Float,
F: FeatureEvaluator<T>,
{
fn get_names(&self) -> Vec<&str> {
self.features.iter().flat_map(|x| x.get_names()).collect()
}
fn get_descriptions(&self) -> Vec<&str> {
self.features
.iter()
.flat_map(|x| x.get_descriptions())
.collect()
}
}
impl<T, F> FeatureEvaluator<T> for FeatureExtractor<T, F>
where
T: Float,
F: FeatureEvaluator<T>,
{
fn eval(&self, ts: &mut TimeSeries<T>) -> Result<Vec<T>, EvaluatorError> {
let mut vec = Vec::with_capacity(self.size_hint());
for x in &self.features {
vec.extend(x.eval(ts)?);
}
Ok(vec)
}
fn eval_or_fill(&self, ts: &mut TimeSeries<T>, fill_value: T) -> Vec<T> {
self.features
.iter()
.flat_map(|x| x.eval_or_fill(ts, fill_value))
.collect()
}
}
#[cfg(test)]
impl<T, F> Default for FeatureExtractor<T, F>
where
T: Float,
F: FeatureEvaluator<T>,
{
fn default() -> Self {
Self::new(vec![])
}
}
#[derive(Serialize, Deserialize, JsonSchema)]
#[serde(rename = "FeatureExtractor")]
struct FeatureExtractorParameters<F> {
features: Vec<F>,
}
impl<T, F> From<FeatureExtractor<T, F>> for FeatureExtractorParameters<F> {
fn from(f: FeatureExtractor<T, F>) -> Self {
Self {
features: f.features,
}
}
}
impl<T, F> From<FeatureExtractorParameters<F>> for FeatureExtractor<T, F>
where
T: Float,
F: FeatureEvaluator<T>,
{
fn from(p: FeatureExtractorParameters<F>) -> Self {
Self::new(p.features)
}
}
impl<T, F> JsonSchema for FeatureExtractor<T, F>
where
F: JsonSchema,
{
json_schema!(FeatureExtractorParameters<F>, true);
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::*;
use crate::Feature;
use serde_test::{assert_ser_tokens, Token};
serialization_name_test!(FeatureExtractor<f64, Feature<f64>>);
serde_json_test!(
feature_extractor_ser_json_de,
FeatureExtractor<f64, Feature<f64>>,
FeatureExtractor::new(vec![crate::Amplitude{}.into(), crate::BeyondNStd::new(2.0).into()]),
);
check_doc_static_method!(feature_extractor_doc_static_method, FeatureExtractor<f64, Feature<f64>>);
#[test]
fn serialization_empty() {
let fe: FeatureExtractor<f64, Feature<_>> = FeatureExtractor::new(vec![]);
assert_ser_tokens(
&fe,
&[
Token::Struct {
len: 1,
name: "FeatureExtractor",
},
Token::String("features"),
Token::Seq { len: Some(0) },
Token::SeqEnd,
Token::StructEnd,
],
)
}
}