use bitvec::vec::BitVec;
use std::{
fmt::{Debug, Display},
ops::Index,
};
use crate::structure::{
concrete_index::FlatIndex, dimension::Dimension, representation::Representation,
slot::IsAbstractSlot, TensorStructure,
};
use super::indices::{AbstractFiberIndex, FiberClassIndex, FiberData, FiberIndex};
use super::traits::AbstractFiber;
#[derive(Debug, Clone)]
struct BareFiber {
indices: Vec<FiberIndex>,
is_single: FiberIndex,
}
impl BareFiber {
pub fn from<I: TensorStructure>(data: FiberData, structure: &I) -> Self {
match data {
FiberData::Flat(i) => Self::from_flat(i, structure),
FiberData::BoolFilter(b) => Self::from_filter(b),
FiberData::BitVec(b) => Self::from_bitvec(b),
FiberData::Single(i) => {
let mut out = Self::zeros(structure);
out.free(i);
out
}
FiberData::IntFilter(i) => {
let mut out = Self::zeros(structure);
for (pos, val) in i.iter().enumerate() {
if *val > 0 {
out.free(pos);
}
}
out
}
FiberData::Pos(i) => {
let mut out = Self::zeros(structure);
for (pos, val) in i.iter().enumerate() {
if *val < 0 {
out.free(pos);
} else {
out.fix(pos, *val as usize);
}
}
out
}
}
}
pub fn bitvec(&self) -> BitVec {
self.indices.iter().map(|x| x.is_free()).collect()
}
pub fn bitvecinv(&self) -> BitVec {
self.indices.iter().map(|x| x.is_fixed()).collect()
}
pub fn from_flat<I>(flat: FlatIndex, structure: &I) -> BareFiber
where
I: TensorStructure,
{
let expanded = structure.expanded_index(flat).unwrap();
BareFiber {
indices: expanded.into_iter().map(FiberIndex::from).collect(),
is_single: FiberIndex::Free,
}
}
pub fn from_filter(filter: &[bool]) -> BareFiber {
let mut f = BareFiber {
indices: filter
.iter()
.map(|i| {
if *i {
FiberIndex::Free
} else {
FiberIndex::Fixed(0)
}
})
.collect(),
is_single: FiberIndex::Free,
};
f.is_single();
f
}
pub fn from_bitvec(filter: &BitVec) -> BareFiber {
let mut f = BareFiber {
indices: filter
.iter()
.map(|i| {
if *i {
FiberIndex::Free
} else {
FiberIndex::Fixed(0)
}
})
.collect(),
is_single: FiberIndex::Free,
};
f.is_single();
f
}
pub fn zeros<I: TensorStructure>(structure: &I) -> BareFiber {
BareFiber {
indices: vec![FiberIndex::Fixed(0); structure.order()],
is_single: FiberIndex::Free,
}
}
pub fn fix(&mut self, pos: usize, val: usize) {
if let FiberIndex::Fixed(single_pos) = self.is_single {
if single_pos == pos {
self.is_single = FiberIndex::Free;
}
}
self.indices[pos] = val.into();
}
pub fn is_single(&mut self) -> FiberIndex {
if let FiberIndex::Fixed(pos) = self.is_single {
FiberIndex::Fixed(pos)
} else {
let mut has_one = false;
let mut has_two = false;
let mut pos = 0;
for (posi, index) in self.indices.iter().enumerate() {
if let FiberIndex::Free = index {
if !has_one {
has_one = true;
pos = posi;
} else {
has_two = true;
}
}
}
if has_one && !has_two {
self.is_single = FiberIndex::Fixed(pos);
return FiberIndex::Fixed(pos);
}
self.is_single = FiberIndex::Free;
FiberIndex::Free
}
}
pub fn free(&mut self, pos: usize) {
self.indices[pos] = FiberIndex::Free;
}
}
impl Index<usize> for BareFiber {
type Output = FiberIndex;
fn index(&self, index: usize) -> &Self::Output {
&(self.indices[index])
}
}
#[derive(Debug)]
pub struct Fiber<'a, I: TensorStructure> {
pub(crate) structure: &'a I,
bare_fiber: BareFiber,
}
impl<I: TensorStructure> Clone for Fiber<'_, I> {
fn clone(&self) -> Self {
Fiber {
structure: self.structure,
bare_fiber: self.bare_fiber.clone(),
}
}
}
impl<I> Index<usize> for Fiber<'_, I>
where
I: TensorStructure,
{
type Output = FiberIndex;
fn index(&self, index: usize) -> &Self::Output {
&(self.bare_fiber[index])
}
}
impl<I> AbstractFiber<FiberIndex> for Fiber<'_, I>
where
I: TensorStructure,
{
type Repr = <I::Slot as IsAbstractSlot>::R;
fn strides(&self) -> Vec<usize> {
self.structure.strides().unwrap()
}
fn reps(&self) -> Vec<Representation<Self::Repr>> {
self.structure.reps()
}
fn shape(&self) -> Vec<Dimension> {
self.structure.shape()
}
fn order(&self) -> usize {
self.structure.order()
}
fn single(&self) -> Option<usize> {
if let FiberIndex::Fixed(pos) = self.bare_fiber.is_single {
Some(pos)
} else {
None
}
}
fn bitvec(&self) -> BitVec {
self.bare_fiber.bitvec()
}
}
impl<'a, S> Fiber<'a, S>
where
S: TensorStructure,
{
pub fn conj(self) -> Self {
self
}
pub fn iter(
self,
) -> super::fiber_iterators::FiberIterator<'a, S, super::core_iterators::CoreFlatFiberIterator>
{
super::fiber_iterators::FiberIterator::new(self, false)
}
pub fn iter_conj(
self,
) -> super::fiber_iterators::FiberIterator<'a, S, super::core_iterators::CoreFlatFiberIterator>
{
super::fiber_iterators::FiberIterator::new(self, true)
}
pub fn iter_perm(
self,
permutation: linnet::permutation::Permutation,
) -> super::fiber_iterators::FiberIterator<
'a,
S,
super::core_iterators::CoreExpandedFiberIterator<<S::Slot as IsAbstractSlot>::R>,
> {
super::fiber_iterators::FiberIterator::new_permuted(self, permutation, false)
}
pub fn iter_metric(
self,
) -> super::fiber_iterators::FiberIterator<
'a,
S,
super::core_iterators::MetricFiberIterator<<S::Slot as IsAbstractSlot>::R>,
> {
super::fiber_iterators::FiberIterator::new(self, false)
}
pub fn iter_perm_metric(
self,
permutation: linnet::permutation::Permutation,
) -> super::fiber_iterators::FiberIterator<
'a,
S,
super::core_iterators::MetricFiberIterator<<S::Slot as IsAbstractSlot>::R>,
> {
super::fiber_iterators::FiberIterator::new_permuted(self, permutation, false)
}
pub fn from<'b>(data: impl Into<FiberData<'b>>, structure: &'a S) -> Self {
Fiber {
bare_fiber: BareFiber::from(data.into(), structure),
structure,
}
}
pub fn bitvec(&self) -> BitVec {
self.bare_fiber.bitvec()
}
pub fn bitvecinv(&self) -> BitVec {
self.bare_fiber.bitvecinv()
}
pub fn from_flat(flat: FlatIndex, structure: &'a S) -> Fiber<'a, S> {
Fiber {
bare_fiber: BareFiber::from_flat(flat, structure),
structure,
}
}
pub fn from_filter(filter: &[bool], structure: &'a S) -> Fiber<'a, S> {
Fiber {
bare_fiber: BareFiber::from_filter(filter),
structure,
}
}
pub fn zeros(structure: &'a S) -> Fiber<'a, S> {
Fiber {
bare_fiber: BareFiber::zeros(structure),
structure,
}
}
pub fn fix(&mut self, pos: usize, val: usize) {
self.bare_fiber.fix(pos, val);
}
pub fn is_single(&mut self) -> FiberIndex {
self.bare_fiber.is_single()
}
pub fn free(&mut self, pos: usize) {
self.bare_fiber.free(pos);
}
}
impl<I: TensorStructure> Display for Fiber<'_, I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for index in self.bare_fiber.indices.iter() {
write!(f, "{} ", index)?
}
Ok(())
}
}
#[derive(Debug)]
pub struct FiberMut<'a, I: TensorStructure> {
pub structure: &'a mut I,
bare_fiber: BareFiber,
}
impl<I> Index<usize> for FiberMut<'_, I>
where
I: TensorStructure,
{
type Output = FiberIndex;
fn index(&self, index: usize) -> &Self::Output {
&(self.bare_fiber[index])
}
}
impl<I> AbstractFiber<FiberIndex> for FiberMut<'_, I>
where
I: TensorStructure,
{
type Repr = <I::Slot as IsAbstractSlot>::R;
fn strides(&self) -> Vec<usize> {
self.structure.strides().unwrap()
}
fn reps(&self) -> Vec<Representation<Self::Repr>> {
self.structure.reps()
}
fn shape(&self) -> Vec<Dimension> {
self.structure.shape()
}
fn order(&self) -> usize {
self.structure.order()
}
fn single(&self) -> Option<usize> {
if let FiberIndex::Fixed(pos) = self.bare_fiber.is_single {
Some(pos)
} else {
None
}
}
fn bitvec(&self) -> BitVec {
self.bare_fiber.bitvec()
}
}
impl<'a, I> FiberMut<'a, I>
where
I: TensorStructure,
{
pub fn from<'b>(data: FiberData<'b>, structure: &'a mut I) -> Self {
FiberMut {
bare_fiber: BareFiber::from(data, &*structure),
structure,
}
}
pub fn conj(self) -> Self {
self
}
pub fn bitvec(&self) -> BitVec {
self.bare_fiber.bitvec()
}
pub fn bitvecinv(&self) -> BitVec {
self.bare_fiber.bitvecinv()
}
pub fn zeros(structure: &'a I) -> Fiber<'a, I> {
Fiber {
bare_fiber: BareFiber::zeros(structure),
structure,
}
}
pub fn fix(&mut self, pos: usize, val: usize) {
self.bare_fiber.fix(pos, val);
}
pub fn is_single(&mut self) -> FiberIndex {
self.bare_fiber.is_single()
}
pub fn free(&mut self, pos: usize) {
self.bare_fiber.free(pos);
}
}
impl<I: TensorStructure> Display for FiberMut<'_, I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for index in self.bare_fiber.indices.iter() {
write!(f, "{} ", index)?
}
Ok(())
}
}
impl<'a, I: TensorStructure> FiberMut<'a, I> {
pub fn iter(
self,
) -> super::fiber_iterators::MutFiberIterator<'a, I, super::core_iterators::CoreFlatFiberIterator>
{
super::fiber_iterators::MutFiberIterator::new(self, false)
}
}
pub struct FiberClass<'a, I: TensorStructure> {
structure: &'a I,
bare_fiber: BareFiber, }
impl<I: TensorStructure> Clone for FiberClass<'_, I> {
fn clone(&self) -> Self {
FiberClass {
bare_fiber: self.bare_fiber.clone(),
structure: self.structure,
}
}
}
impl<I> Index<usize> for FiberClass<'_, I>
where
I: TensorStructure,
{
type Output = FiberClassIndex;
fn index(&self, index: usize) -> &Self::Output {
if self.bare_fiber[index].is_fixed() {
&FiberClassIndex::Free
} else {
&FiberClassIndex::Fixed
}
}
}
impl<'a, I: TensorStructure> From<Fiber<'a, I>> for FiberClass<'a, I> {
fn from(fiber: Fiber<'a, I>) -> Self {
FiberClass {
bare_fiber: fiber.bare_fiber,
structure: fiber.structure,
}
}
}
impl<'a, I: TensorStructure> From<FiberClass<'a, I>> for Fiber<'a, I> {
fn from(fiber: FiberClass<'a, I>) -> Self {
Fiber {
bare_fiber: fiber.bare_fiber,
structure: fiber.structure,
}
}
}
impl<I: TensorStructure> AbstractFiber<FiberClassIndex> for FiberClass<'_, I> {
type Repr = <I::Slot as IsAbstractSlot>::R;
fn strides(&self) -> Vec<usize> {
self.structure.strides().unwrap()
}
fn shape(&self) -> Vec<Dimension> {
self.structure.shape()
}
fn reps(&self) -> Vec<Representation<Self::Repr>> {
self.structure.reps()
}
fn order(&self) -> usize {
self.structure.order()
}
fn single(&self) -> Option<usize> {
match self.bare_fiber.is_single {
FiberIndex::Fixed(i) => Some(i),
_ => None,
}
}
fn bitvec(&self) -> BitVec {
!self.bare_fiber.bitvec()
}
}
impl<'a, S: TensorStructure> FiberClass<'a, S> {
pub fn iter(self) -> super::fiber_iterators::FiberClassIterator<'a, S> {
super::fiber_iterators::FiberClassIterator::new(self)
}
pub fn iter_perm(
self,
permutation: linnet::permutation::Permutation,
) -> super::fiber_iterators::FiberClassIterator<
'a,
S,
super::core_iterators::CoreExpandedFiberIterator<<S::Slot as IsAbstractSlot>::R>,
> {
super::fiber_iterators::FiberClassIterator::new_permuted(self, permutation)
}
pub fn iter_perm_metric(
self,
permutation: linnet::permutation::Permutation,
) -> super::fiber_iterators::FiberClassIterator<
'a,
S,
super::core_iterators::MetricFiberIterator<<S::Slot as IsAbstractSlot>::R>,
> {
super::fiber_iterators::FiberClassIterator::new_permuted(self, permutation)
}
}
pub struct FiberClassMut<'a, I: TensorStructure> {
structure: &'a mut I,
bare_fiber: BareFiber, }
impl<I> Index<usize> for FiberClassMut<'_, I>
where
I: TensorStructure,
{
type Output = FiberClassIndex;
fn index(&self, index: usize) -> &Self::Output {
if self.bare_fiber[index].is_fixed() {
&FiberClassIndex::Free
} else {
&FiberClassIndex::Fixed
}
}
}
impl<'a, I: TensorStructure> From<FiberMut<'a, I>> for FiberClassMut<'a, I> {
fn from(fiber: FiberMut<'a, I>) -> Self {
FiberClassMut {
bare_fiber: fiber.bare_fiber,
structure: fiber.structure,
}
}
}
impl<'a, I: TensorStructure> From<FiberClassMut<'a, I>> for FiberMut<'a, I> {
fn from(fiber: FiberClassMut<'a, I>) -> Self {
FiberMut {
bare_fiber: fiber.bare_fiber,
structure: fiber.structure,
}
}
}
impl<I: TensorStructure> AbstractFiber<FiberClassIndex> for FiberClassMut<'_, I> {
type Repr = <I::Slot as IsAbstractSlot>::R;
fn strides(&self) -> Vec<usize> {
self.structure.strides().unwrap()
}
fn shape(&self) -> Vec<Dimension> {
self.structure.shape()
}
fn reps(&self) -> Vec<Representation<Self::Repr>> {
self.structure.reps()
}
fn order(&self) -> usize {
self.structure.order()
}
fn single(&self) -> Option<usize> {
match self.bare_fiber.is_single {
FiberIndex::Fixed(i) => Some(i),
_ => None,
}
}
fn bitvec(&self) -> BitVec {
!self.bare_fiber.bitvec()
}
}