s2gpp 1.0.2

Algorithm for Highly Efficient Detection of Correlation Anomalies in Multivariate Time Series
Documentation
use crate::s2gpp as orig_s2gpp;
use crate::training::Clustering;
use crate::{Parameters, Role};
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray2};
use pyo3::exceptions;
use pyo3::prelude::*;
use std::panic;
use std::str::FromStr;

#[pyfunction]
fn s2gpp_local_array<'py>(
    py: Python<'py>,
    data: PyReadonlyArray2<'py, f32>,
    pattern_length: usize,
    latent: usize,
    query_length: usize,
    rate: usize,
    n_threads: usize,
    clustering: String,
    self_correction: bool,
) -> PyResult<&'py PyArray1<f32>> {
    let mut params = Parameters::default();
    params.pattern_length = pattern_length;
    params.latent = latent;
    params.query_length = query_length;
    params.rate = rate;
    params.n_threads = n_threads;
    params.clustering = Clustering::from_str(&clustering).unwrap();
    params.self_correction = self_correction;

    let data = data.as_array().to_owned();
    let anomaly_scores =
        orig_s2gpp(params, Some(data)).expect("Series2Graph++ did not terminate correctly!");

    match anomaly_scores {
        Some(res) => Ok(res.into_pyarray(py)),
        None => Err(exceptions::PyTypeError::new_err("Error message")),
    }
}

#[pyfunction]
fn s2gpp_local_file<'py>(
    _py: Python<'py>,
    data_path: String,
    pattern_length: usize,
    latent: usize,
    query_length: usize,
    rate: usize,
    n_threads: usize,
    score_output_path: Option<String>,
    column_start: usize,
    column_end: isize,
    clustering: String,
    self_correction: bool,
    local_host: String,
) -> PyResult<()> {
    let result = panic::catch_unwind(|| {
        let mut params = Parameters::default();

        params.role = Role::Main {
            data_path: Some(data_path),
        };
        params.pattern_length = pattern_length;
        params.latent = latent;
        params.query_length = query_length;
        params.rate = rate;
        params.n_threads = n_threads;
        params.score_output_path = score_output_path;
        params.column_start = column_start;
        params.column_end = column_end;
        params.clustering = Clustering::from_str(&clustering).unwrap();
        params.self_correction = self_correction;
        params.local_host = local_host.parse().unwrap();

        orig_s2gpp(params, None).expect("Series2Graph++ did not terminate correctly!");
    });
    match result {
        Ok(_) => Ok(()),
        Err(_) => Err(exceptions::PyTypeError::new_err("Error message")),
    }
}

#[pyfunction]
fn s2gpp_distributed_main<'py>(
    _py: Python<'py>,
    data_path: String,
    pattern_length: usize,
    latent: usize,
    query_length: usize,
    rate: usize,
    n_threads: usize,
    score_output_path: Option<String>,
    column_start: usize,
    column_end: isize,
    clustering: String,
    self_correction: bool,
    local_host: String,
    n_cluster_nodes: usize,
) -> PyResult<()> {
    let mut params = Parameters::default();

    params.role = Role::Main {
        data_path: Some(data_path),
    };
    params.pattern_length = pattern_length;
    params.latent = latent;
    params.query_length = query_length;
    params.rate = rate;
    params.n_threads = n_threads;
    params.score_output_path = score_output_path;
    params.column_start = column_start;
    params.column_end = column_end;
    params.clustering = Clustering::from_str(&clustering).unwrap();
    params.self_correction = self_correction;
    params.local_host = local_host.parse()?;
    params.n_cluster_nodes = n_cluster_nodes;

    orig_s2gpp(params, None).expect("Series2Graph++ did not terminate correctly!");
    Ok(())
}

#[pyfunction]
fn s2gpp_distributed_sub<'py>(
    _py: Python<'py>,
    pattern_length: usize,
    latent: usize,
    query_length: usize,
    rate: usize,
    n_threads: usize,
    score_output_path: Option<String>,
    column_start: usize,
    column_end: isize,
    clustering: String,
    self_correction: bool,
    local_host: String,
    n_cluster_nodes: usize,
    mainhost: String,
) -> PyResult<()> {
    let mut params = Parameters::default();

    params.role = Role::Sub {
        mainhost: mainhost.parse()?,
    };
    params.pattern_length = pattern_length;
    params.latent = latent;
    params.query_length = query_length;
    params.rate = rate;
    params.n_threads = n_threads;
    params.score_output_path = score_output_path;
    params.column_start = column_start;
    params.column_end = column_end;
    params.clustering = Clustering::from_str(&clustering).unwrap();
    params.self_correction = self_correction;
    params.local_host = local_host.parse()?;
    params.n_cluster_nodes = n_cluster_nodes;

    orig_s2gpp(params, None).expect("Series2Graph++ did not terminate correctly!");
    Ok(())
}

#[pymodule]
fn s2gpp(_py: Python, m: &PyModule) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(s2gpp_local_file, m)?)?;
    m.add_function(wrap_pyfunction!(s2gpp_local_array, m)?)?;
    m.add_function(wrap_pyfunction!(s2gpp_distributed_main, m)?)?;
    m.add_function(wrap_pyfunction!(s2gpp_distributed_sub, m)?)?;

    Ok(())
}