Skip to main content

entrenar/autograd/
context.rs

1//! Execution context for managing computational graphs
2
3/// Context for managing the computational graph
4pub struct Context {
5    // For now, context is minimal. In future, it could track:
6    // - All tensors for memory management
7    // - Training vs inference mode
8    // - Random state
9    training: bool,
10}
11
12impl Context {
13    /// Create a new context
14    pub fn new() -> Self {
15        Self { training: true }
16    }
17
18    /// Set training mode
19    pub fn train(&mut self) {
20        self.training = true;
21    }
22
23    /// Set evaluation mode
24    pub fn eval(&mut self) {
25        self.training = false;
26    }
27
28    /// Check if in training mode
29    pub fn is_training(&self) -> bool {
30        self.training
31    }
32}
33
34impl Default for Context {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43
44    #[test]
45    fn test_context_new() {
46        let ctx = Context::new();
47        assert!(ctx.is_training());
48    }
49
50    #[test]
51    fn test_context_default() {
52        let ctx = Context::default();
53        assert!(ctx.is_training());
54    }
55
56    #[test]
57    fn test_context_train_mode() {
58        let mut ctx = Context::new();
59        ctx.eval();
60        assert!(!ctx.is_training());
61
62        ctx.train();
63        assert!(ctx.is_training());
64    }
65
66    #[test]
67    fn test_context_eval_mode() {
68        let mut ctx = Context::new();
69        ctx.eval();
70        assert!(!ctx.is_training());
71    }
72}