Skip to main content

nautilus_indicators/average/
ama.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::Display;
17
18use nautilus_model::{
19    data::{Bar, QuoteTick, TradeTick},
20    enums::PriceType,
21};
22
23use crate::{
24    indicator::{Indicator, MovingAverage},
25    ratio::efficiency_ratio::EfficiencyRatio,
26};
27
28/// An indicator which calculates an adaptive moving average (AMA) across a
29/// rolling window. Developed by Perry Kaufman, the AMA is a moving average
30/// designed to account for market noise and volatility. The AMA will closely
31/// follow prices when the price swings are relatively small and the noise is
32/// low. The AMA will increase lag when the price swings increase.
33#[repr(C)]
34#[derive(Debug)]
35#[cfg_attr(
36    feature = "python",
37    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
38)]
39#[cfg_attr(
40    feature = "python",
41    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
42)]
43pub struct AdaptiveMovingAverage {
44    /// The period for the internal `EfficiencyRatio` indicator.
45    pub period_efficiency_ratio: usize,
46    /// The period for the fast smoothing constant (> 0).
47    pub period_fast: usize,
48    /// The period for the slow smoothing constant (> `period_fast`).
49    pub period_slow: usize,
50    /// The price type used for calculations.
51    pub price_type: PriceType,
52    /// The last indicator value.
53    pub value: f64,
54    /// The input count for the indicator.
55    pub count: usize,
56    pub initialized: bool,
57    has_inputs: bool,
58    efficiency_ratio: EfficiencyRatio,
59    prior_value: Option<f64>,
60    alpha_fast: f64,
61    alpha_slow: f64,
62}
63
64impl Display for AdaptiveMovingAverage {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        write!(
67            f,
68            "{}({},{},{})",
69            self.name(),
70            self.period_efficiency_ratio,
71            self.period_fast,
72            self.period_slow
73        )
74    }
75}
76
77impl Indicator for AdaptiveMovingAverage {
78    fn name(&self) -> String {
79        stringify!(AdaptiveMovingAverage).to_string()
80    }
81
82    fn has_inputs(&self) -> bool {
83        self.has_inputs
84    }
85
86    fn initialized(&self) -> bool {
87        self.initialized
88    }
89
90    fn handle_quote(&mut self, quote: &QuoteTick) {
91        self.update_raw(quote.extract_price(self.price_type).into());
92    }
93
94    fn handle_trade(&mut self, trade: &TradeTick) {
95        self.update_raw((&trade.price).into());
96    }
97
98    fn handle_bar(&mut self, bar: &Bar) {
99        self.update_raw((&bar.close).into());
100    }
101
102    fn reset(&mut self) {
103        self.value = 0.0;
104        self.count = 0;
105        self.has_inputs = false;
106        self.initialized = false;
107    }
108}
109
110impl AdaptiveMovingAverage {
111    /// Creates a new [`AdaptiveMovingAverage`] instance.
112    ///
113    /// # Panics
114    ///
115    /// This function panics if:
116    /// - `period_efficiency_ratio` == 0.
117    /// - `period_fast` == 0.
118    /// - `period_slow` == 0.
119    /// - `period_slow` ≤ `period_fast`.
120    #[must_use]
121    pub fn new(
122        period_efficiency_ratio: usize,
123        period_fast: usize,
124        period_slow: usize,
125        price_type: Option<PriceType>,
126    ) -> Self {
127        assert!(
128            period_efficiency_ratio > 0,
129            "period_efficiency_ratio must be a positive integer"
130        );
131        assert!(period_fast > 0, "period_fast must be a positive integer");
132        assert!(period_slow > 0, "period_slow must be a positive integer");
133        assert!(
134            period_slow > period_fast,
135            "period_slow ({period_slow}) must be greater than period_fast ({period_fast})"
136        );
137        Self {
138            period_efficiency_ratio,
139            period_fast,
140            period_slow,
141            price_type: price_type.unwrap_or(PriceType::Last),
142            value: 0.0,
143            count: 0,
144            alpha_fast: 2.0 / (period_fast + 1) as f64,
145            alpha_slow: 2.0 / (period_slow + 1) as f64,
146            prior_value: None,
147            has_inputs: false,
148            initialized: false,
149            efficiency_ratio: EfficiencyRatio::new(period_efficiency_ratio, price_type),
150        }
151    }
152
153    #[must_use]
154    pub fn alpha_diff(&self) -> f64 {
155        self.alpha_fast - self.alpha_slow
156    }
157
158    pub const fn reset(&mut self) {
159        self.value = 0.0;
160        self.prior_value = None;
161        self.count = 0;
162        self.has_inputs = false;
163        self.initialized = false;
164    }
165}
166
167impl MovingAverage for AdaptiveMovingAverage {
168    fn value(&self) -> f64 {
169        self.value
170    }
171
172    fn count(&self) -> usize {
173        self.count
174    }
175
176    fn update_raw(&mut self, value: f64) {
177        self.count += 1;
178
179        if !self.has_inputs {
180            self.prior_value = Some(value);
181            self.efficiency_ratio.update_raw(value);
182            self.value = value;
183            self.has_inputs = true;
184            return;
185        }
186
187        self.efficiency_ratio.update_raw(value);
188        self.prior_value = Some(self.value);
189
190        // Calculate the smoothing constant
191        let smoothing_constant = self
192            .efficiency_ratio
193            .value
194            .mul_add(self.alpha_diff(), self.alpha_slow)
195            .powi(2);
196
197        // Calculate the AMA
198        // TODO: Remove unwraps
199        self.value = smoothing_constant
200            .mul_add(value - self.prior_value.unwrap(), self.prior_value.unwrap());
201
202        if self.efficiency_ratio.initialized() {
203            self.initialized = true;
204        }
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use nautilus_model::data::{Bar, QuoteTick, TradeTick};
211    use rstest::rstest;
212
213    use crate::{
214        average::ama::AdaptiveMovingAverage,
215        indicator::{Indicator, MovingAverage},
216        stubs::*,
217    };
218
219    #[rstest]
220    fn test_ama_initialized(indicator_ama_10: AdaptiveMovingAverage) {
221        let display_str = format!("{indicator_ama_10}");
222        assert_eq!(display_str, "AdaptiveMovingAverage(10,2,30)");
223        assert_eq!(indicator_ama_10.name(), "AdaptiveMovingAverage");
224        assert!(!indicator_ama_10.has_inputs());
225        assert!(!indicator_ama_10.initialized());
226    }
227
228    #[rstest]
229    fn test_value_with_one_input(mut indicator_ama_10: AdaptiveMovingAverage) {
230        indicator_ama_10.update_raw(1.0);
231        assert_eq!(indicator_ama_10.value, 1.0);
232    }
233
234    #[rstest]
235    fn test_value_with_two_inputs(mut indicator_ama_10: AdaptiveMovingAverage) {
236        indicator_ama_10.update_raw(1.0);
237        indicator_ama_10.update_raw(2.0);
238        assert_eq!(indicator_ama_10.value, 1.444_444_444_444_444_2);
239    }
240
241    #[rstest]
242    fn test_value_with_three_inputs(mut indicator_ama_10: AdaptiveMovingAverage) {
243        indicator_ama_10.update_raw(1.0);
244        indicator_ama_10.update_raw(2.0);
245        indicator_ama_10.update_raw(3.0);
246        assert_eq!(indicator_ama_10.value, 2.135_802_469_135_802);
247    }
248
249    #[rstest]
250    fn test_reset(mut indicator_ama_10: AdaptiveMovingAverage) {
251        for _ in 0..10 {
252            indicator_ama_10.update_raw(1.0);
253        }
254        assert!(indicator_ama_10.initialized);
255        indicator_ama_10.reset();
256        assert!(!indicator_ama_10.initialized);
257        assert!(!indicator_ama_10.has_inputs);
258        assert_eq!(indicator_ama_10.value, 0.0);
259        assert_eq!(indicator_ama_10.count, 0);
260    }
261
262    #[rstest]
263    fn test_initialized_after_correct_number_of_input(indicator_ama_10: AdaptiveMovingAverage) {
264        let mut ama = indicator_ama_10;
265        for _ in 0..9 {
266            ama.update_raw(1.0);
267        }
268        assert!(!ama.initialized);
269        ama.update_raw(1.0);
270        assert!(ama.initialized);
271    }
272
273    #[rstest]
274    fn test_count_increments(mut indicator_ama_10: AdaptiveMovingAverage) {
275        assert_eq!(indicator_ama_10.count(), 0);
276        indicator_ama_10.update_raw(1.0);
277        assert_eq!(indicator_ama_10.count(), 1);
278        indicator_ama_10.update_raw(2.0);
279        indicator_ama_10.update_raw(3.0);
280        assert_eq!(indicator_ama_10.count(), 3);
281    }
282
283    #[rstest]
284    fn test_handle_quote_tick(mut indicator_ama_10: AdaptiveMovingAverage, stub_quote: QuoteTick) {
285        indicator_ama_10.handle_quote(&stub_quote);
286        assert!(indicator_ama_10.has_inputs);
287        assert!(!indicator_ama_10.initialized);
288        assert_eq!(indicator_ama_10.value, 1501.0);
289        assert_eq!(indicator_ama_10.count(), 1);
290    }
291
292    #[rstest]
293    fn test_handle_trade_tick_update(
294        mut indicator_ama_10: AdaptiveMovingAverage,
295        stub_trade: TradeTick,
296    ) {
297        indicator_ama_10.handle_trade(&stub_trade);
298        assert!(indicator_ama_10.has_inputs);
299        assert!(!indicator_ama_10.initialized);
300        assert_eq!(indicator_ama_10.value, 1500.0);
301        assert_eq!(indicator_ama_10.count(), 1);
302    }
303
304    #[rstest]
305    fn handle_handle_bar(
306        mut indicator_ama_10: AdaptiveMovingAverage,
307        bar_ethusdt_binance_minute_bid: Bar,
308    ) {
309        indicator_ama_10.handle_bar(&bar_ethusdt_binance_minute_bid);
310        assert!(indicator_ama_10.has_inputs);
311        assert!(!indicator_ama_10.initialized);
312        assert_eq!(indicator_ama_10.value, 1522.0);
313        assert_eq!(indicator_ama_10.count(), 1);
314    }
315
316    #[rstest]
317    fn new_panics_when_slow_not_greater_than_fast() {
318        let result = std::panic::catch_unwind(|| {
319            let _ = AdaptiveMovingAverage::new(10, 20, 20, None);
320        });
321        assert!(result.is_err());
322    }
323
324    #[rstest]
325    fn new_panics_when_er_is_zero() {
326        let result = std::panic::catch_unwind(|| {
327            let _ = AdaptiveMovingAverage::new(0, 2, 30, None);
328        });
329        assert!(result.is_err());
330    }
331
332    #[rstest]
333    fn new_panics_when_fast_is_zero() {
334        let result = std::panic::catch_unwind(|| {
335            let _ = AdaptiveMovingAverage::new(10, 0, 30, None);
336        });
337        assert!(result.is_err());
338    }
339
340    #[rstest]
341    fn new_panics_when_slow_is_zero() {
342        let result = std::panic::catch_unwind(|| {
343            let _ = AdaptiveMovingAverage::new(10, 2, 0, None);
344        });
345        assert!(result.is_err());
346    }
347
348    #[rstest]
349    fn new_panics_when_slow_less_than_fast() {
350        let result = std::panic::catch_unwind(|| {
351            let _ = AdaptiveMovingAverage::new(10, 20, 5, None);
352        });
353        assert!(result.is_err());
354    }
355}