1use std::collections::HashMap;
2use std::ops::{Add, Sub};
3
4use crate::model::{FSRS, Get, MemoryStateTensors};
5use crate::simulation::{D_MAX, D_MIN, S_MIN};
6use crate::training::ComputeParametersInput;
7use burn::nn::loss::Reduction;
8use burn::tensor::cast::ToElement;
9use burn::tensor::{Shape, Tensor, TensorData};
10use burn::{data::dataloader::batcher::Batcher, tensor::backend::Backend};
11
12use crate::dataset::{
13 FSRSBatch, FSRSBatcher, constant_weighted_fsrs_items, recency_weighted_fsrs_items,
14};
15use crate::error::Result;
16use crate::model::Model;
17use crate::training::BCELoss;
18use crate::{FSRSError, FSRSItem};
19use burn::tensor::ElementConversion;
20pub type Parameters = [f32];
22use itertools::izip;
23
24pub const FSRS5_DEFAULT_DECAY: f32 = 0.5;
25pub const FSRS6_DEFAULT_DECAY: f32 = 0.1542;
26
27pub static DEFAULT_PARAMETERS: [f32; 21] = [
28 0.212,
29 1.2931,
30 2.3065,
31 8.2956,
32 6.4133,
33 0.8334,
34 3.0194,
35 0.001,
36 1.8722,
37 0.1666,
38 0.796,
39 1.4835,
40 0.0614,
41 0.2629,
42 1.6483,
43 0.6014,
44 1.8729,
45 0.5425,
46 0.0912,
47 0.0658,
48 FSRS6_DEFAULT_DECAY,
49];
50
51fn infer<B: Backend>(
52 model: &Model<B>,
53 batch: FSRSBatch<B>,
54) -> (MemoryStateTensors<B>, Tensor<B, 1>) {
55 let state = model.forward(batch.t_historys, batch.r_historys, None);
56 let retrievability = model.power_forgetting_curve(batch.delta_ts, state.stability.clone());
57 (state, retrievability)
58}
59
60pub fn current_retrievability(state: MemoryState, days_elapsed: f32, decay: f32) -> f32 {
61 let factor = 0.9f32.powf(1.0 / -decay) - 1.0;
62 (days_elapsed / state.stability * factor + 1.0).powf(-decay)
63}
64
65#[derive(Debug, PartialEq, Clone, Copy)]
66pub struct MemoryState {
67 pub stability: f32,
68 pub difficulty: f32,
69}
70
71impl<B: Backend> From<MemoryStateTensors<B>> for MemoryState {
72 fn from(m: MemoryStateTensors<B>) -> Self {
73 Self {
74 stability: m.stability.into_scalar().elem(),
75 difficulty: m.difficulty.into_scalar().elem(),
76 }
77 }
78}
79
80impl<B: Backend> From<MemoryState> for MemoryStateTensors<B> {
81 fn from(m: MemoryState) -> Self {
82 Self {
83 stability: Tensor::from_floats([m.stability], &B::Device::default()),
84 difficulty: Tensor::from_floats([m.difficulty], &B::Device::default()),
85 }
86 }
87}
88
89#[derive(Default)]
90struct RMatrixValue {
91 predicted: f32,
92 actual: f32,
93 count: f32,
94 weight: f32,
95}
96
97impl<B: Backend> FSRS<B> {
98 fn item_to_tensors(&self, item: &FSRSItem) -> (Tensor<B, 2>, Tensor<B, 2>) {
99 let (time_history, rating_history) =
100 item.reviews.iter().map(|r| (r.delta_t, r.rating)).unzip();
101 let size = item.reviews.len();
102 let time_history = Tensor::<B, 1>::from_data(
103 TensorData::new(time_history, Shape { dims: vec![size] }),
104 &self.device(),
105 )
106 .unsqueeze()
107 .transpose();
108 let rating_history = Tensor::<B, 1>::from_data(
109 TensorData::new(rating_history, Shape { dims: vec![size] }),
110 &self.device(),
111 )
112 .unsqueeze()
113 .transpose();
114 (time_history, rating_history)
115 }
116
117 pub fn memory_state(
123 &self,
124 item: FSRSItem,
125 starting_state: Option<MemoryState>,
126 ) -> Result<MemoryState> {
127 let (time_history, rating_history) = self.item_to_tensors(&item);
128 let state: MemoryState = self
129 .model()
130 .forward(time_history, rating_history, starting_state.map(Into::into))
131 .into();
132 if !state.stability.is_finite() || !state.difficulty.is_finite() {
133 Err(FSRSError::InvalidInput)
134 } else {
135 Ok(state)
136 }
137 }
138
139 pub fn historical_memory_states(
140 &self,
141 item: FSRSItem,
142 starting_state: Option<MemoryState>,
143 ) -> Result<Vec<MemoryState>> {
144 let (time_history, rating_history) = self.item_to_tensors(&item);
145 let mut states = vec![];
146 if let Some(starting_state) = starting_state {
147 states.push(starting_state);
148 }
149 let [seq_len, _batch_size] = time_history.dims();
150 let mut inner_state = starting_state.map(Into::into);
151 for i in 0..seq_len {
152 let delta_t = time_history.get(i).squeeze(0);
153 let rating = rating_history.get(i).squeeze(0);
155 inner_state = Some(self.model().step(delta_t, rating, inner_state.clone()));
157 if let Some(state) = inner_state.clone() {
158 let state: MemoryState = state.into();
159 if !state.stability.is_finite() || !state.difficulty.is_finite() {
160 return Err(FSRSError::InvalidInput);
161 }
162 states.push(state);
163 }
164 }
165 Ok(states)
166 }
167
168 pub fn memory_state_from_sm2(
172 &self,
173 ease_factor: f32,
174 interval: f32,
175 sm2_retention: f32,
176 ) -> Result<MemoryState> {
177 let w = &self.model().w;
178 let decay: f32 = w.get(20).neg().into_scalar().elem();
179 let factor = 0.9f32.powf(1.0 / decay) - 1.0;
180 let stability = interval.max(S_MIN) * factor / (sm2_retention.powf(1.0 / decay) - 1.0);
181 let w8: f32 = w.get(8).into_scalar().elem();
182 let w9: f32 = w.get(9).into_scalar().elem();
183 let w10: f32 = w.get(10).into_scalar().elem();
184 let difficulty = 11.0
185 - (ease_factor - 1.0)
186 / (w8.exp() * stability.powf(-w9) * ((1.0 - sm2_retention) * w10).exp_m1());
187 if !stability.is_finite() || !difficulty.is_finite() {
188 Err(FSRSError::InvalidInput)
189 } else {
190 Ok(MemoryState {
191 stability,
192 difficulty: difficulty.clamp(D_MIN, D_MAX),
193 })
194 }
195 }
196
197 pub fn next_interval(
201 &self,
202 stability: Option<f32>,
203 desired_retention: f32,
204 rating: u32,
205 ) -> f32 {
206 let model = self.model();
207 let stability = stability.unwrap_or_else(|| {
208 let rating = Tensor::from_floats([rating], &self.device());
210 model.init_stability(rating).into_scalar().elem()
211 });
212 model
213 .next_interval(
214 Tensor::from_floats([stability], &self.device()),
215 Tensor::from_floats([desired_retention], &self.device()),
216 )
217 .into_scalar()
218 .elem()
219 }
220
221 pub fn next_states(
224 &self,
225 current_memory_state: Option<MemoryState>,
226 desired_retention: f32,
227 days_elapsed: u32,
228 ) -> Result<NextStates> {
229 let delta_t = Tensor::from_data(
230 TensorData::new(vec![days_elapsed], Shape { dims: vec![1] }),
231 &self.device(),
232 );
233 let current_memory_state_tensors = current_memory_state.map(MemoryStateTensors::from);
234 let model = self.model();
235 let mut next_memory_states = (1..=4).map(|rating| {
236 Ok({
237 let state = MemoryState::from(model.step(
238 delta_t.clone(),
239 Tensor::from_data(
240 TensorData::new(vec![rating], Shape { dims: vec![1] }),
241 &self.device(),
242 ),
243 current_memory_state_tensors.clone(),
244 ));
245 if !state.stability.is_finite() || !state.difficulty.is_finite() {
246 return Err(FSRSError::InvalidInput);
247 }
248 state
249 })
250 });
251
252 let mut get_next_state = || {
253 let memory = next_memory_states.next().unwrap()?;
254 let interval = model
255 .next_interval(
256 Tensor::from_floats([memory.stability], &self.device()),
257 Tensor::from_floats([desired_retention], &self.device()),
258 )
259 .into_scalar()
260 .elem();
261 Ok(ItemState { memory, interval })
262 };
263
264 Ok(NextStates {
265 again: get_next_state()?,
266 hard: get_next_state()?,
267 good: get_next_state()?,
268 easy: get_next_state()?,
269 })
270 }
271
272 pub fn evaluate<F>(&self, items: Vec<FSRSItem>, mut progress: F) -> Result<ModelEvaluation>
275 where
276 F: FnMut(ItemProgress) -> bool,
277 {
278 if items.is_empty() {
279 return Err(FSRSError::NotEnoughData);
280 }
281 let weighted_items = recency_weighted_fsrs_items(items);
282 let batcher = FSRSBatcher::new(self.device());
283 let mut all_retrievability = vec![];
284 let mut all_labels = vec![];
285 let mut all_weights = vec![];
286 let mut progress_info = ItemProgress {
287 current: 0,
288 total: weighted_items.len(),
289 };
290 let model = self.model();
291 let mut r_matrix: HashMap<(u32, u32, u32), RMatrixValue> = HashMap::new();
292
293 for chunk in weighted_items.chunks(512) {
294 let batch = batcher.batch(chunk.to_vec(), &self.device());
295 let (_state, retrievability) = infer::<B>(model, batch.clone());
296 let pred = retrievability.clone().to_data().to_vec::<f32>().unwrap();
297 let true_val = batch.labels.clone().to_data().to_vec::<i64>().unwrap();
298 all_retrievability.push(retrievability);
299 all_labels.push(batch.labels);
300 all_weights.push(batch.weights);
301 izip!(chunk, pred, true_val).for_each(|(weighted_item, p, y)| {
302 let bin = weighted_item.item.r_matrix_index();
303 let value = r_matrix.entry(bin).or_default();
304 value.predicted += p;
305 value.actual += y as f32;
306 value.count += 1.0;
307 value.weight += weighted_item.weight;
308 });
309 progress_info.current += chunk.len();
310 if !progress(progress_info) {
311 return Err(FSRSError::Interrupted);
312 }
313 }
314 let rmse = (r_matrix
315 .values()
316 .map(|v| {
317 let pred = v.predicted / v.count;
318 let real = v.actual / v.count;
319 (pred - real).powi(2) * v.weight
320 })
321 .sum::<f32>()
322 / r_matrix.values().map(|v| v.weight).sum::<f32>())
323 .sqrt();
324 let all_retrievability = Tensor::cat(all_retrievability, 0);
325 let all_labels = Tensor::cat(all_labels, 0).float();
326 let all_weights = Tensor::cat(all_weights, 0);
327 let loss =
328 BCELoss::new().forward(all_retrievability, all_labels, all_weights, Reduction::Auto);
329 Ok(ModelEvaluation {
330 log_loss: loss.into_scalar().to_f32(),
331 rmse_bins: rmse,
332 })
333 }
334
335 pub fn evaluate_with_time_series_splits<F>(
350 &self,
351 ComputeParametersInput {
352 train_set,
353 enable_short_term,
354 num_relearning_steps,
355 ..
356 }: ComputeParametersInput,
357 mut progress: F,
358 ) -> Result<ModelEvaluation>
359 where
360 F: FnMut(ItemProgress) -> bool,
361 {
362 if train_set.is_empty() {
363 return Err(FSRSError::NotEnoughData);
364 }
365
366 let splits = TimeSeriesSplit::split(train_set, 5);
367 let mut all_predictions = Vec::new();
368 let mut progress_info = ItemProgress {
369 current: 0,
370 total: splits.len(),
371 };
372
373 for split in splits.into_iter() {
374 let input = ComputeParametersInput {
376 train_set: split.train_items.clone(),
377 enable_short_term,
378 num_relearning_steps,
379 progress: None,
380 };
381 let parameters = self.compute_parameters(input)?;
382
383 let predictions = batch_predict::<B>(split.test_items, ¶meters)?;
385
386 all_predictions.extend(predictions);
388
389 progress_info.current += 1;
390 if !progress(progress_info) {
391 return Err(FSRSError::Interrupted);
392 }
393 }
394
395 evaluate::<B>(all_predictions)
397 }
398
399 pub fn current_retrievability(&self, state: MemoryState, days_elapsed: u32, decay: f32) -> f32 {
402 current_retrievability(state, days_elapsed as f32, decay)
403 }
404
405 pub fn current_retrievability_seconds(
408 &self,
409 state: MemoryState,
410 seconds_elapsed: u32,
411 decay: f32,
412 ) -> f32 {
413 current_retrievability(state, seconds_elapsed as f32 / 86400.0, decay)
414 }
415
416 pub fn universal_metrics<F>(
419 &self,
420 items: Vec<FSRSItem>,
421 parameters: &Parameters,
422 mut progress: F,
423 ) -> Result<(f32, f32)>
424 where
425 F: FnMut(ItemProgress) -> bool,
426 {
427 if items.is_empty() {
428 return Err(FSRSError::NotEnoughData);
429 }
430 let weighted_items = constant_weighted_fsrs_items(items);
431 let batcher = FSRSBatcher::new(self.device());
432 let mut all_predictions_self = vec![];
433 let mut all_predictions_other = vec![];
434 let mut all_true_val = vec![];
435 let mut progress_info = ItemProgress {
436 current: 0,
437 total: weighted_items.len(),
438 };
439 let model_self = self.model();
440 let fsrs_other = Self::new_with_backend(Some(parameters), self.device())?;
441 let model_other = fsrs_other.model();
442 for chunk in weighted_items.chunks(512) {
443 let batch = batcher.batch(chunk.to_vec(), &self.device());
444
445 let (_state, retrievability) = infer::<B>(model_self, batch.clone());
446 let pred = retrievability.clone().to_data().to_vec::<f32>().unwrap();
447 all_predictions_self.extend(pred);
448
449 let (_state, retrievability) = infer::<B>(model_other, batch.clone());
450 let pred = retrievability.clone().to_data().to_vec::<f32>().unwrap();
451 all_predictions_other.extend(pred);
452
453 let true_val: Vec<f32> = batch
454 .labels
455 .clone()
456 .to_data()
457 .convert::<f32>()
458 .to_vec()
459 .unwrap();
460 all_true_val.extend(true_val);
461 progress_info.current += chunk.len();
462 if !progress(progress_info) {
463 return Err(FSRSError::Interrupted);
464 }
465 }
466 let self_by_other =
467 measure_a_by_b(&all_predictions_self, &all_predictions_other, &all_true_val);
468 let other_by_self =
469 measure_a_by_b(&all_predictions_other, &all_predictions_self, &all_true_val);
470 Ok((self_by_other, other_by_self))
471 }
472}
473
474#[derive(Debug, Clone)]
475pub struct PredictedFSRSItem {
476 pub item: FSRSItem,
477 pub retrievability: f32,
478}
479
480fn batch_predict<B: Backend>(
490 items: Vec<FSRSItem>,
491 parameters: &[f32],
492) -> Result<Vec<PredictedFSRSItem>>
493where
494{
495 if items.is_empty() {
496 return Err(FSRSError::NotEnoughData);
497 }
498 let weighted_items = constant_weighted_fsrs_items(items);
499 let batcher = FSRSBatcher::new(B::Device::default());
500
501 let fsrs = FSRS::<B>::new_with_backend(Some(parameters), B::Device::default())?;
502 let model = fsrs.model();
503 let mut predicted_items = Vec::with_capacity(weighted_items.len());
504
505 for chunk in weighted_items.chunks(512) {
506 let batch = batcher.batch(chunk.to_vec(), &B::Device::default());
507 let (_state, retrievability) = infer::<B>(model, batch.clone());
508 let pred = retrievability.to_data().to_vec::<f32>().unwrap();
509
510 for (weighted_item, p) in chunk.iter().zip(pred) {
511 predicted_items.push(PredictedFSRSItem {
512 item: weighted_item.item.clone(),
513 retrievability: p,
514 });
515 }
516 }
517
518 Ok(predicted_items)
519}
520
521fn evaluate<B: Backend>(predicted_items: Vec<PredictedFSRSItem>) -> Result<ModelEvaluation> {
530 if predicted_items.is_empty() {
531 return Err(FSRSError::NotEnoughData);
532 }
533 let mut all_labels = Vec::with_capacity(predicted_items.len());
534 let mut r_matrix: HashMap<(u32, u32, u32), RMatrixValue> = HashMap::new();
535 for predicted_item in predicted_items.iter() {
536 let pred = predicted_item.retrievability;
537 let y = (predicted_item.item.current().rating > 1) as i32;
538 all_labels.push(y);
539 let bin = predicted_item.item.r_matrix_index();
540 let value = r_matrix.entry(bin).or_default();
541 value.predicted += pred;
542 value.actual += y as f32;
543 value.count += 1.0;
544 value.weight += 1.0;
545 }
546
547 let rmse = (r_matrix
548 .values()
549 .map(|v| {
550 let pred = v.predicted / v.count;
551 let real = v.actual / v.count;
552 (pred - real).powi(2) * v.weight
553 })
554 .sum::<f32>()
555 / r_matrix.values().map(|v| v.weight).sum::<f32>())
556 .sqrt();
557
558 let all_labels = Tensor::from_data(
559 TensorData::new(
560 all_labels.clone(),
561 Shape {
562 dims: vec![all_labels.len()],
563 },
564 ),
565 &B::Device::default(),
566 );
567 let all_weights = Tensor::ones(all_labels.shape(), &B::Device::default());
568 let all_retrievability: Tensor<B, 1> = Tensor::from_data(
569 TensorData::new(
570 predicted_items.iter().map(|p| p.retrievability).collect(),
571 Shape {
572 dims: vec![predicted_items.len()],
573 },
574 ),
575 &B::Device::default(),
576 );
577
578 let loss = BCELoss::new().forward(all_retrievability, all_labels, all_weights, Reduction::Auto);
579 Ok(ModelEvaluation {
580 log_loss: loss.into_scalar().to_f32(),
581 rmse_bins: rmse,
582 })
583}
584
585#[derive(Debug, Copy, Clone)]
586pub struct ModelEvaluation {
587 pub log_loss: f32,
588 pub rmse_bins: f32,
589}
590
591#[derive(Debug, Clone, PartialEq)]
592pub struct NextStates {
593 pub again: ItemState,
594 pub hard: ItemState,
595 pub good: ItemState,
596 pub easy: ItemState,
597}
598
599#[derive(Debug, PartialEq, Clone)]
600pub struct ItemState {
601 pub memory: MemoryState,
602 pub interval: f32,
603}
604
605#[derive(Debug, Clone, Copy)]
606pub struct ItemProgress {
607 pub current: usize,
608 pub total: usize,
609}
610
611#[derive(Debug, Clone)]
612pub struct TimeSeriesSplit {
613 pub train_items: Vec<FSRSItem>,
614 pub test_items: Vec<FSRSItem>,
615}
616
617impl TimeSeriesSplit {
618 pub fn split(sorted_items: Vec<FSRSItem>, n_splits: usize) -> Vec<TimeSeriesSplit> {
636 if sorted_items.is_empty() || n_splits == 0 {
637 return vec![];
638 }
639 let total_items = sorted_items.len();
640 let segment_size = total_items / (n_splits + 1);
641 if segment_size == 0 {
642 return vec![];
643 }
644
645 (0..n_splits)
646 .map(|i| {
647 let test_start = (i + 1) * segment_size;
649 let test_end = if i == n_splits - 1 {
651 total_items
652 } else {
653 (i + 2) * segment_size
654 };
655
656 TimeSeriesSplit {
658 train_items: sorted_items[..test_start].to_vec(),
659 test_items: sorted_items[test_start..test_end].to_vec(),
660 }
661 })
662 .collect()
663 }
664}
665
666fn get_bin(x: f32, bins: i32) -> i32 {
667 let log_base = (bins.add(1) as f32).ln();
668 let binned_x = (x * log_base).exp().floor().sub(1.0);
669 (binned_x as i32).clamp(0, bins - 1)
670}
671
672fn measure_a_by_b(pred_a: &[f32], pred_b: &[f32], true_val: &[f32]) -> f32 {
673 let mut groups = HashMap::new();
674 izip!(pred_a, pred_b, true_val).for_each(|(a, b, t)| {
675 let bin = get_bin(*b, 20);
676 groups.entry(bin).or_insert_with(Vec::new).push((a, t));
677 });
678 let mut total_sum = 0.0;
679 let mut total_count = 0.0;
680 for group in groups.values() {
681 let count = group.len() as f32;
682 let pred_mean = group.iter().map(|(p, _)| *p).sum::<f32>() / count;
683 let true_mean = group.iter().map(|(_, t)| *t).sum::<f32>() / count;
684
685 let rmse = (pred_mean - true_mean).powi(2);
686 total_sum += rmse * count;
687 total_count += count;
688 }
689
690 (total_sum / total_count).sqrt()
691}
692
693#[cfg(test)]
694mod tests {
695 use super::*;
696 use crate::{
697 FSRSReview, convertor_tests::anki21_sample_file_converted_to_fsrs, dataset::filter_outlier,
698 test_helpers::TestHelper,
699 };
700
701 static PARAMETERS: &[f32] = &[
702 0.6845422,
703 1.6790825,
704 4.7349424,
705 10.042885,
706 7.4410233,
707 0.64219797,
708 1.071918,
709 0.0025195254,
710 1.432437,
711 0.1544,
712 0.8692766,
713 2.0696752,
714 0.0953,
715 0.2975,
716 2.4691248,
717 0.19542035,
718 3.201072,
719 0.18046261,
720 0.121442534,
721 ];
722
723 #[test]
724 fn test_get_bin() {
725 let pred = (0..=100).map(|i| i as f32 / 100.0).collect::<Vec<_>>();
726 let bin = pred.iter().map(|p| get_bin(*p, 20)).collect::<Vec<_>>();
727 assert_eq!(
728 bin,
729 [
730 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
731 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4,
732 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
733 11, 11, 11, 12, 12, 13, 13, 14, 14, 14, 15, 15, 16, 17, 17, 18, 18, 19, 19
734 ]
735 );
736 }
737
738 #[test]
739 fn test_memo_state() -> Result<()> {
740 let item = FSRSItem {
741 reviews: vec![
742 FSRSReview {
743 rating: 1,
744 delta_t: 0,
745 },
746 FSRSReview {
747 rating: 3,
748 delta_t: 1,
749 },
750 FSRSReview {
751 rating: 3,
752 delta_t: 3,
753 },
754 FSRSReview {
755 rating: 3,
756 delta_t: 8,
757 },
758 FSRSReview {
759 rating: 3,
760 delta_t: 21,
761 },
762 ],
763 };
764 let fsrs = FSRS::new(Some(PARAMETERS))?;
765 assert_eq!(
766 fsrs.memory_state(item, None).unwrap(),
767 MemoryState {
768 stability: 31.722992,
769 difficulty: 7.382128
770 }
771 );
772
773 assert_eq!(
774 fsrs.next_states(
775 Some(MemoryState {
776 stability: 20.925528,
777 difficulty: 7.005062
778 }),
779 0.9,
780 21
781 )
782 .unwrap()
783 .good
784 .memory,
785 MemoryState {
786 stability: 40.87456,
787 difficulty: 6.9913807
788 }
789 );
790 Ok(())
791 }
792
793 fn assert_memory_state(w: &[f32], expected_stability: f32, expected_difficulty: f32) {
794 let desired_retention = 0.9;
795 let fsrs = FSRS::new(Some(w)).unwrap();
796 let ratings: [u32; 6] = [1, 3, 3, 3, 3, 3];
797 let intervals: [u32; 6] = [0, 0, 1, 3, 8, 21];
798
799 let mut memory_state = None;
800 for (&rating, &interval) in ratings.iter().zip(intervals.iter()) {
801 let state = fsrs
802 .next_states(memory_state, desired_retention, interval)
803 .unwrap();
804 memory_state = match rating {
805 1 => Some(state.again.memory),
806 2 => Some(state.hard.memory),
807 3 => Some(state.good.memory),
808 4 => Some(state.easy.memory),
809 _ => None,
810 };
811 }
817
818 let memory_state = memory_state.unwrap();
819 let stability = memory_state.stability;
820 let difficulty = memory_state.difficulty;
821 assert!(
822 (stability - expected_stability).abs() < 1e-4,
823 "stability: {}",
824 stability
825 );
826 assert!(
827 (difficulty - expected_difficulty).abs() < 1e-4,
828 "difficulty: {}",
829 difficulty
830 );
831 }
832 #[test]
833 fn test_memory_state() {
834 let mut w = DEFAULT_PARAMETERS;
835 assert_memory_state(&w, 53.62691, 6.3574867);
836 w[17] = 0.0;
838 w[18] = 0.0;
839 w[19] = 0.0;
840 assert_memory_state(&w, 53.335106, 6.3574867);
841 }
842
843 #[test]
844 fn test_next_interval() {
845 let fsrs = FSRS::new(Some(&DEFAULT_PARAMETERS)).unwrap();
846 let desired_retentions = (1..=10).map(|i| i as f32 / 10.0).collect::<Vec<_>>();
847 let intervals = desired_retentions
848 .iter()
849 .map(|r| fsrs.next_interval(Some(1.0), *r, 1).round().max(1.0) as i32)
850 .collect::<Vec<_>>();
851 assert_eq!(intervals, [3116766, 34793, 2508, 387, 90, 27, 9, 3, 1, 1]);
852 }
853
854 #[test]
855 fn test_evaluate() -> Result<()> {
856 let items = anki21_sample_file_converted_to_fsrs();
857 let (mut dataset_for_initialization, mut trainset): (Vec<FSRSItem>, Vec<FSRSItem>) = items
858 .into_iter()
859 .partition(|item| item.long_term_review_cnt() == 1);
860 (dataset_for_initialization, trainset) =
861 filter_outlier(dataset_for_initialization, trainset);
862 let items = [dataset_for_initialization, trainset].concat();
863
864 let fsrs = FSRS::new(Some(&[
865 0.335561,
866 1.6840581,
867 5.166598,
868 11.659035,
869 7.466705,
870 0.7205129,
871 2.622295,
872 0.001,
873 1.315015,
874 0.10468433,
875 0.8349206,
876 1.822305,
877 0.12473127,
878 0.26111007,
879 2.3030033,
880 0.13117497,
881 3.0265594,
882 0.41468078,
883 0.09714265,
884 0.106824234,
885 0.20447432,
886 ]))?;
887 let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
888
889 [metrics.log_loss, metrics.rmse_bins].assert_approx_eq([0.20580745, 0.026005825]);
890
891 let fsrs = FSRS::new(Some(&[]))?;
892 let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
893
894 [metrics.log_loss, metrics.rmse_bins].assert_approx_eq([0.20967911, 0.030774858]);
895
896 let fsrs = FSRS::new(Some(PARAMETERS))?;
897 let metrics = fsrs.evaluate(items.clone(), |_| true).unwrap();
898
899 [metrics.log_loss, metrics.rmse_bins].assert_approx_eq([0.208_657_4, 0.030_946_612]);
900
901 let (self_by_other, other_by_self) = fsrs
902 .universal_metrics(items.clone(), &DEFAULT_PARAMETERS, |_| true)
903 .unwrap();
904
905 [self_by_other, other_by_self].assert_approx_eq([0.014087644, 0.017199915]);
906
907 Ok(())
908 }
909
910 #[test]
911 fn test_time_series_split() -> Result<()> {
912 let items = anki21_sample_file_converted_to_fsrs();
913 let splits = TimeSeriesSplit::split(items[..6].to_vec(), 5);
914 assert_eq!(splits.len(), 5);
915 assert_eq!(splits[0].train_items.len(), 1);
916 assert_eq!(splits[0].test_items.len(), 1);
917 assert_eq!(splits[1].train_items.len(), 2);
918 assert_eq!(splits[1].test_items.len(), 1);
919 assert_eq!(splits[2].train_items.len(), 3);
920 assert_eq!(splits[2].test_items.len(), 1);
921 assert_eq!(splits[3].train_items.len(), 4);
922 assert_eq!(splits[3].test_items.len(), 1);
923 assert_eq!(splits[4].train_items.len(), 5);
924 assert_eq!(splits[4].test_items.len(), 1);
925
926 let splits = TimeSeriesSplit::split(items[..5].to_vec(), 5);
927 assert!(splits.is_empty());
928
929 let splits = TimeSeriesSplit::split(items[..6].to_vec(), 0);
930 assert!(splits.is_empty());
931
932 Ok(())
933 }
934
935 #[test]
936 fn test_evaluate_with_time_series_splits() -> Result<()> {
937 let items = anki21_sample_file_converted_to_fsrs();
938 let (mut dataset_for_initialization, mut trainset): (Vec<FSRSItem>, Vec<FSRSItem>) = items
939 .into_iter()
940 .partition(|item| item.long_term_review_cnt() == 1);
941 (dataset_for_initialization, trainset) =
942 filter_outlier(dataset_for_initialization, trainset);
943 let items = [dataset_for_initialization, trainset].concat();
944 let input = ComputeParametersInput {
945 train_set: items.clone(),
946 progress: None,
947 enable_short_term: true,
948 num_relearning_steps: None,
949 };
950
951 let fsrs = FSRS::new(None)?;
952 let metrics = fsrs
953 .evaluate_with_time_series_splits(input.clone(), |_| true)
954 .unwrap();
955
956 [metrics.log_loss, metrics.rmse_bins].assert_approx_eq([0.19692886, 0.025453836]);
957
958 let result = fsrs.evaluate_with_time_series_splits(
959 ComputeParametersInput {
960 train_set: items[..5].to_vec(),
961 progress: None,
962 enable_short_term: true,
963 num_relearning_steps: None,
964 },
965 |_| true,
966 );
967 assert!(result.is_err());
968 Ok(())
969 }
970
971 #[test]
972 fn next_states() -> Result<()> {
973 let item = FSRSItem {
974 reviews: vec![
975 FSRSReview {
976 rating: 1,
977 delta_t: 0,
978 },
979 FSRSReview {
980 rating: 3,
981 delta_t: 1,
982 },
983 FSRSReview {
984 rating: 3,
985 delta_t: 3,
986 },
987 FSRSReview {
988 rating: 3,
989 delta_t: 8,
990 },
991 ],
992 };
993 let fsrs = FSRS::new(Some(PARAMETERS))?;
994 let state = fsrs.memory_state(item, None).unwrap();
995 assert_eq!(
996 fsrs.next_states(Some(state), 0.9, 21).unwrap(),
997 NextStates {
998 again: ItemState {
999 memory: MemoryState {
1000 stability: 2.9691455,
1001 difficulty: 8.000659
1002 },
1003 interval: 2.9691455
1004 },
1005 hard: ItemState {
1006 memory: MemoryState {
1007 stability: 17.091452,
1008 difficulty: 7.6913934
1009 },
1010 interval: 17.091452
1011 },
1012 good: ItemState {
1013 memory: MemoryState {
1014 stability: 31.722992,
1015 difficulty: 7.382128
1016 },
1017 interval: 31.722992
1018 },
1019 easy: ItemState {
1020 memory: MemoryState {
1021 stability: 71.7502,
1022 difficulty: 7.0728626
1023 },
1024 interval: 71.7502
1025 }
1026 }
1027 );
1028 assert_eq!(fsrs.next_interval(Some(121.01552), 0.9, 1), 121.01551);
1029 Ok(())
1030 }
1031
1032 #[test]
1033 #[ignore = "just for exploration"]
1034 fn short_term_stability() -> Result<()> {
1035 let fsrs = FSRS::new(Some(&DEFAULT_PARAMETERS))?;
1036 let mut state = MemoryState {
1037 stability: 1.0,
1038 difficulty: 5.0,
1039 };
1040
1041 let mut stability = Vec::new();
1042 for _ in 0..20 {
1043 state = fsrs.next_states(Some(state), 0.9, 0).unwrap().good.memory;
1044 stability.push(state.stability);
1045 }
1046
1047 dbg!(stability);
1048 Ok(())
1049 }
1050
1051 #[test]
1052 #[ignore = "just for exploration"]
1053 fn good_again_loop_during_the_same_day() -> Result<()> {
1054 let fsrs = FSRS::new(Some(&DEFAULT_PARAMETERS))?;
1055 let mut state = MemoryState {
1056 stability: 1.0,
1057 difficulty: 5.0,
1058 };
1059
1060 let mut stability = Vec::with_capacity(10);
1061 for _ in 0..10 {
1062 state = fsrs.next_states(Some(state), 0.9, 0).unwrap().good.memory;
1063 state = fsrs.next_states(Some(state), 0.9, 0).unwrap().again.memory;
1064 stability.push(state.stability);
1065 }
1066
1067 dbg!(stability);
1068 Ok(())
1069 }
1070
1071 #[test]
1072 #[ignore = "just for exploration"]
1073 fn stability_after_same_day_review_less_than_next_day_review() -> Result<()> {
1074 let fsrs = FSRS::new(Some(&DEFAULT_PARAMETERS))?;
1075 let state = MemoryState {
1076 stability: 10.0,
1077 difficulty: 5.0,
1078 };
1079
1080 let next_state = fsrs.next_states(Some(state), 0.9, 0)?.good.memory;
1081 dbg!(next_state);
1082 let next_state = fsrs.next_states(Some(state), 0.9, 1)?.good.memory;
1085 dbg!(next_state);
1086 Ok(())
1087 }
1088
1089 #[test]
1090 #[ignore = "just for exploration"]
1091 fn init_stability_after_same_day_review_hard_vs_good_vs_easy() -> Result<()> {
1092 let fsrs = FSRS::new(Some(&DEFAULT_PARAMETERS))?;
1093 let item = FSRSItem {
1094 reviews: vec![
1095 FSRSReview {
1096 rating: 2,
1097 delta_t: 0,
1098 },
1099 FSRSReview {
1100 rating: 3,
1101 delta_t: 0,
1102 },
1103 FSRSReview {
1104 rating: 3,
1105 delta_t: 0,
1106 },
1107 ],
1108 };
1109 let state = fsrs.memory_state(item, None).unwrap();
1110 dbg!(state);
1111 let item = FSRSItem {
1112 reviews: vec![
1113 FSRSReview {
1114 rating: 3,
1115 delta_t: 0,
1116 },
1117 FSRSReview {
1118 rating: 3,
1119 delta_t: 0,
1120 },
1121 ],
1122 };
1123 let state = fsrs.memory_state(item, None).unwrap();
1124 dbg!(state);
1125 let item = FSRSItem {
1126 reviews: vec![FSRSReview {
1127 rating: 4,
1128 delta_t: 0,
1129 }],
1130 };
1131 let state = fsrs.memory_state(item, None).unwrap();
1132 dbg!(state);
1133 Ok(())
1134 }
1135
1136 #[test]
1137 fn current_retrievability() {
1138 let fsrs = FSRS::new(None).unwrap();
1139 let state = MemoryState {
1140 stability: 1.0,
1141 difficulty: 5.0,
1142 };
1143 assert_eq!(fsrs.current_retrievability(state, 0, 0.2), 1.0);
1144 assert_eq!(fsrs.current_retrievability(state, 1, 0.2), 0.9);
1145 assert_eq!(fsrs.current_retrievability(state, 2, 0.2), 0.84028935);
1146 assert_eq!(fsrs.current_retrievability(state, 3, 0.2), 0.7985001);
1147 }
1148
1149 #[test]
1150 fn current_retrievability_seconds() {
1151 let fsrs = FSRS::new(None).unwrap();
1152 let state = MemoryState {
1153 stability: 1.0,
1154 difficulty: 5.0,
1155 };
1156 assert_eq!(fsrs.current_retrievability_seconds(state, 0, 0.2), 1.0);
1157 assert_eq!(
1158 fsrs.current_retrievability_seconds(state, 1, 0.2),
1159 0.9999984
1160 );
1161 assert_eq!(
1162 fsrs.current_retrievability_seconds(state, 60, 0.2),
1163 0.9999037
1164 );
1165 assert_eq!(
1166 fsrs.current_retrievability_seconds(state, 3600, 0.2),
1167 0.9943189
1168 );
1169 assert_eq!(fsrs.current_retrievability_seconds(state, 86400, 0.2), 0.9);
1170 }
1171
1172 #[test]
1173 fn memory_from_sm2() -> Result<()> {
1174 let fsrs = FSRS::new(Some(&[]))?;
1175 let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.9).unwrap();
1176
1177 [memory_state.stability, memory_state.difficulty].assert_approx_eq([10.0, 6.9140563]);
1178 let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.8).unwrap();
1179
1180 [memory_state.stability, memory_state.difficulty].assert_approx_eq([3.01572, 9.393428]);
1181 let memory_state = fsrs.memory_state_from_sm2(2.5, 10.0, 0.95).unwrap();
1182
1183 [memory_state.stability, memory_state.difficulty].assert_approx_eq([24.841097, 1.2974405]);
1184 let memory_state = fsrs.memory_state_from_sm2(1.3, 20.0, 0.9).unwrap();
1185
1186 [memory_state.stability, memory_state.difficulty].assert_approx_eq([20.0, 10.0]);
1187 let interval = 15;
1188 let ease_factor = 2.0;
1189 let fsrs_factor = fsrs
1190 .next_states(
1191 Some(
1192 fsrs.memory_state_from_sm2(ease_factor, interval as f32, 0.9)
1193 .unwrap(),
1194 ),
1195 0.9,
1196 interval,
1197 )?
1198 .good
1199 .memory
1200 .stability
1201 / interval as f32;
1202 assert!((fsrs_factor - ease_factor).abs() < 0.01);
1203 Ok(())
1204 }
1205}