use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
use rand_distr::{Distribution, LogNormal, Normal};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GaussianComponent {
pub weight: f64,
pub mu: f64,
pub sigma: f64,
#[serde(default)]
pub label: Option<String>,
}
impl GaussianComponent {
pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
Self {
weight,
mu,
sigma,
label: None,
}
}
pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
Self {
weight,
mu,
sigma,
label: Some(label.into()),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GaussianMixtureConfig {
pub components: Vec<GaussianComponent>,
#[serde(default = "default_true")]
pub allow_negative: bool,
#[serde(default)]
pub min_value: Option<f64>,
#[serde(default)]
pub max_value: Option<f64>,
}
fn default_true() -> bool {
true
}
impl Default for GaussianMixtureConfig {
fn default() -> Self {
Self {
components: vec![GaussianComponent::new(1.0, 0.0, 1.0)],
allow_negative: true,
min_value: None,
max_value: None,
}
}
}
impl GaussianMixtureConfig {
pub fn new(components: Vec<GaussianComponent>) -> Self {
Self {
components,
..Default::default()
}
}
pub fn validate(&self) -> Result<(), String> {
if self.components.is_empty() {
return Err("At least one component is required".to_string());
}
let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
if (weight_sum - 1.0).abs() > 0.01 {
return Err(format!(
"Component weights must sum to 1.0, got {weight_sum}"
));
}
for (i, component) in self.components.iter().enumerate() {
if component.weight < 0.0 || component.weight > 1.0 {
return Err(format!(
"Component {} weight must be between 0.0 and 1.0, got {}",
i, component.weight
));
}
if component.sigma <= 0.0 {
return Err(format!(
"Component {} sigma must be positive, got {}",
i, component.sigma
));
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogNormalComponent {
pub weight: f64,
pub mu: f64,
pub sigma: f64,
#[serde(default)]
pub label: Option<String>,
}
impl LogNormalComponent {
pub fn new(weight: f64, mu: f64, sigma: f64) -> Self {
Self {
weight,
mu,
sigma,
label: None,
}
}
pub fn with_label(weight: f64, mu: f64, sigma: f64, label: impl Into<String>) -> Self {
Self {
weight,
mu,
sigma,
label: Some(label.into()),
}
}
pub fn expected_value(&self) -> f64 {
(self.mu + self.sigma.powi(2) / 2.0).exp()
}
pub fn median(&self) -> f64 {
self.mu.exp()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogNormalMixtureConfig {
pub components: Vec<LogNormalComponent>,
#[serde(default = "default_min_value")]
pub min_value: f64,
#[serde(default)]
pub max_value: Option<f64>,
#[serde(default = "default_decimal_places")]
pub decimal_places: u8,
}
fn default_min_value() -> f64 {
0.01
}
fn default_decimal_places() -> u8 {
2
}
impl Default for LogNormalMixtureConfig {
fn default() -> Self {
Self {
components: vec![LogNormalComponent::new(1.0, 7.0, 2.0)],
min_value: 0.01,
max_value: None,
decimal_places: 2,
}
}
}
impl LogNormalMixtureConfig {
pub fn new(components: Vec<LogNormalComponent>) -> Self {
Self {
components,
..Default::default()
}
}
pub fn typical_transactions() -> Self {
Self {
components: vec![
LogNormalComponent::with_label(0.60, 6.0, 1.5, "routine"),
LogNormalComponent::with_label(0.30, 8.5, 1.0, "significant"),
LogNormalComponent::with_label(0.10, 11.0, 0.8, "major"),
],
min_value: 0.01,
max_value: Some(100_000_000.0),
decimal_places: 2,
}
}
pub fn validate(&self) -> Result<(), String> {
if self.components.is_empty() {
return Err("At least one component is required".to_string());
}
let weight_sum: f64 = self.components.iter().map(|c| c.weight).sum();
if (weight_sum - 1.0).abs() > 0.01 {
return Err(format!(
"Component weights must sum to 1.0, got {weight_sum}"
));
}
for (i, component) in self.components.iter().enumerate() {
if component.weight < 0.0 || component.weight > 1.0 {
return Err(format!(
"Component {} weight must be between 0.0 and 1.0, got {}",
i, component.weight
));
}
if component.sigma <= 0.0 {
return Err(format!(
"Component {} sigma must be positive, got {}",
i, component.sigma
));
}
}
if self.min_value < 0.0 {
return Err("min_value must be non-negative".to_string());
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SampleWithComponent {
pub value: f64,
pub component_index: usize,
pub component_label: Option<String>,
}
pub struct GaussianMixtureSampler {
rng: ChaCha8Rng,
config: GaussianMixtureConfig,
cumulative_weights: Vec<f64>,
distributions: Vec<Normal<f64>>,
}
impl GaussianMixtureSampler {
pub fn new(seed: u64, config: GaussianMixtureConfig) -> Result<Self, String> {
config.validate()?;
let mut cumulative_weights = Vec::with_capacity(config.components.len());
let mut cumulative = 0.0;
for component in &config.components {
cumulative += component.weight;
cumulative_weights.push(cumulative);
}
let distributions: Result<Vec<_>, _> = config
.components
.iter()
.map(|c| {
Normal::new(c.mu, c.sigma).map_err(|e| format!("Invalid normal distribution: {e}"))
})
.collect();
Ok(Self {
rng: ChaCha8Rng::seed_from_u64(seed),
config,
cumulative_weights,
distributions: distributions?,
})
}
fn select_component(&mut self) -> usize {
let p: f64 = self.rng.random();
match self.cumulative_weights.binary_search_by(|w| {
w.partial_cmp(&p).unwrap_or_else(|| {
tracing::debug!("NaN detected in mixture weight comparison");
std::cmp::Ordering::Less
})
}) {
Ok(i) => i,
Err(i) => i.min(self.distributions.len() - 1),
}
}
pub fn sample(&mut self) -> f64 {
let component_idx = self.select_component();
let mut value = self.distributions[component_idx].sample(&mut self.rng);
if !self.config.allow_negative {
value = value.abs();
}
if let Some(min) = self.config.min_value {
value = value.max(min);
}
if let Some(max) = self.config.max_value {
value = value.min(max);
}
value
}
pub fn sample_with_component(&mut self) -> SampleWithComponent {
let component_idx = self.select_component();
let mut value = self.distributions[component_idx].sample(&mut self.rng);
if !self.config.allow_negative {
value = value.abs();
}
if let Some(min) = self.config.min_value {
value = value.max(min);
}
if let Some(max) = self.config.max_value {
value = value.min(max);
}
SampleWithComponent {
value,
component_index: component_idx,
component_label: self.config.components[component_idx].label.clone(),
}
}
pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
(0..n).map(|_| self.sample()).collect()
}
pub fn reset(&mut self, seed: u64) {
self.rng = ChaCha8Rng::seed_from_u64(seed);
}
pub fn config(&self) -> &GaussianMixtureConfig {
&self.config
}
}
pub struct LogNormalMixtureSampler {
rng: ChaCha8Rng,
config: LogNormalMixtureConfig,
cumulative_weights: Vec<f64>,
distributions: Vec<LogNormal<f64>>,
decimal_multiplier: f64,
}
impl LogNormalMixtureSampler {
pub fn new(seed: u64, config: LogNormalMixtureConfig) -> Result<Self, String> {
config.validate()?;
let mut cumulative_weights = Vec::with_capacity(config.components.len());
let mut cumulative = 0.0;
for component in &config.components {
cumulative += component.weight;
cumulative_weights.push(cumulative);
}
let distributions: Result<Vec<_>, _> = config
.components
.iter()
.map(|c| {
LogNormal::new(c.mu, c.sigma)
.map_err(|e| format!("Invalid log-normal distribution: {e}"))
})
.collect();
let decimal_multiplier = 10_f64.powi(config.decimal_places as i32);
Ok(Self {
rng: ChaCha8Rng::seed_from_u64(seed),
config,
cumulative_weights,
distributions: distributions?,
decimal_multiplier,
})
}
fn select_component(&mut self) -> usize {
let p: f64 = self.rng.random();
match self.cumulative_weights.binary_search_by(|w| {
w.partial_cmp(&p).unwrap_or_else(|| {
tracing::debug!("NaN detected in mixture weight comparison");
std::cmp::Ordering::Less
})
}) {
Ok(i) => i,
Err(i) => i.min(self.distributions.len() - 1),
}
}
pub fn sample(&mut self) -> f64 {
let component_idx = self.select_component();
let mut value = self.distributions[component_idx].sample(&mut self.rng);
value = value.max(self.config.min_value);
if let Some(max) = self.config.max_value {
value = value.min(max);
}
(value * self.decimal_multiplier).round() / self.decimal_multiplier
}
pub fn sample_decimal(&mut self) -> Decimal {
let value = self.sample();
Decimal::from_f64_retain(value).unwrap_or(Decimal::ONE)
}
pub fn sample_with_component(&mut self) -> SampleWithComponent {
let component_idx = self.select_component();
let mut value = self.distributions[component_idx].sample(&mut self.rng);
value = value.max(self.config.min_value);
if let Some(max) = self.config.max_value {
value = value.min(max);
}
value = (value * self.decimal_multiplier).round() / self.decimal_multiplier;
SampleWithComponent {
value,
component_index: component_idx,
component_label: self.config.components[component_idx].label.clone(),
}
}
pub fn sample_n(&mut self, n: usize) -> Vec<f64> {
(0..n).map(|_| self.sample()).collect()
}
pub fn sample_n_decimal(&mut self, n: usize) -> Vec<Decimal> {
(0..n).map(|_| self.sample_decimal()).collect()
}
pub fn reset(&mut self, seed: u64) {
self.rng = ChaCha8Rng::seed_from_u64(seed);
}
pub fn config(&self) -> &LogNormalMixtureConfig {
&self.config
}
pub fn expected_value(&self) -> f64 {
self.config
.components
.iter()
.map(|c| c.weight * c.expected_value())
.sum()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_gaussian_mixture_validation() {
let config = GaussianMixtureConfig::new(vec![
GaussianComponent::new(0.5, 0.0, 1.0),
GaussianComponent::new(0.5, 5.0, 2.0),
]);
assert!(config.validate().is_ok());
let invalid_config = GaussianMixtureConfig::new(vec![
GaussianComponent::new(0.3, 0.0, 1.0),
GaussianComponent::new(0.3, 5.0, 2.0),
]);
assert!(invalid_config.validate().is_err());
let invalid_config =
GaussianMixtureConfig::new(vec![GaussianComponent::new(1.0, 0.0, -1.0)]);
assert!(invalid_config.validate().is_err());
}
#[test]
fn test_gaussian_mixture_sampling() {
let config = GaussianMixtureConfig::new(vec![
GaussianComponent::new(0.5, 0.0, 1.0),
GaussianComponent::new(0.5, 10.0, 1.0),
]);
let mut sampler = GaussianMixtureSampler::new(42, config).unwrap();
let samples = sampler.sample_n(1000);
assert_eq!(samples.len(), 1000);
let low_count = samples.iter().filter(|&&x| x < 5.0).count();
let high_count = samples.iter().filter(|&&x| x >= 5.0).count();
assert!(low_count > 350 && low_count < 650);
assert!(high_count > 350 && high_count < 650);
}
#[test]
fn test_gaussian_mixture_determinism() {
let config = GaussianMixtureConfig::new(vec![
GaussianComponent::new(0.5, 0.0, 1.0),
GaussianComponent::new(0.5, 10.0, 1.0),
]);
let mut sampler1 = GaussianMixtureSampler::new(42, config.clone()).unwrap();
let mut sampler2 = GaussianMixtureSampler::new(42, config).unwrap();
for _ in 0..100 {
assert_eq!(sampler1.sample(), sampler2.sample());
}
}
#[test]
fn test_lognormal_mixture_validation() {
let config = LogNormalMixtureConfig::new(vec![
LogNormalComponent::new(0.6, 6.0, 1.5),
LogNormalComponent::new(0.4, 8.5, 1.0),
]);
assert!(config.validate().is_ok());
let invalid_config = LogNormalMixtureConfig::new(vec![
LogNormalComponent::new(0.2, 6.0, 1.5),
LogNormalComponent::new(0.2, 8.5, 1.0),
]);
assert!(invalid_config.validate().is_err());
}
#[test]
fn test_lognormal_mixture_sampling() {
let config = LogNormalMixtureConfig::typical_transactions();
let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
let samples = sampler.sample_n(1000);
assert_eq!(samples.len(), 1000);
assert!(samples.iter().all(|&x| x > 0.0));
assert!(samples.iter().all(|&x| x >= 0.01));
}
#[test]
fn test_sample_with_component() {
let config = LogNormalMixtureConfig::new(vec![
LogNormalComponent::with_label(0.5, 6.0, 1.0, "small"),
LogNormalComponent::with_label(0.5, 10.0, 0.5, "large"),
]);
let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
let mut small_count = 0;
let mut large_count = 0;
for _ in 0..1000 {
let result = sampler.sample_with_component();
match result.component_label.as_deref() {
Some("small") => small_count += 1,
Some("large") => large_count += 1,
_ => panic!("Unexpected label"),
}
}
assert!(small_count > 400 && small_count < 600);
assert!(large_count > 400 && large_count < 600);
}
#[test]
fn test_lognormal_mixture_determinism() {
let config = LogNormalMixtureConfig::typical_transactions();
let mut sampler1 = LogNormalMixtureSampler::new(42, config.clone()).unwrap();
let mut sampler2 = LogNormalMixtureSampler::new(42, config).unwrap();
for _ in 0..100 {
assert_eq!(sampler1.sample(), sampler2.sample());
}
}
#[test]
fn test_lognormal_expected_value() {
let config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 7.0, 1.0)]);
let sampler = LogNormalMixtureSampler::new(42, config).unwrap();
let expected = sampler.expected_value();
assert!((expected - 1808.04).abs() < 1.0);
}
#[test]
fn test_component_label() {
let component = LogNormalComponent::with_label(0.5, 7.0, 1.0, "test_label");
assert_eq!(component.label, Some("test_label".to_string()));
let component_no_label = LogNormalComponent::new(0.5, 7.0, 1.0);
assert_eq!(component_no_label.label, None);
}
#[test]
fn test_max_value_constraint() {
let mut config = LogNormalMixtureConfig::new(vec![LogNormalComponent::new(1.0, 10.0, 1.0)]);
config.max_value = Some(1000.0);
let mut sampler = LogNormalMixtureSampler::new(42, config).unwrap();
let samples = sampler.sample_n(1000);
assert!(samples.iter().all(|&x| x <= 1000.0));
}
}