use std::marker::PhantomData;
use std::ops::{Div, Rem};
use std::{fmt, iter};
use async_hash::{Digest, Hash, Output};
use async_trait::async_trait;
use collate::Collator;
use destream::{de, en};
use futures::TryFutureExt;
use log::debug;
use safecast::{AsType, CastFrom, CastInto, TryCastFrom, TryCastInto};
use smallvec::SmallVec;
use tc_error::*;
use tc_transact::lock::{PermitRead, PermitWrite};
use tc_transact::{fs, IntoView, Transact, Transaction, TxnId};
use tc_value::{Number, NumberType, Value, ValueType};
use tcgeneric::{
label, path_label, Class, ClassVisitor, Instance, Label, NativeClass, PathLabel, PathSegment,
TCPathBuf, ThreadSafe,
};
pub use dense::{Buffer, DenseBase, DenseCacheFile, DenseView};
pub use shape::{AxisRange, Range, Shape};
pub use sparse::{Node, SparseBase, SparseView};
mod block;
mod complex;
pub mod dense;
pub mod public;
pub mod shape;
pub mod sparse;
mod transform;
pub(super) mod view;
const REAL: Label = label("re");
const IMAG: Label = label("im");
const PREFIX: PathLabel = path_label(&["state", "collection", "tensor"]);
const IDEAL_BLOCK_SIZE: usize = 65_536;
pub type Axes = SmallVec<[usize; 8]>;
pub type Coord = SmallVec<[u64; 8]>;
pub type Strides = SmallVec<[u64; 8]>;
type Semaphore = tc_transact::lock::Semaphore<Collator<u64>, Range>;
#[derive(Clone, Eq, PartialEq)]
pub struct Schema {
pub dtype: NumberType,
pub shape: Shape,
}
impl From<(NumberType, Shape)> for Schema {
fn from(schema: (NumberType, Shape)) -> Self {
let (dtype, shape) = schema;
Self { dtype, shape }
}
}
impl TryCastFrom<Value> for Schema {
fn can_cast_from(value: &Value) -> bool {
match value {
Value::Tuple(tuple) => TryCastInto::<(TCPathBuf, Shape)>::can_cast_into(tuple),
_ => false,
}
}
fn opt_cast_from(value: Value) -> Option<Self> {
match value {
Value::Tuple(tuple) => {
let (dtype, shape): (TCPathBuf, Shape) = tuple.opt_cast_into()?;
let dtype = ValueType::from_path(&dtype)?;
match dtype {
ValueType::Number(dtype) => Some(Schema { shape, dtype }),
_ => None,
}
}
_ => None,
}
}
}
impl CastFrom<Schema> for Value {
fn cast_from(schema: Schema) -> Self {
Value::Tuple(
vec![
ValueType::Number(schema.dtype).path().cast_into(),
schema.shape.cast_into(),
]
.into(),
)
}
}
impl<'a, D: Digest> Hash<D> for &'a Schema {
fn hash(self) -> Output<D> {
Hash::<D>::hash((&self.shape, ValueType::from(self.dtype).path()))
}
}
impl fmt::Debug for Schema {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"tensor of type {:?} with shape {:?}",
self.dtype, self.shape
)
}
}
#[async_trait]
impl de::FromStream for Schema {
type Context = ();
async fn from_stream<D: de::Decoder>(cxt: (), decoder: &mut D) -> Result<Self, D::Error> {
let (classpath, shape): (TCPathBuf, Shape) =
de::FromStream::from_stream(cxt, decoder).await?;
if let Some(ValueType::Number(dtype)) = ValueType::from_path(&classpath) {
Ok(Self { shape, dtype })
} else {
Err(de::Error::invalid_value("a Number type", classpath))
}
}
}
impl<'en> en::IntoStream<'en> for Schema {
fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> Result<E::Ok, E::Error> {
en::IntoStream::into_stream((ValueType::from(self.dtype).path(), self.shape), encoder)
}
}
impl<'en> en::ToStream<'en> for Schema {
fn to_stream<E: en::Encoder<'en>>(&'en self, encoder: E) -> Result<E::Ok, E::Error> {
en::IntoStream::into_stream((ValueType::from(self.dtype).path(), &self.shape), encoder)
}
}
#[async_trait]
pub trait TensorPermitRead: Send + Sync {
async fn read_permit(
&self,
txn_id: TxnId,
range: Range,
) -> TCResult<SmallVec<[PermitRead<Range>; 16]>>;
}
#[async_trait]
pub trait TensorPermitWrite: Send + Sync {
async fn write_permit(&self, txn_id: TxnId, range: Range) -> TCResult<PermitWrite<Range>>;
}
#[derive(Clone, Copy, Eq, PartialEq)]
pub enum TensorType {
Dense,
Sparse,
}
impl Class for TensorType {}
impl NativeClass for TensorType {
fn from_path(path: &[PathSegment]) -> Option<Self> {
if path.len() == 4 && &path[..3] == &PREFIX[..] {
match path[3].as_str() {
"dense" => Some(Self::Dense),
"sparse" => Some(Self::Sparse),
_ => None,
}
} else {
None
}
}
fn path(&self) -> TCPathBuf {
TCPathBuf::from(PREFIX).append(label(match self {
Self::Dense => "dense",
Self::Sparse => "sparse",
}))
}
}
#[async_trait]
impl de::FromStream for TensorType {
type Context = ();
async fn from_stream<D: de::Decoder>(_: (), decoder: &mut D) -> Result<Self, D::Error> {
decoder.decode_any(ClassVisitor::default()).await
}
}
impl fmt::Debug for TensorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("type Tensor")
}
}
pub trait TensorInstance: ThreadSafe + Sized {
fn dtype(&self) -> NumberType;
fn ndim(&self) -> usize {
self.shape().len()
}
fn shape(&self) -> &Shape;
fn size(&self) -> u64 {
self.shape().iter().product()
}
fn schema(&self) -> Schema {
Schema::from((self.dtype(), self.shape().clone()))
}
}
impl<T: TensorInstance> TensorInstance for Box<T> {
fn dtype(&self) -> NumberType {
(&**self).dtype()
}
fn shape(&self) -> &Shape {
(&**self).shape()
}
}
pub trait TensorBoolean<O> {
type Combine: TensorInstance;
type LeftCombine: TensorInstance;
fn and(self, other: O) -> TCResult<Self::LeftCombine>;
fn or(self, other: O) -> TCResult<Self::Combine>;
fn xor(self, other: O) -> TCResult<Self::Combine>;
}
pub trait TensorBooleanConst {
type Combine: TensorInstance;
fn and_const(self, other: Number) -> TCResult<Self::Combine>;
fn or_const(self, other: Number) -> TCResult<Self::Combine>;
fn xor_const(self, other: Number) -> TCResult<Self::Combine>;
}
pub trait TensorCast {
type Cast;
fn cast_into(self, dtype: NumberType) -> TCResult<Self::Cast>;
}
pub trait TensorCompare<O> {
type Compare: TensorInstance;
fn eq(self, other: O) -> TCResult<Self::Compare>;
fn gt(self, other: O) -> TCResult<Self::Compare>;
fn ge(self, other: O) -> TCResult<Self::Compare>;
fn lt(self, other: O) -> TCResult<Self::Compare>;
fn le(self, other: O) -> TCResult<Self::Compare>;
fn ne(self, other: O) -> TCResult<Self::Compare>;
}
pub trait TensorCompareConst {
type Compare: TensorInstance;
fn eq_const(self, other: Number) -> TCResult<Self::Compare>;
fn gt_const(self, other: Number) -> TCResult<Self::Compare>;
fn ge_const(self, other: Number) -> TCResult<Self::Compare>;
fn lt_const(self, other: Number) -> TCResult<Self::Compare>;
fn le_const(self, other: Number) -> TCResult<Self::Compare>;
fn ne_const(self, other: Number) -> TCResult<Self::Compare>;
}
pub trait TensorCond<Then, OrElse> {
type Cond: TensorInstance;
fn cond(self, then: Then, or_else: OrElse) -> TCResult<Self::Cond>;
}
pub trait TensorConvert: ThreadSafe {
type Dense: TensorInstance;
type Sparse: TensorInstance;
fn into_dense(self) -> Self::Dense;
fn into_sparse(self) -> Self::Sparse;
}
pub trait TensorDiagonal {
type Diagonal: TensorInstance;
fn diagonal(self) -> TCResult<Self::Diagonal>;
}
pub trait TensorMath<O> {
type Combine: TensorInstance;
type LeftCombine: TensorInstance;
fn add(self, other: O) -> TCResult<Self::Combine>;
fn div(self, other: O) -> TCResult<Self::LeftCombine>;
fn log(self, base: O) -> TCResult<Self::LeftCombine>;
fn mul(self, other: O) -> TCResult<Self::LeftCombine>;
fn pow(self, other: O) -> TCResult<Self::LeftCombine>;
fn sub(self, other: O) -> TCResult<Self::Combine>;
}
pub trait TensorMathConst {
type Combine: TensorInstance;
fn add_const(self, other: Number) -> TCResult<Self::Combine>;
fn div_const(self, other: Number) -> TCResult<Self::Combine>;
fn log_const(self, base: Number) -> TCResult<Self::Combine>;
fn mul_const(self, other: Number) -> TCResult<Self::Combine>;
fn pow_const(self, other: Number) -> TCResult<Self::Combine>;
fn sub_const(self, other: Number) -> TCResult<Self::Combine>;
}
pub trait TensorMatMul<O> {
type MatMul: TensorInstance;
fn matmul(self, other: O) -> TCResult<Self::MatMul>;
}
#[async_trait]
pub trait TensorRead {
async fn read_value(self, txn_id: TxnId, coord: Coord) -> TCResult<Number>;
}
#[async_trait]
pub trait TensorReduce {
type Reduce: TensorInstance;
async fn all(self, txn_id: TxnId) -> TCResult<bool>;
async fn any(self, txn_id: TxnId) -> TCResult<bool>;
fn max(self, axes: Axes, keepdims: bool) -> TCResult<Self::Reduce>;
async fn max_all(self, txn_id: TxnId) -> TCResult<Number>;
fn min(self, axes: Axes, keepdims: bool) -> TCResult<Self::Reduce>;
async fn min_all(self, txn_id: TxnId) -> TCResult<Number>;
fn product(self, axes: Axes, keepdims: bool) -> TCResult<Self::Reduce>;
async fn product_all(self, txn_id: TxnId) -> TCResult<Number>;
fn sum(self, axes: Axes, keepdims: bool) -> TCResult<Self::Reduce>;
async fn sum_all(self, txn_id: TxnId) -> TCResult<Number>;
}
pub trait TensorTransform {
type Broadcast: TensorInstance;
type Expand: TensorInstance;
type Reshape: TensorInstance;
type Slice: TensorInstance;
type Transpose: TensorInstance;
fn broadcast(self, shape: Shape) -> TCResult<Self::Broadcast>;
fn expand(self, axes: Axes) -> TCResult<Self::Expand>;
fn reshape(self, shape: Shape) -> TCResult<Self::Reshape>;
fn slice(self, range: Range) -> TCResult<Self::Slice>;
fn transpose(self, permutation: Option<Axes>) -> TCResult<Self::Transpose>;
}
#[async_trait]
pub trait TensorTrig {
type Unary: TensorInstance;
fn asin(self) -> TCResult<Self::Unary>;
fn sin(self) -> TCResult<Self::Unary>;
fn sinh(self) -> TCResult<Self::Unary>;
fn acos(self) -> TCResult<Self::Unary>;
fn cos(self) -> TCResult<Self::Unary>;
fn cosh(self) -> TCResult<Self::Unary>;
fn atan(self) -> TCResult<Self::Unary>;
fn tan(self) -> TCResult<Self::Unary>;
fn tanh(self) -> TCResult<Self::Unary>;
}
pub trait TensorUnary {
type Unary: TensorInstance;
fn abs(self) -> TCResult<Self::Unary>;
fn exp(self) -> TCResult<Self::Unary>;
fn ln(self) -> TCResult<Self::Unary>;
fn round(self) -> TCResult<Self::Unary>;
}
pub trait TensorUnaryBoolean {
type Unary: TensorInstance;
fn not(self) -> TCResult<Self::Unary>;
}
#[async_trait]
pub trait TensorWrite {
async fn write_value(&self, txn_id: TxnId, range: Range, value: Number) -> TCResult<()>;
async fn write_value_at(&self, txn_id: TxnId, coord: Coord, value: Number) -> TCResult<()>;
}
#[async_trait]
pub trait TensorWriteDual<O> {
async fn write(self, txn_id: TxnId, range: Range, value: O) -> TCResult<()>;
}
pub enum Dense<Txn, FE> {
Base(DenseBase<Txn, FE>),
View(DenseView<Txn, FE>),
}
impl<Txn, FE> Clone for Dense<Txn, FE> {
fn clone(&self) -> Self {
match self {
Self::Base(base) => Self::Base(base.clone()),
Self::View(view) => Self::View(view.clone()),
}
}
}
impl<Txn, FE> Dense<Txn, FE> {
pub fn into_view(self) -> DenseView<Txn, FE> {
self.into()
}
}
impl<Txn: ThreadSafe, FE: ThreadSafe> Instance for Dense<Txn, FE> {
type Class = TensorType;
fn class(&self) -> Self::Class {
TensorType::Dense
}
}
impl<Txn: ThreadSafe, FE: ThreadSafe> TensorInstance for Dense<Txn, FE> {
fn dtype(&self) -> NumberType {
match self {
Self::Base(base) => base.dtype(),
Self::View(view) => view.dtype(),
}
}
fn shape(&self) -> &Shape {
match self {
Self::Base(base) => base.shape(),
Self::View(view) => view.shape(),
}
}
}
impl<Txn, FE> TensorConvert for Dense<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Dense = Self;
type Sparse = Sparse<Txn, FE>;
fn into_dense(self) -> Self::Dense {
self
}
fn into_sparse(self) -> Self::Sparse {
Sparse::View(self.into_view().into_sparse())
}
}
#[async_trait]
impl<Txn, FE> TensorRead for Dense<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
async fn read_value(self, txn_id: TxnId, coord: Coord) -> TCResult<Number> {
match self {
Self::Base(base) => base.read_value(txn_id, coord).await,
Self::View(view) => view.read_value(txn_id, coord).await,
}
}
}
#[async_trait]
impl<Txn, FE> TensorWrite for Dense<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node>,
{
async fn write_value(&self, txn_id: TxnId, range: Range, value: Number) -> TCResult<()> {
if let Self::Base(base) = self {
base.write_value(txn_id, range, value).await
} else {
Err(bad_request!("cannot write to {:?}", self))
}
}
async fn write_value_at(&self, txn_id: TxnId, coord: Coord, value: Number) -> TCResult<()> {
if let Self::Base(base) = self {
base.write_value_at(txn_id, coord, value).await
} else {
Err(bad_request!("cannot write to {:?}", self))
}
}
}
#[async_trait]
impl<Txn, FE> TensorWriteDual<Self> for Dense<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
async fn write(self, txn_id: TxnId, range: Range, value: Self) -> TCResult<()> {
if let Self::Base(base) = self {
base.write(txn_id, range, value.into()).await
} else {
Err(bad_request!("cannot write to {:?}", self))
}
}
}
#[async_trait]
impl<'en, Txn, FE> IntoView<'en, FE> for Dense<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Txn = Txn;
type View = view::DenseView;
async fn into_view(self, txn: Self::Txn) -> TCResult<view::DenseView> {
view::DenseView::read_from(self, *txn.id()).await
}
}
impl<Txn, FE> From<DenseView<Txn, FE>> for Dense<Txn, FE> {
fn from(view: DenseView<Txn, FE>) -> Self {
Self::View(view)
}
}
impl<Txn, FE> From<Dense<Txn, FE>> for DenseView<Txn, FE> {
fn from(dense: Dense<Txn, FE>) -> Self {
match dense {
Dense::Base(base) => base.into(),
Dense::View(view) => view,
}
}
}
impl<Txn: ThreadSafe, FE: ThreadSafe> fmt::Debug for Dense<Txn, FE> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Base(base) => base.fmt(f),
Self::View(view) => view.fmt(f),
}
}
}
pub enum Sparse<Txn, FE> {
Base(SparseBase<Txn, FE>),
View(SparseView<Txn, FE>),
}
impl<Txn, FE> Clone for Sparse<Txn, FE> {
fn clone(&self) -> Self {
match self {
Self::Base(base) => Self::Base(base.clone()),
Self::View(view) => Self::View(view.clone()),
}
}
}
impl<Txn, FE> Sparse<Txn, FE> {
pub fn into_view(self) -> SparseView<Txn, FE> {
self.into()
}
}
impl<Txn: ThreadSafe, FE: ThreadSafe> Instance for Sparse<Txn, FE> {
type Class = TensorType;
fn class(&self) -> Self::Class {
TensorType::Sparse
}
}
impl<Txn: ThreadSafe, FE: ThreadSafe> TensorInstance for Sparse<Txn, FE> {
fn dtype(&self) -> NumberType {
match self {
Self::Base(base) => base.dtype(),
Self::View(view) => view.dtype(),
}
}
fn shape(&self) -> &Shape {
match self {
Self::Base(base) => base.shape(),
Self::View(view) => view.shape(),
}
}
}
#[async_trait]
impl<Txn: Transaction<FE>, FE: DenseCacheFile + AsType<Node>> TensorRead for Sparse<Txn, FE> {
async fn read_value(self, txn_id: TxnId, coord: Coord) -> TCResult<Number> {
match self {
Self::Base(base) => base.read_value(txn_id, coord).await,
Self::View(view) => view.read_value(txn_id, coord).await,
}
}
}
#[async_trait]
impl<Txn: Transaction<FE>, FE: DenseCacheFile + AsType<Node>> TensorWrite for Sparse<Txn, FE> {
async fn write_value(&self, txn_id: TxnId, range: Range, value: Number) -> TCResult<()> {
if let Self::Base(base) = self {
base.write_value(txn_id, range, value).await
} else {
Err(bad_request!("cannot write to {:?}", self))
}
}
async fn write_value_at(&self, txn_id: TxnId, coord: Coord, value: Number) -> TCResult<()> {
if let Self::Base(base) = self {
base.write_value_at(txn_id, coord, value).await
} else {
Err(bad_request!("cannot write to {:?}", self))
}
}
}
#[async_trait]
impl<Txn, FE> TensorWriteDual<Self> for Sparse<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
async fn write(self, txn_id: TxnId, range: Range, value: Self) -> TCResult<()> {
if let Self::Base(base) = self {
base.write(txn_id, range, value.into()).await
} else {
Err(bad_request!("cannot write to {:?}", self))
}
}
}
#[async_trait]
impl<'en, Txn, FE> IntoView<'en, FE> for Sparse<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Txn = Txn;
type View = view::SparseView;
async fn into_view(self, txn: Self::Txn) -> TCResult<view::SparseView> {
view::SparseView::read_from(self, *txn.id()).await
}
}
impl<Txn, FE> From<SparseView<Txn, FE>> for Sparse<Txn, FE> {
fn from(view: SparseView<Txn, FE>) -> Self {
Self::View(view)
}
}
impl<Txn, FE> From<Sparse<Txn, FE>> for SparseView<Txn, FE> {
fn from(sparse: Sparse<Txn, FE>) -> Self {
match sparse {
Sparse::Base(base) => base.into(),
Sparse::View(view) => view.into(),
}
}
}
impl<Txn: ThreadSafe, FE: ThreadSafe> fmt::Debug for Sparse<Txn, FE> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Base(base) => base.fmt(f),
Self::View(view) => view.fmt(f),
}
}
}
pub enum Tensor<Txn, FE> {
Dense(Dense<Txn, FE>),
Sparse(Sparse<Txn, FE>),
}
impl<Txn, FE> Clone for Tensor<Txn, FE> {
fn clone(&self) -> Self {
match self {
Self::Dense(dense) => Self::Dense(dense.clone()),
Self::Sparse(sparse) => Self::Sparse(sparse.clone()),
}
}
}
impl<Txn: ThreadSafe, FE: ThreadSafe> Instance for Tensor<Txn, FE> {
type Class = TensorType;
fn class(&self) -> Self::Class {
match self {
Self::Dense(dense) => dense.class(),
Self::Sparse(sparse) => sparse.class(),
}
}
}
impl<Txn: ThreadSafe, FE: ThreadSafe> TensorInstance for Tensor<Txn, FE> {
fn dtype(&self) -> NumberType {
match self {
Self::Dense(dense) => dense.dtype(),
Self::Sparse(sparse) => sparse.dtype(),
}
}
fn shape(&self) -> &Shape {
match self {
Self::Dense(dense) => dense.shape(),
Self::Sparse(sparse) => sparse.shape(),
}
}
}
impl<Txn, FE> TensorBoolean<Self> for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Combine = Self;
type LeftCombine = Self;
fn and(self, other: Self) -> TCResult<Self::LeftCombine> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().and(that.into()).map(Self::from),
Self::Sparse(that) => that
.into_view()
.and(this.into_view().into_sparse())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.and(that.into_view().into_sparse())
.map(Self::from),
Self::Sparse(that) => this.into_view().and(that.into()).map(Self::from),
},
}
}
fn or(self, other: Self) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().or(that.into()).map(Self::from),
Self::Sparse(that) => this
.into_view()
.or(that.into_view().into_dense())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.into_dense()
.or(that.into())
.map(Self::from),
Self::Sparse(that) => this.into_view().or(that.into()).map(Self::from),
},
}
}
fn xor(self, other: Self) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().xor(that.into()).map(Self::from),
Self::Sparse(that) => this
.into_view()
.xor(that.into_view().into_dense())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.into_dense()
.xor(that.into())
.map(Self::from),
Self::Sparse(that) => this.into_view().xor(that.into()).map(Self::from),
},
}
}
}
impl<Txn, FE> TensorBooleanConst for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Combine = Self;
fn and_const(self, other: Number) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => this.into_view().and_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().and_const(other).map(Self::from),
}
}
fn or_const(self, other: Number) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => this.into_view().or_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().or_const(other).map(Self::from),
}
}
fn xor_const(self, other: Number) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => this.into_view().xor_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().xor_const(other).map(Self::from),
}
}
}
impl<Txn, FE> TensorCast for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Cast = Self;
fn cast_into(self, dtype: NumberType) -> TCResult<Self::Cast> {
match self {
Self::Dense(this) => TensorCast::cast_into(this.into_view(), dtype).map(Self::from),
Self::Sparse(this) => TensorCast::cast_into(this.into_view(), dtype).map(Self::from),
}
}
}
impl<Txn, FE> TensorCompare<Self> for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Compare = Self;
fn eq(self, other: Self) -> TCResult<Self::Compare> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().eq(that.into()).map(Self::from),
Self::Sparse(that) => this
.into_view()
.eq(that.into_view().into_dense())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.eq(that.into_view().into_sparse())
.map(Self::from),
Self::Sparse(that) => this.into_view().eq(that.into()).map(Self::from),
},
}
}
fn gt(self, other: Self) -> TCResult<Self::Compare> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().gt(that.into()).map(Self::from),
Self::Sparse(that) => this
.into_view()
.gt(that.into_view().into_dense())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.into_dense()
.gt(that.into())
.map(Self::from),
Self::Sparse(that) => this.into_view().gt(that.into()).map(Self::from),
},
}
}
fn ge(self, other: Self) -> TCResult<Self::Compare> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().ge(that.into()).map(Self::from),
Self::Sparse(that) => this
.into_view()
.ge(that.into_view().into_dense())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.into_dense()
.ge(that.into())
.map(Self::from),
Self::Sparse(that) => this.into_view().ge(that.into()).map(Self::from),
},
}
}
fn lt(self, other: Self) -> TCResult<Self::Compare> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().lt(that.into()).map(Self::from),
Self::Sparse(that) => this
.into_view()
.lt(that.into_view().into_dense())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.into_dense()
.lt(that.into())
.map(Self::from),
Self::Sparse(that) => this.into_view().lt(that.into()).map(Self::from),
},
}
}
fn le(self, other: Self) -> TCResult<Self::Compare> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().le(that.into()).map(Self::from),
Self::Sparse(that) => this
.into_view()
.le(that.into_view().into_dense())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.into_dense()
.le(that.into())
.map(Self::from),
Self::Sparse(that) => this.into_view().le(that.into()).map(Self::from),
},
}
}
fn ne(self, other: Self) -> TCResult<Self::Compare> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().ne(that.into()).map(Self::from),
Self::Sparse(that) => this
.into_view()
.ne(that.into_view().into_dense())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.into_dense()
.ne(that.into())
.map(Self::from),
Self::Sparse(that) => this.into_view().ne(that.into()).map(Self::from),
},
}
}
}
impl<Txn, FE> TensorCompareConst for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Compare = Self;
fn eq_const(self, other: Number) -> TCResult<Self::Compare> {
match self {
Self::Dense(this) => this.into_view().eq_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().eq_const(other).map(Self::from),
}
}
fn gt_const(self, other: Number) -> TCResult<Self::Compare> {
match self {
Self::Dense(this) => this.into_view().gt_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().gt_const(other).map(Self::from),
}
}
fn ge_const(self, other: Number) -> TCResult<Self::Compare> {
match self {
Self::Dense(this) => this.into_view().ge_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().ge_const(other).map(Self::from),
}
}
fn lt_const(self, other: Number) -> TCResult<Self::Compare> {
match self {
Self::Dense(this) => this.into_view().lt_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().lt_const(other).map(Self::from),
}
}
fn le_const(self, other: Number) -> TCResult<Self::Compare> {
match self {
Self::Dense(this) => this.into_view().le_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().le_const(other).map(Self::from),
}
}
fn ne_const(self, other: Number) -> TCResult<Self::Compare> {
match self {
Self::Dense(this) => this.into_view().ne_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().ne_const(other).map(Self::from),
}
}
}
impl<Txn, FE> TensorCond<Self, Self> for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Cond = Self;
fn cond(self, then: Self, or_else: Self) -> TCResult<Self::Cond> {
match (self, then, or_else) {
(Self::Dense(this), Self::Dense(then), Self::Dense(or_else)) => this
.into_view()
.cond(then.into_view(), or_else.into_view())
.map(Self::from),
(Self::Sparse(this), Self::Sparse(then), Self::Sparse(or_else)) => this
.into_view()
.cond(then.into_view(), or_else.into_view())
.map(Self::from),
(this, then, or_else) => Self::Dense(this.into_dense())
.cond(then.into_dense().into(), or_else.into_dense().into()),
}
}
}
impl<Txn, FE> TensorConvert for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Dense = Dense<Txn, FE>;
type Sparse = Sparse<Txn, FE>;
fn into_dense(self) -> Dense<Txn, FE> {
match self {
Self::Dense(this) => this,
Self::Sparse(this) => Dense::View(this.into_view().into_dense()),
}
}
fn into_sparse(self) -> Sparse<Txn, FE> {
match self {
Self::Dense(this) => Sparse::View(this.into_view().into_sparse()),
Self::Sparse(this) => this,
}
}
}
impl<Txn, FE> TensorDiagonal for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Diagonal = Self;
fn diagonal(self) -> TCResult<Self::Diagonal> {
match self {
Self::Dense(dense) => dense.into_view().diagonal().map(Self::from),
Self::Sparse(sparse) => Err(not_implemented!("diagonal of {:?}", sparse)),
}
}
}
impl<Txn, FE> TensorMath<Self> for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Combine = Self;
type LeftCombine = Self;
fn add(self, other: Self) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().add(that.into()).map(Self::from),
Self::Sparse(that) => this
.into_view()
.add(that.into_view().into_dense())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.into_dense()
.add(that.into())
.map(Self::from),
Self::Sparse(that) => this.into_view().add(that.into_view()).map(Self::from),
},
}
}
fn div(self, other: Self) -> TCResult<Self::LeftCombine> {
if let Self::Dense(that) = other {
match self {
Self::Dense(this) => this.into_view().div(that.into()).map(Self::from),
Self::Sparse(this) => this
.into_view()
.div(that.into_view().into_sparse())
.map(Self::from),
}
} else {
Err(bad_request!("cannot divide by {other:?}"))
}
}
fn log(self, base: Self) -> TCResult<Self::LeftCombine> {
if let Self::Dense(that) = base {
match self {
Self::Dense(this) => this.into_view().log(that.into()).map(Self::from),
Self::Sparse(this) => this
.into_view()
.log(that.into_view().into_sparse())
.map(Self::from),
}
} else {
Err(bad_request!("log base {base:?} is undefined"))
}
}
fn mul(self, other: Self) -> TCResult<Self::LeftCombine> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().mul(that.into()).map(Self::from),
Self::Sparse(that) => this
.into_view()
.into_sparse()
.mul(that.into())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.mul(that.into_view().into_sparse())
.map(Self::from),
Self::Sparse(that) => this.into_view().mul(that.into()).map(Self::from),
},
}
}
fn pow(self, other: Self) -> TCResult<Self::LeftCombine> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().pow(that.into()).map(Self::from),
Self::Sparse(that) => this
.into_view()
.pow(that.into_view().into_dense())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.pow(that.into_view().into_sparse())
.map(Self::from),
Self::Sparse(that) => this.into_view().pow(that.into()).map(Self::from),
},
}
}
fn sub(self, other: Self) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().sub(that.into()).map(Self::from),
Self::Sparse(that) => this
.into_view()
.sub(that.into_view().into_dense())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.into_dense()
.sub(that.into())
.map(Self::from),
Self::Sparse(that) => this.into_view().sub(that.into()).map(Self::from),
},
}
}
}
impl<Txn, FE> TensorMathConst for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Combine = Self;
fn add_const(self, other: Number) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => this.into_view().add_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().add_const(other).map(Self::from),
}
}
fn div_const(self, other: Number) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => this.into_view().div_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().div_const(other).map(Self::from),
}
}
fn log_const(self, base: Number) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => this.into_view().log_const(base).map(Self::from),
Self::Sparse(this) => this.into_view().log_const(base).map(Self::from),
}
}
fn mul_const(self, other: Number) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => this.into_view().mul_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().mul_const(other).map(Self::from),
}
}
fn pow_const(self, other: Number) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => this.into_view().pow_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().pow_const(other).map(Self::from),
}
}
fn sub_const(self, other: Number) -> TCResult<Self::Combine> {
match self {
Self::Dense(this) => this.into_view().sub_const(other).map(Self::from),
Self::Sparse(this) => this.into_view().sub_const(other).map(Self::from),
}
}
}
impl<Txn, FE> TensorMatMul<Self> for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type MatMul = Self;
fn matmul(self, other: Self) -> TCResult<Self::MatMul> {
debug!("{:?} @ {:?}", self, other);
match self {
Self::Dense(this) => match other {
Self::Dense(that) => this.into_view().matmul(that.into_view()).map(Self::from),
Self::Sparse(that) => this
.into_sparse()
.into_view()
.matmul(that.into_view())
.map(Self::from),
},
Self::Sparse(this) => match other {
Self::Dense(that) => this
.into_view()
.matmul(that.into_sparse().into_view())
.map(Self::from),
Self::Sparse(that) => this.into_view().matmul(that.into_view()).map(Self::from),
},
}
}
}
#[async_trait]
impl<Txn, FE> TensorRead for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
async fn read_value(self, txn_id: TxnId, coord: Coord) -> TCResult<Number> {
match self {
Self::Dense(dense) => dense.read_value(txn_id, coord).await,
Self::Sparse(sparse) => sparse.read_value(txn_id, coord).await,
}
}
}
#[async_trait]
impl<Txn, FE> TensorReduce for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Reduce = Self;
async fn all(self, txn_id: TxnId) -> TCResult<bool> {
match self {
Self::Dense(this) => this.into_view().all(txn_id).await,
Self::Sparse(this) => this.into_view().all(txn_id).await,
}
}
async fn any(self, txn_id: TxnId) -> TCResult<bool> {
match self {
Self::Dense(this) => this.into_view().any(txn_id).await,
Self::Sparse(this) => this.into_view().any(txn_id).await,
}
}
fn max(self, axes: Axes, keepdims: bool) -> TCResult<Self::Reduce> {
match self {
Self::Dense(this) => this.into_view().max(axes, keepdims).map(Self::from),
Self::Sparse(this) => this.into_view().max(axes, keepdims).map(Self::from),
}
}
async fn max_all(self, txn_id: TxnId) -> TCResult<Number> {
match self {
Self::Dense(this) => this.into_view().max_all(txn_id).await,
Self::Sparse(this) => this.into_view().max_all(txn_id).await,
}
}
fn min(self, axes: Axes, keepdims: bool) -> TCResult<Self::Reduce> {
match self {
Self::Dense(this) => this.into_view().min(axes, keepdims).map(Self::from),
Self::Sparse(this) => this.into_view().min(axes, keepdims).map(Self::from),
}
}
async fn min_all(self, txn_id: TxnId) -> TCResult<Number> {
match self {
Self::Dense(this) => this.into_view().min_all(txn_id).await,
Self::Sparse(this) => this.into_view().min_all(txn_id).await,
}
}
fn product(self, axes: Axes, keepdims: bool) -> TCResult<Self::Reduce> {
match self {
Self::Dense(this) => this.into_view().product(axes, keepdims).map(Self::from),
Self::Sparse(this) => this.into_view().product(axes, keepdims).map(Self::from),
}
}
async fn product_all(self, txn_id: TxnId) -> TCResult<Number> {
match self {
Self::Dense(this) => this.into_view().product_all(txn_id).await,
Self::Sparse(this) => this.into_view().product_all(txn_id).await,
}
}
fn sum(self, axes: Axes, keepdims: bool) -> TCResult<Self::Reduce> {
match self {
Self::Dense(this) => this.into_view().sum(axes, keepdims).map(Self::from),
Self::Sparse(this) => this.into_view().sum(axes, keepdims).map(Self::from),
}
}
async fn sum_all(self, txn_id: TxnId) -> TCResult<Number> {
match self {
Self::Dense(this) => this.into_view().sum_all(txn_id).await,
Self::Sparse(this) => this.into_view().sum_all(txn_id).await,
}
}
}
impl<Txn, FE> TensorTransform for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Broadcast = Self;
type Expand = Self;
type Reshape = Self;
type Slice = Self;
type Transpose = Self;
fn broadcast(self, shape: Shape) -> TCResult<Self::Broadcast> {
match self {
Self::Dense(this) => this.into_view().broadcast(shape).map(Self::from),
Self::Sparse(this) => this.into_view().broadcast(shape).map(Self::from),
}
}
fn expand(self, axes: Axes) -> TCResult<Self::Expand> {
match self {
Self::Dense(this) => this.into_view().expand(axes).map(Self::from),
Self::Sparse(this) => this.into_view().expand(axes).map(Self::from),
}
}
fn reshape(self, shape: Shape) -> TCResult<Self::Reshape> {
match self {
Self::Dense(this) => this.into_view().reshape(shape).map(Self::from),
Self::Sparse(this) => this.into_view().reshape(shape).map(Self::from),
}
}
fn slice(self, range: Range) -> TCResult<Self::Slice> {
match self {
Self::Dense(this) => this.into_view().slice(range).map(Self::from),
Self::Sparse(this) => this.into_view().slice(range).map(Self::from),
}
}
fn transpose(self, permutation: Option<Axes>) -> TCResult<Self::Transpose> {
match self {
Self::Dense(this) => this.into_view().transpose(permutation).map(Self::from),
Self::Sparse(this) => this.into_view().transpose(permutation).map(Self::from),
}
}
}
impl<Txn, FE> TensorTrig for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Unary = Self;
fn asin(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().asin().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().asin().map(Self::from),
}
}
fn sin(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().sin().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().sin().map(Self::from),
}
}
fn sinh(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().sinh().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().sinh().map(Self::from),
}
}
fn acos(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().acos().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().acos().map(Self::from),
}
}
fn cos(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().cos().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().cos().map(Self::from),
}
}
fn cosh(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().cosh().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().cosh().map(Self::from),
}
}
fn atan(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().atan().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().atan().map(Self::from),
}
}
fn tan(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().tan().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().tan().map(Self::from),
}
}
fn tanh(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().tanh().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().tanh().map(Self::from),
}
}
}
impl<Txn, FE> TensorUnary for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Unary = Self;
fn abs(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().abs().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().abs().map(Self::from),
}
}
fn exp(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().exp().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().exp().map(Self::from),
}
}
fn ln(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().ln().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().ln().map(Self::from),
}
}
fn round(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().round().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().round().map(Self::from),
}
}
}
impl<Txn, FE> TensorUnaryBoolean for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Unary = Self;
fn not(self) -> TCResult<Self::Unary> {
match self {
Self::Dense(dense) => dense.into_view().not().map(Self::from),
Self::Sparse(sparse) => sparse.into_view().not().map(Self::from),
}
}
}
#[async_trait]
impl<Txn, FE> TensorWrite for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node>,
{
async fn write_value(&self, txn_id: TxnId, range: Range, value: Number) -> TCResult<()> {
match self {
Self::Dense(dense) => dense.write_value(txn_id, range, value).await,
Self::Sparse(sparse) => sparse.write_value(txn_id, range, value).await,
}
}
async fn write_value_at(&self, txn_id: TxnId, coord: Coord, value: Number) -> TCResult<()> {
match self {
Self::Dense(dense) => dense.write_value_at(txn_id, coord, value).await,
Self::Sparse(sparse) => sparse.write_value_at(txn_id, coord, value).await,
}
}
}
#[async_trait]
impl<Txn, FE> TensorWriteDual<Self> for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
async fn write(self, txn_id: TxnId, range: Range, value: Self) -> TCResult<()> {
match self {
Self::Dense(this) => this.write(txn_id, range, value.into_dense()).await,
Self::Sparse(this) => match value {
Self::Sparse(value) => this.write(txn_id, range, value).await,
Self::Dense(value) => Err(bad_request!("cannot write {value:?} to {this:?}")),
},
}
}
}
#[async_trait]
impl<'en, Txn, FE> IntoView<'en, FE> for Tensor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Txn = Txn;
type View = view::TensorView;
async fn into_view(self, txn: Self::Txn) -> TCResult<Self::View> {
view::TensorView::read_from(self, *txn.id()).await
}
}
impl<Txn, FE> From<Dense<Txn, FE>> for Tensor<Txn, FE> {
fn from(dense: Dense<Txn, FE>) -> Self {
Self::Dense(dense)
}
}
impl<Txn, FE> From<DenseView<Txn, FE>> for Tensor<Txn, FE> {
fn from(dense: DenseView<Txn, FE>) -> Self {
Self::Dense(dense.into())
}
}
impl<Txn, FE> From<Sparse<Txn, FE>> for Tensor<Txn, FE> {
fn from(sparse: Sparse<Txn, FE>) -> Self {
Self::Sparse(sparse)
}
}
impl<Txn, FE> From<SparseView<Txn, FE>> for Tensor<Txn, FE> {
fn from(sparse: SparseView<Txn, FE>) -> Self {
Self::Sparse(sparse.into())
}
}
impl<Txn, FE> From<Tensor<Txn, FE>> for TensorView<Txn, FE> {
fn from(tensor: Tensor<Txn, FE>) -> Self {
match tensor {
Tensor::Dense(dense) => Self::Dense(dense.into_view()),
Tensor::Sparse(sparse) => Self::Sparse(sparse.into_view()),
}
}
}
impl<Txn: ThreadSafe, FE: ThreadSafe> fmt::Debug for Tensor<Txn, FE> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Dense(this) => this.fmt(f),
Self::Sparse(this) => this.fmt(f),
}
}
}
pub enum TensorBase<Txn, FE> {
Dense(DenseBase<Txn, FE>),
Sparse(SparseBase<Txn, FE>),
}
impl<Txn, FE> Clone for TensorBase<Txn, FE> {
fn clone(&self) -> Self {
match self {
Self::Dense(dense) => Self::Dense(dense.clone()),
Self::Sparse(sparse) => Self::Sparse(sparse.clone()),
}
}
}
impl<Txn: ThreadSafe, FE: ThreadSafe> Instance for TensorBase<Txn, FE> {
type Class = TensorType;
fn class(&self) -> Self::Class {
match self {
Self::Dense(dense) => dense.class(),
Self::Sparse(sparse) => sparse.class(),
}
}
}
impl<Txn: ThreadSafe, FE: ThreadSafe> TensorInstance for TensorBase<Txn, FE> {
fn dtype(&self) -> NumberType {
match self {
Self::Dense(dense) => dense.dtype(),
Self::Sparse(sparse) => sparse.dtype(),
}
}
fn shape(&self) -> &Shape {
match self {
Self::Dense(dense) => dense.shape(),
Self::Sparse(sparse) => sparse.shape(),
}
}
}
#[async_trait]
impl<Txn, FE> Transact for TensorBase<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + for<'en> fs::FileSave<'en> + Clone,
{
type Commit = ();
async fn commit(&self, txn_id: TxnId) -> Self::Commit {
match self {
Self::Dense(dense) => dense.commit(txn_id).await,
Self::Sparse(sparse) => sparse.commit(txn_id).await,
}
}
async fn rollback(&self, txn_id: &TxnId) {
match self {
Self::Dense(dense) => dense.rollback(txn_id).await,
Self::Sparse(sparse) => sparse.rollback(txn_id).await,
}
}
async fn finalize(&self, txn_id: &TxnId) {
match self {
Self::Dense(dense) => dense.finalize(txn_id).await,
Self::Sparse(sparse) => sparse.finalize(txn_id).await,
}
}
}
#[async_trait]
impl<Txn, FE> fs::Persist<FE> for TensorBase<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Txn = Txn;
type Schema = (TensorType, Schema);
async fn create(txn_id: TxnId, schema: Self::Schema, store: fs::Dir<FE>) -> TCResult<Self> {
let (class, schema) = schema;
match class {
TensorType::Dense => {
DenseBase::create(txn_id, schema, store)
.map_ok(Self::Dense)
.await
}
TensorType::Sparse => {
SparseBase::create(txn_id, schema, store)
.map_ok(Self::Sparse)
.await
}
}
}
async fn load(txn_id: TxnId, schema: Self::Schema, store: fs::Dir<FE>) -> TCResult<Self> {
let (class, schema) = schema;
match class {
TensorType::Dense => {
DenseBase::load(txn_id, schema, store)
.map_ok(Self::Dense)
.await
}
TensorType::Sparse => {
SparseBase::load(txn_id, schema, store)
.map_ok(Self::Sparse)
.await
}
}
}
fn dir(&self) -> fs::Inner<FE> {
match self {
Self::Dense(dense) => dense.dir(),
Self::Sparse(sparse) => sparse.dir(),
}
}
}
#[async_trait]
impl<Txn, FE> fs::CopyFrom<FE, TensorView<Txn, FE>> for TensorBase<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
async fn copy_from(
txn: &Txn,
store: fs::Dir<FE>,
instance: TensorView<Txn, FE>,
) -> TCResult<Self> {
match instance {
TensorView::Dense(dense) => {
DenseBase::copy_from(txn, store, dense)
.map_ok(Self::Dense)
.await
}
TensorView::Sparse(sparse) => {
SparseBase::copy_from(txn, store, sparse)
.map_ok(Self::Sparse)
.await
}
}
}
}
#[async_trait]
impl<Txn, FE> fs::Restore<FE> for TensorBase<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
async fn restore(&self, txn_id: TxnId, backup: &Self) -> TCResult<()> {
match (self, backup) {
(Self::Dense(this), Self::Dense(that)) => this.restore(txn_id, that).await,
(Self::Sparse(this), Self::Sparse(that)) => this.restore(txn_id, that).await,
(this, that) => Err(bad_request!("cannot restore {this:?} from {that:?}")),
}
}
}
#[async_trait]
impl<Txn, FE> de::FromStream for TensorBase<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Context = Txn;
async fn from_stream<D: de::Decoder>(txn: Txn, decoder: &mut D) -> Result<Self, D::Error> {
let visitor = TensorVisitor::new(txn);
decoder.decode_map(visitor).await
}
}
impl<Txn, FE> From<TensorBase<Txn, FE>> for Tensor<Txn, FE> {
fn from(base: TensorBase<Txn, FE>) -> Tensor<Txn, FE> {
match base {
TensorBase::Dense(base) => Tensor::Dense(Dense::Base(base)),
TensorBase::Sparse(base) => Tensor::Sparse(Sparse::Base(base)),
}
}
}
impl<Txn: ThreadSafe, FE: ThreadSafe> fmt::Debug for TensorBase<Txn, FE> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Dense(dense) => dense.fmt(f),
Self::Sparse(sparse) => sparse.fmt(f),
}
}
}
pub enum TensorView<Txn, FE> {
Dense(DenseView<Txn, FE>),
Sparse(SparseView<Txn, FE>),
}
impl<Txn, FE> Clone for TensorView<Txn, FE> {
fn clone(&self) -> Self {
match self {
Self::Dense(dense) => Self::Dense(dense.clone()),
Self::Sparse(sparse) => Self::Sparse(sparse.clone()),
}
}
}
struct TensorVisitor<Txn, FE> {
txn: Txn,
phantom: PhantomData<FE>,
}
impl<Txn, FE> TensorVisitor<Txn, FE> {
fn new(txn: Txn) -> Self {
Self {
txn,
phantom: PhantomData,
}
}
}
#[async_trait]
impl<Txn, FE> de::Visitor for TensorVisitor<Txn, FE>
where
Txn: Transaction<FE>,
FE: DenseCacheFile + AsType<Node> + Clone,
{
type Value = TensorBase<Txn, FE>;
fn expecting() -> &'static str {
"a tensor"
}
async fn visit_map<A: de::MapAccess>(self, mut map: A) -> Result<Self::Value, A::Error> {
let class = map.next_key::<TensorType>(()).await?;
let class = class.ok_or_else(|| de::Error::invalid_length(0, Self::expecting()))?;
match class {
TensorType::Dense => map.next_value(self.txn).map_ok(TensorBase::Dense).await,
TensorType::Sparse => map.next_value(self.txn).map_ok(TensorBase::Sparse).await,
}
}
}
pub fn broadcast<L, R>(left: L, right: R) -> TCResult<(L::Broadcast, R::Broadcast)>
where
L: TensorInstance + TensorTransform + fmt::Debug,
R: TensorInstance + TensorTransform + fmt::Debug,
{
let broadcast_shape = broadcast_shape(left.shape().as_slice(), right.shape().as_slice())?;
Ok((
left.broadcast(broadcast_shape.clone())?,
right.broadcast(broadcast_shape)?,
))
}
pub fn broadcast_shape(left: &[u64], right: &[u64]) -> TCResult<Shape> {
let ndim = Ord::max(left.len(), right.len());
let left = iter::repeat(1)
.take(ndim - left.len())
.chain(left.iter().copied());
let right = iter::repeat(1)
.take(ndim - right.len())
.chain(right.iter().copied());
left.into_iter()
.zip(right)
.map(|(l, r)| {
if l == r || l == 1 {
Ok(r)
} else if r == 1 {
Ok(l)
} else {
Err(bad_request!("cannot broadcast dimension {l} into {r}"))
}
})
.collect()
}
#[inline]
fn coord_of<T: Copy + Div<Output = T> + Rem<Output = T> + PartialEq>(
offset: T,
strides: &[T],
shape: &[T],
zero: T,
) -> SmallVec<[T; 8]> {
strides
.iter()
.copied()
.map(|stride| {
if stride == zero {
zero
} else {
offset / stride
}
})
.zip(shape.iter().copied())
.map(|(axis_offset, dim)| axis_offset % dim)
.collect()
}
#[inline]
fn offset_of(coord: Coord, shape: &[u64]) -> u64 {
let strides = shape.iter().enumerate().map(|(x, dim)| {
if *dim == 1 {
0
} else {
shape.iter().rev().take(shape.len() - 1 - x).product()
}
});
coord.into_iter().zip(strides).map(|(i, dim)| i * dim).sum()
}
#[inline]
fn strides_for(shape: &[u64], ndim: usize) -> impl Iterator<Item = u64> + '_ {
debug_assert!(ndim >= shape.len());
let zeros = std::iter::repeat(0).take(ndim - shape.len());
let strides = shape.iter().enumerate().map(|(x, dim)| {
if *dim == 1 {
0
} else {
shape.iter().rev().take(shape.len() - 1 - x).product()
}
});
zeros.chain(strides)
}