use burn::prelude::Backend;
use burn_rmexp_dyntensor::DynTensor;
use futures::future::join_all;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Debug;
use std::pin::Pin;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum TensorLibraryQuery {
Uuid(uuid::Uuid),
Route(Vec<String>),
Path(String),
}
impl From<uuid::Uuid> for TensorLibraryQuery {
fn from(uuid: uuid::Uuid) -> Self {
Self::Uuid(uuid)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum TensorLibraryError {
InvalidQuery(TensorLibraryQuery),
}
pub trait TensorLibrary<B: Backend>: 'static + Debug {
fn query<'a>(
&'a mut self,
query: TensorLibraryQuery,
) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>;
}
#[derive(Debug)]
pub struct TensorLibraryCollection<B: Backend> {
libs: Vec<Box<dyn TensorLibrary<B>>>,
}
impl<B: Backend> Default for TensorLibraryCollection<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> TensorLibraryCollection<B> {
pub fn new() -> Self {
Self { libs: Vec::new() }
}
pub fn push(
&mut self,
lib: Box<dyn TensorLibrary<B>>,
) {
self.libs.push(lib);
}
pub fn libs(&self) -> &[Box<dyn TensorLibrary<B>>] {
&self.libs
}
pub fn libs_mut(&mut self) -> &mut [Box<dyn TensorLibrary<B>>] {
&mut self.libs
}
}
impl<B: Backend> TensorLibrary<B> for TensorLibraryCollection<B> {
fn query<'a>(
&'a mut self,
query: TensorLibraryQuery,
) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
{
let fs = self
.libs
.iter_mut()
.map(|lib| lib.query(query.clone()))
.collect::<Vec<_>>();
Box::pin(async move {
let res = join_all(fs).await;
res.into_iter().try_fold(None, |acc, result| match result {
Err(e) => Err(e),
Ok(Some(val)) if acc.is_none() => Ok(Some(val)),
Ok(_) => Ok(acc),
})
})
}
}
#[derive(Debug, Clone)]
pub struct UuidMapTensorLibrary<B: Backend> {
hash_map: HashMap<uuid::Uuid, DynTensor<B>>,
}
impl<B: Backend> From<HashMap<uuid::Uuid, DynTensor<B>>> for UuidMapTensorLibrary<B> {
fn from(hash_map: HashMap<uuid::Uuid, DynTensor<B>>) -> Self {
Self { hash_map }
}
}
impl<B: Backend> Default for UuidMapTensorLibrary<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Backend> UuidMapTensorLibrary<B> {
pub fn new() -> Self {
Self {
hash_map: HashMap::new(),
}
}
pub fn hash_map(&self) -> &HashMap<uuid::Uuid, DynTensor<B>> {
&self.hash_map
}
pub fn hash_map_mut(&mut self) -> &mut HashMap<uuid::Uuid, DynTensor<B>> {
&mut self.hash_map
}
pub fn insert<T: Into<DynTensor<B>>>(
&mut self,
key: uuid::Uuid,
value: T,
) -> Option<DynTensor<B>> {
self.hash_map.insert(key, value.into())
}
pub fn bind<T: Into<DynTensor<B>>>(
&mut self,
value: T,
) -> uuid::Uuid {
let key = uuid::Uuid::new_v4();
self.insert(key, value);
key
}
pub fn remove(
&mut self,
key: &uuid::Uuid,
) -> Option<DynTensor<B>> {
self.hash_map.remove(key)
}
pub fn clear(&mut self) {
self.hash_map.clear();
}
pub fn len(&self) -> usize {
self.hash_map.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn size_estimate(&self) -> usize {
self.hash_map
.values()
.map(|tensor| tensor.size_estimate())
.sum()
}
pub fn get(
&self,
key: &uuid::Uuid,
) -> Option<&DynTensor<B>> {
self.hash_map.get(key)
}
}
impl<B: Backend> TensorLibrary<B> for UuidMapTensorLibrary<B> {
fn query<'a>(
&'a mut self,
query: TensorLibraryQuery,
) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
{
Box::pin(async move {
match query {
TensorLibraryQuery::Uuid(uuid) => Ok(self.get(&uuid).cloned()),
_ => Ok(None),
}
})
}
}
pub trait LazyBuilder<B: Backend>: Debug + Sync + Send + 'static {
fn build<'a>(
&'a self,
query: TensorLibraryQuery,
) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>;
}
#[derive(Debug, Default)]
pub struct LazyBuilderLibrary<B: Backend> {
builders: HashMap<uuid::Uuid, Box<dyn LazyBuilder<B>>>,
cached: UuidMapTensorLibrary<B>,
}
impl<B: Backend> LazyBuilderLibrary<B> {
pub fn new() -> Self {
Self::default()
}
pub fn cached(&self) -> &UuidMapTensorLibrary<B> {
&self.cached
}
pub fn cached_mut(&mut self) -> &mut UuidMapTensorLibrary<B> {
&mut self.cached
}
pub fn register_builder<T: LazyBuilder<B> + 'static>(
&mut self,
uuid: uuid::Uuid,
builder: T,
) {
self.builders.insert(uuid, Box::new(builder));
}
}
impl<B: Backend> TensorLibrary<B> for LazyBuilderLibrary<B> {
fn query<'a>(
&'a mut self,
query: TensorLibraryQuery,
) -> Pin<Box<dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>> + Send + 'a>>
{
Box::pin(async move {
match query {
TensorLibraryQuery::Uuid(uuid) => {
if let Some(tensor) = self.cached.get(&uuid).cloned() {
return Ok(Some(tensor));
}
match self.builders.get(&uuid) {
None => Ok(None),
Some(builder) => match builder.build(query.clone()).await? {
Some(tensor) => {
self.cached.insert(uuid, tensor.clone());
Ok(Some(tensor))
}
None => Ok(None),
},
}
}
_ => Ok(None),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::Tensor;
use burn::backend::Wgpu;
use burn::backend::wgpu::WgpuDevice;
use burn::prelude::Shape;
use burn_rmexp_dyntensor::{DynTensor, KindFlag};
#[tokio::test]
async fn test_map_library() {
type B = Wgpu;
let device = Default::default();
let mut library = UuidMapTensorLibrary::new();
let source: Tensor<B, 2> = Tensor::random([2, 3], Default::default(), &device);
assert!(
library
.query(uuid::Uuid::new_v4().into())
.await
.expect("query failed")
.is_none()
);
let id = library.bind(source.clone());
assert_eq!(library.len(), 1);
assert_eq!(
library.size_estimate(),
1 * source.shape().num_elements() * source.dtype().size()
);
let _dup = library.bind(source.clone());
assert_eq!(library.len(), 2);
assert_eq!(
library.size_estimate(),
2 * source.shape().num_elements() * source.dtype().size()
);
let dyn_tensor = library
.query(id.into())
.await
.expect("query failed")
.expect("tensor not found");
dyn_tensor
.into_data()
.unwrap()
.assert_eq(&source.to_data(), true);
}
#[tokio::test]
async fn test_lazy_builder_library() {
type B = Wgpu;
let device: WgpuDevice = Default::default();
#[derive(Debug)]
struct RandomBuilder<B: Backend, const R: usize> {
pub shape: [usize; R],
pub device: B::Device,
}
impl<B: Backend, const R: usize> LazyBuilder<B> for RandomBuilder<B, R> {
fn build<'a>(
&'a self,
_query: TensorLibraryQuery,
) -> Pin<
Box<
dyn Future<Output = Result<Option<DynTensor<B>>, TensorLibraryError>>
+ Send
+ 'a,
>,
> {
Box::pin(async move {
Ok(Some(
Tensor::<B, R>::random(
self.shape.clone(),
Default::default(),
&self.device,
)
.into(),
))
})
}
}
let mut library: LazyBuilderLibrary<B> = LazyBuilderLibrary::new();
let id = uuid::Uuid::new_v4();
library.register_builder(
id,
RandomBuilder {
shape: [2, 3],
device: device.clone(),
},
);
let dyn_tensor = library
.query(id.into())
.await
.expect("query failed")
.expect("tensor not found");
assert_eq!(dyn_tensor.rank(), 2);
assert_eq!(dyn_tensor.shape(), Shape::new([2, 3]));
assert_eq!(dyn_tensor.kind(), KindFlag::Float);
assert_eq!(dyn_tensor.device(), device);
}
}