use crate::ring_buffer::buffer::{period_to_idx, BufferElement, SerdeElement};
use serde::{
de::{self, MapAccess, Visitor},
ser::SerializeStruct,
Deserialize, Deserializer, Serialize, Serializer,
};
use std::{fmt, marker::PhantomData};
#[derive(Clone)]
pub struct FixedRingBuffer<T: BufferElement, const N: usize> {
pub(crate) vals: [T; N],
pub(crate) index: usize,
pub(crate) count: usize,
}
impl<T: BufferElement, const N: usize> FixedRingBuffer<T, N> {
#[inline]
pub fn new() -> Self {
Self {
vals: [T::default(); N],
index: 0,
count: 0,
}
}
#[inline(always)]
pub fn is_full(&self) -> bool {
self.count == N
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.count == 0
}
#[inline(always)]
pub fn len(&self) -> usize {
self.count
}
#[inline(always)]
pub const fn capacity(&self) -> usize {
N
}
#[inline(always)]
pub fn push(&mut self, value: T) {
self.vals[self.index] = value;
self.index += 1;
if self.index == N {
self.index = 0;
}
if self.count < N {
self.count += 1;
}
}
#[inline(always)]
pub fn push_with_info(&mut self, value: T) -> Option<T> {
if self.count == N {
Some(unsafe { self.push_with_info_unchecked(value) })
} else {
self.push(value);
None
}
}
#[inline(always)]
pub unsafe fn push_unchecked(&mut self, value: T) {
*self.vals.get_unchecked_mut(self.index) = value;
self.index += 1;
if self.index == N {
self.index = 0;
}
}
#[inline(always)]
pub unsafe fn push_with_info_unchecked(&mut self, value: T) -> T {
let evicted = *self.vals.get_unchecked(self.index);
self.push_unchecked(value);
evicted
}
#[inline(always)]
pub fn get_slice(&self) -> &[T] {
if self.count < N {
&self.vals[..self.count]
} else {
&self.vals
}
}
#[inline(always)]
pub fn back(&self) -> Option<T> {
if self.count == 0 {
return None;
}
let prev = (self.index + N - 1) % N;
Some(unsafe { *self.vals.get_unchecked(prev) })
}
#[inline(always)]
pub fn front(&self) -> Option<T> {
if self.count == 0 {
return None;
}
let oldest = if self.count == N { self.index } else { 0 };
Some(unsafe { *self.vals.get_unchecked(oldest) })
}
#[inline(always)]
pub fn get_by_period(&self, period: usize) -> T {
let idx = period_to_idx(self.index, N, period);
unsafe { *self.vals.get_unchecked(idx) }
}
pub fn to_ordered_vec(&self) -> Vec<T> {
if self.count == 0 {
return Vec::new();
}
if self.count < N {
return self.vals[..self.count].to_vec();
}
let mut out = Vec::with_capacity(N);
out.extend_from_slice(&self.vals[self.index..]);
if self.index > 0 {
out.extend_from_slice(&self.vals[..self.index]);
}
out
}
pub fn to_ordered_by_period(&self, period: usize) -> Vec<T> {
if self.count == 0 || period == 0 {
return Vec::new();
}
let take = period.min(self.count);
(0..take)
.map(|i| self.get_by_period(take - 1 - i))
.collect()
}
#[inline(always)]
pub fn window_index_to_bars_ago(&self, window_index: usize) -> usize {
self.count - 1 - window_index
}
}
impl<T: BufferElement, const N: usize> Default for FixedRingBuffer<T, N> {
fn default() -> Self {
Self::new()
}
}
pub struct FixedRingIter<'a, T: BufferElement, const N: usize> {
buffer: &'a FixedRingBuffer<T, N>,
pos: usize,
}
impl<'a, T: BufferElement, const N: usize> Iterator for FixedRingIter<'a, T, N> {
type Item = T;
#[inline]
fn next(&mut self) -> Option<T> {
if self.pos >= self.buffer.count {
return None;
}
let val = self.buffer.get_by_period(self.pos);
self.pos += 1;
Some(val)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.buffer.count.saturating_sub(self.pos);
(remaining, Some(remaining))
}
}
impl<'a, T: BufferElement, const N: usize> ExactSizeIterator for FixedRingIter<'a, T, N> {}
impl<'a, T: BufferElement, const N: usize> IntoIterator for &'a FixedRingBuffer<T, N> {
type Item = T;
type IntoIter = FixedRingIter<'a, T, N>;
#[inline]
fn into_iter(self) -> FixedRingIter<'a, T, N> {
FixedRingIter {
buffer: self,
pos: 0,
}
}
}
impl<T: BufferElement, const N: usize> std::ops::Index<usize> for FixedRingBuffer<T, N> {
type Output = T;
#[inline]
fn index(&self, bars_ago: usize) -> &T {
debug_assert!(
bars_ago < self.count,
"index out of bounds: bars_ago {bars_ago} >= count {}",
self.count
);
let idx = period_to_idx(self.index, N, bars_ago);
&self.vals[idx]
}
}
impl<T: BufferElement + SerdeElement, const N: usize> Serialize for FixedRingBuffer<T, N> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut state = serializer.serialize_struct("FixedRingBuffer", 3)?;
let repr: Vec<T::Repr> = self.vals.iter().map(|v| T::to_repr(*v)).collect();
state.serialize_field("vals", &repr)?;
state.serialize_field("index", &self.index)?;
state.serialize_field("count", &self.count)?;
state.end()
}
}
impl<'de, T: BufferElement + SerdeElement, const N: usize> Deserialize<'de>
for FixedRingBuffer<T, N>
where
T::Repr: Deserialize<'de>,
{
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
const FIELDS: &[&str] = &["vals", "index", "count"];
enum Field {
Vals,
Index,
Count,
}
impl<'de> Deserialize<'de> for Field {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct FieldVisitor;
impl<'de> Visitor<'de> for FieldVisitor {
type Value = Field;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("`vals`, `index`, or `count`")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Field, E> {
match v {
"vals" => Ok(Field::Vals),
"index" => Ok(Field::Index),
"count" => Ok(Field::Count),
_ => Err(de::Error::unknown_field(v, FIELDS)),
}
}
}
deserializer.deserialize_identifier(FieldVisitor)
}
}
struct FRBVisitor<T, const N: usize>(PhantomData<fn() -> T>);
impl<'de, T: BufferElement + SerdeElement, const N: usize> Visitor<'de> for FRBVisitor<T, N>
where
T::Repr: Deserialize<'de>,
{
type Value = FixedRingBuffer<T, N>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("struct FixedRingBuffer")
}
fn visit_map<V: MapAccess<'de>>(
self,
mut map: V,
) -> Result<FixedRingBuffer<T, N>, V::Error> {
let mut vals: Option<Vec<T::Repr>> = None;
let mut index: Option<usize> = None;
let mut count: Option<usize> = None;
while let Some(key) = map.next_key::<Field>()? {
match key {
Field::Vals => {
if vals.is_some() {
return Err(de::Error::duplicate_field("vals"));
}
vals = Some(map.next_value()?);
}
Field::Index => {
if index.is_some() {
return Err(de::Error::duplicate_field("index"));
}
index = Some(map.next_value()?);
}
Field::Count => {
if count.is_some() {
return Err(de::Error::duplicate_field("count"));
}
count = Some(map.next_value()?);
}
}
}
let vals_repr: Vec<T::Repr> =
vals.ok_or_else(|| de::Error::missing_field("vals"))?;
let index = index.ok_or_else(|| de::Error::missing_field("index"))?;
let count = count.ok_or_else(|| de::Error::missing_field("count"))?;
let vals_vec: Vec<T> = vals_repr.into_iter().map(T::from_repr).collect();
let vals_arr: [T; N] = vals_vec.try_into().map_err(|v: Vec<T>| {
de::Error::invalid_length(v.len(), &"vals array of length N")
})?;
Ok(FixedRingBuffer {
vals: vals_arr,
index,
count,
})
}
}
deserializer.deserialize_struct("FixedRingBuffer", FIELDS, FRBVisitor::<T, N>(PhantomData))
}
}