entrenar/autograd/
context.rs1pub struct Context {
5 training: bool,
10}
11
12impl Context {
13 pub fn new() -> Self {
15 Self { training: true }
16 }
17
18 pub fn train(&mut self) {
20 self.training = true;
21 }
22
23 pub fn eval(&mut self) {
25 self.training = false;
26 }
27
28 pub fn is_training(&self) -> bool {
30 self.training
31 }
32}
33
34impl Default for Context {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40#[cfg(test)]
41mod tests {
42 use super::*;
43
44 #[test]
45 fn test_context_new() {
46 let ctx = Context::new();
47 assert!(ctx.is_training());
48 }
49
50 #[test]
51 fn test_context_default() {
52 let ctx = Context::default();
53 assert!(ctx.is_training());
54 }
55
56 #[test]
57 fn test_context_train_mode() {
58 let mut ctx = Context::new();
59 ctx.eval();
60 assert!(!ctx.is_training());
61
62 ctx.train();
63 assert!(ctx.is_training());
64 }
65
66 #[test]
67 fn test_context_eval_mode() {
68 let mut ctx = Context::new();
69 ctx.eval();
70 assert!(!ctx.is_training());
71 }
72}