use rand::Rng;
use rand_distr::{Distribution, Normal};
use rust_decimal::Decimal;
use rust_decimal::prelude::{FromPrimitive, ToPrimitive};
use crate::config::{GeneratorConfig, TrendDirection};
pub struct RandomWalkGenerator {
current_price: Decimal,
config: GeneratorConfig,
price_distribution: Normal<f64>,
volume_distribution: Normal<f64>,
}
impl RandomWalkGenerator {
pub fn new(config: GeneratorConfig) -> Result<Self, String> {
let volatility_f64 = config.volatility.to_f64()
.ok_or("Failed to convert volatility to f64")?;
let price_distribution = Normal::new(0.0, volatility_f64)
.map_err(|e| format!("Failed to create price distribution: {e}"))?;
let volume_distribution = Normal::new(
config.base_volume as f64,
config.base_volume as f64 * config.volume_volatility
).map_err(|e| format!("Failed to create volume distribution: {e}"))?;
Ok(Self {
current_price: config.starting_price,
config,
price_distribution,
volume_distribution,
})
}
pub fn next_price<R: Rng>(&mut self, rng: &mut R) -> Decimal {
let drift = match self.config.trend_direction {
TrendDirection::Bullish => self.config.trend_strength,
TrendDirection::Bearish => -self.config.trend_strength,
TrendDirection::Sideways => Decimal::ZERO,
};
let random_change_f64 = self.price_distribution.sample(rng);
let random_change = Decimal::from_f64(random_change_f64)
.unwrap_or(Decimal::ZERO);
let price_change = self.current_price * (drift + random_change);
self.current_price += price_change;
self.current_price = self.current_price
.max(self.config.min_price)
.min(self.config.max_price);
self.current_price
}
pub fn generate_ohlc<R: Rng>(&mut self, rng: &mut R, num_ticks: usize) -> (Decimal, Decimal, Decimal, Decimal) {
if num_ticks == 0 {
let price = self.current_price;
return (price, price, price, price);
}
let open = self.current_price;
let mut high = open;
let mut low = open;
for _ in 0..num_ticks {
let price = self.next_price(rng);
high = high.max(price);
low = low.min(price);
}
let close = self.current_price;
(open, high, low, close)
}
pub fn generate_volume<R: Rng>(&mut self, rng: &mut R) -> u64 {
let volume = self.volume_distribution.sample(rng);
volume.max(0.0) as u64
}
pub fn reset(&mut self) {
self.current_price = self.config.starting_price;
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
use rand::rngs::StdRng;
#[test]
fn test_random_walk_creation() {
let config = GeneratorConfig::default();
let generator = RandomWalkGenerator::new(config);
assert!(generator.is_ok());
}
#[test]
fn test_price_generation() {
let config = GeneratorConfig {
seed: Some(42),
volatility: Decimal::from_f64(0.01).unwrap(),
..Default::default()
};
let mut generator = RandomWalkGenerator::new(config).unwrap();
let mut rng = StdRng::seed_from_u64(42);
let price1 = generator.next_price(&mut rng);
let price2 = generator.next_price(&mut rng);
assert!(price1 > Decimal::ZERO);
assert!(price2 > Decimal::ZERO);
assert_ne!(price1, price2); }
#[test]
fn test_bullish_trend() {
let config = GeneratorConfig {
seed: Some(42),
trend_direction: TrendDirection::Bullish,
trend_strength: Decimal::from_f64(0.01).unwrap(),
volatility: Decimal::from_f64(0.001).unwrap(), starting_price: Decimal::from_f64(100.0).unwrap(),
..Default::default()
};
let mut generator = RandomWalkGenerator::new(config).unwrap();
let mut rng = StdRng::seed_from_u64(42);
let start_price = generator.config.starting_price;
let mut last_price = start_price;
for _ in 0..100 {
last_price = generator.next_price(&mut rng);
}
assert!(last_price > start_price);
}
#[test]
fn test_price_boundaries() {
let config = GeneratorConfig {
min_price: Decimal::from_f64(50.0).unwrap(),
max_price: Decimal::from_f64(150.0).unwrap(),
starting_price: Decimal::from_f64(100.0).unwrap(),
volatility: Decimal::from_f64(0.5).unwrap(), ..Default::default()
};
let mut generator = RandomWalkGenerator::new(config).unwrap();
let mut rng = StdRng::seed_from_u64(42);
let min_price = Decimal::from_f64(50.0).unwrap();
let max_price = Decimal::from_f64(150.0).unwrap();
for _ in 0..1000 {
let price = generator.next_price(&mut rng);
assert!(price >= min_price);
assert!(price <= max_price);
}
}
#[test]
fn test_ohlc_generation() {
let config = GeneratorConfig::default();
let mut generator = RandomWalkGenerator::new(config).unwrap();
let mut rng = StdRng::seed_from_u64(42);
let (open, high, low, close) = generator.generate_ohlc(&mut rng, 10);
assert!(high >= open);
assert!(high >= close);
assert!(low <= open);
assert!(low <= close);
assert!(high >= low);
}
#[test]
fn test_volume_generation() {
let config = GeneratorConfig {
base_volume: 100000,
volume_volatility: 0.2,
..Default::default()
};
let mut generator = RandomWalkGenerator::new(config).unwrap();
let mut rng = StdRng::seed_from_u64(42);
let volume = generator.generate_volume(&mut rng);
assert!(volume > 0);
}
}