1use burn::prelude::Backend;
2use burn_rmexp_dyntensor::DynTensor;
3use futures::future::join_all;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fmt::Debug;
7use std::pin::Pin;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub enum TensorLibraryQuery {
12 Uuid(uuid::Uuid),
14
15 Route(Vec<String>),
17
18 Path(String),
20}
21
22impl From<uuid::Uuid> for TensorLibraryQuery {
23 fn from(uuid: uuid::Uuid) -> Self {
24 Self::Uuid(uuid)
25 }
26}
27
28#[derive(Debug, Clone, PartialEq, Eq, Hash)]
30pub enum TensorLibraryError {
31 InvalidQuery(TensorLibraryQuery),
32}
33
34pub trait TensorLibrary<B: Backend>: 'static + Debug {
35 fn query<'a>(
37 &'a mut self,
38 query: TensorLibraryQuery,
39 ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>;
40}
41
42#[derive(Debug)]
43pub struct TensorLibraryCollection<B: Backend> {
44 libs: Vec<Box<dyn TensorLibrary<B>>>,
45}
46
47impl<B: Backend> Default for TensorLibraryCollection<B> {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl<B: Backend> TensorLibraryCollection<B> {
54 pub fn new() -> Self {
56 Self { libs: Vec::new() }
57 }
58
59 pub fn push(
61 &mut self,
62 lib: Box<dyn TensorLibrary<B>>,
63 ) {
64 self.libs.push(lib);
65 }
66
67 pub fn libs(&self) -> &[Box<dyn TensorLibrary<B>>] {
69 &self.libs
70 }
71
72 pub fn libs_mut(&mut self) -> &mut [Box<dyn TensorLibrary<B>>] {
74 &mut self.libs
75 }
76}
77
78impl<B: Backend> TensorLibrary<B> for TensorLibraryCollection<B> {
79 fn query<'a>(
80 &'a mut self,
81 query: TensorLibraryQuery,
82 ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
83 {
84 let fs = self
86 .libs
87 .iter_mut()
88 .map(|lib| lib.query(query.clone()))
89 .collect::<Vec<_>>();
90
91 Box::pin(async move {
93 let res = join_all(fs).await;
94 res.into_iter().try_fold(None, |acc, result| match result {
95 Err(e) => Err(e),
96 Ok(Some(val)) if acc.is_none() => Ok(Some(val)),
97 Ok(_) => Ok(acc),
98 })
99 })
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct UuidMapTensorLibrary<B: Backend> {
106 hash_map: HashMap<uuid::Uuid, DynTensor<B>>,
107}
108
109impl<B: Backend> From<HashMap<uuid::Uuid, DynTensor<B>>> for UuidMapTensorLibrary<B> {
110 fn from(hash_map: HashMap<uuid::Uuid, DynTensor<B>>) -> Self {
111 Self { hash_map }
112 }
113}
114
115impl<B: Backend> Default for UuidMapTensorLibrary<B> {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121impl<B: Backend> UuidMapTensorLibrary<B> {
122 pub fn new() -> Self {
124 Self {
125 hash_map: HashMap::new(),
126 }
127 }
128
129 pub fn hash_map(&self) -> &HashMap<uuid::Uuid, DynTensor<B>> {
131 &self.hash_map
132 }
133
134 pub fn hash_map_mut(&mut self) -> &mut HashMap<uuid::Uuid, DynTensor<B>> {
136 &mut self.hash_map
137 }
138
139 pub fn insert<T: Into<DynTensor<B>>>(
146 &mut self,
147 key: uuid::Uuid,
148 value: T,
149 ) -> Option<DynTensor<B>> {
150 self.hash_map.insert(key, value.into())
151 }
152
153 pub fn bind<T: Into<DynTensor<B>>>(
156 &mut self,
157 value: T,
158 ) -> uuid::Uuid {
159 let key = uuid::Uuid::new_v4();
160 self.insert(key, value);
161 key
162 }
163
164 pub fn remove(
167 &mut self,
168 key: &uuid::Uuid,
169 ) -> Option<DynTensor<B>> {
170 self.hash_map.remove(key)
171 }
172
173 pub fn clear(&mut self) {
175 self.hash_map.clear();
176 }
177
178 pub fn len(&self) -> usize {
180 self.hash_map.len()
181 }
182
183 pub fn size_estimate(&self) -> usize {
185 self.hash_map
186 .values()
187 .map(|tensor| tensor.size_estimate())
188 .sum()
189 }
190
191 pub fn get(
193 &self,
194 key: &uuid::Uuid,
195 ) -> Option<&DynTensor<B>> {
196 self.hash_map.get(key)
197 }
198}
199
200impl<B: Backend> TensorLibrary<B> for UuidMapTensorLibrary<B> {
201 fn query<'a>(
203 &'a mut self,
204 query: TensorLibraryQuery,
205 ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
206 {
207 Box::pin(async move {
208 match query {
209 TensorLibraryQuery::Uuid(uuid) => Ok(self.get(&uuid).cloned()),
210 _ => Ok(None),
211 }
212 })
213 }
214}
215
216pub trait LazyBuilder<B: Backend>: Debug + Sync + Send + 'static {
217 fn build<'a>(
218 &'a self,
219 query: TensorLibraryQuery,
220 ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>;
221}
222
223#[derive(Debug, Default)]
224pub struct LazyBuilderLibrary<B: Backend> {
225 builders: HashMap<uuid::Uuid, Box<dyn LazyBuilder<B>>>,
226 cached: UuidMapTensorLibrary<B>,
227}
228
229impl<B: Backend> LazyBuilderLibrary<B> {
230 pub fn new() -> Self {
231 Self::default()
232 }
233
234 pub fn cached(&self) -> &UuidMapTensorLibrary<B> {
235 &self.cached
236 }
237
238 pub fn cached_mut(&mut self) -> &mut UuidMapTensorLibrary<B> {
239 &mut self.cached
240 }
241
242 pub fn register_builder<T: LazyBuilder<B> + 'static>(
243 &mut self,
244 uuid: uuid::Uuid,
245 builder: T,
246 ) {
247 self.builders.insert(uuid, Box::new(builder));
248 }
249}
250
251impl<B: Backend> TensorLibrary<B> for LazyBuilderLibrary<B> {
252 fn query<'a>(
253 &'a mut self,
254 query: TensorLibraryQuery,
255 ) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
256 {
257 Box::pin(async move {
258 match query {
259 TensorLibraryQuery::Uuid(uuid) => {
260 if let Some(tensor) = self.cached.get(&uuid).cloned() {
261 return Ok(Some(tensor));
262 }
263
264 let builder = self.builders.get(&uuid);
265 if builder.is_none() {
266 return Ok(None);
267 }
268
269 let qr = builder.unwrap().build(query.clone()).await?;
270 if qr.is_some() {
271 self.cached.insert(uuid, qr.as_ref().unwrap().clone());
272 }
273 Ok(qr)
274 }
275 _ => Ok(None),
276 }
277 })
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use burn::Tensor;
285 use burn::backend::Wgpu;
286 use burn::backend::wgpu::WgpuDevice;
287 use burn::prelude::Shape;
288 use burn_rmexp_dyntensor::{DynTensor, KindFlag};
289
290 #[tokio::test]
291 async fn test_map_library() {
292 type B = Wgpu;
293 let device = Default::default();
294
295 let mut library = UuidMapTensorLibrary::new();
296
297 let source: Tensor<B, 2> = Tensor::random([2, 3], Default::default(), &device);
298
299 assert!(
300 library
301 .query(uuid::Uuid::new_v4().into())
302 .await
303 .expect("query failed")
304 .is_none()
305 );
306
307 let id = library.bind(source.clone());
308
309 assert_eq!(library.len(), 1);
310 assert_eq!(
311 library.size_estimate(),
312 1 * source.shape().num_elements() * source.dtype().size()
313 );
314
315 let _dup = library.bind(source.clone());
316
317 assert_eq!(library.len(), 2);
318 assert_eq!(
319 library.size_estimate(),
320 2 * source.shape().num_elements() * source.dtype().size()
321 );
322
323 let dyn_tensor = library
324 .query(id.into())
325 .await
326 .expect("query failed")
327 .expect("tensor not found");
328
329 dyn_tensor
330 .to_data()
331 .unwrap()
332 .assert_eq(&source.to_data(), true);
333 }
334
335 #[tokio::test]
336 async fn test_lazy_builder_library() {
337 type B = Wgpu;
338 let device: WgpuDevice = Default::default();
339
340 #[derive(Debug)]
341 struct RandomBuilder<B: Backend, const R: usize> {
342 pub shape: [usize; R],
343 pub device: B::Device,
344 }
345
346 impl<B: Backend, const R: usize> LazyBuilder<B> for RandomBuilder<B, R> {
347 fn build<'a>(
348 &'a self,
349 _query: TensorLibraryQuery,
350 ) -> Pin<
351 Box<
352 dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>>
353 + Send
354 + 'a,
355 >,
356 > {
357 Box::pin(async move {
358 Ok(Some(
359 Tensor::<B, R>::random(
360 self.shape.clone(),
361 Default::default(),
362 &self.device,
363 )
364 .into(),
365 ))
366 })
367 }
368 }
369
370 let mut library: LazyBuilderLibrary<B> = LazyBuilderLibrary::new();
371 let id = uuid::Uuid::new_v4();
372
373 library.register_builder(
374 id,
375 RandomBuilder {
376 shape: [2, 3],
377 device: device.clone(),
378 },
379 );
380
381 let dyn_tensor = library
382 .query(id.into())
383 .await
384 .expect("query failed")
385 .expect("tensor not found");
386
387 assert_eq!(dyn_tensor.rank(), 2);
388 assert_eq!(dyn_tensor.shape(), Shape::new([2, 3]));
389
390 assert_eq!(dyn_tensor.kind(), KindFlag::Float);
391 assert_eq!(dyn_tensor.device(), device);
392 }
393}