use std::collections::VecDeque;
use num_traits::Float;
use crate::Node;
use crate::error::RoplatError;
pub struct MovingAverage<T>
where
T: Float + Send + Sync,
{
window_size: usize,
buffer: VecDeque<T>,
}
impl<T> MovingAverage<T>
where
T: Float + Send + Sync,
{
pub fn new(window_size: usize) -> Self {
if window_size == 0 {
panic!("窗口大小必须大于0");
}
Self { window_size, buffer: VecDeque::with_capacity(window_size) }
}
}
impl<T> Node for MovingAverage<T>
where
T: Float + Send + Sync + 'static,
{
type Input = T;
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: T) -> Result<T, RoplatError> {
if self.buffer.len() >= self.window_size {
self.buffer.pop_front();
}
self.buffer.push_back(input);
let sum = self.buffer.iter().fold(T::zero(), |acc, &x| acc + x);
let count = T::from(self.buffer.len()).unwrap();
Ok(sum / count)
}
}
pub struct ExponentialMovingAverage<T>
where
T: Float + Send + Sync,
{
alpha: T,
last_output: Option<T>,
}
impl<T> ExponentialMovingAverage<T>
where
T: Float + Send + Sync,
{
pub fn new(alpha: T) -> Self {
if alpha < T::zero() || alpha > T::one() {
panic!("alpha 必须在 [0, 1] 范围内");
}
Self { alpha, last_output: None }
}
pub fn from_window_size(n: usize) -> Self
where
T: Float + num_traits::NumCast,
{
let n_float = T::from(n).unwrap();
let two = T::from(2usize).unwrap();
let alpha = two / (n_float + T::one());
Self::new(alpha)
}
}
impl<T> Node for ExponentialMovingAverage<T>
where
T: Float + Send + Sync + 'static,
{
type Input = T;
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: T) -> Result<T, RoplatError> {
let output = match self.last_output {
Some(last) => self.alpha * input + (T::one() - self.alpha) * last,
None => input, };
self.last_output = Some(output);
Ok(output)
}
}
pub struct MedianFilter<T>
where
T: Float + Send + Sync,
{
window_size: usize,
buffer: VecDeque<T>,
}
impl<T> MedianFilter<T>
where
T: Float + Send + Sync,
{
pub fn new(window_size: usize) -> Self {
if window_size == 0 {
panic!("窗口大小必须大于0");
}
if window_size.is_multiple_of(2) {
panic!("窗口大小应该是奇数,以便计算中值");
}
Self { window_size, buffer: VecDeque::with_capacity(window_size) }
}
}
impl<T> Node for MedianFilter<T>
where
T: Float + Send + Sync + 'static,
{
type Input = T;
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: T) -> Result<T, RoplatError> {
if self.buffer.len() >= self.window_size {
self.buffer.pop_front();
}
self.buffer.push_back(input);
let mut sorted: Vec<T> = self.buffer.iter().copied().collect();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mid = sorted.len() / 2;
Ok(sorted[mid])
}
}
pub struct RateLimiter<T>
where
T: Float + Send + Sync,
{
max_change: T,
last_output: Option<T>,
}
impl<T> RateLimiter<T>
where
T: Float + Send + Sync,
{
pub fn new(max_change: T) -> Self {
Self { max_change, last_output: None }
}
}
impl<T> Node for RateLimiter<T>
where
T: Float + Send + Sync + 'static,
{
type Input = T;
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: T) -> Result<T, RoplatError> {
let output = match self.last_output {
Some(last) => {
let change = input - last;
if change > self.max_change {
last + self.max_change
} else if change < -self.max_change {
last - self.max_change
} else {
input
}
}
None => input,
};
self.last_output = Some(output);
Ok(output)
}
}
pub struct DeadzoneFilter<T>
where
T: Float + Send + Sync,
{
threshold: T,
last_output: Option<T>,
}
impl<T> DeadzoneFilter<T>
where
T: Float + Send + Sync,
{
pub fn new(threshold: T) -> Self {
Self { threshold, last_output: None }
}
}
impl<T> Node for DeadzoneFilter<T>
where
T: Float + Send + Sync + 'static,
{
type Input = T;
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: T) -> Result<T, RoplatError> {
let output = match self.last_output {
Some(last) => {
let change = (input - last).abs();
if change < self.threshold {
last } else {
input
}
}
None => input,
};
self.last_output = Some(output);
Ok(output)
}
}
pub struct LowPassFilter<T>
where
T: Float + Send + Sync,
{
time_constant: T,
last_output: Option<T>,
}
impl<T> LowPassFilter<T>
where
T: Float + Send + Sync,
{
pub fn new(time_constant: T) -> Self {
if time_constant < T::zero() {
panic!("时间常数必须大于等于0");
}
Self { time_constant, last_output: None }
}
}
impl<T> Node for LowPassFilter<T>
where
T: Float + Send + Sync + 'static,
{
type Input = T;
type Output = Result<T, RoplatError>;
type Error = RoplatError;
async fn process(&mut self, input: T) -> Result<T, RoplatError> {
let output = match self.last_output {
Some(last) => {
let delta = (input - last) / (self.time_constant + T::one());
last + delta
}
None => input,
};
self.last_output = Some(output);
Ok(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_moving_average_basic() {
let mut filter = MovingAverage::<f64>::new(3);
let result = filter.process(10.0).await.unwrap();
assert!((result - 10.0).abs() < 1e-10);
let result = filter.process(20.0).await.unwrap();
assert!((result - 15.0).abs() < 1e-10);
let result = filter.process(30.0).await.unwrap();
assert!((result - 20.0).abs() < 1e-10);
let result = filter.process(40.0).await.unwrap();
assert!((result - 30.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_moving_average_smoothing() {
let mut filter = MovingAverage::<f64>::new(5);
let noisy_data = vec![10.0, 12.0, 8.0, 11.0, 9.0, 10.0, 10.0, 10.0];
for value in noisy_data {
filter.process(value).await.unwrap();
}
let result = filter.process(10.0).await.unwrap();
assert!((result - 10.0).abs() < 1.0);
}
#[tokio::test]
#[should_panic(expected = "窗口大小必须大于0")]
async fn test_moving_average_zero_window() {
MovingAverage::<f64>::new(0);
}
#[tokio::test]
async fn test_exponential_moving_average_basic() {
let mut filter = ExponentialMovingAverage::<f64>::new(0.5);
let result = filter.process(10.0).await.unwrap();
assert!((result - 10.0).abs() < 1e-10);
let result = filter.process(20.0).await.unwrap();
assert!((result - 15.0).abs() < 1e-10);
let result = filter.process(30.0).await.unwrap();
assert!((result - 22.5).abs() < 1e-10);
}
#[tokio::test]
async fn test_exponential_moving_average_from_window() {
let mut filter = ExponentialMovingAverage::<f64>::from_window_size(9);
filter.process(10.0).await.unwrap();
let result = filter.process(20.0).await.unwrap();
assert!((result - 12.0).abs() < 0.1);
let result = filter.process(30.0).await.unwrap();
assert!((result - 15.6).abs() < 0.1);
}
#[tokio::test]
#[should_panic(expected = "alpha 必须在 [0, 1] 范围内")]
async fn test_exponential_moving_average_invalid_alpha_low() {
ExponentialMovingAverage::<f64>::new(-0.1);
}
#[tokio::test]
#[should_panic(expected = "alpha 必须在 [0, 1] 范围内")]
async fn test_exponential_moving_average_invalid_alpha_high() {
ExponentialMovingAverage::<f64>::new(1.5);
}
#[tokio::test]
async fn test_exponential_moving_average_boundary() {
let mut filter = ExponentialMovingAverage::<f64>::new(0.0);
filter.process(10.0).await.unwrap();
let result = filter.process(100.0).await.unwrap();
assert!((result - 10.0).abs() < 1e-10);
let mut filter = ExponentialMovingAverage::<f64>::new(1.0);
filter.process(10.0).await.unwrap();
let result = filter.process(100.0).await.unwrap();
assert!((result - 100.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_median_filter_basic() {
let mut filter = MedianFilter::<f64>::new(3);
let result = filter.process(10.0).await.unwrap();
assert!((result - 10.0).abs() < 1e-10);
let result = filter.process(20.0).await.unwrap();
assert!((result - 20.0).abs() < 1e-10);
let result = filter.process(30.0).await.unwrap();
assert!((result - 20.0).abs() < 1e-10);
let result = filter.process(100.0).await.unwrap();
assert!((result - 30.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_median_filter_removes_spikes() {
let mut filter = MedianFilter::<f64>::new(5);
for _ in 0..3 {
filter.process(10.0).await.unwrap();
}
let result = filter.process(1000.0).await.unwrap();
assert!(result < 20.0);
let result = filter.process(-1000.0).await.unwrap();
assert!(result > 0.0);
}
#[tokio::test]
#[should_panic(expected = "窗口大小必须大于0")]
async fn test_median_filter_zero_window() {
MedianFilter::<f64>::new(0);
}
#[tokio::test]
#[should_panic(expected = "窗口大小应该是奇数")]
async fn test_median_filter_even_window() {
MedianFilter::<f64>::new(4);
}
#[tokio::test]
async fn test_rate_limiter_basic() {
let mut filter = RateLimiter::<f64>::new(5.0);
let result = filter.process(10.0).await.unwrap();
assert!((result - 10.0).abs() < 1e-10);
let result = filter.process(25.0).await.unwrap();
assert!((result - 15.0).abs() < 1e-10);
let result = filter.process(0.0).await.unwrap();
assert!((result - 10.0).abs() < 1e-10);
let result = filter.process(13.0).await.unwrap();
assert!((result - 13.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_rate_limiter_zero_max_change() {
let mut filter = RateLimiter::<f64>::new(0.0);
filter.process(10.0).await.unwrap();
let result = filter.process(100.0).await.unwrap();
assert!((result - 10.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_deadzone_filter_basic() {
let mut filter = DeadzoneFilter::<f64>::new(2.0);
let result = filter.process(10.0).await.unwrap();
assert!((result - 10.0).abs() < 1e-10);
let result = filter.process(11.0).await.unwrap();
assert!((result - 10.0).abs() < 1e-10);
let result = filter.process(16.0).await.unwrap();
assert!((result - 16.0).abs() < 1e-10);
let result = filter.process(15.0).await.unwrap();
assert!((result - 16.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_deadzone_filter_zero_threshold() {
let mut filter = DeadzoneFilter::<f64>::new(0.0);
filter.process(10.0).await.unwrap();
let result = filter.process(10.0001).await.unwrap();
assert!((result - 10.0001).abs() < 1e-10);
}
#[tokio::test]
async fn test_low_pass_filter_basic() {
let mut filter = LowPassFilter::<f64>::new(1.0);
let result = filter.process(10.0).await.unwrap();
assert!((result - 10.0).abs() < 1e-10);
let result = filter.process(20.0).await.unwrap();
assert!((result - 15.0).abs() < 1e-10);
let result = filter.process(30.0).await.unwrap();
assert!((result - 22.5).abs() < 1e-10);
}
#[tokio::test]
async fn test_low_pass_filter_large_time_constant() {
let mut filter = LowPassFilter::<f64>::new(10.0);
filter.process(10.0).await.unwrap();
let result = filter.process(20.0).await.unwrap();
assert!((result - 11.0).abs() < 0.1);
}
#[tokio::test]
#[should_panic(expected = "时间常数必须大于等于0")]
async fn test_low_pass_filter_negative_time_constant() {
LowPassFilter::<f64>::new(-1.0);
}
#[tokio::test]
async fn test_low_pass_filter_zero_time_constant() {
let mut filter = LowPassFilter::<f64>::new(0.0);
filter.process(10.0).await.unwrap();
let result = filter.process(20.0).await.unwrap();
assert!((result - 20.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_filter_chain() {
let mut moving_avg = MovingAverage::<f64>::new(3);
let mut rate_limiter = RateLimiter::<f64>::new(2.0);
let avg_result = moving_avg.process(10.0).await.unwrap();
let final_result = rate_limiter.process(avg_result + 1.0).await.unwrap();
assert!((final_result - 11.0).abs() < 1e-10);
}
#[tokio::test]
async fn test_filters_with_nan() {
let mut filter = MovingAverage::<f64>::new(3);
filter.process(f64::NAN).await.unwrap();
let result = filter.process(10.0).await.unwrap();
assert!(result.is_nan());
}
#[tokio::test]
async fn test_filters_with_infinity() {
let mut filter = RateLimiter::<f64>::new(5.0);
filter.process(10.0).await.unwrap();
let result = filter.process(f64::INFINITY).await.unwrap();
assert!(result.is_finite() || result.is_infinite());
}
#[tokio::test]
async fn test_filters_with_negative_numbers() {
let mut filter = MovingAverage::<f64>::new(3);
filter.process(-10.0).await.unwrap();
filter.process(-20.0).await.unwrap();
let result = filter.process(-30.0).await.unwrap();
assert!((result - (-20.0)).abs() < 1e-10);
}
#[tokio::test]
async fn test_filters_stability() {
let mut filter = ExponentialMovingAverage::<f64>::new(0.1);
filter.process(100.0).await.unwrap();
for _ in 0..100 {
filter.process(50.0).await.unwrap();
}
let result = filter.process(50.0).await.unwrap();
assert!((result - 50.0).abs() < 1.0);
}
#[tokio::test]
async fn test_median_filter_with_duplicates() {
let mut filter = MedianFilter::<f64>::new(5);
for _ in 0..5 {
filter.process(10.0).await.unwrap();
}
let result = filter.process(10.0).await.unwrap();
assert!((result - 10.0).abs() < 1e-10);
}
}