1use crate::v2::adx::types::{ADXConfig, ADXError, ADXInput, ADXOutput, ADXPeriodData, ADXState, TrendDirection, TrendStrength};
2
3pub struct ADX {
17 state: ADXState,
18}
19
20impl ADX {
21 pub fn new() -> Self {
23 Self::with_config(ADXConfig::default())
24 }
25
26 pub fn with_period(period: usize) -> Result<Self, ADXError> {
28 if period == 0 {
29 return Err(ADXError::InvalidPeriod);
30 }
31
32 let config = ADXConfig {
33 period,
34 adx_smoothing: period,
35 ..Default::default()
36 };
37 Ok(Self::with_config(config))
38 }
39
40 pub fn with_periods(period: usize, adx_smoothing: usize) -> Result<Self, ADXError> {
42 if period == 0 || adx_smoothing == 0 {
43 return Err(ADXError::InvalidPeriod);
44 }
45
46 let config = ADXConfig {
47 period,
48 adx_smoothing,
49 ..Default::default()
50 };
51 Ok(Self::with_config(config))
52 }
53
54 pub fn with_config(config: ADXConfig) -> Self {
56 Self { state: ADXState::new(config) }
57 }
58
59 pub fn calculate(&mut self, input: ADXInput) -> Result<ADXOutput, ADXError> {
61 self.validate_input(&input)?;
63 self.validate_config()?;
64
65 if self.state.is_first {
66 self.handle_first_calculation(input)
67 } else {
68 self.handle_normal_calculation(input)
69 }
70 }
71
72 pub fn calculate_batch(&mut self, inputs: &[ADXInput]) -> Result<Vec<ADXOutput>, ADXError> {
74 inputs.iter().map(|input| self.calculate(*input)).collect()
75 }
76
77 pub fn reset(&mut self) {
79 self.state = ADXState::new(self.state.config);
80 }
81
82 pub fn get_state(&self) -> &ADXState {
84 &self.state
85 }
86
87 pub fn set_state(&mut self, state: ADXState) {
89 self.state = state;
90 }
91
92 pub fn trend_strength(&self) -> TrendStrength {
94 if let Some(adx) = self.state.current_adx {
95 self.classify_trend_strength(adx)
96 } else {
97 TrendStrength::Insufficient
98 }
99 }
100
101 pub fn trend_direction(&self) -> Option<TrendDirection> {
103 self.state
104 .period_data
105 .back()
106 .map(|last_data| self.determine_trend_direction(last_data.plus_di, last_data.minus_di))
107 }
108
109 fn validate_input(&self, input: &ADXInput) -> Result<(), ADXError> {
112 if !input.high.is_finite() || !input.low.is_finite() || !input.close.is_finite() {
114 return Err(ADXError::InvalidPrice);
115 }
116
117 if input.high < input.low {
119 return Err(ADXError::InvalidHLC);
120 }
121
122 if input.close < input.low || input.close > input.high {
123 return Err(ADXError::InvalidHLC);
124 }
125
126 Ok(())
127 }
128
129 fn validate_config(&self) -> Result<(), ADXError> {
130 if self.state.config.period == 0 || self.state.config.adx_smoothing == 0 {
131 return Err(ADXError::InvalidPeriod);
132 }
133
134 if self.state.config.strong_trend_threshold >= self.state.config.very_strong_trend_threshold {
135 return Err(ADXError::InvalidThresholds);
136 }
137
138 Ok(())
139 }
140
141 fn handle_first_calculation(&mut self, input: ADXInput) -> Result<ADXOutput, ADXError> {
142 self.state.previous_high = Some(input.high);
144 self.state.previous_low = Some(input.low);
145 self.state.previous_close = Some(input.close);
146 self.state.is_first = false;
147
148 Ok(ADXOutput {
150 adx: 0.0,
151 plus_di: 0.0,
152 minus_di: 0.0,
153 dx: 0.0,
154 true_range: 0.0,
155 trend_strength: TrendStrength::Insufficient,
156 trend_direction: TrendDirection::Sideways,
157 di_spread: 0.0,
158 })
159 }
160
161 fn handle_normal_calculation(&mut self, input: ADXInput) -> Result<ADXOutput, ADXError> {
162 let true_range = self.calculate_true_range(&input);
164
165 let (plus_dm, minus_dm) = self.calculate_directional_movements(&input);
167
168 if self.state.period_data.len() < self.state.config.period {
170 self.accumulate_initial_data(true_range, plus_dm, minus_dm);
172 } else {
173 self.update_smoothed_values(true_range, plus_dm, minus_dm);
175 }
176
177 let (plus_di, minus_di) = self.calculate_directional_indicators();
179
180 let dx = self.calculate_dx(plus_di, minus_di)?;
182
183 let adx = self.calculate_adx(dx);
185
186 let period_data = ADXPeriodData {
188 true_range,
189 plus_dm,
190 minus_dm,
191 plus_di,
192 minus_di,
193 dx,
194 };
195
196 if self.state.period_data.len() >= self.state.config.period {
198 self.state.period_data.pop_front();
199 }
200 self.state.period_data.push_back(period_data);
201
202 self.state.previous_high = Some(input.high);
204 self.state.previous_low = Some(input.low);
205 self.state.previous_close = Some(input.close);
206
207 let trend_strength = self.classify_trend_strength(adx);
209 let trend_direction = self.determine_trend_direction(plus_di, minus_di);
210 let di_spread = plus_di - minus_di;
211
212 Ok(ADXOutput {
213 adx,
214 plus_di,
215 minus_di,
216 dx,
217 true_range,
218 trend_strength,
219 trend_direction,
220 di_spread,
221 })
222 }
223
224 fn calculate_true_range(&self, input: &ADXInput) -> f64 {
225 if let Some(prev_close) = self.state.previous_close {
226 let hl = input.high - input.low;
227 let hc = (input.high - prev_close).abs();
228 let lc = (input.low - prev_close).abs();
229 hl.max(hc).max(lc)
230 } else {
231 input.high - input.low
232 }
233 }
234
235 fn calculate_directional_movements(&self, input: &ADXInput) -> (f64, f64) {
236 if let (Some(prev_high), Some(prev_low)) = (self.state.previous_high, self.state.previous_low) {
237 let up_move = input.high - prev_high;
238 let down_move = prev_low - input.low;
239
240 let plus_dm = if up_move > down_move && up_move > 0.0 { up_move } else { 0.0 };
241
242 let minus_dm = if down_move > up_move && down_move > 0.0 { down_move } else { 0.0 };
243
244 (plus_dm, minus_dm)
245 } else {
246 (0.0, 0.0)
247 }
248 }
249
250 fn accumulate_initial_data(&mut self, true_range: f64, plus_dm: f64, minus_dm: f64) {
251 match self.state.smoothed_tr {
253 Some(tr) => self.state.smoothed_tr = Some(tr + true_range),
254 None => self.state.smoothed_tr = Some(true_range),
255 }
256
257 match self.state.smoothed_plus_dm {
258 Some(dm) => self.state.smoothed_plus_dm = Some(dm + plus_dm),
259 None => self.state.smoothed_plus_dm = Some(plus_dm),
260 }
261
262 match self.state.smoothed_minus_dm {
263 Some(dm) => self.state.smoothed_minus_dm = Some(dm + minus_dm),
264 None => self.state.smoothed_minus_dm = Some(minus_dm),
265 }
266
267 if self.state.period_data.len() + 1 >= self.state.config.period {
269 self.state.has_di_data = true;
270 }
271 }
272
273 fn update_smoothed_values(&mut self, true_range: f64, plus_dm: f64, minus_dm: f64) {
274 let period = self.state.config.period as f64;
275
276 if let Some(smoothed_tr) = self.state.smoothed_tr {
278 self.state.smoothed_tr = Some((smoothed_tr * (period - 1.0) + true_range) / period);
279 }
280
281 if let Some(smoothed_plus_dm) = self.state.smoothed_plus_dm {
282 self.state.smoothed_plus_dm = Some((smoothed_plus_dm * (period - 1.0) + plus_dm) / period);
283 }
284
285 if let Some(smoothed_minus_dm) = self.state.smoothed_minus_dm {
286 self.state.smoothed_minus_dm = Some((smoothed_minus_dm * (period - 1.0) + minus_dm) / period);
287 }
288 }
289
290 fn calculate_directional_indicators(&self) -> (f64, f64) {
291 if let (Some(smoothed_tr), Some(smoothed_plus_dm), Some(smoothed_minus_dm)) = (self.state.smoothed_tr, self.state.smoothed_plus_dm, self.state.smoothed_minus_dm) {
292 if smoothed_tr != 0.0 {
293 let plus_di = (smoothed_plus_dm / smoothed_tr) * 100.0;
294 let minus_di = (smoothed_minus_dm / smoothed_tr) * 100.0;
295 (plus_di, minus_di)
296 } else {
297 (0.0, 0.0)
298 }
299 } else {
300 (0.0, 0.0)
301 }
302 }
303
304 fn calculate_dx(&self, plus_di: f64, minus_di: f64) -> Result<f64, ADXError> {
305 let di_sum = plus_di + minus_di;
306 if di_sum == 0.0 {
307 Ok(0.0)
308 } else {
309 let di_diff = (plus_di - minus_di).abs();
310 Ok((di_diff / di_sum) * 100.0)
311 }
312 }
313
314 fn calculate_adx(&mut self, dx: f64) -> f64 {
315 if self.state.dx_history.len() >= self.state.config.adx_smoothing {
317 self.state.dx_history.pop_front();
318 }
319 self.state.dx_history.push_back(dx);
320
321 if self.state.dx_history.len() >= self.state.config.adx_smoothing {
323 if !self.state.has_adx_data {
324 let adx = self.state.dx_history.iter().sum::<f64>() / self.state.dx_history.len() as f64;
326 self.state.current_adx = Some(adx);
327 self.state.has_adx_data = true;
328 adx
329 } else {
330 if let Some(prev_adx) = self.state.current_adx {
332 let period = self.state.config.adx_smoothing as f64;
333 let adx = (prev_adx * (period - 1.0) + dx) / period;
334 self.state.current_adx = Some(adx);
335 adx
336 } else {
337 0.0
338 }
339 }
340 } else {
341 0.0
342 }
343 }
344
345 fn classify_trend_strength(&self, adx: f64) -> TrendStrength {
346 if !self.state.has_adx_data {
347 TrendStrength::Insufficient
348 } else if adx >= self.state.config.very_strong_trend_threshold {
349 TrendStrength::VeryStrong
350 } else if adx >= self.state.config.strong_trend_threshold {
351 TrendStrength::Strong
352 } else {
353 TrendStrength::Weak
354 }
355 }
356
357 fn determine_trend_direction(&self, plus_di: f64, minus_di: f64) -> TrendDirection {
358 if plus_di > minus_di {
359 TrendDirection::Up
360 } else if minus_di > plus_di {
361 TrendDirection::Down
362 } else {
363 TrendDirection::Sideways
364 }
365 }
366}
367
368impl Default for ADX {
369 fn default() -> Self {
370 Self::new()
371 }
372}
373
374pub fn calculate_adx_simple(highs: &[f64], lows: &[f64], closes: &[f64], period: usize) -> Result<Vec<f64>, ADXError> {
376 let len = highs.len();
377 if len != lows.len() || len != closes.len() {
378 return Err(ADXError::InvalidInput("All price arrays must have same length".to_string()));
379 }
380
381 if len == 0 {
382 return Ok(Vec::new());
383 }
384
385 let mut adx_calculator = ADX::with_period(period)?;
386 let mut results = Vec::with_capacity(len);
387
388 for i in 0..len {
389 let input = ADXInput {
390 high: highs[i],
391 low: lows[i],
392 close: closes[i],
393 };
394 let output = adx_calculator.calculate(input)?;
395 results.push(output.adx);
396 }
397
398 Ok(results)
399}