1use std::{fmt, sync::Arc};
25
26use crate::positive_float::PositiveFloat;
27
28#[derive(Clone, Debug, PartialEq)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct InitialParams {
32 pub k: f64,
34 pub m: f64,
36 pub delta: Vec<f64>,
38 pub beta: Vec<f64>,
40 pub sigma_obs: PositiveFloat,
42}
43
44#[derive(Clone, Debug, Copy, Eq, PartialEq)]
46pub enum TrendIndicator {
47 Linear,
49 Logistic,
51 Flat,
53}
54
55#[cfg(feature = "serde")]
56impl serde::Serialize for TrendIndicator {
57 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
58 where
59 S: serde::Serializer,
60 {
61 serializer.serialize_u8(match self {
62 Self::Linear => 0,
63 Self::Logistic => 1,
64 Self::Flat => 2,
65 })
66 }
67}
68
69#[cfg(feature = "serde")]
70impl<'de> serde::Deserialize<'de> for TrendIndicator {
71 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
72 where
73 D: serde::Deserializer<'de>,
74 D::Error: serde::de::Error,
75 {
76 let value = u8::deserialize(deserializer)?;
77 match value {
78 0 => Ok(Self::Linear),
79 1 => Ok(Self::Logistic),
80 2 => Ok(Self::Flat),
81 _ => Err(serde::de::Error::custom("invalid trend indicator")),
82 }
83 }
84}
85
86#[derive(Clone, Debug, PartialEq)]
88#[allow(non_snake_case)]
89pub struct Data {
90 pub T: i32,
92 pub y: Vec<f64>,
94 pub t: Vec<f64>,
96 pub cap: Vec<f64>,
98 pub S: i32,
100 pub t_change: Vec<f64>,
102 pub trend_indicator: TrendIndicator,
104 pub K: i32,
107 pub s_a: Vec<i32>,
109 pub s_m: Vec<i32>,
111 pub X: Vec<f64>,
120 pub sigmas: Vec<PositiveFloat>,
122 pub tau: PositiveFloat,
125}
126
127#[cfg(feature = "serde")]
128impl serde::Serialize for Data {
129 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
130 where
131 S: serde::Serializer,
132 {
133 use serde::ser::{SerializeSeq, SerializeStruct};
134
135 struct XSerializer<'a>(&'a [f64], usize);
138
139 impl serde::Serialize for XSerializer<'_> {
140 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
141 where
142 S: serde::Serializer,
143 {
144 if self.1 == 0 {
145 return Err(serde::ser::Error::custom(
146 "Invalid value for K: cannot be zero",
147 ));
148 }
149 let chunk_size = self.1;
150 let mut outer = serializer.serialize_seq(Some(self.0.len() / chunk_size))?;
151 for chunk in self.0.chunks(chunk_size) {
152 outer.serialize_element(&chunk)?;
153 }
154 outer.end()
155 }
156 }
157
158 let mut s = serializer.serialize_struct("Data", 13)?;
159 let x = XSerializer(&self.X, self.K as usize);
160 s.serialize_field("T", &self.T)?;
161 s.serialize_field("y", &self.y)?;
162 s.serialize_field("t", &self.t)?;
163 s.serialize_field("cap", &self.cap)?;
164 s.serialize_field("S", &self.S)?;
165 s.serialize_field("t_change", &self.t_change)?;
166 s.serialize_field("trend_indicator", &self.trend_indicator)?;
167 s.serialize_field("K", &self.K)?;
168 s.serialize_field("s_a", &self.s_a)?;
169 s.serialize_field("s_m", &self.s_m)?;
170 s.serialize_field("X", &x)?;
171 s.serialize_field("sigmas", &self.sigmas)?;
172 s.serialize_field("tau", &self.tau)?;
173 s.end()
174 }
175}
176
177#[derive(Debug, Clone, Copy, Eq, PartialEq)]
179pub enum Algorithm {
180 Newton,
182 Bfgs,
184 Lbfgs,
186}
187
188impl fmt::Display for Algorithm {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 let s = match self {
191 Self::Lbfgs => "lbfgs",
192 Self::Newton => "newton",
193 Self::Bfgs => "bfgs",
194 };
195 f.write_str(s)
196 }
197}
198
199#[derive(Default, Debug, Clone)]
201pub struct OptimizeOpts {
202 pub algorithm: Option<Algorithm>,
204 pub seed: Option<u32>,
206 pub chain: Option<u32>,
208 pub init_alpha: Option<f64>,
210 pub tol_obj: Option<f64>,
212 pub tol_rel_obj: Option<f64>,
214 pub tol_grad: Option<f64>,
216 pub tol_rel_grad: Option<f64>,
218 pub tol_param: Option<f64>,
220 pub history_size: Option<u32>,
224 pub iter: Option<u32>,
226 pub jacobian: Option<bool>,
229 pub refresh: Option<u32>,
231}
232
233#[derive(Debug, Clone)]
235pub struct OptimizedParams {
236 pub k: f64,
238 pub m: f64,
240 pub sigma_obs: PositiveFloat,
242 pub delta: Vec<f64>,
244 pub beta: Vec<f64>,
246 pub trend: Vec<f64>,
248}
249
250#[derive(Debug, thiserror::Error)]
252#[error(transparent)]
253pub struct Error(
254 #[from]
259 ErrorKind,
260);
261
262impl Error {
263 pub fn static_str(s: &'static str) -> Self {
265 Self(ErrorKind::StaticStr(s))
266 }
267
268 pub fn string(s: String) -> Self {
270 Self(ErrorKind::String(s))
271 }
272
273 pub fn custom<E: std::error::Error + 'static>(e: E) -> Self {
275 Self(ErrorKind::Custom(Box::new(e)))
276 }
277}
278
279#[derive(Debug, thiserror::Error)]
280enum ErrorKind {
281 #[error("Error in optimization: {0}")]
282 StaticStr(&'static str),
283 #[error("Error in optimization: {0}")]
284 String(String),
285 #[error("Error in optimization: {0}")]
286 Custom(#[from] Box<dyn std::error::Error>),
287}
288
289pub trait Optimizer: std::fmt::Debug {
292 fn optimize(
295 &self,
296 init: &InitialParams,
297 data: &Data,
298 opts: &OptimizeOpts,
299 ) -> Result<OptimizedParams, Error>;
300}
301
302impl Optimizer for Arc<dyn Optimizer> {
306 fn optimize(
307 &self,
308 init: &InitialParams,
309 data: &Data,
310 opts: &OptimizeOpts,
311 ) -> Result<OptimizedParams, Error> {
312 (**self).optimize(init, data, opts)
313 }
314}
315
316#[cfg(test)]
317pub(crate) mod mock_optimizer {
318 use std::cell::RefCell;
319
320 use super::*;
321
322 #[derive(Debug, Clone)]
323 pub(crate) struct OptimizeCall {
324 pub init: InitialParams,
325 pub data: Data,
326 pub _opts: OptimizeOpts,
327 }
328
329 #[derive(Clone, Debug)]
331 pub(crate) struct MockOptimizer {
332 pub call: RefCell<Option<OptimizeCall>>,
339 }
340
341 impl MockOptimizer {
342 pub(crate) fn new() -> Self {
344 Self {
345 call: RefCell::new(None),
346 }
347 }
348
349 pub(crate) fn take_call(&self) -> Option<OptimizeCall> {
351 self.call.borrow_mut().take()
352 }
353 }
354
355 impl Optimizer for MockOptimizer {
356 fn optimize(
357 &self,
358 init: &InitialParams,
359 data: &Data,
360 opts: &OptimizeOpts,
361 ) -> Result<OptimizedParams, Error> {
362 *self.call.borrow_mut() = Some(OptimizeCall {
363 init: init.clone(),
364 data: data.clone(),
365 _opts: opts.clone(),
366 });
367 Ok(OptimizedParams {
368 k: init.k,
369 m: init.m,
370 sigma_obs: init.sigma_obs,
371 delta: init.delta.clone(),
372 beta: init.beta.clone(),
373 trend: Vec::new(),
374 })
375 }
376 }
377}
378
379#[cfg(test)]
380mod tests {
381
382 #[cfg(feature = "serde")]
383 #[test]
384 fn serialize_data() {
385 use super::*;
386 let data = Data {
387 T: 3,
388 y: vec![1.0, 2.0, 3.0],
389 t: vec![0.0, 1.0, 2.0],
390 X: vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0],
391 sigmas: vec![
392 1.0.try_into().unwrap(),
393 2.0.try_into().unwrap(),
394 3.0.try_into().unwrap(),
395 ],
396 tau: 1.0.try_into().unwrap(),
397 K: 2,
398 s_a: vec![1, 1, 1],
399 s_m: vec![0, 0, 0],
400 cap: vec![0.0, 0.0, 0.0],
401 S: 2,
402 t_change: vec![0.0, 0.0, 0.0],
403 trend_indicator: TrendIndicator::Linear,
404 };
405 let serialized = serde_json::to_string_pretty(&data).unwrap();
406 pretty_assertions::assert_eq!(
407 serialized,
408 r#"{
409 "T": 3,
410 "y": [
411 1.0,
412 2.0,
413 3.0
414 ],
415 "t": [
416 0.0,
417 1.0,
418 2.0
419 ],
420 "cap": [
421 0.0,
422 0.0,
423 0.0
424 ],
425 "S": 2,
426 "t_change": [
427 0.0,
428 0.0,
429 0.0
430 ],
431 "trend_indicator": 0,
432 "K": 2,
433 "s_a": [
434 1,
435 1,
436 1
437 ],
438 "s_m": [
439 0,
440 0,
441 0
442 ],
443 "X": [
444 [
445 1.0,
446 2.0
447 ],
448 [
449 3.0,
450 1.0
451 ],
452 [
453 2.0,
454 3.0
455 ]
456 ],
457 "sigmas": [
458 1.0,
459 2.0,
460 3.0
461 ],
462 "tau": 1.0
463}"#
464 );
465 }
466}