use crate::indicators::{
max::{find_max_scalar, find_max_simd, State as MaxState},
min::{find_min_scalar, find_min_simd, State as MinState},
};
use crate::ring_buffer::buffer::{period_to_idx, BufferElement, SerdeElement};
use serde::{
de::{self, MapAccess, Visitor},
ser::SerializeStruct,
Deserialize, Deserializer, Serialize, Serializer,
};
use std::{fmt, marker::PhantomData};
#[derive(Clone)]
pub struct FixedMirrorBuffer<T: BufferElement, const N: usize> {
pub(crate) ring: [T; N],
pub(crate) view: [T; N],
pub(crate) index: usize,
pub(crate) count: usize,
}
impl<T: BufferElement, const N: usize> FixedMirrorBuffer<T, N> {
#[inline]
pub fn new() -> Self {
Self {
ring: [T::default(); N],
view: [T::default(); N],
index: 0,
count: 0,
}
}
#[inline(always)]
pub fn is_full(&self) -> bool {
self.count == N
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.count == 0
}
#[inline(always)]
pub fn len(&self) -> usize {
self.count
}
#[inline(always)]
pub const fn capacity(&self) -> usize {
N
}
#[inline(always)]
pub fn push(&mut self, value: T) {
self.ring[self.index] = value;
self.index += 1;
if self.index == N {
self.index = 0;
}
if self.count < N {
self.view[self.count] = value;
self.count += 1;
} else {
self.view.copy_within(1.., 0);
self.view[N - 1] = value;
}
}
#[inline(always)]
pub fn push_with_info(&mut self, value: T) -> Option<T> {
if self.count == N {
Some(unsafe { self.push_with_info_unchecked(value) })
} else {
self.push(value);
None
}
}
#[inline(always)]
pub unsafe fn push_unchecked(&mut self, value: T) {
*self.ring.get_unchecked_mut(self.index) = value;
self.index += 1;
if self.index == N {
self.index = 0;
}
self.view.copy_within(1.., 0);
*self.view.get_unchecked_mut(N - 1) = value;
}
#[inline(always)]
pub unsafe fn push_with_info_unchecked(&mut self, value: T) -> T {
let evicted = *self.view.get_unchecked(0);
self.push_unchecked(value);
evicted
}
#[inline(always)]
pub fn get_slice(&self) -> &[T] {
&self.view[..self.count]
}
#[inline(always)]
pub fn get_slice_mut(&mut self) -> &mut [T] {
&mut self.view[..self.count]
}
#[inline(always)]
pub fn get_slice_by_period(&self, period: usize) -> &[T] {
if self.count == 0 || period == 0 {
return &[];
}
let take = period.min(self.count);
&self.view[self.count - take..self.count]
}
#[inline(always)]
pub fn get_by_period(&self, period: usize) -> T {
let idx = period_to_idx(self.index, N, period);
unsafe { *self.ring.get_unchecked(idx) }
}
#[inline(always)]
pub fn window_index_to_bars_ago(&self, window_index: usize) -> usize {
self.count - 1 - window_index
}
pub fn to_ordered_vec(&self) -> Vec<T> {
self.view[..self.count].to_vec()
}
pub fn sync_mirrors(&mut self) {
if self.count == 0 {
return;
}
if self.count < N {
for i in 0..self.count {
self.ring[i] = self.view[i];
}
} else {
for i in 0..N {
self.ring[(self.index + i) % N] = self.view[i];
}
}
}
}
impl<const N: usize> FixedMirrorBuffer<f64, N> {
#[inline(always)]
pub fn max<const CHUNK_SIZE: usize>(
&self,
state: &mut MaxState,
bar: f64,
period: usize,
) -> (f64, usize) {
let (mut max, mut trail) = (state.max, state.trail);
trail += 1;
if period <= trail {
let window = self.get_slice_by_period(period);
let (max_val, max_idx) = if CHUNK_SIZE == 1 {
find_max_scalar(window)
} else {
find_max_simd::<CHUNK_SIZE>(window)
};
max = max_val;
trail = window.len().saturating_sub(1 + max_idx);
} else if bar >= max {
max = bar;
trail = 0;
}
(state.max, state.trail) = (max, trail);
(max, trail)
}
#[inline(always)]
pub fn min<const CHUNK_SIZE: usize>(
&self,
state: &mut MinState,
bar: f64,
period: usize,
) -> (f64, usize) {
let (mut min, mut trail) = (state.min, state.trail);
trail += 1;
if period <= trail {
let window = self.get_slice_by_period(period);
let (min_val, min_idx) = if CHUNK_SIZE == 1 {
find_min_scalar(window)
} else {
find_min_simd::<CHUNK_SIZE>(window)
};
min = min_val;
trail = window.len().saturating_sub(1 + min_idx);
} else if bar <= min {
min = bar;
trail = 0;
}
(state.min, state.trail) = (min, trail);
(min, trail)
}
}
impl<T: BufferElement, const N: usize> Default for FixedMirrorBuffer<T, N> {
fn default() -> Self {
Self::new()
}
}
pub struct FixedMirrorIter<'a, T: BufferElement, const N: usize> {
buffer: &'a FixedMirrorBuffer<T, N>,
pos: usize,
}
impl<'a, T: BufferElement, const N: usize> Iterator for FixedMirrorIter<'a, T, N> {
type Item = T;
#[inline]
fn next(&mut self) -> Option<T> {
if self.pos >= self.buffer.count {
return None;
}
let val = self.buffer.view[self.buffer.count - 1 - self.pos];
self.pos += 1;
Some(val)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.buffer.count.saturating_sub(self.pos);
(remaining, Some(remaining))
}
}
impl<'a, T: BufferElement, const N: usize> ExactSizeIterator for FixedMirrorIter<'a, T, N> {}
impl<'a, T: BufferElement, const N: usize> IntoIterator for &'a FixedMirrorBuffer<T, N> {
type Item = T;
type IntoIter = FixedMirrorIter<'a, T, N>;
#[inline]
fn into_iter(self) -> FixedMirrorIter<'a, T, N> {
FixedMirrorIter {
buffer: self,
pos: 0,
}
}
}
impl<T: BufferElement, const N: usize> std::ops::Index<usize> for FixedMirrorBuffer<T, N> {
type Output = T;
#[inline]
fn index(&self, bars_ago: usize) -> &T {
assert!(
bars_ago < self.count,
"index out of bounds: bars_ago {bars_ago} >= count {}",
self.count
);
&self.view[self.count - 1 - bars_ago]
}
}
impl<T: BufferElement + SerdeElement, const N: usize> Serialize for FixedMirrorBuffer<T, N> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut state = serializer.serialize_struct("FixedMirrorBuffer", 4)?;
let ring_repr: Vec<T::Repr> = self.ring.iter().map(|v| T::to_repr(*v)).collect();
let view_repr: Vec<T::Repr> = self.view.iter().map(|v| T::to_repr(*v)).collect();
state.serialize_field("ring", &ring_repr)?;
state.serialize_field("view", &view_repr)?;
state.serialize_field("index", &self.index)?;
state.serialize_field("count", &self.count)?;
state.end()
}
}
impl<'de, T: BufferElement + SerdeElement, const N: usize> Deserialize<'de>
for FixedMirrorBuffer<T, N>
where
T::Repr: Deserialize<'de>,
{
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
const FIELDS: &[&str] = &["ring", "view", "index", "count"];
enum Field {
Ring,
View,
Index,
Count,
}
impl<'de> Deserialize<'de> for Field {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct FieldVisitor;
impl<'de> Visitor<'de> for FieldVisitor {
type Value = Field;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("`ring`, `view`, `index`, or `count`")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Field, E> {
match v {
"ring" => Ok(Field::Ring),
"view" => Ok(Field::View),
"index" => Ok(Field::Index),
"count" => Ok(Field::Count),
_ => Err(de::Error::unknown_field(v, FIELDS)),
}
}
}
deserializer.deserialize_identifier(FieldVisitor)
}
}
struct FMBVisitor<T, const N: usize>(PhantomData<fn() -> T>);
impl<'de, T: BufferElement + SerdeElement, const N: usize> Visitor<'de> for FMBVisitor<T, N>
where
T::Repr: Deserialize<'de>,
{
type Value = FixedMirrorBuffer<T, N>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("struct FixedMirrorBuffer")
}
fn visit_map<V: MapAccess<'de>>(
self,
mut map: V,
) -> Result<FixedMirrorBuffer<T, N>, V::Error> {
let mut ring: Option<Vec<T::Repr>> = None;
let mut view: Option<Vec<T::Repr>> = None;
let mut index: Option<usize> = None;
let mut count: Option<usize> = None;
while let Some(key) = map.next_key::<Field>()? {
match key {
Field::Ring => {
if ring.is_some() {
return Err(de::Error::duplicate_field("ring"));
}
ring = Some(map.next_value()?);
}
Field::View => {
if view.is_some() {
return Err(de::Error::duplicate_field("view"));
}
view = Some(map.next_value()?);
}
Field::Index => {
if index.is_some() {
return Err(de::Error::duplicate_field("index"));
}
index = Some(map.next_value()?);
}
Field::Count => {
if count.is_some() {
return Err(de::Error::duplicate_field("count"));
}
count = Some(map.next_value()?);
}
}
}
let ring_repr: Vec<T::Repr> =
ring.ok_or_else(|| de::Error::missing_field("ring"))?;
let view_repr: Vec<T::Repr> =
view.ok_or_else(|| de::Error::missing_field("view"))?;
let index = index.ok_or_else(|| de::Error::missing_field("index"))?;
let count = count.ok_or_else(|| de::Error::missing_field("count"))?;
let ring_vec: Vec<T> = ring_repr.into_iter().map(T::from_repr).collect();
let view_vec: Vec<T> = view_repr.into_iter().map(T::from_repr).collect();
let ring_arr: [T; N] = ring_vec.try_into().map_err(|v: Vec<T>| {
de::Error::invalid_length(v.len(), &"ring array of length N")
})?;
let view_arr: [T; N] = view_vec.try_into().map_err(|v: Vec<T>| {
de::Error::invalid_length(v.len(), &"view array of length N")
})?;
Ok(FixedMirrorBuffer {
ring: ring_arr,
view: view_arr,
index,
count,
})
}
}
deserializer.deserialize_struct(
"FixedMirrorBuffer",
FIELDS,
FMBVisitor::<T, N>(PhantomData),
)
}
}