use core::slice;
use smallvec::SmallVec;
use crate::core::{PolicyGroupId, DEFAULT_POLICY_GROUP_ID};
use crate::param::Price;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct PreTradeLock {
default_prices: SmallVec<[Price; 1]>,
other: SmallVec<[GroupSection; 2]>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
struct GroupSection {
group_id: PolicyGroupId,
prices: SmallVec<[Price; 1]>,
}
impl PreTradeLock {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.default_prices.len() + self.other.iter().map(|s| s.prices.len()).sum::<usize>()
}
pub fn is_empty(&self) -> bool {
self.default_prices.is_empty() && self.other.iter().all(|s| s.prices.is_empty())
}
pub fn from_entries<EntriesIter>(entries: EntriesIter) -> Self
where
EntriesIter: IntoIterator<Item = (PolicyGroupId, Price)>,
{
let mut lock = Self::default();
lock.append(entries);
lock
}
pub fn push(&mut self, policy_group_id: PolicyGroupId, price: Price) {
if policy_group_id == DEFAULT_POLICY_GROUP_ID {
self.default_prices.push(price);
return;
}
for section in self.other.iter_mut() {
if section.group_id == policy_group_id {
section.prices.push(price);
return;
}
}
let mut prices = SmallVec::<[Price; 1]>::new();
prices.push(price);
self.other.push(GroupSection {
group_id: policy_group_id,
prices,
});
}
pub fn push_many<PricesIter>(&mut self, policy_group_id: PolicyGroupId, prices: PricesIter)
where
PricesIter: IntoIterator<Item = Price>,
PricesIter::IntoIter: ExactSizeIterator,
{
let iter = prices.into_iter();
let count = iter.len();
if count == 0 {
return;
}
if policy_group_id == DEFAULT_POLICY_GROUP_ID {
self.default_prices.reserve(count);
self.default_prices.extend(iter);
return;
}
if let Some(section) = self
.other
.iter_mut()
.find(|s| s.group_id == policy_group_id)
{
section.prices.reserve(count);
section.prices.extend(iter);
return;
}
let mut section_prices = SmallVec::<[Price; 1]>::with_capacity(count);
section_prices.extend(iter);
self.other.push(GroupSection {
group_id: policy_group_id,
prices: section_prices,
});
}
pub fn merge(&mut self, other: &Self) {
self.default_prices.extend_from_slice(&other.default_prices);
for section in &other.other {
if let Some(existing) = self
.other
.iter_mut()
.find(|s| s.group_id == section.group_id)
{
existing.prices.extend_from_slice(§ion.prices);
} else {
self.other.push(section.clone());
}
}
}
pub fn entries(&self) -> Entries<'_> {
Entries {
default_iter: self.default_prices.iter(),
sections_iter: self.other.iter(),
current: None,
}
}
pub fn prices_of(&self, policy_group_id: PolicyGroupId) -> PricesByGroup<'_> {
const EMPTY: &[Price] = &[];
let slice: &[Price] = if policy_group_id == DEFAULT_POLICY_GROUP_ID {
self.default_prices.as_slice()
} else {
self.other
.iter()
.find(|section| section.group_id == policy_group_id)
.map_or(EMPTY, |section| section.prices.as_slice())
};
PricesByGroup { iter: slice.iter() }
}
fn append<EntriesIter>(&mut self, entries: EntriesIter)
where
EntriesIter: IntoIterator<Item = (PolicyGroupId, Price)>,
{
let iter = entries.into_iter();
let (lower, _) = iter.size_hint();
if lower > 0 {
self.other.reserve(lower);
}
for (group_id, price) in iter {
self.push(group_id, price);
}
}
}
impl FromIterator<(PolicyGroupId, Price)> for PreTradeLock {
fn from_iter<EntriesIter>(iter: EntriesIter) -> Self
where
EntriesIter: IntoIterator<Item = (PolicyGroupId, Price)>,
{
Self::from_entries(iter)
}
}
impl Extend<(PolicyGroupId, Price)> for PreTradeLock {
fn extend<EntriesIter>(&mut self, iter: EntriesIter)
where
EntriesIter: IntoIterator<Item = (PolicyGroupId, Price)>,
{
self.append(iter);
}
}
pub struct Entries<'a> {
default_iter: slice::Iter<'a, Price>,
sections_iter: slice::Iter<'a, GroupSection>,
current: Option<(PolicyGroupId, slice::Iter<'a, Price>)>,
}
impl Iterator for Entries<'_> {
type Item = (PolicyGroupId, Price);
fn next(&mut self) -> Option<Self::Item> {
if let Some(price) = self.default_iter.next() {
return Some((DEFAULT_POLICY_GROUP_ID, *price));
}
loop {
if let Some((group_id, iter)) = self.current.as_mut() {
if let Some(price) = iter.next() {
return Some((*group_id, *price));
}
self.current = None;
}
let section = self.sections_iter.next()?;
self.current = Some((section.group_id, section.prices.iter()));
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let mut remaining = self.default_iter.len();
if let Some((_, iter)) = self.current.as_ref() {
remaining += iter.len();
}
for section in self.sections_iter.clone() {
remaining += section.prices.len();
}
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for Entries<'_> {}
pub struct PricesByGroup<'a> {
iter: slice::Iter<'a, Price>,
}
impl Iterator for PricesByGroup<'_> {
type Item = Price;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().copied()
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl ExactSizeIterator for PricesByGroup<'_> {
fn len(&self) -> usize {
self.iter.len()
}
}
#[cfg(feature = "serde")]
mod serde_impl {
use core::fmt;
use serde::de::{DeserializeSeed, Error as DeError, SeqAccess, Visitor};
use serde::ser::SerializeSeq;
use serde::{Deserializer, Serializer};
use super::{GroupSection, PreTradeLock};
use crate::core::{PolicyGroupId, DEFAULT_POLICY_GROUP_ID};
use crate::param::Price;
struct SectionView<'a>(&'a GroupSection);
impl serde::Serialize for SectionView<'_> {
fn serialize<Target>(&self, target: Target) -> Result<Target::Ok, Target::Error>
where
Target: Serializer,
{
let mut seq = target.serialize_seq(Some(1 + self.0.prices.len()))?;
seq.serialize_element(&self.0.group_id.value())?;
for price in &self.0.prices {
seq.serialize_element(price)?;
}
seq.end()
}
}
impl serde::Serialize for PreTradeLock {
fn serialize<Target>(&self, target: Target) -> Result<Target::Ok, Target::Error>
where
Target: Serializer,
{
if self.is_empty() {
return target.serialize_seq(Some(0))?.end();
}
let mut seq = target.serialize_seq(Some(1 + self.other.len()))?;
seq.serialize_element(self.default_prices.as_slice())?;
for section in &self.other {
seq.serialize_element(&SectionView(section))?;
}
seq.end()
}
}
struct DefaultPricesSeed<'a>(&'a mut PreTradeLock);
impl<'de> DeserializeSeed<'de> for DefaultPricesSeed<'_> {
type Value = ();
fn deserialize<Source>(self, source: Source) -> Result<Self::Value, Source::Error>
where
Source: Deserializer<'de>,
{
source.deserialize_seq(DefaultPricesVisitor(self.0))
}
}
struct DefaultPricesVisitor<'a>(&'a mut PreTradeLock);
impl<'de> Visitor<'de> for DefaultPricesVisitor<'_> {
type Value = ();
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("a list of prices for the default group")
}
fn visit_seq<SeqEntries>(
self,
mut entries: SeqEntries,
) -> Result<Self::Value, SeqEntries::Error>
where
SeqEntries: SeqAccess<'de>,
{
if let Some(hint) = entries.size_hint() {
self.0.default_prices.reserve(hint);
}
while let Some(price) = entries.next_element::<Price>()? {
self.0.push(DEFAULT_POLICY_GROUP_ID, price);
}
Ok(())
}
}
struct GroupSectionSeed<'a>(&'a mut PreTradeLock);
impl<'de> DeserializeSeed<'de> for GroupSectionSeed<'_> {
type Value = ();
fn deserialize<Source>(self, source: Source) -> Result<Self::Value, Source::Error>
where
Source: Deserializer<'de>,
{
source.deserialize_seq(GroupSectionVisitor(self.0))
}
}
struct GroupSectionVisitor<'a>(&'a mut PreTradeLock);
impl<'de> Visitor<'de> for GroupSectionVisitor<'_> {
type Value = ();
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("a non-default group section [group_id, prices...]")
}
fn visit_seq<SeqEntries>(
self,
mut entries: SeqEntries,
) -> Result<Self::Value, SeqEntries::Error>
where
SeqEntries: SeqAccess<'de>,
{
let raw_group_id: u16 = entries.next_element()?.ok_or_else(|| {
DeError::invalid_length(0, &"non-default group section must start with a group_id")
})?;
let group_id = PolicyGroupId::new(raw_group_id);
if group_id == DEFAULT_POLICY_GROUP_ID {
return Err(DeError::custom(
"default group must be encoded as the first sublist, not as a tagged section",
));
}
while let Some(price) = entries.next_element::<Price>()? {
self.0.push(group_id, price);
}
Ok(())
}
}
struct LockVisitor;
impl<'de> Visitor<'de> for LockVisitor {
type Value = PreTradeLock;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(
"a sequence whose first element is the default-group price list \
and whose remaining elements are [group_id, prices...] sections",
)
}
fn visit_seq<SeqEntries>(
self,
mut entries: SeqEntries,
) -> Result<Self::Value, SeqEntries::Error>
where
SeqEntries: SeqAccess<'de>,
{
let mut lock = PreTradeLock::new();
if entries
.next_element_seed(DefaultPricesSeed(&mut lock))?
.is_none()
{
return Ok(lock);
}
while entries
.next_element_seed(GroupSectionSeed(&mut lock))?
.is_some()
{}
Ok(lock)
}
}
impl<'de> serde::Deserialize<'de> for PreTradeLock {
fn deserialize<Source>(source: Source) -> Result<Self, Source::Error>
where
Source: Deserializer<'de>,
{
source.deserialize_seq(LockVisitor)
}
}
}
#[cfg(test)]
mod tests {
use super::PreTradeLock;
use crate::core::{PolicyGroupId, DEFAULT_POLICY_GROUP_ID};
use crate::param::Price;
fn price(value: &str) -> Price {
Price::from_str(value).expect("price must be valid")
}
#[test]
fn new_is_empty() {
let lock = PreTradeLock::new();
assert!(lock.is_empty());
assert_eq!(lock.len(), 0);
assert_eq!(lock.entries().count(), 0);
assert!(lock.prices_of(DEFAULT_POLICY_GROUP_ID).next().is_none());
}
#[test]
fn from_entries_populates_default_group() {
let lock = PreTradeLock::from_entries([(DEFAULT_POLICY_GROUP_ID, price("185"))]);
assert_eq!(lock.len(), 1);
let by_default: Vec<_> = lock.prices_of(DEFAULT_POLICY_GROUP_ID).collect();
assert_eq!(by_default, vec![price("185")]);
}
#[test]
fn push_merges_prices_per_group() {
let gid = PolicyGroupId::new(7);
let mut lock = PreTradeLock::new();
lock.push(DEFAULT_POLICY_GROUP_ID, price("100"));
lock.push(gid, price("200"));
lock.push(DEFAULT_POLICY_GROUP_ID, price("101"));
lock.push(gid, price("201"));
let defaults: Vec<_> = lock.prices_of(DEFAULT_POLICY_GROUP_ID).collect();
assert_eq!(defaults, vec![price("100"), price("101")]);
let others: Vec<_> = lock.prices_of(gid).collect();
assert_eq!(others, vec![price("200"), price("201")]);
}
#[test]
fn prices_of_unknown_group_is_empty() {
let mut lock = PreTradeLock::new();
lock.push(PolicyGroupId::new(1), price("10"));
assert!(lock.prices_of(PolicyGroupId::new(99)).next().is_none());
}
#[test]
fn entries_iterates_default_first_then_each_group() {
let gid_a = PolicyGroupId::new(3);
let gid_b = PolicyGroupId::new(5);
let mut lock = PreTradeLock::new();
lock.push(gid_a, price("300"));
lock.push(DEFAULT_POLICY_GROUP_ID, price("100"));
lock.push(gid_b, price("500"));
lock.push(gid_a, price("301"));
let collected: Vec<_> = lock.entries().collect();
assert_eq!(
collected,
vec![
(DEFAULT_POLICY_GROUP_ID, price("100")),
(gid_a, price("300")),
(gid_a, price("301")),
(gid_b, price("500")),
]
);
}
#[test]
fn from_iterator_and_extend_traits_work() {
let gid = PolicyGroupId::new(2);
let source = [
(DEFAULT_POLICY_GROUP_ID, price("10")),
(gid, price("20")),
(DEFAULT_POLICY_GROUP_ID, price("11")),
(gid, price("21")),
];
let collected: PreTradeLock = source.iter().copied().collect();
assert_eq!(collected.len(), 4);
let mut extended = PreTradeLock::new();
extended.extend(source.iter().copied());
assert_eq!(extended, collected);
}
#[test]
fn entries_size_hint_matches_total_length() {
let mut lock = PreTradeLock::new();
lock.push(DEFAULT_POLICY_GROUP_ID, price("1"));
lock.push(PolicyGroupId::new(1), price("2"));
lock.push(PolicyGroupId::new(2), price("3"));
lock.push(PolicyGroupId::new(1), price("4"));
let iter = lock.entries();
assert_eq!(iter.size_hint(), (4, Some(4)));
assert_eq!(iter.len(), 4);
}
#[test]
fn merge_appends_entries() {
let gid = PolicyGroupId::new(7);
let mut base = PreTradeLock::from_entries([
(DEFAULT_POLICY_GROUP_ID, price("100")),
(gid, price("200")),
]);
let extra = PreTradeLock::from_entries([
(DEFAULT_POLICY_GROUP_ID, price("101")),
(gid, price("201")),
]);
base.merge(&extra);
assert_eq!(base.len(), 4);
let defaults: Vec<_> = base.prices_of(DEFAULT_POLICY_GROUP_ID).collect();
assert_eq!(defaults, vec![price("100"), price("101")]);
let others: Vec<_> = base.prices_of(gid).collect();
assert_eq!(others, vec![price("200"), price("201")]);
}
#[test]
fn merge_single_entry_hot_path() {
let mut base = PreTradeLock::from_entries([(DEFAULT_POLICY_GROUP_ID, price("100"))]);
let single = PreTradeLock::from_entries([(DEFAULT_POLICY_GROUP_ID, price("101"))]);
base.merge(&single);
assert_eq!(base.len(), 2);
let defaults: Vec<_> = base.prices_of(DEFAULT_POLICY_GROUP_ID).collect();
assert_eq!(defaults, vec![price("100"), price("101")]);
}
#[test]
fn merge_empty_into_non_empty() {
let mut base = PreTradeLock::from_entries([(DEFAULT_POLICY_GROUP_ID, price("1"))]);
base.merge(&PreTradeLock::new());
assert_eq!(base.len(), 1);
}
#[test]
fn merge_non_empty_into_empty() {
let mut base = PreTradeLock::new();
let other = PreTradeLock::from_entries([(DEFAULT_POLICY_GROUP_ID, price("42"))]);
base.merge(&other);
assert_eq!(base.len(), 1);
assert_eq!(
base.prices_of(DEFAULT_POLICY_GROUP_ID).next(),
Some(price("42"))
);
}
#[test]
fn lock_clone_preserves_state() {
let mut lock = PreTradeLock::from_entries([(DEFAULT_POLICY_GROUP_ID, price("185"))]);
lock.push(PolicyGroupId::new(4), price("400"));
lock.push(PolicyGroupId::new(4), price("401"));
let cloned = lock.clone();
assert_eq!(cloned, lock);
assert_eq!(cloned.len(), 3);
}
#[cfg(feature = "serde")]
#[test]
fn lock_implements_serde_traits() {
fn assert_serde<Subject: serde::Serialize + serde::de::DeserializeOwned>() {}
assert_serde::<PreTradeLock>();
}
#[cfg(feature = "serde")]
#[test]
fn serde_round_trip_default_only() {
let mut original = PreTradeLock::new();
original.push(DEFAULT_POLICY_GROUP_ID, price("100"));
original.push(DEFAULT_POLICY_GROUP_ID, price("200"));
let json = serde_json::to_string(&original).expect("serialize must succeed");
let restored: PreTradeLock = serde_json::from_str(&json).expect("deserialize must succeed");
assert_eq!(restored.len(), 2, "len must be preserved after round-trip");
assert!(!restored.is_empty(), "must not be empty after round-trip");
let defaults: Vec<_> = restored.prices_of(DEFAULT_POLICY_GROUP_ID).collect();
assert_eq!(defaults, vec![price("100"), price("200")]);
let json2 = serde_json::to_string(&restored).expect("re-serialize must succeed");
assert_eq!(json, json2, "repeated serialize must be idempotent");
}
#[cfg(feature = "serde")]
#[test]
fn serde_round_trip_mixed() {
let gid = PolicyGroupId::new(3);
let mut original = PreTradeLock::new();
original.push(DEFAULT_POLICY_GROUP_ID, price("100"));
original.push(DEFAULT_POLICY_GROUP_ID, price("101"));
original.push(gid, price("300"));
original.push(gid, price("301"));
let json = serde_json::to_string(&original).expect("serialize must succeed");
let restored: PreTradeLock = serde_json::from_str(&json).expect("deserialize must succeed");
assert_eq!(restored.len(), 4);
assert!(!restored.is_empty());
let defaults: Vec<_> = restored.prices_of(DEFAULT_POLICY_GROUP_ID).collect();
assert_eq!(defaults, vec![price("100"), price("101")]);
let others: Vec<_> = restored.prices_of(gid).collect();
assert_eq!(others, vec![price("300"), price("301")]);
let json2 = serde_json::to_string(&restored).expect("re-serialize must succeed");
assert_eq!(json, json2, "repeated serialize must be idempotent");
}
}