burn_core/module/param/
running.rs

1use super::ParamId;
2use crate::module::{
3    AutodiffModule, Content, Module, ModuleDisplay, ModuleDisplayDefault, ModuleMapper,
4    ModuleVisitor, Param,
5};
6
7use alloc::string::ToString;
8use alloc::vec::Vec;
9
10#[cfg(target_has_atomic = "ptr")]
11use alloc::sync::Arc;
12
13#[cfg(not(target_has_atomic = "ptr"))]
14use portable_atomic_util::Arc;
15
16use burn_common::stub::Mutex;
17use burn_tensor::{
18    Tensor,
19    backend::{AutodiffBackend, Backend},
20    ops::Device,
21};
22
23#[cfg(feature = "std")]
24mod threading {
25    pub(super) use std::collections::HashMap;
26    pub(super) use std::thread::ThreadId;
27
28    #[inline(always)]
29    pub(super) fn get_thread_current_id() -> ThreadId {
30        std::thread::current().id()
31    }
32}
33
34#[cfg(not(feature = "std"))]
35mod threading {
36    pub(super) use burn_common::stub::ThreadId;
37    pub(super) use hashbrown::HashMap;
38
39    #[inline(always)]
40    pub(super) fn get_thread_current_id() -> ThreadId {
41        panic!("Current thread id is not available")
42    }
43}
44
45// Re-export items from the disabled/enabled blocks
46use threading::*;
47
48/// A state that can be updated during the forward pass while being thread safe.
49///
50/// # Note
51///
52/// The state value is the average of all updates on all threads.
53#[derive(Clone, Debug)]
54pub struct RunningState<V> {
55    id: ParamId,
56    values: Arc<Mutex<HashMap<ThreadId, V>>>,
57    value: Arc<Mutex<V>>,
58}
59
60// Implement display for the module
61
62impl<V> core::fmt::Display for RunningState<V> {
63    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
64        write!(f, "RunningState(id={})", self.id)
65    }
66}
67
68impl<V> ModuleDisplayDefault for RunningState<V> {
69    fn content(&self, content: Content) -> Option<Content> {
70        content
71            .add_formatted(&"RunningState".to_string())
72            .optional()
73    }
74}
75
76impl<V> ModuleDisplay for RunningState<V> {}
77
78impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>> {
79    type Record = Param<Tensor<B, D>>;
80
81    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
82        let tensor = self.value.lock().unwrap();
83        let param = Param::initialized(self.id, tensor.clone());
84        visitor.visit_float(&param)
85    }
86
87    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
88        let mut tensor = self.value.lock().unwrap();
89        let param = Param::initialized(self.id, tensor.clone());
90        let param_out = mapper.map_float(param);
91        let (_, tensor_out, _) = param_out.consume();
92
93        *tensor = tensor_out;
94        core::mem::drop(tensor);
95
96        self
97    }
98
99    fn into_record(self) -> Self::Record {
100        self.sync();
101        let tensor = self.value.lock().unwrap();
102
103        Param::initialized(self.id, tensor.clone())
104    }
105
106    fn load_record(mut self, record: Self::Record) -> Self {
107        let mut tensor = self.value.lock().unwrap();
108        *tensor = record.val().to_device(&tensor.device());
109        self.id = record.id;
110
111        core::mem::drop(tensor);
112
113        self
114    }
115
116    fn to_device(self, device: &Device<B>) -> Self {
117        let mut tensor = self.value.lock().unwrap();
118        let tensor_out = tensor.clone().to_device(device);
119
120        *tensor = tensor_out;
121        core::mem::drop(tensor);
122
123        self
124    }
125
126    fn fork(self, device: &Device<B>) -> Self {
127        self.to_device(device) // Same thing here since no grad.
128    }
129
130    fn collect_devices(&self, mut devices: Vec<Device<B>>) -> Vec<Device<B>> {
131        let device = self.value.lock().unwrap().device();
132
133        if !devices.contains(&device) {
134            devices.push(device)
135        }
136
137        devices
138    }
139}
140
141impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
142    /// Create a new running state.
143    pub fn new(value: Tensor<B, D>) -> Self {
144        Self {
145            id: ParamId::new(),
146            values: Arc::new(Mutex::new(HashMap::new())),
147            value: Arc::new(Mutex::new(value)),
148        }
149    }
150
151    /// Create a new running state.
152    pub fn with_id(id: ParamId, value: Tensor<B, D>) -> Self {
153        Self {
154            id,
155            values: Arc::new(Mutex::new(HashMap::new())),
156            value: Arc::new(Mutex::new(value)),
157        }
158    }
159
160    /// Create a new running state from a record.
161    pub fn from_record(record: Param<Tensor<B, D>>) -> Self {
162        let tensor = record.val();
163        Self {
164            id: record.id,
165            values: Arc::new(Mutex::new(HashMap::new())),
166            value: Arc::new(Mutex::new(tensor)),
167        }
168    }
169
170    /// Update the value on the current thread.
171    pub fn update(&self, value: Tensor<B, D>) {
172        let thread_id = get_thread_current_id();
173        let mut map = self.values.lock().unwrap();
174
175        if map.contains_key(&thread_id) {
176            self.update_value(&mut map);
177        }
178
179        map.insert(thread_id, value);
180    }
181
182    /// Get the current value,
183    ///
184    /// # Note
185    ///
186    /// The current value might be outdated by one update.
187    pub fn value(&self) -> Tensor<B, D> {
188        let value = self.value.lock().unwrap();
189        value.clone()
190    }
191
192    /// Get the current value and make sure it is sync.
193    ///
194    /// # Note
195    ///
196    /// Don't use this function after an update on the same thread where other threads might have to
197    /// register their update before the actual synchronization needs to happen.
198    pub fn value_sync(&self) -> Tensor<B, D> {
199        let thread_id = get_thread_current_id();
200        let mut map = self.values.lock().unwrap();
201
202        if map.contains_key(&thread_id) {
203            self.update_value(&mut map);
204        }
205
206        let value = self.value.lock().unwrap();
207        value.clone()
208    }
209
210    fn sync(&self) {
211        let mut map = self.values.lock().unwrap();
212
213        if !map.is_empty() {
214            self.update_value(&mut map);
215        }
216    }
217
218    fn update_value(&self, map: &mut HashMap<ThreadId, Tensor<B, D>>) {
219        let mut value_updated: Option<Tensor<B, D>> = None;
220        let mut counter = 0;
221
222        for (_key, tensor) in map.drain() {
223            counter += 1;
224
225            value_updated = match value_updated {
226                Some(current) => {
227                    let device = current.device();
228                    Some(tensor.to_device(&device).add(current))
229                }
230                None => Some(tensor),
231            };
232        }
233
234        if let Some(value) = value_updated {
235            let value = value.div_scalar(counter);
236            let mut value_old = self.value.lock().unwrap();
237            *value_old = value;
238        }
239    }
240}
241
242impl<const D: usize, B: AutodiffBackend> AutodiffModule<B> for RunningState<Tensor<B, D>> {
243    type InnerModule = RunningState<Tensor<B::InnerBackend, D>>;
244
245    fn valid(&self) -> Self::InnerModule {
246        self.sync();
247        let value = self.value();
248
249        RunningState::with_id(self.id, value.inner())
250    }
251}