use std::any::Any;
pub use bb_ir::slot_value::*;
use crate::ids::ComponentRef;
pub struct BackendTensorCarrier {
pub(crate) inner: Box<dyn Any + Send + Sync>,
pub(crate) clone_fn: fn(&(dyn Any + Send + Sync)) -> Box<dyn Any + Send + Sync>,
pub(crate) wire_encode_fn: fn(&(dyn Any + Send + Sync)) -> Result<Vec<u8>, SlotValueError>,
pub(crate) type_hash: u64,
pub(crate) charged_bytes: usize,
pub(crate) backend_ref: ComponentRef,
}
impl BackendTensorCarrier {
pub fn from_typed<T>(
tensor: T,
type_hash: u64,
charged_bytes: usize,
backend_ref: ComponentRef,
) -> Self
where
T: Any + Send + Sync + Clone + serde::Serialize + 'static,
{
Self {
inner: Box::new(tensor),
clone_fn: |any| {
let t: &T = any.downcast_ref::<T>().expect("inner is T by construction");
Box::new(t.clone())
},
wire_encode_fn: |any| {
let t: &T = any.downcast_ref::<T>().expect("inner is T by construction");
bincode::serialize(t).map_err(|e| SlotValueError::EncodeFailed(Box::new(e)))
},
type_hash,
charged_bytes,
backend_ref,
}
}
pub fn type_hash(&self) -> u64 {
self.type_hash
}
pub fn backend_ref(&self) -> ComponentRef {
self.backend_ref
}
pub fn downcast_inner<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
self.inner.downcast_ref::<T>()
}
}
impl std::fmt::Debug for BackendTensorCarrier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BackendTensorCarrier")
.field("type_hash", &format_args!("{:#018x}", self.type_hash))
.field("charged_bytes", &self.charged_bytes)
.field("backend_ref", &self.backend_ref)
.finish()
}
}
impl SlotValue for BackendTensorCarrier {
fn as_any(&self) -> &dyn Any {
self
}
fn into_any_boxed(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
self
}
fn clone_boxed(&self) -> Box<dyn SlotValue> {
Box::new(Self {
inner: (self.clone_fn)(&*self.inner),
clone_fn: self.clone_fn,
wire_encode_fn: self.wire_encode_fn,
type_hash: self.type_hash,
charged_bytes: self.charged_bytes,
backend_ref: self.backend_ref,
})
}
fn to_wire_bytes(&self) -> Result<Vec<u8>, SlotValueError> {
(self.wire_encode_fn)(&*self.inner)
}
fn type_hash(&self) -> u64 {
self.type_hash
}
fn charged_bytes(&self) -> usize {
self.charged_bytes
}
}
#[derive(Debug, Clone)]
pub struct BackendMaterializeError {
pub summary: String,
}
impl std::fmt::Display for BackendMaterializeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Backend::materialize_from_wire: {}", self.summary)
}
}
impl std::error::Error for BackendMaterializeError {}