use crate::core::{Error, PeriodType, ValueType, Window};
use crate::core::{Method, MovingAverage};
use crate::helpers::Peekable;
use std::{cmp::Ordering, slice::SliceIndex};
#[cfg(feature = "serde")]
use serde::{ser::SerializeStruct, Deserialize, Deserializer, Serialize, Serializer};
#[inline]
#[cfg(feature = "unsafe_performance")]
#[allow(unsafe_code)]
fn get<T>(slice: &[ValueType], index: T) -> &T::Output
where
T: SliceIndex<[ValueType]>,
{
unsafe { slice.get_unchecked(index) }
}
#[inline]
#[cfg(not(feature = "unsafe_performance"))]
fn get<T>(slice: &[ValueType], index: T) -> &T::Output
where
T: SliceIndex<[ValueType]>,
{
&slice[index]
}
#[inline]
fn next_half(
value: ValueType,
slice: &[ValueType],
padding: usize,
f: fn(value: ValueType, slice: &[ValueType], padding: usize) -> usize,
) -> usize {
let half = slice.len() / 2;
if value.to_bits() == get(slice, half).to_bits() {
padding + half
} else if &value > get(slice, half) {
f(value, get(slice, (half + 1)..), padding + half + 1)
} else {
f(value, get(slice, ..half), padding)
}
}
#[inline]
fn find_index(value: ValueType, slice: &[ValueType], padding: usize) -> usize {
if slice.len() < 2 {
return padding + 1 - slice.len();
}
next_half(value, slice, padding, find_index)
}
#[inline]
fn find_insert_index(value: ValueType, slice: &[ValueType], padding: usize) -> usize {
if slice.is_empty() {
return padding;
}
next_half(value, slice, padding, find_insert_index)
}
#[derive(Debug, Clone)]
pub struct SMM {
half: PeriodType,
half_m1: PeriodType,
window: Window<ValueType>,
slice: Box<[ValueType]>,
}
impl SMM {
#[inline]
#[must_use]
pub const fn get_window(&self) -> &Window<ValueType> {
&self.window
}
#[inline]
#[must_use]
#[deprecated(since = "0.5.1", note = "Use `Peekable::peek` instead")]
pub fn get_last_value(&self) -> ValueType {
self.peek()
}
}
impl Method for SMM {
type Params = PeriodType;
type Input = ValueType;
type Output = Self::Input;
fn new(length: Self::Params, &value: &Self::Input) -> Result<Self, Error> {
if !value.is_finite() {
return Err(Error::InvalidCandles);
}
match length {
0 => Err(Error::WrongMethodParameters),
length => {
let half = length / 2;
let is_even = length % 2 == 0;
Ok(Self {
half,
half_m1: half.saturating_sub(is_even as PeriodType),
window: Window::new(length, value),
slice: vec![value; length as usize].into(),
})
}
}
}
#[inline]
fn next(&mut self, &value: &Self::Input) -> Self::Output {
assert!(
value.is_finite(),
"SMM method cannot operate with NAN values"
);
let old_value = self.window.push(value);
let old_index = find_index(old_value, &self.slice, 0);
let index = find_insert_index(value, &self.slice, 0);
let index = index - (old_index < index) as usize;
if cfg!(feature = "unsafe_performance") {
if index != old_index {
let is_after = (index > old_index) as usize;
let start = (old_index + 1) * is_after + index * (1 - is_after);
let dest = old_index * is_after + (index + 1) * (1 - is_after);
let count = index.saturating_sub(old_index) * is_after
+ old_index.saturating_sub(index) * (1 - is_after);
#[allow(unsafe_code)]
unsafe {
std::ptr::copy(
self.slice.as_ptr().add(start),
self.slice.as_mut_ptr().add(dest),
count,
);
}
}
#[allow(unsafe_code)]
unsafe {
let q = self.slice.get_unchecked_mut(index);
*q = value;
}
} else {
match index.cmp(&old_index) {
Ordering::Greater => self.slice.copy_within((old_index + 1)..=index, old_index),
Ordering::Less => self.slice.copy_within(index..old_index, index + 1),
Ordering::Equal => {}
};
self.slice[index] = value;
}
self.peek()
}
}
#[cfg(feature = "serde")]
impl Serialize for SMM {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut s = serializer.serialize_struct("SMM", 1)?;
s.serialize_field("window", &self.window)?;
s.end()
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for SMM {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct DeserializedSMM {
window: Window<ValueType>,
}
let de = DeserializedSMM::deserialize(deserializer)?;
let window = de.window;
if window.is_empty() {
return Err(serde::de::Error::custom("SMM must have non-zero length."));
}
let mut slice = window.as_slice().to_owned().into_boxed_slice();
let mut sort_error = false;
slice.sort_unstable_by(|a, b| {
a.partial_cmp(b).unwrap_or_else(|| {
sort_error = true;
Ordering::Equal
})
});
if sort_error {
return Err(serde::de::Error::custom("SMM cannot operate NaN values"));
}
let half = window.len() / 2;
let is_even = window.len() % 2 == 0;
let smm = Self {
half,
half_m1: half.saturating_sub(is_even as PeriodType),
window,
slice,
};
Ok(smm)
}
}
impl MovingAverage for SMM {}
impl Peekable<<Self as Method>::Output> for SMM {
fn peek(&self) -> <Self as Method>::Output {
(get(&self.slice, self.half as usize) + get(&self.slice, self.half_m1 as usize)) * 0.5
}
}
#[cfg(test)]
mod tests {
use super::{Method, SMM as TestingMethod};
use crate::core::ValueType;
use crate::helpers::{assert_eq_float, RandomCandles};
use crate::methods::tests::test_const;
#[test]
fn test_smm_const() {
for i in 1..255 {
let input = (i as ValueType + 56.0) / 16.3251;
let mut method = TestingMethod::new(i, &input).unwrap();
let output = method.next(&input);
test_const(&mut method, &input, &output);
}
}
#[test]
fn test_smm1() {
let mut candles = RandomCandles::default();
let mut ma = TestingMethod::new(1, &candles.first().close).unwrap();
candles.take(100).for_each(|x| {
assert_eq_float(x.close, ma.next(&x.close));
});
}
#[test]
fn test_smm() {
let candles = RandomCandles::default();
let src: Vec<ValueType> = candles.take(3000).map(|x| x.close).collect();
for &ma_length in &[1, 2, 3, 5, 11, 23, 51, 100, 150, 203, 254] {
let mut ma = TestingMethod::new(ma_length, &src[0]).unwrap();
let ma_length = ma_length as usize;
src.iter().enumerate().for_each(|(i, x)| {
let value = ma.next(x);
let slice_from = i.saturating_sub(ma_length - 1);
let slice_to = i;
let mut slice = Vec::with_capacity(ma_length);
src.iter()
.skip(slice_from)
.take(slice_to - slice_from + 1)
.for_each(|&x| slice.push(x));
while slice.len() < ma_length {
slice.push(src[0]);
}
slice.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert_eq!(slice.len(), ma.slice.len());
slice
.iter()
.zip(ma.slice.iter())
.for_each(|(&a, &b)| assert_eq!(a.to_bits(), b.to_bits()));
let value2 = if ma_length % 2 == 0 {
(slice[ma_length / 2] + slice[ma_length / 2 - 1]) / 2.0
} else {
slice[ma_length / 2]
};
assert_eq_float(value2, value);
});
}
}
}