airs_types/variable/
mod.rs

1use std::ops::Add;
2
3use dashmap::DashMap;
4
5use crate::{Device, Tensor};
6
7mod names;
8
9pub const VARIABLE_NAME_SEPARATOR: &'static str = ".";
10
11// When the variable store is frozen, the trainable_variables vector
12// still contains the same tensors however these tensors are set not
13// to require gradients.
14#[derive(Debug)]
15pub struct Variables {
16    // pub named_variables: HashMap<String, Tensor>,
17    // pub trainable_variables: Vec<Var>,
18}
19
20/// A VarStore is used to store variables used by one or multiple layers.
21/// It specifies a single device where all variables are stored.
22#[derive(Debug)]
23pub struct VariableStore {
24    variables: DashMap<String, Tensor>,
25    device: Device,
26}
27
28/// A variable store with an associated path for variables naming.
29#[derive(Clone)]
30pub struct VariableName<'s> {
31    path: Vec<String>,
32    index: u8,
33    store: &'s VariableStore,
34}
35
36impl VariableStore {
37    pub fn new(device: Device) -> Self {
38        Self { variables: DashMap::new(), device }
39    }
40}