burn_core/module/param/
running.rs1use super::ParamId;
2use crate::module::{
3 AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
4 ModuleVisitor, Param,
5};
6
7use alloc::string::ToString;
8use alloc::vec::Vec;
9
10#[cfg(target_has_atomic = "ptr")]
11use alloc::sync::Arc;
12
13#[cfg(not(target_has_atomic = "ptr"))]
14use portable_atomic_util::Arc;
15
16use burn_common::stub::Mutex;
17use burn_tensor::{
18 Tensor,
19 backend::{AutodiffBackend, Backend},
20 ops::Device,
21};
22
23#[cfg(feature = "std")]
24mod threading {
25 pub(super) use std::collections::HashMap;
26 pub(super) use std::thread::ThreadId;
27
28 #[inline(always)]
29 pub(super) fn get_thread_current_id() -> ThreadId {
30 std::thread::current().id()
31 }
32}
33
34#[cfg(not(feature = "std"))]
35mod threading {
36 pub(super) use burn_common::stub::ThreadId;
37 pub(super) use hashbrown::HashMap;
38
39 #[inline(always)]
40 pub(super) fn get_thread_current_id() -> ThreadId {
41 panic!("Current thread id is not available")
42 }
43}
44
45use threading::*;
47
48#[derive(Clone, Debug)]
54pub struct RunningState<V> {
55 id: ParamId,
56 values: Arc<Mutex<HashMap<ThreadId, V>>>,
57 value: Arc<Mutex<V>>,
58}
59
60impl<V> core::fmt::Display for RunningState<V> {
63 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
64 write!(f, "RunningState(id={})", self.id)
65 }
66}
67
68impl<V> ModuleDisplayDefault for RunningState<V> {
69 fn content(&self, content: Content) -> Option<Content> {
70 content
71 .add_formatted(&"RunningState".to_string())
72 .optional()
73 }
74}
75
76impl<V> ModuleDisplay for RunningState<V> {}
77
78impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
79 type Record = Param<Tensor<B, D>>;
80
81 fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
82 let tensor = self.value.lock().unwrap();
83 let param = Param::initialized(self.id, tensor.clone());
84 visitor.visit_float(¶m)
85 }
86
87 fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
88 let mut tensor = self.value.lock().unwrap();
89 let param = Param::initialized(self.id, tensor.clone());
90 let param_out = mapper.map_float(param);
91 let (_, tensor_out, _) = param_out.consume();
92
93 *tensor = tensor_out;
94 core::mem::drop(tensor);
95
96 self
97 }
98
99 fn into_record(self) -> Self::Record {
100 self.sync();
101 let tensor = self.value.lock().unwrap();
102
103 Param::initialized(self.id, tensor.clone())
104 }
105
106 fn load_record(mut self, record: Self::Record) -> Self {
107 let mut tensor = self.value.lock().unwrap();
108 *tensor = record.val().to_device(&tensor.device());
109 self.id = record.id;
110
111 core::mem::drop(tensor);
112
113 self
114 }
115
116 fn to_device(self, device: &Device<B>) -> Self {
117 let mut tensor = self.value.lock().unwrap();
118 let tensor_out = tensor.clone().to_device(device);
119
120 *tensor = tensor_out;
121 core::mem::drop(tensor);
122
123 self
124 }
125
126 fn fork(self, device: &Device<B>) -> Self {
127 self.to_device(device) }
129
130 fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
131 let device = self.value.lock().unwrap().device();
132
133 if !devices.contains(&device) {
134 devices.push(device)
135 }
136
137 devices
138 }
139}
140
141impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
142 pub fn new(value: Tensor<B, D>) -> Self {
144 Self {
145 id: ParamId::new(),
146 values: Arc::new(Mutex::new(HashMap::new())),
147 value: Arc::new(Mutex::new(value)),
148 }
149 }
150
151 pub fn with_id(id: ParamId, value: Tensor<B, D>) -> Self {
153 Self {
154 id,
155 values: Arc::new(Mutex::new(HashMap::new())),
156 value: Arc::new(Mutex::new(value)),
157 }
158 }
159
160 pub fn from_record(record: Param<Tensor<B, D>>) -> Self {
162 let tensor = record.val();
163 Self {
164 id: record.id,
165 values: Arc::new(Mutex::new(HashMap::new())),
166 value: Arc::new(Mutex::new(tensor)),
167 }
168 }
169
170 pub fn update(&self, value: Tensor<B, D>) {
172 let thread_id = get_thread_current_id();
173 let mut map = self.values.lock().unwrap();
174
175 if map.contains_key(&thread_id) {
176 self.update_value(&mut map);
177 }
178
179 map.insert(thread_id, value);
180 }
181
182 pub fn value(&self) -> Tensor<B, D> {
188 let value = self.value.lock().unwrap();
189 value.clone()
190 }
191
192 pub fn value_sync(&self) -> Tensor<B, D> {
199 let thread_id = get_thread_current_id();
200 let mut map = self.values.lock().unwrap();
201
202 if map.contains_key(&thread_id) {
203 self.update_value(&mut map);
204 }
205
206 let value = self.value.lock().unwrap();
207 value.clone()
208 }
209
210 fn sync(&self) {
211 let mut map = self.values.lock().unwrap();
212
213 if !map.is_empty() {
214 self.update_value(&mut map);
215 }
216 }
217
218 fn update_value(&self, map: &mut HashMap<ThreadId, Tensor<B, D>>) {
219 let mut value_updated: Option<Tensor<B, D>> = None;
220 let mut counter = 0;
221
222 for (_key, tensor) in map.drain() {
223 counter += 1;
224
225 value_updated = match value_updated {
226 Some(current) => {
227 let device = current.device();
228 Some(tensor.to_device(&device).add(current))
229 }
230 None => Some(tensor),
231 };
232 }
233
234 if let Some(value) = value_updated {
235 let value = value.div_scalar(counter);
236 let mut value_old = self.value.lock().unwrap();
237 *value_old = value;
238 }
239 }
240}
241
242impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for RunningState<Tensor<B, D>> {
243 type InnerModule = RunningState<Tensor<B::InnerBackend, D>>;
244
245 fn valid(&self) -> Self::InnerModule {
246 self.sync();
247 let value = self.value();
248
249 RunningState::with_id(self.id, value.inner())
250 }
251}