1use crate::types::TimeInterval;
4use rust_decimal::Decimal;
5use rust_decimal::prelude::FromPrimitive;
6use std::fmt;
7
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize, Serializer, Deserializer};
10
11#[cfg(feature = "api-server")]
12use utoipa::ToSchema;
13
14#[cfg(feature = "serde")]
15fn serialize_decimal_inf<S>(value: &Decimal, serializer: S) -> Result<S::Ok, S::Error>
16where
17 S: Serializer,
18{
19 let max_decimal = Decimal::from_f64(1e15).expect("1e15 should always convert to Decimal");
21 if *value >= max_decimal {
22 serializer.serialize_none()
23 } else {
24 Serialize::serialize(value, serializer)
25 }
26}
27
28#[cfg(feature = "serde")]
29fn deserialize_decimal_inf<'de, D>(deserializer: D) -> Result<Decimal, D::Error>
30where
31 D: Deserializer<'de>,
32{
33 let opt: Option<Decimal> = Option::deserialize(deserializer)?;
34 Ok(opt.unwrap_or_else(|| Decimal::from_f64(1e15).expect("1e15 should always convert to Decimal")))
35}
36
37#[derive(Debug, Clone, Copy, PartialEq)]
39#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
40#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
41#[cfg_attr(feature = "api-server", derive(ToSchema))]
42pub enum TrendDirection {
43 #[cfg_attr(feature = "serde", serde(alias = "up"))]
44 Bullish,
46 #[cfg_attr(feature = "serde", serde(alias = "down"))]
47 Bearish,
49 #[cfg_attr(feature = "serde", serde(alias = "flat"))]
50 Sideways,
52}
53
54#[derive(Debug, Clone)]
56#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
57#[cfg_attr(feature = "serde", serde(default))]
58#[cfg_attr(feature = "api-server", derive(ToSchema))]
59pub struct GeneratorConfig {
60 #[cfg_attr(feature = "serde", serde(default = "default_starting_price"))]
62 pub starting_price: Decimal,
63 #[cfg_attr(feature = "serde", serde(default = "default_min_price"))]
65 pub min_price: Decimal,
66 #[cfg_attr(feature = "serde", serde(default = "default_max_price"))]
68 #[cfg_attr(feature = "serde", serde(serialize_with = "serialize_decimal_inf"))]
69 #[cfg_attr(feature = "serde", serde(deserialize_with = "deserialize_decimal_inf"))]
70 pub max_price: Decimal,
71 #[cfg_attr(feature = "serde", serde(default = "default_trend_direction"))]
73 pub trend_direction: TrendDirection,
74 #[cfg_attr(feature = "serde", serde(default = "default_trend_strength"))]
76 pub trend_strength: Decimal,
77 #[cfg_attr(feature = "serde", serde(default = "default_volatility"))]
79 pub volatility: Decimal,
80 #[cfg_attr(feature = "serde", serde(default = "default_time_interval"))]
82 pub time_interval: TimeInterval,
83 #[cfg_attr(feature = "serde", serde(default = "default_num_points"))]
85 pub num_points: usize,
86 #[cfg_attr(feature = "serde", serde(default))]
88 pub seed: Option<u64>,
89 #[cfg_attr(feature = "serde", serde(default = "default_base_volume"))]
91 pub base_volume: u64,
92 #[cfg_attr(feature = "serde", serde(default = "default_volume_volatility"))]
94 pub volume_volatility: f64,
95}
96
97fn default_starting_price() -> Decimal {
99 Decimal::from_f64(100.0).expect("100.0 should always convert to Decimal")
100}
101
102fn default_min_price() -> Decimal {
103 Decimal::from_f64(1.0).expect("1.0 should always convert to Decimal")
104}
105
106fn default_max_price() -> Decimal {
107 Decimal::from_f64(1e15).expect("1e15 should always convert to Decimal")
108}
109
110fn default_trend_direction() -> TrendDirection {
111 TrendDirection::Sideways
112}
113
114fn default_trend_strength() -> Decimal {
115 Decimal::ZERO
116}
117
118fn default_volatility() -> Decimal {
119 Decimal::from_f64(0.02).expect("0.02 should always convert to Decimal")
120}
121
122fn default_time_interval() -> TimeInterval {
123 TimeInterval::OneMinute
124}
125
126fn default_num_points() -> usize {
127 100
128}
129
130fn default_base_volume() -> u64 {
131 100000
132}
133
134fn default_volume_volatility() -> f64 {
135 0.3
136}
137
138impl Default for GeneratorConfig {
139 fn default() -> Self {
140 Self {
141 starting_price: default_starting_price(),
142 min_price: default_min_price(),
143 max_price: default_max_price(),
144 trend_direction: default_trend_direction(),
145 trend_strength: default_trend_strength(),
146 volatility: default_volatility(),
147 time_interval: default_time_interval(),
148 num_points: default_num_points(),
149 seed: None,
150 base_volume: default_base_volume(),
151 volume_volatility: default_volume_volatility(),
152 }
153 }
154}
155
156impl GeneratorConfig {
157 pub fn new() -> Self {
159 Self::default()
160 }
161
162 pub fn builder() -> ConfigBuilder {
164 ConfigBuilder::new()
165 }
166
167 pub fn apply_smart_defaults(&mut self) {
170 if self.min_price == default_min_price() && self.starting_price > Decimal::from(1000) {
172 self.min_price = self.starting_price * Decimal::from_f64(0.01).expect("0.01 should always convert to Decimal"); }
174
175 if self.max_price == default_max_price() && self.starting_price != default_starting_price() {
177 self.max_price = self.starting_price * Decimal::from(100); }
179
180 if self.min_price >= self.starting_price {
182 self.min_price = self.starting_price * Decimal::from_f64(0.5).expect("0.5 should always convert to Decimal");
183 }
184 if self.max_price <= self.starting_price {
185 self.max_price = self.starting_price * Decimal::from(2);
186 }
187
188 if self.volatility == default_volatility() {
190 if self.starting_price > Decimal::from(10000) {
191 self.volatility = Decimal::from_f64(0.05).expect("0.05 should always convert to Decimal"); } else if self.starting_price < Decimal::from(10) {
194 self.volatility = Decimal::from_f64(0.005).expect("0.005 should always convert to Decimal"); }
197 }
198
199 if self.trend_strength == Decimal::ZERO {
201 match self.trend_direction {
202 TrendDirection::Bullish => self.trend_strength = Decimal::from_f64(0.0001).expect("0.0001 should always convert to Decimal"),
203 TrendDirection::Bearish => self.trend_strength = Decimal::from_f64(-0.0001).expect("-0.0001 should always convert to Decimal"),
204 TrendDirection::Sideways => {}
205 }
206 }
207 }
208
209 pub fn validate(&self) -> Result<(), ConfigError> {
211 if self.starting_price <= Decimal::ZERO {
212 return Err(ConfigError::InvalidPrice("Starting price must be positive".into()));
213 }
214 if self.min_price <= Decimal::ZERO {
215 return Err(ConfigError::InvalidPrice("Minimum price must be positive".into()));
216 }
217 if self.min_price >= self.max_price {
218 return Err(ConfigError::InvalidPrice("Minimum price must be less than maximum price".into()));
219 }
220 if self.volatility < Decimal::ZERO {
221 return Err(ConfigError::InvalidVolatility("Volatility must be non-negative".into()));
222 }
223 let one = Decimal::ONE;
224 if self.trend_strength < -one || self.trend_strength > one {
225 return Err(ConfigError::InvalidTrend("Trend strength must be between -100% and +100%".into()));
226 }
227 if self.num_points == 0 {
228 return Err(ConfigError::InvalidParameter("Number of points must be positive".into()));
229 }
230 if self.base_volume == 0 {
231 return Err(ConfigError::InvalidParameter("Base volume must be positive".into()));
232 }
233 if self.volume_volatility < 0.0 {
234 return Err(ConfigError::InvalidVolatility("Volume volatility must be non-negative".into()));
235 }
236 Ok(())
237 }
238}
239
240pub struct ConfigBuilder {
242 config: GeneratorConfig,
243}
244
245impl Default for ConfigBuilder {
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251impl ConfigBuilder {
252 pub fn new() -> Self {
254 Self {
255 config: GeneratorConfig::default(),
256 }
257 }
258
259 pub fn starting_price(mut self, price: Decimal) -> Self {
261 self.config.starting_price = price;
262 self
263 }
264
265 pub fn starting_price_f64(mut self, price: f64) -> Self {
267 self.config.starting_price = Decimal::from_f64(price)
268 .unwrap_or_else(|| Decimal::from_f64(100.0).expect("100.0 should always convert to Decimal"));
269 self
270 }
271
272 pub fn price_range(mut self, min: Decimal, max: Decimal) -> Self {
274 self.config.min_price = min;
275 self.config.max_price = max;
276 self
277 }
278
279 pub fn price_range_f64(mut self, min: f64, max: f64) -> Self {
281 self.config.min_price = Decimal::from_f64(min)
282 .unwrap_or_else(|| Decimal::from_f64(1.0).expect("1.0 should always convert to Decimal"));
283 self.config.max_price = Decimal::from_f64(max)
284 .unwrap_or_else(|| Decimal::from_f64(1e15).expect("1e15 should always convert to Decimal"));
285 self
286 }
287
288 pub fn trend(mut self, direction: TrendDirection, strength: Decimal) -> Self {
290 self.config.trend_direction = direction;
291 self.config.trend_strength = strength;
292 self
293 }
294
295 pub fn trend_f64(mut self, direction: TrendDirection, strength: f64) -> Self {
297 self.config.trend_direction = direction;
298 self.config.trend_strength = Decimal::from_f64(strength).unwrap_or(Decimal::ZERO);
299 self
300 }
301
302 pub fn volatility(mut self, volatility: Decimal) -> Self {
304 self.config.volatility = volatility;
305 self
306 }
307
308 pub fn volatility_f64(mut self, volatility: f64) -> Self {
310 self.config.volatility = Decimal::from_f64(volatility)
311 .unwrap_or_else(|| Decimal::from_f64(0.02).expect("0.02 should always convert to Decimal"));
312 self
313 }
314
315 pub fn time_interval(mut self, interval: TimeInterval) -> Self {
317 self.config.time_interval = interval;
318 self
319 }
320
321 pub fn num_points(mut self, num: usize) -> Self {
323 self.config.num_points = num;
324 self
325 }
326
327 pub fn seed(mut self, seed: u64) -> Self {
329 self.config.seed = Some(seed);
330 self
331 }
332
333 pub fn base_volume(mut self, volume: u64) -> Self {
335 self.config.base_volume = volume;
336 self
337 }
338
339 pub fn volume_volatility(mut self, volatility: f64) -> Self {
341 self.config.volume_volatility = volatility;
342 self
343 }
344
345 pub fn build(self) -> Result<GeneratorConfig, ConfigError> {
347 self.config.validate()?;
348 Ok(self.config)
349 }
350}
351
352#[derive(Debug, Clone)]
354pub enum ConfigError {
355 InvalidPrice(String),
356 InvalidVolatility(String),
357 InvalidTrend(String),
358 InvalidParameter(String),
359}
360
361impl fmt::Display for ConfigError {
362 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363 match self {
364 ConfigError::InvalidPrice(msg) => write!(f, "Invalid price configuration: {msg}"),
365 ConfigError::InvalidVolatility(msg) => write!(f, "Invalid volatility configuration: {msg}"),
366 ConfigError::InvalidTrend(msg) => write!(f, "Invalid trend configuration: {msg}"),
367 ConfigError::InvalidParameter(msg) => write!(f, "Invalid parameter: {msg}"),
368 }
369 }
370}
371
372impl std::error::Error for ConfigError {}
373
374impl GeneratorConfig {
376 pub fn volatile() -> Self {
378 Self {
379 volatility: Decimal::from_f64(0.05).expect("0.05 should always convert to Decimal"), volume_volatility: 0.5, ..Self::default()
382 }
383 }
384
385 pub fn stable() -> Self {
387 Self {
388 volatility: Decimal::from_f64(0.005).expect("0.005 should always convert to Decimal"), volume_volatility: 0.1, ..Self::default()
391 }
392 }
393
394 pub fn bull_market() -> Self {
396 Self {
397 trend_direction: TrendDirection::Bullish,
398 trend_strength: Decimal::from_f64(0.002).expect("0.002 should always convert to Decimal"), volatility: Decimal::from_f64(0.02).expect("0.02 should always convert to Decimal"),
400 ..Self::default()
401 }
402 }
403
404 pub fn bear_market() -> Self {
406 Self {
407 trend_direction: TrendDirection::Bearish,
408 trend_strength: Decimal::from_f64(0.002).expect("0.002 should always convert to Decimal"), volatility: Decimal::from_f64(0.03).expect("0.03 should always convert to Decimal"), ..Self::default()
411 }
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[test]
420 fn test_default_config() {
421 let config = GeneratorConfig::default();
422 assert_eq!(config.starting_price, Decimal::from_f64(100.0).unwrap());
423 assert_eq!(config.min_price, Decimal::from_f64(1.0).unwrap());
424 assert_eq!(config.trend_direction, TrendDirection::Sideways);
425 assert!(config.validate().is_ok());
426 }
427
428 #[test]
429 fn test_config_builder() {
430 let config = GeneratorConfig::builder()
431 .starting_price_f64(50.0)
432 .price_range_f64(10.0, 200.0)
433 .trend_f64(TrendDirection::Bullish, 0.01)
434 .volatility_f64(0.03)
435 .num_points(500)
436 .seed(42)
437 .build()
438 .unwrap();
439
440 assert_eq!(config.starting_price, Decimal::from_f64(50.0).unwrap());
441 assert_eq!(config.min_price, Decimal::from_f64(10.0).unwrap());
442 assert_eq!(config.max_price, Decimal::from_f64(200.0).unwrap());
443 assert_eq!(config.trend_direction, TrendDirection::Bullish);
444 assert_eq!(config.trend_strength, Decimal::from_f64(0.01).unwrap());
445 assert_eq!(config.volatility, Decimal::from_f64(0.03).unwrap());
446 assert_eq!(config.num_points, 500);
447 assert_eq!(config.seed, Some(42));
448 }
449
450 #[test]
451 fn test_config_validation() {
452 let config = GeneratorConfig {
454 starting_price: Decimal::from_f64(-10.0).unwrap(),
455 ..Default::default()
456 };
457 assert!(config.validate().is_err());
458
459 let config = GeneratorConfig {
461 min_price: Decimal::from_f64(100.0).unwrap(),
462 max_price: Decimal::from_f64(50.0).unwrap(),
463 ..Default::default()
464 };
465 assert!(config.validate().is_err());
466
467 let config = GeneratorConfig {
469 volatility: Decimal::from_f64(-0.1).unwrap(),
470 ..Default::default()
471 };
472 assert!(config.validate().is_err());
473
474 let config = GeneratorConfig {
476 trend_strength: Decimal::from_f64(1.5).unwrap(),
477 ..Default::default()
478 };
479 assert!(config.validate().is_err());
480
481 let config = GeneratorConfig {
483 num_points: 0,
484 ..Default::default()
485 };
486 assert!(config.validate().is_err());
487 }
488
489 #[test]
490 fn test_preset_configs() {
491 let volatile = GeneratorConfig::volatile();
492 assert_eq!(volatile.volatility, Decimal::from_f64(0.05).unwrap());
493 assert!(volatile.validate().is_ok());
494
495 let stable = GeneratorConfig::stable();
496 assert_eq!(stable.volatility, Decimal::from_f64(0.005).unwrap());
497 assert!(stable.validate().is_ok());
498
499 let bull = GeneratorConfig::bull_market();
500 assert_eq!(bull.trend_direction, TrendDirection::Bullish);
501 assert!(bull.validate().is_ok());
502
503 let bear = GeneratorConfig::bear_market();
504 assert_eq!(bear.trend_direction, TrendDirection::Bearish);
505 assert!(bear.validate().is_ok());
506 }
507
508 #[cfg(feature = "serde")]
509 mod serde_tests {
510 use super::*;
511 use serde_json;
512
513 #[test]
514 fn test_trend_direction_serialization() {
515 let trend = TrendDirection::Bullish;
516 let json = serde_json::to_string(&trend).unwrap();
517 assert_eq!(json, r#""bullish""#);
518
519 let deserialized: TrendDirection = serde_json::from_str(&json).unwrap();
520 assert_eq!(trend, deserialized);
521 }
522
523 #[test]
524 fn test_generator_config_serialization() {
525 let config = GeneratorConfig::builder()
526 .starting_price_f64(50.0)
527 .price_range_f64(10.0, 200.0)
528 .trend_f64(TrendDirection::Bullish, 0.01)
529 .volatility_f64(0.03)
530 .num_points(500)
531 .seed(42)
532 .base_volume(100000)
533 .volume_volatility(0.3)
534 .time_interval(TimeInterval::FiveMinutes)
535 .build()
536 .unwrap();
537
538 let json = serde_json::to_string(&config).unwrap();
540
541 let deserialized: GeneratorConfig = serde_json::from_str(&json).unwrap();
543
544 assert_eq!(config.starting_price, deserialized.starting_price);
546 assert_eq!(config.min_price, deserialized.min_price);
547 assert_eq!(config.max_price, deserialized.max_price);
548 assert_eq!(config.trend_direction, deserialized.trend_direction);
549 assert_eq!(config.trend_strength, deserialized.trend_strength);
550 assert_eq!(config.volatility, deserialized.volatility);
551 assert_eq!(config.num_points, deserialized.num_points);
552 assert_eq!(config.seed, deserialized.seed);
553 assert_eq!(config.base_volume, deserialized.base_volume);
554 assert_eq!(config.time_interval, deserialized.time_interval);
555 }
556
557 #[test]
558 fn test_config_json_format() {
559 let config = GeneratorConfig::default();
560 let json = serde_json::to_string_pretty(&config).unwrap();
561
562 let _: GeneratorConfig = serde_json::from_str(&json).unwrap();
564
565 assert!(json.contains("starting_price"));
567 assert!(json.contains("trend_direction"));
568 assert!(json.contains("volatility"));
569 }
570 }
571}