1use burn::prelude::Backend;
2use burn_rmexp_dyntensor::DynTensor;
3use futures::future::join_all;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fmt::Debug;
7use std::pin::Pin;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub enum TensorLibraryQuery {
12 Uuid(uuid::Uuid),
14
15 Route(Vec<String>),
17
18 Path(String),
20}
21
22impl From<uuid::Uuid> for TensorLibraryQuery {
23 fn from(uuid: uuid::Uuid) -> Self {
24 Self::Uuid(uuid)
25 }
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30pub enum TensorLibraryError {
31 InvalidQuery(TensorLibraryQuery),
32}
33
34pub trait TensorLibrary<B: Backend>: 'static + Debug {
35 fn query<'a>(
37 &'a mut self,
38 query: TensorLibraryQuery,
39 ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>;
40}
41
42#[derive(Debug)]
43pub struct TensorLibraryCollection<B: Backend> {
44 libs: Vec<Box<dyn TensorLibrary<B>>>,
45}
46
47impl<B: Backend> Default for TensorLibraryCollection<B> {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl<B: Backend> TensorLibraryCollection<B> {
54 pub fn new() -> Self {
56 Self { libs: Vec::new() }
57 }
58
59 pub fn push(
61 &mut self,
62 lib: Box<dyn TensorLibrary<B>>,
63 ) {
64 self.libs.push(lib);
65 }
66
67 pub fn libs(&self) -> &[Box<dyn TensorLibrary<B>>] {
69 &self.libs
70 }
71
72 pub fn libs_mut(&mut self) -> &mut [Box<dyn TensorLibrary<B>>] {
74 &mut self.libs
75 }
76}
77
78impl<B: Backend> TensorLibrary<B> for TensorLibraryCollection<B> {
79 fn query<'a>(
80 &'a mut self,
81 query: TensorLibraryQuery,
82 ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
83 {
84 let fs = self
86 .libs
87 .iter_mut()
88 .map(|lib| lib.query(query.clone()))
89 .collect::<Vec<_>>();
90
91 Box::pin(async move {
93 let res = join_all(fs).await;
94 res.into_iter().try_fold(None, |acc, result| match result {
95 Err(e) => Err(e),
96 Ok(Some(val)) if acc.is_none() => Ok(Some(val)),
97 Ok(_) => Ok(acc),
98 })
99 })
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct UuidMapTensorLibrary<B: Backend> {
106 hash_map: HashMap<uuid::Uuid, DynTensor<B>>,
107}
108
109impl<B: Backend> From<HashMap<uuid::Uuid, DynTensor<B>>> for UuidMapTensorLibrary<B> {
110 fn from(hash_map: HashMap<uuid::Uuid, DynTensor<B>>) -> Self {
111 Self { hash_map }
112 }
113}
114
115impl<B: Backend> Default for UuidMapTensorLibrary<B> {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121impl<B: Backend> UuidMapTensorLibrary<B> {
122 pub fn new() -> Self {
124 Self {
125 hash_map: HashMap::new(),
126 }
127 }
128
129 pub fn hash_map(&self) -> &HashMap<uuid::Uuid, DynTensor<B>> {
131 &self.hash_map
132 }
133
134 pub fn hash_map_mut(&mut self) -> &mut HashMap<uuid::Uuid, DynTensor<B>> {
136 &mut self.hash_map
137 }
138
139 pub fn insert<T: Into<DynTensor<B>>>(
146 &mut self,
147 key: uuid::Uuid,
148 value: T,
149 ) -> Option<DynTensor<B>> {
150 self.hash_map.insert(key, value.into())
151 }
152
153 pub fn bind<T: Into<DynTensor<B>>>(
156 &mut self,
157 value: T,
158 ) -> uuid::Uuid {
159 let key = uuid::Uuid::new_v4();
160 self.insert(key, value);
161 key
162 }
163
164 pub fn remove(
167 &mut self,
168 key: &uuid::Uuid,
169 ) -> Option<DynTensor<B>> {
170 self.hash_map.remove(key)
171 }
172
173 pub fn clear(&mut self) {
175 self.hash_map.clear();
176 }
177
178 pub fn len(&self) -> usize {
180 self.hash_map.len()
181 }
182
183 pub fn is_empty(&self) -> bool {
185 self.len() == 0
186 }
187
188 pub fn size_estimate(&self) -> usize {
190 self.hash_map
191 .values()
192 .map(|tensor| tensor.size_estimate())
193 .sum()
194 }
195
196 pub fn get(
198 &self,
199 key: &uuid::Uuid,
200 ) -> Option<&DynTensor<B>> {
201 self.hash_map.get(key)
202 }
203}
204
205impl<B: Backend> TensorLibrary<B> for UuidMapTensorLibrary<B> {
206 fn query<'a>(
208 &'a mut self,
209 query: TensorLibraryQuery,
210 ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
211 {
212 Box::pin(async move {
213 match query {
214 TensorLibraryQuery::Uuid(uuid) => Ok(self.get(&uuid).cloned()),
215 _ => Ok(None),
216 }
217 })
218 }
219}
220
221pub trait LazyBuilder<B: Backend>: Debug + Sync + Send + 'static {
222 fn build<'a>(
223 &'a self,
224 query: TensorLibraryQuery,
225 ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>;
226}
227
228#[derive(Debug, Default)]
229pub struct LazyBuilderLibrary<B: Backend> {
230 builders: HashMap<uuid::Uuid, Box<dyn LazyBuilder<B>>>,
231 cached: UuidMapTensorLibrary<B>,
232}
233
234impl<B: Backend> LazyBuilderLibrary<B> {
235 pub fn new() -> Self {
236 Self::default()
237 }
238
239 pub fn cached(&self) -> &UuidMapTensorLibrary<B> {
240 &self.cached
241 }
242
243 pub fn cached_mut(&mut self) -> &mut UuidMapTensorLibrary<B> {
244 &mut self.cached
245 }
246
247 pub fn register_builder<T: LazyBuilder<B> + 'static>(
248 &mut self,
249 uuid: uuid::Uuid,
250 builder: T,
251 ) {
252 self.builders.insert(uuid, Box::new(builder));
253 }
254}
255
256impl<B: Backend> TensorLibrary<B> for LazyBuilderLibrary<B> {
257 fn query<'a>(
258 &'a mut self,
259 query: TensorLibraryQuery,
260 ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
261 {
262 Box::pin(async move {
263 match query {
264 TensorLibraryQuery::Uuid(uuid) => {
265 if let Some(tensor) = self.cached.get(&uuid).cloned() {
266 return Ok(Some(tensor));
267 }
268
269 match self.builders.get(&uuid) {
270 None => Ok(None),
271 Some(builder) => match builder.build(query.clone()).await? {
272 Some(tensor) => {
273 self.cached.insert(uuid, tensor.clone());
274 Ok(Some(tensor))
275 }
276 None => Ok(None),
277 },
278 }
279 }
280 _ => Ok(None),
281 }
282 })
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use burn::Tensor;
290 use burn::backend::Wgpu;
291 use burn::backend::wgpu::WgpuDevice;
292 use burn::prelude::Shape;
293 use burn_rmexp_dyntensor::{DynTensor, KindFlag};
294
295 #[tokio::test]
296 async fn test_map_library() {
297 type B = Wgpu;
298 let device = Default::default();
299
300 let mut library = UuidMapTensorLibrary::new();
301
302 let source: Tensor<B, 2> = Tensor::random([2, 3], Default::default(), &device);
303
304 assert!(
305 library
306 .query(uuid::Uuid::new_v4().into())
307 .await
308 .expect("query failed")
309 .is_none()
310 );
311
312 let id = library.bind(source.clone());
313
314 assert_eq!(library.len(), 1);
315 assert_eq!(
316 library.size_estimate(),
317 1 * source.shape().num_elements() * source.dtype().size()
318 );
319
320 let _dup = library.bind(source.clone());
321
322 assert_eq!(library.len(), 2);
323 assert_eq!(
324 library.size_estimate(),
325 2 * source.shape().num_elements() * source.dtype().size()
326 );
327
328 let dyn_tensor = library
329 .query(id.into())
330 .await
331 .expect("query failed")
332 .expect("tensor not found");
333
334 dyn_tensor
335 .into_data()
336 .unwrap()
337 .assert_eq(&source.to_data(), true);
338 }
339
340 #[tokio::test]
341 async fn test_lazy_builder_library() {
342 type B = Wgpu;
343 let device: WgpuDevice = Default::default();
344
345 #[derive(Debug)]
346 struct RandomBuilder<B: Backend, const R: usize> {
347 pub shape: [usize; R],
348 pub device: B::Device,
349 }
350
351 impl<B: Backend, const R: usize> LazyBuilder<B> for RandomBuilder<B, R> {
352 fn build<'a>(
353 &'a self,
354 _query: TensorLibraryQuery,
355 ) -> Pin<
356 Box<
357 dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>>
358 + Send
359 + 'a,
360 >,
361 > {
362 Box::pin(async move {
363 Ok(Some(
364 Tensor::<B, R>::random(
365 self.shape.clone(),
366 Default::default(),
367 &self.device,
368 )
369 .into(),
370 ))
371 })
372 }
373 }
374
375 let mut library: LazyBuilderLibrary<B> = LazyBuilderLibrary::new();
376 let id = uuid::Uuid::new_v4();
377
378 library.register_builder(
379 id,
380 RandomBuilder {
381 shape: [2, 3],
382 device: device.clone(),
383 },
384 );
385
386 let dyn_tensor = library
387 .query(id.into())
388 .await
389 .expect("query failed")
390 .expect("tensor not found");
391
392 assert_eq!(dyn_tensor.rank(), 2);
393 assert_eq!(dyn_tensor.shape(), Shape::new([2, 3]));
394
395 assert_eq!(dyn_tensor.kind(), KindFlag::Float);
396 assert_eq!(dyn_tensor.device(), device);
397 }
398}