airs_types/variable/
names.rs

1use crate::Tensor;
2use std::{
3    fmt::{Debug, Display, Formatter, Write},
4    ops::Div,
5};
6
7use super::*;
8
9impl<'s> Debug for VariableName<'s> {
10    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
11        f.debug_tuple("VariableName").field(&self.to_string()).field(&self.store.device).finish()
12    }
13}
14
15impl<'s> Display for VariableName<'s> {
16    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
17        if self.path.is_empty() {
18            return Ok(());
19        }
20        f.write_str(&self.pure_name())?;
21        if self.index != 0 {
22            f.write_char('_')?;
23            f.write_str(&self.index.to_string())?
24        }
25        Ok(())
26    }
27}
28
29impl<'s, T> Add<T> for &mut VariableName<'s>
30where
31    T: ToString,
32{
33    type Output = VariableName<'s>;
34
35    fn add(self, rhs: T) -> Self::Output {
36        self.store.add_name(self, rhs)
37    }
38}
39
40impl<'s, T> Add<T> for &VariableName<'s>
41where
42    T: ToString,
43{
44    type Output = VariableName<'s>;
45
46    fn add(self, rhs: T) -> Self::Output {
47        self.store.add_name(self, rhs)
48    }
49}
50
51impl<'s, T> Add<T> for VariableName<'s>
52where
53    T: ToString,
54{
55    type Output = VariableName<'s>;
56
57    fn add(self, rhs: T) -> Self::Output {
58        self.store.add_name(&self, rhs)
59    }
60}
61
62impl<'a, T> Div<T> for &mut VariableName<'a>
63where
64    T: ToString,
65{
66    type Output = VariableName<'a>;
67
68    fn div(self, rhs: T) -> Self::Output {
69        let path = self.store.div_name(&self.path, rhs);
70        VariableName { path, index: 0, store: &self.store }
71    }
72}
73
74impl<'a, T> Div<T> for &VariableName<'a>
75where
76    T: ToString,
77{
78    type Output = VariableName<'a>;
79
80    fn div(self, rhs: T) -> Self::Output {
81        let path = self.store.div_name(&self.path, rhs);
82        VariableName { path, index: 0, store: &self.store }
83    }
84}
85
86impl<'a, T> Div<T> for VariableName<'a>
87where
88    T: ToString,
89{
90    type Output = VariableName<'a>;
91
92    fn div(self, rhs: T) -> Self::Output {
93        let path = self.store.div_name(&self.path, rhs);
94        VariableName { path, index: 0, store: &self.store }
95    }
96}
97
98impl<'s> VariableName<'s> {
99    fn insert_deduplication(&mut self) {
100        self.index += 1;
101        let name = self.to_string();
102        if self.store.variables.contains_key(&name) {
103            if self.index == u8::MAX {
104                panic!("Too many variables with the same name `{}`", self.pure_name());
105            }
106            self.insert_deduplication()
107        }
108        else {
109            self.store.variables.insert(name, Tensor::default());
110        }
111    }
112    #[inline(always)]
113    fn pure_name(&self) -> String {
114        self.path.join(VARIABLE_NAME_SEPARATOR)
115    }
116}
117
118impl VariableStore {
119    pub fn root(&self) -> VariableName {
120        VariableName { path: vec![], index: 0, store: self }
121    }
122    /// Create a new variable name if not exist, or reuse existence variable.
123    pub fn div_name<T: ToString>(&self, names: &[String], new: T) -> Vec<String> {
124        let (name, names) = new_name(names.to_vec(), new.to_string());
125        if self.variables.contains_key(&name) {
126            return names;
127        }
128        self.variables.insert(name, Tensor::default());
129        names
130    }
131    /// Create a new variable in any case.
132    pub fn add_name<T: ToString>(&self, old: &VariableName, new: T) -> VariableName {
133        assert!(std::ptr::eq(old.store, self), "VariableName must be created by this VariableStore");
134        let mut path = old.path.to_vec();
135        path.push(new.to_string());
136        let mut insert = VariableName { path, index: 0, store: &self };
137        if self.variables.contains_key(&insert.to_string()) {
138            insert.insert_deduplication();
139        }
140        self.variables.insert(insert.to_string(), Tensor::default());
141        insert
142    }
143}
144
145#[inline(always)]
146fn new_name(mut names: Vec<String>, mut name: String) -> (String, Vec<String>) {
147    names.push(name);
148    name = names.join(VARIABLE_NAME_SEPARATOR);
149    (name, names)
150}