use std::any::Any;
use std::cell::RefCell;
use std::ops::{Add, Sub, Mul, Div, Index, BitAnd, BitOr, BitXor, Not};
use std::rc::Rc;
use std::sync::Arc;
use rayon::prelude::*;
const DEFAULT_CAPACITY: usize = 128;
trait ErasedOperation: Send + Sync {
fn execute(&self, results: &[Option<Arc<dyn Any + Send + Sync>>]) -> Arc<dyn Any + Send + Sync>;
fn result_id(&self) -> usize;
fn dependencies(&self) -> Vec<usize>;
}
#[derive(Clone, Copy, Debug)]
enum BinaryOpType {
Add,
Sub,
Mul,
Div,
}
#[derive(Clone)]
enum OperandSource<T: Clone + Send + Sync> {
Direct(Arc<Vec<T>>),
Pending(usize),
}
struct BinaryOp<T: Clone + Send + Sync + 'static> {
op_type: BinaryOpType,
left: OperandSource<T>,
right: OperandSource<T>,
result_id: usize,
}
impl<T> ErasedOperation for BinaryOp<T>
where
T: Clone + Send + Sync + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T> + 'static,
{
fn execute(&self, results: &[Option<Arc<dyn Any + Send + Sync>>]) -> Arc<dyn Any + Send + Sync> {
let left_data = get_data::<T>(&self.left, results);
let right_data = get_data::<T>(&self.right, results);
let len = left_data.len().min(right_data.len());
let op_type = self.op_type;
let result: Vec<T> = (0..len)
.into_par_iter()
.map(|i| {
let l = left_data[i].clone();
let r = right_data[i].clone();
match op_type {
BinaryOpType::Add => l + r,
BinaryOpType::Sub => l - r,
BinaryOpType::Mul => l * r,
BinaryOpType::Div => l / r,
}
})
.collect();
Arc::new(result)
}
fn result_id(&self) -> usize {
self.result_id
}
fn dependencies(&self) -> Vec<usize> {
let mut deps = Vec::new();
if let OperandSource::Pending(id) = &self.left {
deps.push(*id);
}
if let OperandSource::Pending(id) = &self.right {
deps.push(*id);
}
deps
}
}
struct MapOp<T: Clone + Send + Sync + 'static> {
source: OperandSource<T>,
func: Arc<dyn Fn(usize, &T) -> T + Send + Sync>,
len: usize,
result_id: usize,
}
impl<T: Clone + Send + Sync + 'static> ErasedOperation for MapOp<T> {
fn execute(&self, results: &[Option<Arc<dyn Any + Send + Sync>>]) -> Arc<dyn Any + Send + Sync> {
let data = get_data::<T>(&self.source, results);
let len = self.len.min(data.len());
let func = &self.func;
let result: Vec<T> = (0..len)
.into_par_iter()
.map(|i| func(i, &data[i]))
.collect();
Arc::new(result)
}
fn result_id(&self) -> usize {
self.result_id
}
fn dependencies(&self) -> Vec<usize> {
if let OperandSource::Pending(id) = &self.source {
vec![*id]
} else {
vec![]
}
}
}
struct MapWhereOp<T: Clone + Send + Sync + 'static> {
source: OperandSource<T>,
condition: Arc<dyn Fn(usize, &T) -> bool + Send + Sync>,
if_true: Arc<dyn Fn(usize, &T) -> T + Send + Sync>,
if_false: Arc<dyn Fn(usize, &T) -> T + Send + Sync>,
len: usize,
result_id: usize,
}
impl<T: Clone + Send + Sync + 'static> ErasedOperation for MapWhereOp<T> {
fn execute(&self, results: &[Option<Arc<dyn Any + Send + Sync>>]) -> Arc<dyn Any + Send + Sync> {
let data = get_data::<T>(&self.source, results);
let len = self.len.min(data.len());
let result: Vec<T> = (0..len)
.into_par_iter()
.map(|i| {
if (self.condition)(i, &data[i]) {
(self.if_true)(i, &data[i])
} else {
(self.if_false)(i, &data[i])
}
})
.collect();
Arc::new(result)
}
fn result_id(&self) -> usize {
self.result_id
}
fn dependencies(&self) -> Vec<usize> {
if let OperandSource::Pending(id) = &self.source {
vec![*id]
} else {
vec![]
}
}
}
struct BlendOp<T: Clone + Send + Sync + 'static> {
if_false: OperandSource<T>,
if_true: OperandSource<T>,
mask: Arc<Vec<bool>>,
len: usize,
result_id: usize,
}
impl<T: Clone + Send + Sync + 'static> ErasedOperation for BlendOp<T> {
fn execute(&self, results: &[Option<Arc<dyn Any + Send + Sync>>]) -> Arc<dyn Any + Send + Sync> {
let false_data = get_data::<T>(&self.if_false, results);
let true_data = get_data::<T>(&self.if_true, results);
let mask = &self.mask;
let len = self.len;
let result: Vec<T> = (0..len)
.into_par_iter()
.map(|i| {
if mask[i] {
true_data[i].clone()
} else {
false_data[i].clone()
}
})
.collect();
Arc::new(result)
}
fn result_id(&self) -> usize {
self.result_id
}
fn dependencies(&self) -> Vec<usize> {
let mut deps = Vec::new();
if let OperandSource::Pending(id) = &self.if_false {
deps.push(*id);
}
if let OperandSource::Pending(id) = &self.if_true {
deps.push(*id);
}
deps
}
}
struct MaskedApplyOp<T: Clone + Send + Sync + 'static> {
source: OperandSource<T>,
mask: Arc<Vec<bool>>,
func: Arc<dyn Fn(usize, &T) -> T + Send + Sync>,
len: usize,
result_id: usize,
}
impl<T: Clone + Send + Sync + 'static> ErasedOperation for MaskedApplyOp<T> {
fn execute(&self, results: &[Option<Arc<dyn Any + Send + Sync>>]) -> Arc<dyn Any + Send + Sync> {
let data = get_data::<T>(&self.source, results);
let mask = &self.mask;
let len = self.len;
let func = &self.func;
let result: Vec<T> = (0..len)
.into_par_iter()
.map(|i| {
if mask[i] {
func(i, &data[i])
} else {
data[i].clone()
}
})
.collect();
Arc::new(result)
}
fn result_id(&self) -> usize {
self.result_id
}
fn dependencies(&self) -> Vec<usize> {
if let OperandSource::Pending(id) = &self.source {
vec![*id]
} else {
vec![]
}
}
}
struct FillOp<T: Clone + Send + Sync + 'static> {
func: Arc<dyn Fn(usize) -> T + Send + Sync>,
len: usize,
result_id: usize,
}
impl<T: Clone + Send + Sync + 'static> ErasedOperation for FillOp<T> {
fn execute(&self, _results: &[Option<Arc<dyn Any + Send + Sync>>]) -> Arc<dyn Any + Send + Sync> {
let len = self.len;
let func = &self.func;
let result: Vec<T> = (0..len)
.into_par_iter()
.map(|i| func(i))
.collect();
Arc::new(result)
}
fn result_id(&self) -> usize {
self.result_id
}
fn dependencies(&self) -> Vec<usize> {
vec![]
}
}
fn get_data<T: Clone + Send + Sync + 'static>(
source: &OperandSource<T>,
results: &[Option<Arc<dyn Any + Send + Sync>>],
) -> Arc<Vec<T>> {
match source {
OperandSource::Direct(data) => Arc::clone(data),
OperandSource::Pending(id) => {
let any_ref = results[*id].as_ref().unwrap();
any_ref.clone().downcast::<Vec<T>>().unwrap()
}
}
}
#[derive(Clone)]
pub struct LMask {
data: Arc<Vec<bool>>,
len: usize,
}
impl LMask {
pub fn new(len: usize, value: bool) -> Self {
LMask {
data: Arc::new(vec![value; len]),
len,
}
}
pub fn from_fn<F>(len: usize, f: F) -> Self
where
F: Fn(usize) -> bool,
{
let data: Vec<bool> = (0..len).map(f).collect();
LMask {
data: Arc::new(data),
len,
}
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_slice(&self) -> &[bool] {
&self.data[..self.len]
}
}
impl Index<usize> for LMask {
type Output = bool;
#[inline]
fn index(&self, index: usize) -> &Self::Output {
&self.data[index]
}
}
impl BitAnd for &LMask {
type Output = LMask;
fn bitand(self, other: Self) -> Self::Output {
let len = self.len.min(other.len);
let data: Vec<bool> = (0..len).map(|i| self.data[i] && other.data[i]).collect();
LMask { data: Arc::new(data), len }
}
}
impl BitAnd for LMask {
type Output = LMask;
fn bitand(self, other: Self) -> Self::Output { (&self).bitand(&other) }
}
impl BitOr for &LMask {
type Output = LMask;
fn bitor(self, other: Self) -> Self::Output {
let len = self.len.min(other.len);
let data: Vec<bool> = (0..len).map(|i| self.data[i] || other.data[i]).collect();
LMask { data: Arc::new(data), len }
}
}
impl BitOr for LMask {
type Output = LMask;
fn bitor(self, other: Self) -> Self::Output { (&self).bitor(&other) }
}
impl BitXor for &LMask {
type Output = LMask;
fn bitxor(self, other: Self) -> Self::Output {
let len = self.len.min(other.len);
let data: Vec<bool> = (0..len).map(|i| self.data[i] ^ other.data[i]).collect();
LMask { data: Arc::new(data), len }
}
}
impl BitXor for LMask {
type Output = LMask;
fn bitxor(self, other: Self) -> Self::Output { (&self).bitxor(&other) }
}
impl Not for &LMask {
type Output = LMask;
fn not(self) -> Self::Output {
let data: Vec<bool> = self.data.iter().map(|&b| !b).collect();
LMask { data: Arc::new(data), len: self.len }
}
}
impl Not for LMask {
type Output = LMask;
fn not(self) -> Self::Output { (&self).not() }
}
impl std::fmt::Debug for LMask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "LMask({:?})", self.as_slice())
}
}
struct LQueueInner {
operations: Vec<Box<dyn ErasedOperation>>,
results: Vec<Option<Arc<dyn Any + Send + Sync>>>,
recording: bool,
next_result_id: usize,
}
#[derive(Clone)]
pub struct LQueue {
inner: Rc<RefCell<LQueueInner>>,
}
impl LQueue {
pub fn new() -> Self {
LQueue {
inner: Rc::new(RefCell::new(LQueueInner {
operations: Vec::new(),
results: Vec::new(),
recording: false,
next_result_id: 0,
})),
}
}
pub fn lvec<T>(&self) -> LVec<T>
where
T: Clone + Send + Sync + Default + 'static,
{
self.lvec_with_capacity(DEFAULT_CAPACITY)
}
pub fn lvec_with_capacity<T>(&self, capacity: usize) -> LVec<T>
where
T: Clone + Send + Sync + Default + 'static,
{
LVec {
data: Arc::new(vec![T::default(); capacity]),
len: capacity,
capacity,
queue: Rc::clone(&self.inner),
pending_result_id: None,
}
}
pub fn start(&self) {
let mut inner = self.inner.borrow_mut();
inner.recording = true;
inner.operations.clear();
inner.results.clear();
inner.next_result_id = 0;
}
pub fn end(&self) {
let mut inner = self.inner.borrow_mut();
inner.recording = false;
Self::execute_all(&mut inner);
}
fn execute_all(inner: &mut LQueueInner) {
if inner.operations.is_empty() {
return;
}
inner.results.resize_with(inner.next_result_id, || None);
let mut levels: Vec<Vec<usize>> = Vec::new();
let mut op_levels: Vec<usize> = vec![0; inner.operations.len()];
let mut result_to_op: std::collections::HashMap<usize, usize> = std::collections::HashMap::new();
for (i, op) in inner.operations.iter().enumerate() {
result_to_op.insert(op.result_id(), i);
}
for (i, op) in inner.operations.iter().enumerate() {
let mut level = 0;
for dep_id in op.dependencies() {
if let Some(&dep_op_idx) = result_to_op.get(&dep_id) {
level = level.max(op_levels[dep_op_idx] + 1);
}
}
op_levels[i] = level;
while levels.len() <= level {
levels.push(Vec::new());
}
levels[level].push(i);
}
for level_ops in levels {
for &op_idx in &level_ops {
let result = inner.operations[op_idx].execute(&inner.results);
let result_id = inner.operations[op_idx].result_id();
inner.results[result_id] = Some(result);
}
}
}
pub fn is_recording(&self) -> bool {
self.inner.borrow().recording
}
}
impl Default for LQueue {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct LVec<T: Clone + Send + Sync + 'static> {
data: Arc<Vec<T>>,
len: usize,
capacity: usize,
queue: Rc<RefCell<LQueueInner>>,
pending_result_id: Option<usize>,
}
impl<T: Clone + Send + Sync + 'static> LVec<T> {
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
#[inline]
pub fn is_pending(&self) -> bool {
self.pending_result_id.is_some()
}
pub fn materialize(&self) -> Option<Arc<Vec<T>>> {
if let Some(result_id) = self.pending_result_id {
let inner = self.queue.borrow();
inner.results.get(result_id).and_then(|r| {
r.as_ref().and_then(|arc| arc.clone().downcast::<Vec<T>>().ok())
})
} else {
Some(Arc::clone(&self.data))
}
}
#[inline]
pub fn as_slice(&self) -> &[T] {
&self.data[..self.len]
}
fn get_source(&self) -> OperandSource<T> {
if let Some(id) = self.pending_result_id {
OperandSource::Pending(id)
} else {
OperandSource::Direct(Arc::clone(&self.data))
}
}
fn create_pending(&self, result_id: usize, len: usize) -> Self {
LVec {
data: Arc::new(Vec::new()),
len,
capacity: len,
queue: Rc::clone(&self.queue),
pending_result_id: Some(result_id),
}
}
pub fn fill_with<F>(&self, f: F) -> Self
where
F: Fn(usize) -> T + Send + Sync + 'static,
{
let mut inner = self.queue.borrow_mut();
if !inner.recording {
panic!("fill_with must be called between q.start() and q.end()");
}
let result_id = inner.next_result_id;
inner.next_result_id += 1;
inner.operations.push(Box::new(FillOp {
func: Arc::new(f),
len: self.len,
result_id,
}));
drop(inner);
self.create_pending(result_id, self.len)
}
pub fn fill(&self, value: T) -> Self {
self.fill_with(move |_| value.clone())
}
pub fn map<F>(&self, f: F) -> Self
where
F: Fn(usize, &T) -> T + Send + Sync + 'static,
{
let mut inner = self.queue.borrow_mut();
if !inner.recording {
panic!("map must be called between q.start() and q.end()");
}
let result_id = inner.next_result_id;
inner.next_result_id += 1;
inner.operations.push(Box::new(MapOp {
source: self.get_source(),
func: Arc::new(f),
len: self.len,
result_id,
}));
drop(inner);
self.create_pending(result_id, self.len)
}
pub fn map_where<C, TF, FF>(&self, condition: C, if_true: TF, if_false: FF) -> Self
where
C: Fn(usize, &T) -> bool + Send + Sync + 'static,
TF: Fn(usize, &T) -> T + Send + Sync + 'static,
FF: Fn(usize, &T) -> T + Send + Sync + 'static,
{
let mut inner = self.queue.borrow_mut();
if !inner.recording {
panic!("map_where must be called between q.start() and q.end()");
}
let result_id = inner.next_result_id;
inner.next_result_id += 1;
inner.operations.push(Box::new(MapWhereOp {
source: self.get_source(),
condition: Arc::new(condition),
if_true: Arc::new(if_true),
if_false: Arc::new(if_false),
len: self.len,
result_id,
}));
drop(inner);
self.create_pending(result_id, self.len)
}
pub fn mask<F>(&self, predicate: F) -> LMask
where
F: Fn(usize, &T) -> bool,
{
let data: Vec<bool> = (0..self.len)
.map(|i| predicate(i, &self.data[i]))
.collect();
LMask {
data: Arc::new(data),
len: self.len,
}
}
pub fn blend(&self, other: &Self, mask: &LMask) -> Self {
let mut inner = self.queue.borrow_mut();
if !inner.recording {
panic!("blend must be called between q.start() and q.end()");
}
let result_id = inner.next_result_id;
inner.next_result_id += 1;
let len = self.len.min(other.len).min(mask.len());
inner.operations.push(Box::new(BlendOp {
if_false: self.get_source(),
if_true: other.get_source(),
mask: Arc::clone(&mask.data),
len,
result_id,
}));
drop(inner);
self.create_pending(result_id, len)
}
pub fn select(mask: &LMask, if_true: &Self, if_false: &Self) -> Self {
if_false.blend(if_true, mask)
}
pub fn masked_apply<F>(&self, mask: &LMask, f: F) -> Self
where
F: Fn(usize, &T) -> T + Send + Sync + 'static,
{
let mut inner = self.queue.borrow_mut();
if !inner.recording {
panic!("masked_apply must be called between q.start() and q.end()");
}
let result_id = inner.next_result_id;
inner.next_result_id += 1;
let len = self.len.min(mask.len());
inner.operations.push(Box::new(MaskedApplyOp {
source: self.get_source(),
mask: Arc::clone(&mask.data),
func: Arc::new(f),
len,
result_id,
}));
drop(inner);
self.create_pending(result_id, len)
}
pub fn masked_fill(&self, mask: &LMask, value: T) -> Self
where
T: 'static,
{
self.masked_apply(mask, move |_, _| value.clone())
}
}
impl<T: Clone + Send + Sync + Default + 'static> Default for LVec<T> {
fn default() -> Self {
panic!("LVec must be created via LQueue::lvec() or LQueue::lvec_with_capacity()")
}
}
impl<T: Clone + Send + Sync + 'static> Index<usize> for LVec<T> {
type Output = T;
#[inline]
fn index(&self, index: usize) -> &Self::Output {
&self.data[index]
}
}
fn record_binary_op<T>(left: &LVec<T>, right: &LVec<T>, op_type: BinaryOpType) -> LVec<T>
where
T: Clone + Send + Sync + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T> + 'static,
{
let mut inner = left.queue.borrow_mut();
if !inner.recording {
panic!("Binary operations must be called between q.start() and q.end()");
}
let result_id = inner.next_result_id;
inner.next_result_id += 1;
let len = left.len.min(right.len);
inner.operations.push(Box::new(BinaryOp {
op_type,
left: left.get_source(),
right: right.get_source(),
result_id,
}));
drop(inner);
left.create_pending(result_id, len)
}
impl<T> Add for LVec<T>
where
T: Clone + Send + Sync + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T> + 'static,
{
type Output = Self;
fn add(self, other: Self) -> Self::Output {
record_binary_op(&self, &other, BinaryOpType::Add)
}
}
impl<T> Add for &LVec<T>
where
T: Clone + Send + Sync + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T> + 'static,
{
type Output = LVec<T>;
fn add(self, other: Self) -> Self::Output {
record_binary_op(self, other, BinaryOpType::Add)
}
}
impl<T> Sub for LVec<T>
where
T: Clone + Send + Sync + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T> + 'static,
{
type Output = Self;
fn sub(self, other: Self) -> Self::Output {
record_binary_op(&self, &other, BinaryOpType::Sub)
}
}
impl<T> Sub for &LVec<T>
where
T: Clone + Send + Sync + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T> + 'static,
{
type Output = LVec<T>;
fn sub(self, other: Self) -> Self::Output {
record_binary_op(self, other, BinaryOpType::Sub)
}
}
impl<T> Mul for LVec<T>
where
T: Clone + Send + Sync + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T> + 'static,
{
type Output = Self;
fn mul(self, other: Self) -> Self::Output {
record_binary_op(&self, &other, BinaryOpType::Mul)
}
}
impl<T> Mul for &LVec<T>
where
T: Clone + Send + Sync + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T> + 'static,
{
type Output = LVec<T>;
fn mul(self, other: Self) -> Self::Output {
record_binary_op(self, other, BinaryOpType::Mul)
}
}
impl<T> Div for LVec<T>
where
T: Clone + Send + Sync + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T> + 'static,
{
type Output = Self;
fn div(self, other: Self) -> Self::Output {
record_binary_op(&self, &other, BinaryOpType::Div)
}
}
impl<T> Div for &LVec<T>
where
T: Clone + Send + Sync + Add<Output = T> + Sub<Output = T> + Mul<Output = T> + Div<Output = T> + 'static,
{
type Output = LVec<T>;
fn div(self, other: Self) -> Self::Output {
record_binary_op(self, other, BinaryOpType::Div)
}
}
impl<T: Clone + Send + Sync + std::fmt::Debug + 'static> std::fmt::Debug for LVec<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.is_pending() {
write!(f, "LVec<pending>")
} else {
write!(f, "LVec({:?})", self.as_slice())
}
}
}