use std::collections::HashSet;
use nalgebra::{DMatrix, DVector};
use crate::error::{BlpError, Result};
#[derive(Clone, Debug)]
pub struct ProductData {
market_ids: Vec<String>,
shares: DVector<f64>,
x1: DMatrix<f64>,
x2: DMatrix<f64>,
instruments: DMatrix<f64>,
partition: MarketPartition,
}
impl ProductData {
pub fn new(
market_ids: Vec<String>,
shares: DVector<f64>,
x1: DMatrix<f64>,
x2: DMatrix<f64>,
instruments: DMatrix<f64>,
) -> Result<Self> {
let builder = ProductDataBuilder::new(market_ids, shares)
.x1(x1)
.x2(x2)
.instruments(instruments);
builder.build()
}
pub fn product_count(&self) -> usize {
self.shares.len()
}
pub fn linear_dim(&self) -> usize {
self.x1.ncols()
}
pub fn nonlinear_dim(&self) -> usize {
self.x2.ncols()
}
pub fn instrument_dim(&self) -> usize {
self.instruments.ncols()
}
pub fn x1(&self) -> &DMatrix<f64> {
&self.x1
}
pub fn x2(&self) -> &DMatrix<f64> {
&self.x2
}
pub fn instruments(&self) -> &DMatrix<f64> {
&self.instruments
}
pub fn shares(&self) -> &DVector<f64> {
&self.shares
}
pub fn partition(&self) -> &MarketPartition {
&self.partition
}
pub fn outside_share_for_product(&self, product_index: usize) -> f64 {
let market_idx = self.partition.market_of(product_index);
self.partition.markets[market_idx].outside_share
}
pub fn market_id(&self, product_index: usize) -> &str {
&self.market_ids[product_index]
}
}
#[derive(Debug)]
pub struct ProductDataBuilder {
market_ids: Vec<String>,
shares: DVector<f64>,
x1: Option<DMatrix<f64>>,
x2: Option<DMatrix<f64>>,
instruments: Option<DMatrix<f64>>,
}
impl ProductDataBuilder {
pub fn new(market_ids: Vec<String>, shares: DVector<f64>) -> Self {
Self {
market_ids,
shares,
x1: None,
x2: None,
instruments: None,
}
}
pub fn x1(mut self, matrix: DMatrix<f64>) -> Self {
self.x1 = Some(matrix);
self
}
pub fn x2(mut self, matrix: DMatrix<f64>) -> Self {
self.x2 = Some(matrix);
self
}
pub fn instruments(mut self, matrix: DMatrix<f64>) -> Self {
self.instruments = Some(matrix);
self
}
pub fn build(self) -> Result<ProductData> {
let n = self.market_ids.len();
if self.shares.len() != n {
return Err(BlpError::dimension_mismatch(
"shares length",
n,
self.shares.len(),
));
}
for (index, share) in self.shares.iter().enumerate() {
if *share <= 0.0 {
return Err(BlpError::NonPositiveShare {
index,
share: *share,
});
}
}
let x1 = self
.x1
.ok_or_else(|| BlpError::dimension_mismatch("X1", n, 0))?;
if x1.nrows() != n {
return Err(BlpError::dimension_mismatch("X1 rows", n, x1.nrows()));
}
let x2 = self.x2.unwrap_or_else(|| DMatrix::zeros(n, 0));
if x2.nrows() != n {
return Err(BlpError::dimension_mismatch("X2 rows", n, x2.nrows()));
}
let instruments = self.instruments.unwrap_or_else(|| x1.clone());
if instruments.nrows() != n {
return Err(BlpError::dimension_mismatch(
"Z rows",
n,
instruments.nrows(),
));
}
let partition = MarketPartition::new(&self.market_ids, &self.shares)?;
Ok(ProductData {
market_ids: self.market_ids,
shares: self.shares,
x1,
x2,
instruments,
partition,
})
}
}
#[derive(Clone, Debug)]
pub struct MarketPartition {
markets: Vec<MarketSegment>,
product_to_market: Vec<usize>,
}
impl MarketPartition {
fn new(market_ids: &[String], shares: &DVector<f64>) -> Result<Self> {
let n = market_ids.len();
let mut markets = Vec::new();
let mut product_to_market = vec![0usize; n];
let mut seen = HashSet::new();
let mut start = 0usize;
while start < n {
let market_id = market_ids[start].clone();
if !seen.insert(market_id.clone()) {
return Err(BlpError::NonContiguousMarket { market_id });
}
let mut end = start + 1;
while end < n && market_ids[end] == market_id {
end += 1;
}
let mut total_share = 0.0f64;
for product_idx in start..end {
product_to_market[product_idx] = markets.len();
total_share += shares[product_idx];
if !shares[product_idx].is_finite() {
return Err(BlpError::NumericalError {
context: "share validation",
});
}
}
let outside_share = 1.0 - total_share;
if outside_share <= 0.0 {
return Err(BlpError::NonPositiveOutsideShare {
market_id: market_id.clone(),
share: outside_share,
});
}
markets.push(MarketSegment {
market_id,
start,
end,
outside_share,
});
start = end;
}
Ok(Self {
markets,
product_to_market,
})
}
pub fn market_count(&self) -> usize {
self.markets.len()
}
pub fn markets(&self) -> impl Iterator<Item = &MarketSegment> {
self.markets.iter()
}
pub fn market_of(&self, product_index: usize) -> usize {
self.product_to_market[product_index]
}
}
#[derive(Clone, Debug)]
pub struct MarketSegment {
market_id: String,
pub(crate) start: usize,
pub(crate) end: usize,
pub outside_share: f64,
}
impl MarketSegment {
pub fn id(&self) -> &str {
&self.market_id
}
pub fn range(&self) -> std::ops::Range<usize> {
self.start..self.end
}
pub fn product_count(&self) -> usize {
self.end - self.start
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_validates_and_constructs_partition() {
let market_ids = vec!["m1".to_string(), "m1".to_string(), "m2".to_string()];
let shares = DVector::from_vec(vec![0.3, 0.2, 0.4]);
let x1 = DMatrix::from_row_slice(3, 2, &[1.0, 10.0, 1.0, 11.0, 1.0, 12.0]);
let x2 = DMatrix::from_row_slice(3, 1, &[10.0, 11.0, 12.0]);
let instruments = x1.clone();
let data = ProductDataBuilder::new(market_ids, shares)
.x1(x1)
.x2(x2)
.instruments(instruments)
.build()
.expect("valid data");
assert_eq!(data.product_count(), 3);
assert_eq!(data.partition.market_count(), 2);
let mut iter = data.partition.markets();
let first = iter.next().unwrap();
assert_eq!(first.id(), "m1");
assert_eq!(first.product_count(), 2);
assert!(iter.next().is_some());
}
#[test]
fn builder_detects_non_contiguous_market() {
let market_ids = vec!["m1".to_string(), "m2".to_string(), "m1".to_string()];
let shares = DVector::from_vec(vec![0.3, 0.3, 0.3]);
let x1 = DMatrix::from_row_slice(3, 1, &[10.0, 11.0, 12.0]);
let result = ProductDataBuilder::new(market_ids, shares).x1(x1).build();
assert!(matches!(result, Err(BlpError::NonContiguousMarket { .. })));
}
}