fsrs/
inference.rs

1use std::collections::HashMap;
2use std::ops::{Add, Sub};
3
4use crate::model::{FSRS, Get, MemoryStateTensors};
5use crate::simulation::{D_MAX, D_MIN, S_MIN};
6use crate::training::ComputeParametersInput;
7use burn::nn::loss::Reduction;
8use burn::tensor::cast::ToElement;
9use burn::tensor::{Shape, Tensor, TensorData};
10use burn::{data::dataloader::batcher::Batcher, tensor::backend::Backend};
11
12use crate::dataset::{
13    FSRSBatch, FSRSBatcher, constant_weighted_fsrs_items, recency_weighted_fsrs_items,
14};
15use crate::error::Result;
16use crate::model::Model;
17use crate::training::BCELoss;
18use crate::{FSRSError, FSRSItem};
19use burn::tensor::ElementConversion;
20/// This is a slice for efficiency, but should always be 21 in length.
21pub type Parameters = [f32];
22use itertools::izip;
23
24pub const FSRS5_DEFAULT_DECAY: f32 = 0.5;
25pub const FSRS6_DEFAULT_DECAY: f32 = 0.1542;
26
27pub static DEFAULT_PARAMETERS: [f32; 21] = [
28    0.212,
29    1.2931,
30    2.3065,
31    8.2956,
32    6.4133,
33    0.8334,
34    3.0194,
35    0.001,
36    1.8722,
37    0.1666,
38    0.796,
39    1.4835,
40    0.0614,
41    0.2629,
42    1.6483,
43    0.6014,
44    1.8729,
45    0.5425,
46    0.0912,
47    0.0658,
48    FSRS6_DEFAULT_DECAY,
49];
50
51fn infer<B: Backend>(
52    model: &Model<B>,
53    batch: FSRSBatch<B>,
54) -> (MemoryStateTensors<B>, Tensor<B, 1>) {
55    let state = model.forward(batch.t_historys, batch.r_historys, None);
56    let retrievability = model.power_forgetting_curve(batch.delta_ts, state.stability.clone());
57    (state, retrievability)
58}
59
60pub fn current_retrievability(state: MemoryState, days_elapsed: f32, decay: f32) -> f32 {
61    let factor = 0.9f32.powf(1.0 / -decay) - 1.0;
62    (days_elapsed / state.stability * factor + 1.0).powf(-decay)
63}
64
65#[derive(Debug, PartialEq, Clone, Copy)]
66pub struct MemoryState {
67    pub stability: f32,
68    pub difficulty: f32,
69}
70
71impl<B: Backend> From<MemoryStateTensors<B>> for MemoryState {
72    fn from(m: MemoryStateTensors<B>) -> Self {
73        Self {
74            stability: m.stability.into_scalar().elem(),
75            difficulty: m.difficulty.into_scalar().elem(),
76        }
77    }
78}
79
80impl<B: Backend> From<MemoryState> for MemoryStateTensors<B> {
81    fn from(m: MemoryState) -> Self {
82        Self {
83            stability: Tensor::from_floats([m.stability], &B::Device::default()),
84            difficulty: Tensor::from_floats([m.difficulty], &B::Device::default()),
85        }
86    }
87}
88
89#[derive(Default)]
90struct RMatrixValue {
91    predicted: f32,
92    actual: f32,
93    count: f32,
94    weight: f32,
95}
96
97impl<B: Backend> FSRS<B> {
98    fn item_to_tensors(&self, item: &FSRSItem) -> (Tensor<B, 2>, Tensor<B, 2>) {
99        let (time_history, rating_history) =
100            item.reviews.iter().map(|r| (r.delta_t, r.rating)).unzip();
101        let size = item.reviews.len();
102        let time_history = Tensor::<B, 1>::from_data(
103            TensorData::new(time_history, Shape { dims: vec![size] }),
104            &self.device(),
105        )
106        .unsqueeze()
107        .transpose();
108        let rating_history = Tensor::<B, 1>::from_data(
109            TensorData::new(rating_history, Shape { dims: vec![size] }),
110            &self.device(),
111        )
112        .unsqueeze()
113        .transpose();
114        (time_history, rating_history)
115    }
116
117    /// Calculate the current memory state for a given card's history of reviews.
118    /// In the case of truncated reviews, `starting_state` can be set to the value of
119    /// [FSRS::memory_state_from_sm2] for the first review (which should not be included
120    /// in FSRSItem). If not provided, the card starts as new.
121    /// Parameters must have been provided when calling FSRS::new().
122    pub fn memory_state(
123        &self,
124        item: FSRSItem,
125        starting_state: Option<MemoryState>,
126    ) -> Result<MemoryState> {
127        let (time_history, rating_history) = self.item_to_tensors(&item);
128        let state: MemoryState = self
129            .model()
130            .forward(time_history, rating_history, starting_state.map(Into::into))
131            .into();
132        if !state.stability.is_finite() || !state.difficulty.is_finite() {
133            Err(FSRSError::InvalidInput)
134        } else {
135            Ok(state)
136        }
137    }
138
139    pub fn historical_memory_states(
140        &self,
141        item: FSRSItem,
142        starting_state: Option<MemoryState>,
143    ) -> Result<Vec<MemoryState>> {
144        let (time_history, rating_history) = self.item_to_tensors(&item);
145        let mut states = vec![];
146        if let Some(starting_state) = starting_state {
147            states.push(starting_state);
148        }
149        let [seq_len, _batch_size] = time_history.dims();
150        let mut inner_state = starting_state.map(Into::into);
151        for i in 0..seq_len {
152            let delta_t = time_history.get(i).squeeze(0);
153            // [batch_size]
154            let rating = rating_history.get(i).squeeze(0);
155            // [batch_size]
156            inner_state = Some(self.model().step(delta_t, rating, inner_state.clone()));
157            if let Some(state) = inner_state.clone() {
158                let state: MemoryState = state.into();
159                if !state.stability.is_finite() || !state.difficulty.is_finite() {
160                    return Err(FSRSError::InvalidInput);
161                }
162                states.push(state);
163            }
164        }
165        Ok(states)
166    }
167
168    /// If a card has incomplete learning history, memory state can be approximated from
169    /// current sm2 values.
170    /// Parameters must have been provided when calling FSRS::new().
171    pub fn memory_state_from_sm2(
172        &self,
173        ease_factor: f32,
174        interval: f32,
175        sm2_retention: f32,
176    ) -> Result<MemoryState> {
177        let w = &self.model().w;
178        let decay: f32 = w.get(20).neg().into_scalar().elem();
179        let factor = 0.9f32.powf(1.0 / decay) - 1.0;
180        let stability = interval.max(S_MIN) * factor / (sm2_retention.powf(1.0 / decay) - 1.0);
181        let w8: f32 = w.get(8).into_scalar().elem();
182        let w9: f32 = w.get(9).into_scalar().elem();
183        let w10: f32 = w.get(10).into_scalar().elem();
184        let difficulty = 11.0
185            - (ease_factor - 1.0)
186                / (w8.exp() * stability.powf(-w9) * ((1.0 - sm2_retention) * w10).exp_m1());
187        if !stability.is_finite() || !difficulty.is_finite() {
188            Err(FSRSError::InvalidInput)
189        } else {
190            Ok(MemoryState {
191                stability,
192                difficulty: difficulty.clamp(D_MIN, D_MAX),
193            })
194        }
195    }
196
197    /// Calculate the next interval for the current memory state, for rescheduling. Stability
198    /// should be provided except when the card is new. Rating is ignored except when card is new.
199    /// Parameters must have been provided when calling FSRS::new().
200    pub fn next_interval(
201        &self,
202        stability: Option<f32>,
203        desired_retention: f32,
204        rating: u32,
205    ) -> f32 {
206        let model = self.model();
207        let stability = stability.unwrap_or_else(|| {
208            // get initial stability for new card
209            let rating = Tensor::from_floats([rating], &self.device());
210            model.init_stability(rating).into_scalar().elem()
211        });
212        model
213            .next_interval(
214                Tensor::from_floats([stability], &self.device()),
215                Tensor::from_floats([desired_retention], &self.device()),
216            )
217            .into_scalar()
218            .elem()
219    }
220
221    /// The intervals and memory states for each answer button.
222    /// Parameters must have been provided when calling FSRS::new().
223    pub fn next_states(
224        &self,
225        current_memory_state: Option<MemoryState>,
226        desired_retention: f32,
227        days_elapsed: u32,
228    ) -> Result<NextStates> {
229        let delta_t = Tensor::from_data(
230            TensorData::new(vec![days_elapsed], Shape { dims: vec![1] }),
231            &self.device(),
232        );
233        let current_memory_state_tensors = current_memory_state.map(MemoryStateTensors::from);
234        let model = self.model();
235        let mut next_memory_states = (1..=4).map(|rating| {
236            Ok({
237                let state = MemoryState::from(model.step(
238                    delta_t.clone(),
239                    Tensor::from_data(
240                        TensorData::new(vec![rating], Shape { dims: vec![1] }),
241                        &self.device(),
242                    ),
243                    current_memory_state_tensors.clone(),
244                ));
245                if !state.stability.is_finite() || !state.difficulty.is_finite() {
246                    return Err(FSRSError::InvalidInput);
247                }
248                state
249            })
250        });
251
252        let mut get_next_state = || {
253            let memory = next_memory_states.next().unwrap()?;
254            let interval = model
255                .next_interval(
256                    Tensor::from_floats([memory.stability], &self.device()),
257                    Tensor::from_floats([desired_retention], &self.device()),
258                )
259                .into_scalar()
260                .elem();
261            Ok(ItemState { memory, interval })
262        };
263
264        Ok(NextStates {
265            again: get_next_state()?,
266            hard: get_next_state()?,
267            good: get_next_state()?,
268            easy: get_next_state()?,
269        })
270    }
271
272    /// Determine how well the model and parameters predict performance.
273    /// Parameters must have been provided when calling FSRS::new().
274    pub fn evaluate<F>(&self, items: Vec<FSRSItem>, mut progress: F) -> Result<ModelEvaluation>
275    where
276        F: FnMut(ItemProgress) -> bool,
277    {
278        if items.is_empty() {
279            return Err(FSRSError::NotEnoughData);
280        }
281        let weighted_items = recency_weighted_fsrs_items(items);
282        let batcher = FSRSBatcher::new(self.device());
283        let mut all_retrievability = vec![];
284        let mut all_labels = vec![];
285        let mut all_weights = vec![];
286        let mut progress_info = ItemProgress {
287            current: 0,
288            total: weighted_items.len(),
289        };
290        let model = self.model();
291        let mut r_matrix: HashMap<(u32, u32, u32), RMatrixValue> = HashMap::new();
292
293        for chunk in weighted_items.chunks(512) {
294            let batch = batcher.batch(chunk.to_vec(), &self.device());
295            let (_state, retrievability) = infer::<B>(model, batch.clone());
296            let pred = retrievability.clone().to_data().to_vec::<f32>().unwrap();
297            let true_val = batch.labels.clone().to_data().to_vec::<i64>().unwrap();
298            all_retrievability.push(retrievability);
299            all_labels.push(batch.labels);
300            all_weights.push(batch.weights);
301            izip!(chunk, pred, true_val).for_each(|(weighted_item, p, y)| {
302                let bin = weighted_item.item.r_matrix_index();
303                let value = r_matrix.entry(bin).or_default();
304                value.predicted += p;
305                value.actual += y as f32;
306                value.count += 1.0;
307                value.weight += weighted_item.weight;
308            });
309            progress_info.current += chunk.len();
310            if !progress(progress_info) {
311                return Err(FSRSError::Interrupted);
312            }
313        }
314        let rmse = (r_matrix
315            .values()
316            .map(|v| {
317                let pred = v.predicted / v.count;
318                let real = v.actual / v.count;
319                (pred - real).powi(2) * v.weight
320            })
321            .sum::<f32>()
322            / r_matrix.values().map(|v| v.weight).sum::<f32>())
323        .sqrt();
324        let all_retrievability = Tensor::cat(all_retrievability, 0);
325        let all_labels = Tensor::cat(all_labels, 0).float();
326        let all_weights = Tensor::cat(all_weights, 0);
327        let loss =
328            BCELoss::new().forward(all_retrievability, all_labels, all_weights, Reduction::Auto);
329        Ok(ModelEvaluation {
330            log_loss: loss.into_scalar().to_f32(),
331            rmse_bins: rmse,
332        })
333    }
334
335    /// Determine how well the model and parameters predict performance using time series splits.
336    /// For each split:
337    /// 1. Use training data to compute parameters
338    /// 2. Use test data to make predictions
339    /// 3. Collect all predictions
340    ///
341    /// Finally, evaluate all predictions together
342    ///
343    /// # Arguments
344    /// * `items` - The dataset to evaluate
345    /// * `progress` - A callback function to report progress
346    ///
347    /// # Returns
348    /// A ModelEvaluation containing metrics for all predictions
349    pub fn evaluate_with_time_series_splits<F>(
350        &self,
351        ComputeParametersInput {
352            train_set,
353            enable_short_term,
354            num_relearning_steps,
355            ..
356        }: ComputeParametersInput,
357        mut progress: F,
358    ) -> Result<ModelEvaluation>
359    where
360        F: FnMut(ItemProgress) -> bool,
361    {
362        if train_set.is_empty() {
363            return Err(FSRSError::NotEnoughData);
364        }
365
366        let splits = TimeSeriesSplit::split(train_set, 5);
367        let mut all_predictions = Vec::new();
368        let mut progress_info = ItemProgress {
369            current: 0,
370            total: splits.len(),
371        };
372
373        for split in splits.into_iter() {
374            // Compute parameters on training data
375            let input = ComputeParametersInput {
376                train_set: split.train_items.clone(),
377                enable_short_term,
378                num_relearning_steps,
379                progress: None,
380            };
381            let parameters = self.compute_parameters(input)?;
382
383            // Make predictions on test data
384            let predictions = batch_predict::<B>(split.test_items, &parameters)?;
385
386            // Collect predictions
387            all_predictions.extend(predictions);
388
389            progress_info.current += 1;
390            if !progress(progress_info) {
391                return Err(FSRSError::Interrupted);
392            }
393        }
394
395        // Evaluate all predictions together
396        evaluate::<B>(all_predictions)
397    }
398
399    /// How well the user is likely to remember the item after `days_elapsed` since the previous
400    /// review.
401    pub fn current_retrievability(&self, state: MemoryState, days_elapsed: u32, decay: f32) -> f32 {
402        current_retrievability(state, days_elapsed as f32, decay)
403    }
404
405    /// How well the user is likely to remember the item after `seconds_elapsed` since the previous
406    /// review.
407    pub fn current_retrievability_seconds(
408        &self,
409        state: MemoryState,
410        seconds_elapsed: u32,
411        decay: f32,
412    ) -> f32 {
413        current_retrievability(state, seconds_elapsed as f32 / 86400.0, decay)
414    }
415
416    /// Returns the universal metrics for the existing and provided parameters. If the first value
417    /// is smaller than the second value, the existing parameters are better than the provided ones.
418    pub fn universal_metrics<F>(
419        &self,
420        items: Vec<FSRSItem>,
421        parameters: &Parameters,
422        mut progress: F,
423    ) -> Result<(f32, f32)>
424    where
425        F: FnMut(ItemProgress) -> bool,
426    {
427        if items.is_empty() {
428            return Err(FSRSError::NotEnoughData);
429        }
430        let weighted_items = constant_weighted_fsrs_items(items);
431        let batcher = FSRSBatcher::new(self.device());
432        let mut all_predictions_self = vec![];
433        let mut all_predictions_other = vec![];
434        let mut all_true_val = vec![];
435        let mut progress_info = ItemProgress {
436            current: 0,
437            total: weighted_items.len(),
438        };
439        let model_self = self.model();
440        let fsrs_other = Self::new_with_backend(Some(parameters), self.device())?;
441        let model_other = fsrs_other.model();
442        for chunk in weighted_items.chunks(512) {
443            let batch = batcher.batch(chunk.to_vec(), &self.device());
444
445            let (_state, retrievability) = infer::<B>(model_self, batch.clone());
446            let pred = retrievability.clone().to_data().to_vec::<f32>().unwrap();
447            all_predictions_self.extend(pred);
448
449            let (_state, retrievability) = infer::<B>(model_other, batch.clone());
450            let pred = retrievability.clone().to_data().to_vec::<f32>().unwrap();
451            all_predictions_other.extend(pred);
452
453            let true_val: Vec<f32> = batch
454                .labels
455                .clone()
456                .to_data()
457                .convert::<f32>()
458                .to_vec()
459                .unwrap();
460            all_true_val.extend(true_val);
461            progress_info.current += chunk.len();
462            if !progress(progress_info) {
463                return Err(FSRSError::Interrupted);
464            }
465        }
466        let self_by_other =
467            measure_a_by_b(&all_predictions_self, &all_predictions_other, &all_true_val);
468        let other_by_self =
469            measure_a_by_b(&all_predictions_other, &all_predictions_self, &all_true_val);
470        Ok((self_by_other, other_by_self))
471    }
472}
473
474#[derive(Debug, Clone)]
475pub struct PredictedFSRSItem {
476    pub item: FSRSItem,
477    pub retrievability: f32,
478}
479
480/// Batch predict retrievability for a set of items.
481///
482/// # Arguments
483/// * `items` - The dataset to predict
484/// * `parameters` - The model parameters to use for prediction
485/// * `progress` - A callback function to report progress
486///
487/// # Returns
488/// A vector of PredictedFSRSItem containing the original items and their predicted retrievability
489fn batch_predict<B: Backend>(
490    items: Vec<FSRSItem>,
491    parameters: &[f32],
492) -> Result<Vec<PredictedFSRSItem>>
493where
494{
495    if items.is_empty() {
496        return Err(FSRSError::NotEnoughData);
497    }
498    let weighted_items = constant_weighted_fsrs_items(items);
499    let batcher = FSRSBatcher::new(B::Device::default());
500
501    let fsrs = FSRS::<B>::new_with_backend(Some(parameters), B::Device::default())?;
502    let model = fsrs.model();
503    let mut predicted_items = Vec::with_capacity(weighted_items.len());
504
505    for chunk in weighted_items.chunks(512) {
506        let batch = batcher.batch(chunk.to_vec(), &B::Device::default());
507        let (_state, retrievability) = infer::<B>(model, batch.clone());
508        let pred = retrievability.to_data().to_vec::<f32>().unwrap();
509
510        for (weighted_item, p) in chunk.iter().zip(pred) {
511            predicted_items.push(PredictedFSRSItem {
512                item: weighted_item.item.clone(),
513                retrievability: p,
514            });
515        }
516    }
517
518    Ok(predicted_items)
519}
520
521/// Evaluate model predictions against ground truth.
522///
523/// # Arguments
524/// * `predicted_items` - The items with their predicted retrievability values
525/// * `progress` - A callback function to report progress
526///
527/// # Returns
528/// A ModelEvaluation containing log loss and RMSE metrics
529fn evaluate<B: Backend>(predicted_items: Vec<PredictedFSRSItem>) -> Result<ModelEvaluation> {
530    if predicted_items.is_empty() {
531        return Err(FSRSError::NotEnoughData);
532    }
533    let mut all_labels = Vec::with_capacity(predicted_items.len());
534    let mut r_matrix: HashMap<(u32, u32, u32), RMatrixValue> = HashMap::new();
535    for predicted_item in predicted_items.iter() {
536        let pred = predicted_item.retrievability;
537        let y = (predicted_item.item.current().rating > 1) as i32;
538        all_labels.push(y);
539        let bin = predicted_item.item.r_matrix_index();
540        let value = r_matrix.entry(bin).or_default();
541        value.predicted += pred;
542        value.actual += y as f32;
543        value.count += 1.0;
544        value.weight += 1.0;
545    }
546
547    let rmse = (r_matrix
548        .values()
549        .map(|v| {
550            let pred = v.predicted / v.count;
551            let real = v.actual / v.count;
552            (pred - real).powi(2) * v.weight
553        })
554        .sum::<f32>()
555        / r_matrix.values().map(|v| v.weight).sum::<f32>())
556    .sqrt();
557
558    let all_labels = Tensor::from_data(
559        TensorData::new(
560            all_labels.clone(),
561            Shape {
562                dims: vec![all_labels.len()],
563            },
564        ),
565        &B::Device::default(),
566    );
567    let all_weights = Tensor::ones(all_labels.shape(), &B::Device::default());
568    let all_retrievability: Tensor<B, 1> = Tensor::from_data(
569        TensorData::new(
570            predicted_items.iter().map(|p| p.retrievability).collect(),
571            Shape {
572                dims: vec![predicted_items.len()],
573            },
574        ),
575        &B::Device::default(),
576    );
577
578    let loss = BCELoss::new().forward(all_retrievability, all_labels, all_weights, Reduction::Auto);
579    Ok(ModelEvaluation {
580        log_loss: loss.into_scalar().to_f32(),
581        rmse_bins: rmse,
582    })
583}
584
585#[derive(Debug, Copy, Clone)]
586pub struct ModelEvaluation {
587    pub log_loss: f32,
588    pub rmse_bins: f32,
589}
590
591#[derive(Debug, Clone, PartialEq)]
592pub struct NextStates {
593    pub again: ItemState,
594    pub hard: ItemState,
595    pub good: ItemState,
596    pub easy: ItemState,
597}
598
599#[derive(Debug, PartialEq, Clone)]
600pub struct ItemState {
601    pub memory: MemoryState,
602    pub interval: f32,
603}
604
605#[derive(Debug, Clone, Copy)]
606pub struct ItemProgress {
607    pub current: usize,
608    pub total: usize,
609}
610
611#[derive(Debug, Clone)]
612pub struct TimeSeriesSplit {
613    pub train_items: Vec<FSRSItem>,
614    pub test_items: Vec<FSRSItem>,
615}
616
617impl TimeSeriesSplit {
618    /// Split the dataset into training and validation sets based on time order.
619    /// Creates n_splits folds where each fold's test set is a single segment,
620    /// and the training set consists of all segments before the test segment.
621    ///
622    /// For example, with n_splits=5, the folds would be:
623    /// Fold 0: Train=[0], Test=[1]
624    /// Fold 1: Train=[0,1], Test=[2]
625    /// Fold 2: Train=[0,1,2], Test=[3]
626    /// Fold 3: Train=[0,1,2,3], Test=[4]
627    /// Fold 4: Train=[0,1,2,3,4], Test=[5]
628    ///
629    /// # Arguments
630    /// * `sorted_items` - The dataset to split, assumed to be in time order
631    /// * `n_splits` - Number of splits to create
632    ///
633    /// # Returns
634    /// A vector of TimeSeriesSplit, each containing train and validation items
635    pub fn split(sorted_items: Vec<FSRSItem>, n_splits: usize) -> Vec<TimeSeriesSplit> {
636        if sorted_items.is_empty() || n_splits == 0 {
637            return vec![];
638        }
639        let total_items = sorted_items.len();
640        let segment_size = total_items / (n_splits + 1);
641        if segment_size == 0 {
642            return vec![];
643        }
644
645        (0..n_splits)
646            .map(|i| {
647                // Calculate the start of the test segment
648                let test_start = (i + 1) * segment_size;
649                // Calculate the end of the test segment (or the end of the data)
650                let test_end = if i == n_splits - 1 {
651                    total_items
652                } else {
653                    (i + 2) * segment_size
654                };
655
656                // Create the split
657                TimeSeriesSplit {
658                    train_items: sorted_items[..test_start].to_vec(),
659                    test_items: sorted_items[test_start..test_end].to_vec(),
660                }
661            })
662            .collect()
663    }
664}
665
666fn get_bin(x: f32, bins: i32) -> i32 {
667    let log_base = (bins.add(1) as f32).ln();
668    let binned_x = (x * log_base).exp().floor().sub(1.0);
669    (binned_x as i32).clamp(0, bins - 1)
670}
671
672fn measure_a_by_b(pred_a: &[f32], pred_b: &[f32], true_val: &[f32]) -> f32 {
673    let mut groups = HashMap::new();
674    izip!(pred_a, pred_b, true_val).for_each(|(a, b, t)| {
675        let bin = get_bin(*b, 20);
676        groups.entry(bin).or_insert_with(Vec::new).push((a, t));
677    });
678    let mut total_sum = 0.0;
679    let mut total_count = 0.0;
680    for group in groups.values() {
681        let count = group.len() as f32;
682        let pred_mean = group.iter().map(|(p, _)| *p).sum::<f32>() / count;
683        let true_mean = group.iter().map(|(_, t)| *t).sum::<f32>() / count;
684
685        let rmse = (pred_mean - true_mean).powi(2);
686        total_sum += rmse * count;
687        total_count += count;
688    }
689
690    (total_sum / total_count).sqrt()
691}
692
693#[cfg(test)]
694mod tests {
695    use super::*;
696    use crate::{
697        FSRSReview, convertor_tests::anki21_sample_file_converted_to_fsrs, dataset::filter_outlier,
698        test_helpers::TestHelper,
699    };
700
701    static PARAMETERS: &[f32] = &[
702        0.6845422,
703        1.6790825,
704        4.7349424,
705        10.042885,
706        7.4410233,
707        0.64219797,
708        1.071918,
709        0.0025195254,
710        1.432437,
711        0.1544,
712        0.8692766,
713        2.0696752,
714        0.0953,
715        0.2975,
716        2.4691248,
717        0.19542035,
718        3.201072,
719        0.18046261,
720        0.121442534,
721    ];
722
723    #[test]
724    fn test_get_bin() {
725        let pred = (0..=100).map(|i| i as f32 / 100.0).collect::<Vec<_>>();
726        let bin = pred.iter().map(|p| get_bin(*p, 20)).collect::<Vec<_>>();
727        assert_eq!(
728            bin,
729            [
730                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
731                1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4,
732                4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
733                11, 11, 11, 12, 12, 13, 13, 14, 14, 14, 15, 15, 16, 17, 17, 18, 18, 19, 19
734            ]
735        );
736    }
737
738    #[test]
739    fn test_memo_state() -> Result<()> {
740        let item = FSRSItem {
741            reviews: vec![
742                FSRSReview {
743                    rating: 1,
744                    delta_t: 0,
745                },
746                FSRSReview {
747                    rating: 3,
748                    delta_t: 1,
749                },
750                FSRSReview {
751                    rating: 3,
752                    delta_t: 3,
753                },
754                FSRSReview {
755                    rating: 3,
756                    delta_t: 8,
757                },
758                FSRSReview {
759                    rating: 3,
760                    delta_t: 21,
761                },
762            ],
763        };
764        let fsrs = FSRS::new(Some(PARAMETERS))?;
765        assert_eq!(
766            fsrs.memory_state(item, None).unwrap(),
767            MemoryState {
768                stability: 31.722992,
769                difficulty: 7.382128
770            }
771        );
772
773        assert_eq!(
774            fsrs.next_states(
775                Some(MemoryState {
776                    stability: 20.925528,
777                    difficulty: 7.005062
778                }),
779                0.9,
780                21
781            )
782            .unwrap()
783            .good
784            .memory,
785            MemoryState {
786                stability: 40.87456,
787                difficulty: 6.9913807
788            }
789        );
790        Ok(())
791    }
792
793    fn assert_memory_state(w: &[f32], expected_stability: f32, expected_difficulty: f32) {
794        let desired_retention = 0.9;
795        let fsrs = FSRS::new(Some(w)).unwrap();
796        let ratings: [u32; 6] = [1, 3, 3, 3, 3, 3];
797        let intervals: [u32; 6] = [0, 0, 1, 3, 8, 21];
798
799        let mut memory_state = None;
800        for (&rating, &interval) in ratings.iter().zip(intervals.iter()) {
801            let state = fsrs
802                .next_states(memory_state, desired_retention, interval)
803                .unwrap();
804            memory_state = match rating {
805                1 => Some(state.again.memory),
806                2 => Some(state.hard.memory),
807                3 => Some(state.good.memory),
808                4 => Some(state.easy.memory),
809                _ => None,
810            };
811            // dbg!(
812            //     "stability: {}, difficulty: {}",
813            //     memory_state.as_ref().unwrap().stability,
814            //     memory_state.as_ref().unwrap().difficulty
815            // );
816        }
817
818        let memory_state = memory_state.unwrap();
819        let stability = memory_state.stability;
820        let difficulty = memory_state.difficulty;
821        assert!(
822            (stability - expected_stability).abs() < 1e-4,
823            "stability: {}",
824            stability
825        );
826        assert!(
827            (difficulty - expected_difficulty).abs() < 1e-4,
828            "difficulty: {}",
829            difficulty
830        );
831    }
832    #[test]
833    fn test_memory_state() {
834        let mut w = DEFAULT_PARAMETERS;
835        assert_memory_state(&w, 53.62691, 6.3574867);
836        // freeze short term
837        w[17] = 0.0;
838        w[18] = 0.0;
839        w[19] = 0.0;
840        assert_memory_state(&w, 53.335106, 6.3574867);
841    }
842
843    #[test]
844    fn test_next_interval() {
845        let fsrs = FSRS::new(Some(&DEFAULT_PARAMETERS)).unwrap();
846        let desired_retentions = (1..=10).map(|i| i as f32 / 10.0).collect::<Vec<_>>();
847        let intervals = desired_retentions
848            .iter()
849            .map(|r| fsrs.next_interval(Some(1.0), *r, 1).round().max(1.0) as i32)
850            .collect::<Vec<_>>();
851        assert_eq!(intervals, [3116766, 34793, 2508, 387, 90, 27, 9, 3, 1, 1]);
852    }
853
854    #[test]
855    fn test_evaluate() -> Result<()> {
856        let items = anki21_sample_file_converted_to_fsrs();
857        let (mut dataset_for_initialization, mut trainset): (Vec<FSRSItem>, Vec<FSRSItem>) = items
858            .into_iter()
859            .partition(|item| item.long_term_review_cnt() == 1);
860        (dataset_for_initialization, trainset) =
861            filter_outlier(dataset_for_initialization, trainset);
862        let items = [dataset_for_initialization, trainset].concat();
863
864        let fsrs = FSRS::new(Some(&[
865            0.335561,
866            1.6840581,
867            5.166598,
868            11.659035,
869            7.466705,
870            0.7205129,
871            2.622295,
872            0.001,
873            1.315015,
874            0.10468433,
875            0.8349206,
876            1.822305,
877            0.12473127,
878            0.26111007,
879            2.3030033,
880            0.13117497,
881            3.0265594,
882            0.41468078,
883            0.09714265,
884            0.106824234,
885            0.20447432,
886        ]))?;
887        let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
888
889        [metrics.log_loss, metrics.rmse_bins].assert_approx_eq([0.20580745, 0.026005825]);
890
891        let fsrs = FSRS::new(Some(&[]))?;
892        let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
893
894        [metrics.log_loss, metrics.rmse_bins].assert_approx_eq([0.20967911, 0.030774858]);
895
896        let fsrs = FSRS::new(Some(PARAMETERS))?;
897        let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
898
899        [metrics.log_loss, metrics.rmse_bins].assert_approx_eq([0.208_657_4, 0.030_946_612]);
900
901        let (self_by_other, other_by_self) = fsrs
902            .universal_metrics(items.clone(), &DEFAULT_PARAMETERS, |_| true)
903            .unwrap();
904
905        [self_by_other, other_by_self].assert_approx_eq([0.014087644, 0.017199915]);
906
907        Ok(())
908    }
909
910    #[test]
911    fn test_time_series_split() -> Result<()> {
912        let items = anki21_sample_file_converted_to_fsrs();
913        let splits = TimeSeriesSplit::split(items[..6].to_vec(), 5);
914        assert_eq!(splits.len(), 5);
915        assert_eq!(splits[0].train_items.len(), 1);
916        assert_eq!(splits[0].test_items.len(), 1);
917        assert_eq!(splits[1].train_items.len(), 2);
918        assert_eq!(splits[1].test_items.len(), 1);
919        assert_eq!(splits[2].train_items.len(), 3);
920        assert_eq!(splits[2].test_items.len(), 1);
921        assert_eq!(splits[3].train_items.len(), 4);
922        assert_eq!(splits[3].test_items.len(), 1);
923        assert_eq!(splits[4].train_items.len(), 5);
924        assert_eq!(splits[4].test_items.len(), 1);
925
926        let splits = TimeSeriesSplit::split(items[..5].to_vec(), 5);
927        assert!(splits.is_empty());
928
929        let splits = TimeSeriesSplit::split(items[..6].to_vec(), 0);
930        assert!(splits.is_empty());
931
932        Ok(())
933    }
934
935    #[test]
936    fn test_evaluate_with_time_series_splits() -> Result<()> {
937        let items = anki21_sample_file_converted_to_fsrs();
938        let (mut dataset_for_initialization, mut trainset): (Vec<FSRSItem>, Vec<FSRSItem>) = items
939            .into_iter()
940            .partition(|item| item.long_term_review_cnt() == 1);
941        (dataset_for_initialization, trainset) =
942            filter_outlier(dataset_for_initialization, trainset);
943        let items = [dataset_for_initialization, trainset].concat();
944        let input = ComputeParametersInput {
945            train_set: items.clone(),
946            progress: None,
947            enable_short_term: true,
948            num_relearning_steps: None,
949        };
950
951        let fsrs = FSRS::new(None)?;
952        let metrics = fsrs
953            .evaluate_with_time_series_splits(input.clone(), |_| true)
954            .unwrap();
955
956        [metrics.log_loss, metrics.rmse_bins].assert_approx_eq([0.19692886, 0.025453836]);
957
958        let result = fsrs.evaluate_with_time_series_splits(
959            ComputeParametersInput {
960                train_set: items[..5].to_vec(),
961                progress: None,
962                enable_short_term: true,
963                num_relearning_steps: None,
964            },
965            |_| true,
966        );
967        assert!(result.is_err());
968        Ok(())
969    }
970
971    #[test]
972    fn next_states() -> Result<()> {
973        let item = FSRSItem {
974            reviews: vec![
975                FSRSReview {
976                    rating: 1,
977                    delta_t: 0,
978                },
979                FSRSReview {
980                    rating: 3,
981                    delta_t: 1,
982                },
983                FSRSReview {
984                    rating: 3,
985                    delta_t: 3,
986                },
987                FSRSReview {
988                    rating: 3,
989                    delta_t: 8,
990                },
991            ],
992        };
993        let fsrs = FSRS::new(Some(PARAMETERS))?;
994        let state = fsrs.memory_state(item, None).unwrap();
995        assert_eq!(
996            fsrs.next_states(Some(state), 0.9, 21).unwrap(),
997            NextStates {
998                again: ItemState {
999                    memory: MemoryState {
1000                        stability: 2.9691455,
1001                        difficulty: 8.000659
1002                    },
1003                    interval: 2.9691455
1004                },
1005                hard: ItemState {
1006                    memory: MemoryState {
1007                        stability: 17.091452,
1008                        difficulty: 7.6913934
1009                    },
1010                    interval: 17.091452
1011                },
1012                good: ItemState {
1013                    memory: MemoryState {
1014                        stability: 31.722992,
1015                        difficulty: 7.382128
1016                    },
1017                    interval: 31.722992
1018                },
1019                easy: ItemState {
1020                    memory: MemoryState {
1021                        stability: 71.7502,
1022                        difficulty: 7.0728626
1023                    },
1024                    interval: 71.7502
1025                }
1026            }
1027        );
1028        assert_eq!(fsrs.next_interval(Some(121.01552), 0.9, 1), 121.01551);
1029        Ok(())
1030    }
1031
1032    #[test]
1033    #[ignore = "just for exploration"]
1034    fn short_term_stability() -> Result<()> {
1035        let fsrs = FSRS::new(Some(&DEFAULT_PARAMETERS))?;
1036        let mut state = MemoryState {
1037            stability: 1.0,
1038            difficulty: 5.0,
1039        };
1040
1041        let mut stability = Vec::new();
1042        for _ in 0..20 {
1043            state = fsrs.next_states(Some(state), 0.9, 0).unwrap().good.memory;
1044            stability.push(state.stability);
1045        }
1046
1047        dbg!(stability);
1048        Ok(())
1049    }
1050
1051    #[test]
1052    #[ignore = "just for exploration"]
1053    fn good_again_loop_during_the_same_day() -> Result<()> {
1054        let fsrs = FSRS::new(Some(&DEFAULT_PARAMETERS))?;
1055        let mut state = MemoryState {
1056            stability: 1.0,
1057            difficulty: 5.0,
1058        };
1059
1060        let mut stability = Vec::with_capacity(10);
1061        for _ in 0..10 {
1062            state = fsrs.next_states(Some(state), 0.9, 0).unwrap().good.memory;
1063            state = fsrs.next_states(Some(state), 0.9, 0).unwrap().again.memory;
1064            stability.push(state.stability);
1065        }
1066
1067        dbg!(stability);
1068        Ok(())
1069    }
1070
1071    #[test]
1072    #[ignore = "just for exploration"]
1073    fn stability_after_same_day_review_less_than_next_day_review() -> Result<()> {
1074        let fsrs = FSRS::new(Some(&DEFAULT_PARAMETERS))?;
1075        let state = MemoryState {
1076            stability: 10.0,
1077            difficulty: 5.0,
1078        };
1079
1080        let next_state = fsrs.next_states(Some(state), 0.9, 0)?.good.memory;
1081        dbg!(next_state);
1082        // let next_state = fsrs.next_states(Some(next_state), 0.9, 0)?.good.memory;
1083        // dbg!(next_state);
1084        let next_state = fsrs.next_states(Some(state), 0.9, 1)?.good.memory;
1085        dbg!(next_state);
1086        Ok(())
1087    }
1088
1089    #[test]
1090    #[ignore = "just for exploration"]
1091    fn init_stability_after_same_day_review_hard_vs_good_vs_easy() -> Result<()> {
1092        let fsrs = FSRS::new(Some(&DEFAULT_PARAMETERS))?;
1093        let item = FSRSItem {
1094            reviews: vec![
1095                FSRSReview {
1096                    rating: 2,
1097                    delta_t: 0,
1098                },
1099                FSRSReview {
1100                    rating: 3,
1101                    delta_t: 0,
1102                },
1103                FSRSReview {
1104                    rating: 3,
1105                    delta_t: 0,
1106                },
1107            ],
1108        };
1109        let state = fsrs.memory_state(item, None).unwrap();
1110        dbg!(state);
1111        let item = FSRSItem {
1112            reviews: vec![
1113                FSRSReview {
1114                    rating: 3,
1115                    delta_t: 0,
1116                },
1117                FSRSReview {
1118                    rating: 3,
1119                    delta_t: 0,
1120                },
1121            ],
1122        };
1123        let state = fsrs.memory_state(item, None).unwrap();
1124        dbg!(state);
1125        let item = FSRSItem {
1126            reviews: vec![FSRSReview {
1127                rating: 4,
1128                delta_t: 0,
1129            }],
1130        };
1131        let state = fsrs.memory_state(item, None).unwrap();
1132        dbg!(state);
1133        Ok(())
1134    }
1135
1136    #[test]
1137    fn current_retrievability() {
1138        let fsrs = FSRS::new(None).unwrap();
1139        let state = MemoryState {
1140            stability: 1.0,
1141            difficulty: 5.0,
1142        };
1143        assert_eq!(fsrs.current_retrievability(state, 0, 0.2), 1.0);
1144        assert_eq!(fsrs.current_retrievability(state, 1, 0.2), 0.9);
1145        assert_eq!(fsrs.current_retrievability(state, 2, 0.2), 0.84028935);
1146        assert_eq!(fsrs.current_retrievability(state, 3, 0.2), 0.7985001);
1147    }
1148
1149    #[test]
1150    fn current_retrievability_seconds() {
1151        let fsrs = FSRS::new(None).unwrap();
1152        let state = MemoryState {
1153            stability: 1.0,
1154            difficulty: 5.0,
1155        };
1156        assert_eq!(fsrs.current_retrievability_seconds(state, 0, 0.2), 1.0);
1157        assert_eq!(
1158            fsrs.current_retrievability_seconds(state, 1, 0.2),
1159            0.9999984
1160        );
1161        assert_eq!(
1162            fsrs.current_retrievability_seconds(state, 60, 0.2),
1163            0.9999037
1164        );
1165        assert_eq!(
1166            fsrs.current_retrievability_seconds(state, 3600, 0.2),
1167            0.9943189
1168        );
1169        assert_eq!(fsrs.current_retrievability_seconds(state, 86400, 0.2), 0.9);
1170    }
1171
1172    #[test]
1173    fn memory_from_sm2() -> Result<()> {
1174        let fsrs = FSRS::new(Some(&[]))?;
1175        let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap();
1176
1177        [memory_state.stability, memory_state.difficulty].assert_approx_eq([10.0, 6.9140563]);
1178        let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.8).unwrap();
1179
1180        [memory_state.stability, memory_state.difficulty].assert_approx_eq([3.01572, 9.393428]);
1181        let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.95).unwrap();
1182
1183        [memory_state.stability, memory_state.difficulty].assert_approx_eq([24.841097, 1.2974405]);
1184        let memory_state = fsrs.memory_state_from_sm2(1.3, 20.0, 0.9).unwrap();
1185
1186        [memory_state.stability, memory_state.difficulty].assert_approx_eq([20.0, 10.0]);
1187        let interval = 15;
1188        let ease_factor = 2.0;
1189        let fsrs_factor = fsrs
1190            .next_states(
1191                Some(
1192                    fsrs.memory_state_from_sm2(ease_factor, interval as f32, 0.9)
1193                        .unwrap(),
1194                ),
1195                0.9,
1196                interval,
1197            )?
1198            .good
1199            .memory
1200            .stability
1201            / interval as f32;
1202        assert!((fsrs_factor - ease_factor).abs() < 0.01);
1203        Ok(())
1204    }
1205}