airs_types/variable/
names.rs1use crate::Tensor;
2use std::{
3 fmt::{Debug, Display, Formatter, Write},
4 ops::Div,
5};
6
7use super::*;
8
9impl<'s> Debug for VariableName<'s> {
10 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
11 f.debug_tuple("VariableName").field(&self.to_string()).field(&self.store.device).finish()
12 }
13}
14
15impl<'s> Display for VariableName<'s> {
16 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
17 if self.path.is_empty() {
18 return Ok(());
19 }
20 f.write_str(&self.pure_name())?;
21 if self.index != 0 {
22 f.write_char('_')?;
23 f.write_str(&self.index.to_string())?
24 }
25 Ok(())
26 }
27}
28
29impl<'s, T> Add<T> for &mut VariableName<'s>
30where
31 T: ToString,
32{
33 type Output = VariableName<'s>;
34
35 fn add(self, rhs: T) -> Self::Output {
36 self.store.add_name(self, rhs)
37 }
38}
39
40impl<'s, T> Add<T> for &VariableName<'s>
41where
42 T: ToString,
43{
44 type Output = VariableName<'s>;
45
46 fn add(self, rhs: T) -> Self::Output {
47 self.store.add_name(self, rhs)
48 }
49}
50
51impl<'s, T> Add<T> for VariableName<'s>
52where
53 T: ToString,
54{
55 type Output = VariableName<'s>;
56
57 fn add(self, rhs: T) -> Self::Output {
58 self.store.add_name(&self, rhs)
59 }
60}
61
62impl<'a, T> Div<T> for &mut VariableName<'a>
63where
64 T: ToString,
65{
66 type Output = VariableName<'a>;
67
68 fn div(self, rhs: T) -> Self::Output {
69 let path = self.store.div_name(&self.path, rhs);
70 VariableName { path, index: 0, store: &self.store }
71 }
72}
73
74impl<'a, T> Div<T> for &VariableName<'a>
75where
76 T: ToString,
77{
78 type Output = VariableName<'a>;
79
80 fn div(self, rhs: T) -> Self::Output {
81 let path = self.store.div_name(&self.path, rhs);
82 VariableName { path, index: 0, store: &self.store }
83 }
84}
85
86impl<'a, T> Div<T> for VariableName<'a>
87where
88 T: ToString,
89{
90 type Output = VariableName<'a>;
91
92 fn div(self, rhs: T) -> Self::Output {
93 let path = self.store.div_name(&self.path, rhs);
94 VariableName { path, index: 0, store: &self.store }
95 }
96}
97
98impl<'s> VariableName<'s> {
99 fn insert_deduplication(&mut self) {
100 self.index += 1;
101 let name = self.to_string();
102 if self.store.variables.contains_key(&name) {
103 if self.index == u8::MAX {
104 panic!("Too many variables with the same name `{}`", self.pure_name());
105 }
106 self.insert_deduplication()
107 }
108 else {
109 self.store.variables.insert(name, Tensor::default());
110 }
111 }
112 #[inline(always)]
113 fn pure_name(&self) -> String {
114 self.path.join(VARIABLE_NAME_SEPARATOR)
115 }
116}
117
118impl VariableStore {
119 pub fn root(&self) -> VariableName {
120 VariableName { path: vec![], index: 0, store: self }
121 }
122 pub fn div_name<T: ToString>(&self, names: &[String], new: T) -> Vec<String> {
124 let (name, names) = new_name(names.to_vec(), new.to_string());
125 if self.variables.contains_key(&name) {
126 return names;
127 }
128 self.variables.insert(name, Tensor::default());
129 names
130 }
131 pub fn add_name<T: ToString>(&self, old: &VariableName, new: T) -> VariableName {
133 assert!(std::ptr::eq(old.store, self), "VariableName must be created by this VariableStore");
134 let mut path = old.path.to_vec();
135 path.push(new.to_string());
136 let mut insert = VariableName { path, index: 0, store: &self };
137 if self.variables.contains_key(&insert.to_string()) {
138 insert.insert_deduplication();
139 }
140 self.variables.insert(insert.to_string(), Tensor::default());
141 insert
142 }
143}
144
145#[inline(always)]
146fn new_name(mut names: Vec<String>, mut name: String) -> (String, Vec<String>) {
147 names.push(name);
148 name = names.join(VARIABLE_NAME_SEPARATOR);
149 (name, names)
150}