burn_rmexp_tensorlib/
tensor_library.rs

1use burn::prelude::Backend;
2use burn_rmexp_dyntensor::DynTensor;
3use futures::future::join_all;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fmt::Debug;
7use std::pin::Pin;
8
9/// Query for a tensor in a [`TensorLibrary`]
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub enum TensorLibraryQuery {
12    /// Query by UUID.
13    Uuid(uuid::Uuid),
14
15    /// Query by route.
16    Route(Vec<String>),
17
18    /// Query by string path.
19    Path(String),
20}
21
22impl From<uuid::Uuid> for TensorLibraryQuery {
23    fn from(uuid: uuid::Uuid) -> Self {
24        Self::Uuid(uuid)
25    }
26}
27
28/// Error returned by [`TensorLibrary`]
29#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30pub enum TensorLibraryError {
31    InvalidQuery(TensorLibraryQuery),
32}
33
34pub trait TensorLibrary<B: Backend>: 'static + Debug {
35    /// Query a tensor from the library.
36    fn query<'a>(
37        &'a mut self,
38        query: TensorLibraryQuery,
39    ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>;
40}
41
42#[derive(Debug)]
43pub struct TensorLibraryCollection<B: Backend> {
44    libs: Vec<Box<dyn TensorLibrary<B>>>,
45}
46
47impl<B: Backend> Default for TensorLibraryCollection<B> {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl<B: Backend> TensorLibraryCollection<B> {
54    /// Create a new empty library collection.
55    pub fn new() -> Self {
56        Self { libs: Vec::new() }
57    }
58
59    /// Add a library to the collection.
60    pub fn push(
61        &mut self,
62        lib: Box<dyn TensorLibrary<B>>,
63    ) {
64        self.libs.push(lib);
65    }
66
67    /// Get a reference to the underlying libraries.
68    pub fn libs(&self) -> &[Box<dyn TensorLibrary<B>>] {
69        &self.libs
70    }
71
72    /// Get a mutable reference to the underlying libraries.
73    pub fn libs_mut(&mut self) -> &mut [Box<dyn TensorLibrary<B>>] {
74        &mut self.libs
75    }
76}
77
78impl<B: Backend> TensorLibrary<B> for TensorLibraryCollection<B> {
79    fn query<'a>(
80        &'a mut self,
81        query: TensorLibraryQuery,
82    ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
83    {
84        // Query each library in parallel.
85        let fs = self
86            .libs
87            .iter_mut()
88            .map(|lib| lib.query(query.clone()))
89            .collect::<Vec<_>>();
90
91        // Forward the first error; or the first non-None result.
92        Box::pin(async move {
93            let res = join_all(fs).await;
94            res.into_iter().try_fold(None, |acc, result| match result {
95                Err(e) => Err(e),
96                Ok(Some(val)) if acc.is_none() => Ok(Some(val)),
97                Ok(_) => Ok(acc),
98            })
99        })
100    }
101}
102
103/// A [`TensorLibrary`] backed by a [`uuid::Uuid`] keyed [`HashMap`].
104#[derive(Debug, Clone)]
105pub struct UuidMapTensorLibrary<B: Backend> {
106    hash_map: HashMap<uuid::Uuid, DynTensor<B>>,
107}
108
109impl<B: Backend> From<HashMap<uuid::Uuid, DynTensor<B>>> for UuidMapTensorLibrary<B> {
110    fn from(hash_map: HashMap<uuid::Uuid, DynTensor<B>>) -> Self {
111        Self { hash_map }
112    }
113}
114
115impl<B: Backend> Default for UuidMapTensorLibrary<B> {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121impl<B: Backend> UuidMapTensorLibrary<B> {
122    /// Create an empty library.
123    pub fn new() -> Self {
124        Self {
125            hash_map: HashMap::new(),
126        }
127    }
128
129    /// Get a reference to the internal map.
130    pub fn hash_map(&self) -> &HashMap<uuid::Uuid, DynTensor<B>> {
131        &self.hash_map
132    }
133
134    /// Get a mutable reference to the internal map.
135    pub fn hash_map_mut(&mut self) -> &mut HashMap<uuid::Uuid, DynTensor<B>> {
136        &mut self.hash_map
137    }
138
139    /// Insert a tensor into the library.
140    /// If a tensor with the same UUID already exists, it will be replaced.
141    ///
142    /// # Returns
143    ///
144    /// The previous value, if any.
145    pub fn insert<T: Into<DynTensor<B>>>(
146        &mut self,
147        key: uuid::Uuid,
148        value: T,
149    ) -> Option<DynTensor<B>> {
150        self.hash_map.insert(key, value.into())
151    }
152
153    /// Bind a tensor into the library.
154    /// Returns the generated UUID.
155    pub fn bind<T: Into<DynTensor<B>>>(
156        &mut self,
157        value: T,
158    ) -> uuid::Uuid {
159        let key = uuid::Uuid::new_v4();
160        self.insert(key, value);
161        key
162    }
163
164    /// Remove a tensor from the library.
165    /// Returns `None` if the tensor was not found.
166    pub fn remove(
167        &mut self,
168        key: &uuid::Uuid,
169    ) -> Option<DynTensor<B>> {
170        self.hash_map.remove(key)
171    }
172
173    /// Clear the library.
174    pub fn clear(&mut self) {
175        self.hash_map.clear();
176    }
177
178    /// Returns the number of tensors in the library.
179    pub fn len(&self) -> usize {
180        self.hash_map.len()
181    }
182
183    /// Returns the size estimate of the library in bytes.
184    pub fn size_estimate(&self) -> usize {
185        self.hash_map
186            .values()
187            .map(|tensor| tensor.size_estimate())
188            .sum()
189    }
190
191    /// Get a tensor ref from the library.
192    pub fn get(
193        &self,
194        key: &uuid::Uuid,
195    ) -> Option<&DynTensor<B>> {
196        self.hash_map.get(key)
197    }
198}
199
200impl<B: Backend> TensorLibrary<B> for UuidMapTensorLibrary<B> {
201    /// Query a tensor from the library.
202    fn query<'a>(
203        &'a mut self,
204        query: TensorLibraryQuery,
205    ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
206    {
207        Box::pin(async move {
208            match query {
209                TensorLibraryQuery::Uuid(uuid) => Ok(self.get(&uuid).cloned()),
210                _ => Ok(None),
211            }
212        })
213    }
214}
215
216pub trait LazyBuilder<B: Backend>: Debug + Sync + Send + 'static {
217    fn build<'a>(
218        &'a self,
219        query: TensorLibraryQuery,
220    ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>;
221}
222
223#[derive(Debug, Default)]
224pub struct LazyBuilderLibrary<B: Backend> {
225    builders: HashMap<uuid::Uuid, Box<dyn LazyBuilder<B>>>,
226    cached: UuidMapTensorLibrary<B>,
227}
228
229impl<B: Backend> LazyBuilderLibrary<B> {
230    pub fn new() -> Self {
231        Self::default()
232    }
233
234    pub fn cached(&self) -> &UuidMapTensorLibrary<B> {
235        &self.cached
236    }
237
238    pub fn cached_mut(&mut self) -> &mut UuidMapTensorLibrary<B> {
239        &mut self.cached
240    }
241
242    pub fn register_builder<T: LazyBuilder<B> + 'static>(
243        &mut self,
244        uuid: uuid::Uuid,
245        builder: T,
246    ) {
247        self.builders.insert(uuid, Box::new(builder));
248    }
249}
250
251impl<B: Backend> TensorLibrary<B> for LazyBuilderLibrary<B> {
252    fn query<'a>(
253        &'a mut self,
254        query: TensorLibraryQuery,
255    ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
256    {
257        Box::pin(async move {
258            match query {
259                TensorLibraryQuery::Uuid(uuid) => {
260                    if let Some(tensor) = self.cached.get(&uuid).cloned() {
261                        return Ok(Some(tensor));
262                    }
263
264                    let builder = self.builders.get(&uuid);
265                    if builder.is_none() {
266                        return Ok(None);
267                    }
268
269                    let qr = builder.unwrap().build(query.clone()).await?;
270                    if qr.is_some() {
271                        self.cached.insert(uuid, qr.as_ref().unwrap().clone());
272                    }
273                    Ok(qr)
274                }
275                _ => Ok(None),
276            }
277        })
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use burn::Tensor;
285    use burn::backend::Wgpu;
286    use burn::backend::wgpu::WgpuDevice;
287    use burn::prelude::Shape;
288    use burn_rmexp_dyntensor::{DynTensor, KindFlag};
289
290    #[tokio::test]
291    async fn test_map_library() {
292        type B = Wgpu;
293        let device = Default::default();
294
295        let mut library = UuidMapTensorLibrary::new();
296
297        let source: Tensor<B, 2> = Tensor::random([2, 3], Default::default(), &device);
298
299        assert!(
300            library
301                .query(uuid::Uuid::new_v4().into())
302                .await
303                .expect("query failed")
304                .is_none()
305        );
306
307        let id = library.bind(source.clone());
308
309        assert_eq!(library.len(), 1);
310        assert_eq!(
311            library.size_estimate(),
312            1 * source.shape().num_elements() * source.dtype().size()
313        );
314
315        let _dup = library.bind(source.clone());
316
317        assert_eq!(library.len(), 2);
318        assert_eq!(
319            library.size_estimate(),
320            2 * source.shape().num_elements() * source.dtype().size()
321        );
322
323        let dyn_tensor = library
324            .query(id.into())
325            .await
326            .expect("query failed")
327            .expect("tensor not found");
328
329        dyn_tensor
330            .to_data()
331            .unwrap()
332            .assert_eq(&source.to_data(), true);
333    }
334
335    #[tokio::test]
336    async fn test_lazy_builder_library() {
337        type B = Wgpu;
338        let device: WgpuDevice = Default::default();
339
340        #[derive(Debug)]
341        struct RandomBuilder<B: Backend, const R: usize> {
342            pub shape: [usize; R],
343            pub device: B::Device,
344        }
345
346        impl<B: Backend, const R: usize> LazyBuilder<B> for RandomBuilder<B, R> {
347            fn build<'a>(
348                &'a self,
349                _query: TensorLibraryQuery,
350            ) -> Pin<
351                Box<
352                    dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>>
353                        + Send
354                        + 'a,
355                >,
356            > {
357                Box::pin(async move {
358                    Ok(Some(
359                        Tensor::<B, R>::random(
360                            self.shape.clone(),
361                            Default::default(),
362                            &self.device,
363                        )
364                        .into(),
365                    ))
366                })
367            }
368        }
369
370        let mut library: LazyBuilderLibrary<B> = LazyBuilderLibrary::new();
371        let id = uuid::Uuid::new_v4();
372
373        library.register_builder(
374            id,
375            RandomBuilder {
376                shape: [2, 3],
377                device: device.clone(),
378            },
379        );
380
381        let dyn_tensor = library
382            .query(id.into())
383            .await
384            .expect("query failed")
385            .expect("tensor not found");
386
387        assert_eq!(dyn_tensor.rank(), 2);
388        assert_eq!(dyn_tensor.shape(), Shape::new([2, 3]));
389
390        assert_eq!(dyn_tensor.kind(), KindFlag::Float);
391        assert_eq!(dyn_tensor.device(), device);
392    }
393}