burn_rmexp_tensorlib/
tensor_library.rs

1use burn::prelude::Backend;
2use burn_rmexp_dyntensor::DynTensor;
3use futures::future::join_all;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fmt::Debug;
7use std::pin::Pin;
8
9/// Query for a tensor in a [`TensorLibrary`]
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub enum TensorLibraryQuery {
12    /// Query by UUID.
13    Uuid(uuid::Uuid),
14
15    /// Query by route.
16    Route(Vec<String>),
17
18    /// Query by string path.
19    Path(String),
20}
21
22impl From<uuid::Uuid> for TensorLibraryQuery {
23    fn from(uuid: uuid::Uuid) -> Self {
24        Self::Uuid(uuid)
25    }
26}
27
28/// Error returned by [`TensorLibrary`]
29#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30pub enum TensorLibraryError {
31    InvalidQuery(TensorLibraryQuery),
32}
33
34pub trait TensorLibrary<B: Backend>: 'static + Debug {
35    /// Query a tensor from the library.
36    fn query<'a>(
37        &'a mut self,
38        query: TensorLibraryQuery,
39    ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>;
40}
41
42#[derive(Debug)]
43pub struct TensorLibraryCollection<B: Backend> {
44    libs: Vec<Box<dyn TensorLibrary<B>>>,
45}
46
47impl<B: Backend> Default for TensorLibraryCollection<B> {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl<B: Backend> TensorLibraryCollection<B> {
54    /// Create a new empty library collection.
55    pub fn new() -> Self {
56        Self { libs: Vec::new() }
57    }
58
59    /// Add a library to the collection.
60    pub fn push(
61        &mut self,
62        lib: Box<dyn TensorLibrary<B>>,
63    ) {
64        self.libs.push(lib);
65    }
66
67    /// Get a reference to the underlying libraries.
68    pub fn libs(&self) -> &[Box<dyn TensorLibrary<B>>] {
69        &self.libs
70    }
71
72    /// Get a mutable reference to the underlying libraries.
73    pub fn libs_mut(&mut self) -> &mut [Box<dyn TensorLibrary<B>>] {
74        &mut self.libs
75    }
76}
77
78impl<B: Backend> TensorLibrary<B> for TensorLibraryCollection<B> {
79    fn query<'a>(
80        &'a mut self,
81        query: TensorLibraryQuery,
82    ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
83    {
84        // Query each library in parallel.
85        let fs = self
86            .libs
87            .iter_mut()
88            .map(|lib| lib.query(query.clone()))
89            .collect::<Vec<_>>();
90
91        // Forward the first error; or the first non-None result.
92        Box::pin(async move {
93            let res = join_all(fs).await;
94            res.into_iter().try_fold(None, |acc, result| match result {
95                Err(e) => Err(e),
96                Ok(Some(val)) if acc.is_none() => Ok(Some(val)),
97                Ok(_) => Ok(acc),
98            })
99        })
100    }
101}
102
103/// A [`TensorLibrary`] backed by a [`uuid::Uuid`] keyed [`HashMap`].
104#[derive(Debug, Clone)]
105pub struct UuidMapTensorLibrary<B: Backend> {
106    hash_map: HashMap<uuid::Uuid, DynTensor<B>>,
107}
108
109impl<B: Backend> From<HashMap<uuid::Uuid, DynTensor<B>>> for UuidMapTensorLibrary<B> {
110    fn from(hash_map: HashMap<uuid::Uuid, DynTensor<B>>) -> Self {
111        Self { hash_map }
112    }
113}
114
115impl<B: Backend> Default for UuidMapTensorLibrary<B> {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121impl<B: Backend> UuidMapTensorLibrary<B> {
122    /// Create an empty library.
123    pub fn new() -> Self {
124        Self {
125            hash_map: HashMap::new(),
126        }
127    }
128
129    /// Get a reference to the internal map.
130    pub fn hash_map(&self) -> &HashMap<uuid::Uuid, DynTensor<B>> {
131        &self.hash_map
132    }
133
134    /// Get a mutable reference to the internal map.
135    pub fn hash_map_mut(&mut self) -> &mut HashMap<uuid::Uuid, DynTensor<B>> {
136        &mut self.hash_map
137    }
138
139    /// Insert a tensor into the library.
140    /// If a tensor with the same UUID already exists, it will be replaced.
141    ///
142    /// # Returns
143    ///
144    /// The previous value, if any.
145    pub fn insert<T: Into<DynTensor<B>>>(
146        &mut self,
147        key: uuid::Uuid,
148        value: T,
149    ) -> Option<DynTensor<B>> {
150        self.hash_map.insert(key, value.into())
151    }
152
153    /// Bind a tensor into the library.
154    /// Returns the generated UUID.
155    pub fn bind<T: Into<DynTensor<B>>>(
156        &mut self,
157        value: T,
158    ) -> uuid::Uuid {
159        let key = uuid::Uuid::new_v4();
160        self.insert(key, value);
161        key
162    }
163
164    /// Remove a tensor from the library.
165    /// Returns `None` if the tensor was not found.
166    pub fn remove(
167        &mut self,
168        key: &uuid::Uuid,
169    ) -> Option<DynTensor<B>> {
170        self.hash_map.remove(key)
171    }
172
173    /// Clear the library.
174    pub fn clear(&mut self) {
175        self.hash_map.clear();
176    }
177
178    /// Returns the number of tensors in the library.
179    pub fn len(&self) -> usize {
180        self.hash_map.len()
181    }
182
183    /// Returns `true` if the library contains no tensors.
184    pub fn is_empty(&self) -> bool {
185        self.len() == 0
186    }
187
188    /// Returns the size estimate of the library in bytes.
189    pub fn size_estimate(&self) -> usize {
190        self.hash_map
191            .values()
192            .map(|tensor| tensor.size_estimate())
193            .sum()
194    }
195
196    /// Get a tensor ref from the library.
197    pub fn get(
198        &self,
199        key: &uuid::Uuid,
200    ) -> Option<&DynTensor<B>> {
201        self.hash_map.get(key)
202    }
203}
204
205impl<B: Backend> TensorLibrary<B> for UuidMapTensorLibrary<B> {
206    /// Query a tensor from the library.
207    fn query<'a>(
208        &'a mut self,
209        query: TensorLibraryQuery,
210    ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
211    {
212        Box::pin(async move {
213            match query {
214                TensorLibraryQuery::Uuid(uuid) => Ok(self.get(&uuid).cloned()),
215                _ => Ok(None),
216            }
217        })
218    }
219}
220
221pub trait LazyBuilder<B: Backend>: Debug + Sync + Send + 'static {
222    fn build<'a>(
223        &'a self,
224        query: TensorLibraryQuery,
225    ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>;
226}
227
228#[derive(Debug, Default)]
229pub struct LazyBuilderLibrary<B: Backend> {
230    builders: HashMap<uuid::Uuid, Box<dyn LazyBuilder<B>>>,
231    cached: UuidMapTensorLibrary<B>,
232}
233
234impl<B: Backend> LazyBuilderLibrary<B> {
235    pub fn new() -> Self {
236        Self::default()
237    }
238
239    pub fn cached(&self) -> &UuidMapTensorLibrary<B> {
240        &self.cached
241    }
242
243    pub fn cached_mut(&mut self) -> &mut UuidMapTensorLibrary<B> {
244        &mut self.cached
245    }
246
247    pub fn register_builder<T: LazyBuilder<B> + 'static>(
248        &mut self,
249        uuid: uuid::Uuid,
250        builder: T,
251    ) {
252        self.builders.insert(uuid, Box::new(builder));
253    }
254}
255
256impl<B: Backend> TensorLibrary<B> for LazyBuilderLibrary<B> {
257    fn query<'a>(
258        &'a mut self,
259        query: TensorLibraryQuery,
260    ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
261    {
262        Box::pin(async move {
263            match query {
264                TensorLibraryQuery::Uuid(uuid) => {
265                    if let Some(tensor) = self.cached.get(&uuid).cloned() {
266                        return Ok(Some(tensor));
267                    }
268
269                    match self.builders.get(&uuid) {
270                        None => Ok(None),
271                        Some(builder) => match builder.build(query.clone()).await? {
272                            Some(tensor) => {
273                                self.cached.insert(uuid, tensor.clone());
274                                Ok(Some(tensor))
275                            }
276                            None => Ok(None),
277                        },
278                    }
279                }
280                _ => Ok(None),
281            }
282        })
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use burn::Tensor;
290    use burn::backend::Wgpu;
291    use burn::backend::wgpu::WgpuDevice;
292    use burn::prelude::Shape;
293    use burn_rmexp_dyntensor::{DynTensor, KindFlag};
294
295    #[tokio::test]
296    async fn test_map_library() {
297        type B = Wgpu;
298        let device = Default::default();
299
300        let mut library = UuidMapTensorLibrary::new();
301
302        let source: Tensor<B, 2> = Tensor::random([2, 3], Default::default(), &device);
303
304        assert!(
305            library
306                .query(uuid::Uuid::new_v4().into())
307                .await
308                .expect("query failed")
309                .is_none()
310        );
311
312        let id = library.bind(source.clone());
313
314        assert_eq!(library.len(), 1);
315        assert_eq!(
316            library.size_estimate(),
317            1 * source.shape().num_elements() * source.dtype().size()
318        );
319
320        let _dup = library.bind(source.clone());
321
322        assert_eq!(library.len(), 2);
323        assert_eq!(
324            library.size_estimate(),
325            2 * source.shape().num_elements() * source.dtype().size()
326        );
327
328        let dyn_tensor = library
329            .query(id.into())
330            .await
331            .expect("query failed")
332            .expect("tensor not found");
333
334        dyn_tensor
335            .into_data()
336            .unwrap()
337            .assert_eq(&source.to_data(), true);
338    }
339
340    #[tokio::test]
341    async fn test_lazy_builder_library() {
342        type B = Wgpu;
343        let device: WgpuDevice = Default::default();
344
345        #[derive(Debug)]
346        struct RandomBuilder<B: Backend, const R: usize> {
347            pub shape: [usize; R],
348            pub device: B::Device,
349        }
350
351        impl<B: Backend, const R: usize> LazyBuilder<B> for RandomBuilder<B, R> {
352            fn build<'a>(
353                &'a self,
354                _query: TensorLibraryQuery,
355            ) -> Pin<
356                Box<
357                    dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>>
358                        + Send
359                        + 'a,
360                >,
361            > {
362                Box::pin(async move {
363                    Ok(Some(
364                        Tensor::<B, R>::random(
365                            self.shape.clone(),
366                            Default::default(),
367                            &self.device,
368                        )
369                        .into(),
370                    ))
371                })
372            }
373        }
374
375        let mut library: LazyBuilderLibrary<B> = LazyBuilderLibrary::new();
376        let id = uuid::Uuid::new_v4();
377
378        library.register_builder(
379            id,
380            RandomBuilder {
381                shape: [2, 3],
382                device: device.clone(),
383            },
384        );
385
386        let dyn_tensor = library
387            .query(id.into())
388            .await
389            .expect("query failed")
390            .expect("tensor not found");
391
392        assert_eq!(dyn_tensor.rank(), 2);
393        assert_eq!(dyn_tensor.shape(), Shape::new([2, 3]));
394
395        assert_eq!(dyn_tensor.kind(), KindFlag::Float);
396        assert_eq!(dyn_tensor.device(), device);
397    }
398}