#![allow(clippy::module_inception)]
use crate::error::*;
use crate::util;
use self::BinOpKind::*;
use self::ExprKind::*;
use self::ScalarKind::*;
use std::collections::BTreeMap;
use std::fmt;
use std::rc::Rc;
use std::vec;
const PLACEHOLDER_NAME: &str = "#placeholder";
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
pub struct Annotations {
values: Option<BTreeMap<String, String>>,
}
impl Annotations {
pub fn new() -> Annotations {
Annotations::default()
}
pub fn set<K: Into<String>, V: Into<String>>(&mut self, key: K, value: V) -> Option<String> {
if self.values.is_none() {
self.values = Some(BTreeMap::new());
}
self.values
.as_mut()
.unwrap()
.insert(key.into(), value.into())
}
pub fn get<K: AsRef<str>>(&self, key: K) -> Option<&str> {
self.values.as_ref()?.get(key.as_ref()).map(|v| v.as_ref())
}
pub fn is_empty(&self) -> bool {
self.values.as_ref().map(|f| f.len()).unwrap_or(0) == 0
}
pub fn clear(&mut self) {
self.values = None;
}
}
impl fmt::Display for Annotations {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.values.is_none() {
return write!(f, "");
}
let annotations = self
.values
.as_ref()
.unwrap()
.iter()
.map(|(k, v)| format!("{}:{}", k, v))
.collect::<Vec<String>>()
.join(",");
write!(f, "@({})", annotations)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Type {
Scalar(ScalarKind),
Simd(ScalarKind),
Vector(Box<Type>),
Dict(Box<Type>, Box<Type>),
Builder(BuilderKind, Annotations),
Struct(Vec<Type>),
Function(Vec<Type>, Box<Type>),
Alias(String, Box<Type>),
Unknown,
}
impl Type {
pub fn children(&self) -> vec::IntoIter<&Type> {
use self::BuilderKind::*;
use self::Type::*;
match *self {
Unknown | Scalar(_) | Simd(_) => vec![],
Alias(_, ref ty) => vec![ty.as_ref()],
Vector(ref elem) => vec![elem.as_ref()],
Dict(ref key, ref value) => vec![key.as_ref(), value.as_ref()],
Builder(ref kind, _) => match kind {
Appender(ref elem) => vec![elem.as_ref()],
Merger(ref elem, _) => vec![elem.as_ref()],
DictMerger(ref key, ref value, _) => vec![key.as_ref(), value.as_ref()],
GroupMerger(ref key, ref value) => vec![key.as_ref(), value.as_ref()],
VecMerger(ref elem, _) => vec![elem.as_ref()],
},
Struct(ref elems) => elems.iter().collect(),
Function(ref params, ref res) => {
let mut children = vec![];
for param in params.iter() {
children.push(param);
}
children.push(res.as_ref());
children
}
}
.into_iter()
}
pub fn children_mut(&mut self) -> vec::IntoIter<&mut Type> {
use self::BuilderKind::*;
use self::Type::*;
match *self {
Unknown | Scalar(_) | Simd(_) => vec![],
Alias(_, ref mut ty) => vec![ty.as_mut()],
Vector(ref mut elem) => vec![elem.as_mut()],
Dict(ref mut key, ref mut value) => vec![key.as_mut(), value.as_mut()],
Builder(ref mut kind, _) => match kind {
Appender(ref mut elem) => vec![elem.as_mut()],
Merger(ref mut elem, _) => vec![elem.as_mut()],
DictMerger(ref mut key, ref mut value, _) => vec![key.as_mut(), value.as_mut()],
GroupMerger(ref mut key, ref mut value) => vec![key.as_mut(), value.as_mut()],
VecMerger(ref mut elem, _) => vec![elem.as_mut()],
},
Struct(ref mut elems) => elems.iter_mut().collect(),
Function(ref mut params, ref mut res) => {
let mut children = vec![];
for param in params.iter_mut() {
children.push(param);
}
children.push(res.as_mut());
children
}
}
.into_iter()
}
pub fn string_type() -> Type {
Type::Vector(Box::new(Type::Scalar(ScalarKind::I8)))
}
pub fn is_simd(&self) -> bool {
use self::Type::*;
match *self {
Simd(_) | Builder(_, _) => true,
Struct(ref fields) => fields.iter().all(|f| f.is_simd()),
_ => false,
}
}
pub fn is_scalar(&self) -> bool {
use self::Type::Scalar;
match *self {
Scalar(_) => true,
_ => false,
}
}
pub fn contains_builder(&self) -> bool {
use self::Type::Builder;
match *self {
Builder(_, _) => true,
_ => self.children().any(|t| t.contains_builder()),
}
}
pub fn is_builder(&self) -> bool {
use self::Type::{Builder, Struct};
match *self {
Builder(_, _) => true,
Struct(ref tys) => tys.iter().all(|t| t.is_builder()),
_ => false,
}
}
pub fn is_hashable(&self) -> bool {
use self::Type::*;
match *self {
Scalar(_) => true,
Simd(_) => true,
Struct(ref tys) => tys.iter().all(|t| t.is_hashable()),
Vector(ref elem) => elem.is_hashable(),
Builder(_, _) => false,
Dict(_, _) => false,
Function(_, _) | Alias(_, _) | Unknown => false,
}
}
pub fn simd_type(&self) -> WeldResult<Type> {
use self::Type::*;
match *self {
Scalar(kind) => Ok(Simd(kind)),
Builder(_, _) => Ok(self.clone()),
Struct(ref fields) => {
let result: WeldResult<_> = fields.iter().map(|f| f.simd_type()).collect();
Ok(Struct(result?))
}
_ => compile_err!("simd_type called on non-scalar type {}", self),
}
}
pub fn scalar_type(&self) -> WeldResult<Type> {
use self::Type::*;
match *self {
Simd(kind) => Ok(Scalar(kind)),
Builder(_, _) => Ok(self.clone()),
Struct(ref fields) => {
let result: WeldResult<_> = fields.iter().map(|f| f.scalar_type()).collect();
Ok(Struct(result?))
}
_ => compile_err!("scalar_type called on non-SIMD type {}", self),
}
}
pub fn merge_type(&self) -> WeldResult<Type> {
use self::Type::Builder;
if let Builder(ref kind, _) = *self {
Ok(kind.merge_type())
} else {
compile_err!("merge_type called on non-builder type {}", self)
}
}
pub fn partial_type(&self) -> bool {
use self::Type::Unknown;
match *self {
Unknown => true,
_ => self.children().any(|t| t.partial_type()),
}
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use self::Type::*;
let text = &match *self {
Scalar(ref kind) => format!("{}", kind),
Simd(ref kind) => format!("simd[{}]", kind),
Vector(ref elem) => format!("vec[{}]", elem),
Dict(ref key, ref value) => format!("dict[{},{}]", key, value),
Struct(ref elems) => util::join("{", ",", "}", elems.iter().map(|e| e.to_string())),
Function(ref params, ref return_type) => {
let mut res = util::join("|", ",", "|(", params.iter().map(|e| e.to_string()));
res.push_str(&return_type.to_string());
res.push_str(")");
res
}
Builder(ref kind, ref annotations) => format!("{}{}", annotations, kind),
Alias(ref name, _) => format!("ALIAS({})", name),
Unknown => String::from("?"),
};
f.write_str(text)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum ScalarKind {
Bool,
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64,
F32,
F64,
}
impl ScalarKind {
pub fn is_float(self) -> bool {
match self {
F32 | F64 => true,
_ => false,
}
}
pub fn is_bool(self) -> bool {
match self {
Bool => true,
_ => false,
}
}
pub fn is_signed_integer(self) -> bool {
match self {
I8 | I16 | I32 | I64 => true,
_ => false,
}
}
pub fn is_unsigned_integer(self) -> bool {
match self {
U8 | U16 | U32 | U64 => true,
_ => false,
}
}
pub fn is_signed(self) -> bool {
self.is_signed_integer() || self.is_float()
}
pub fn is_integer(self) -> bool {
self.is_signed_integer() || self.is_unsigned_integer()
}
pub fn is_numeric(self) -> bool {
self.is_integer() || self.is_float()
}
pub fn bits(self) -> u32 {
match self {
Bool => 1,
I8 | U8 => 8,
I16 | U16 => 16,
I32 | U32 | F32 => 32,
I64 | U64 | F64 => 64,
}
}
pub fn is_upcast(self, target: ScalarKind) -> bool {
target.bits() >= self.bits()
}
pub fn is_strict_upcast(self, target: ScalarKind) -> bool {
target.bits() > self.bits()
}
}
impl fmt::Display for ScalarKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let text = match *self {
Bool => "bool",
I8 => "i8",
I16 => "i16",
I32 => "i32",
I64 => "i64",
U8 => "u8",
U16 => "u16",
U32 => "u32",
U64 => "u64",
F32 => "f32",
F64 => "f64",
};
f.write_str(text)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum BuilderKind {
Appender(Box<Type>),
Merger(Box<Type>, BinOpKind),
DictMerger(Box<Type>, Box<Type>, BinOpKind),
GroupMerger(Box<Type>, Box<Type>),
VecMerger(Box<Type>, BinOpKind),
}
impl BuilderKind {
pub fn merge_type(&self) -> Type {
use self::BuilderKind::*;
use self::Type::*;
match *self {
Appender(ref elem) => *elem.clone(),
Merger(ref elem, _) => *elem.clone(),
DictMerger(ref key, ref value, _) => Struct(vec![*key.clone(), *value.clone()]),
GroupMerger(ref key, ref value) => Struct(vec![*key.clone(), *value.clone()]),
VecMerger(ref elem, _) => Struct(vec![Scalar(I64), *elem.clone()]),
}
}
pub fn result_type(&self) -> Type {
use self::BuilderKind::*;
use self::Type::*;
match *self {
Appender(ref elem) => Vector(elem.clone()),
Merger(ref elem, _) => *elem.clone(),
DictMerger(ref key, ref value, _) => Dict(key.clone(), value.clone()),
GroupMerger(ref key, ref value) => Dict(key.clone(), Box::new(Vector(value.clone()))),
VecMerger(ref elem, _) => Vector(elem.clone()),
}
}
}
impl fmt::Display for BuilderKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use self::BuilderKind::*;
let text = &match *self {
Appender(ref t) => format!("appender[{}]", t),
DictMerger(ref key, ref value, op) => format!("dictmerger[{},{},{}]", key, value, op),
GroupMerger(ref key, ref value) => format!("groupmerger[{},{}]", key, value),
VecMerger(ref elem, op) => format!("vecmerger[{},{}]", elem, op),
Merger(ref elem, op) => format!("merger[{},{}]", elem, op),
};
f.write_str(text)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct Symbol {
name: Rc<String>,
id: i32,
}
impl Symbol {
pub fn new<T: Into<String>>(name: T, id: i32) -> Symbol {
Symbol {
name: Rc::new(name.into()),
id,
}
}
pub fn placeholder() -> Symbol {
Symbol::new(PLACEHOLDER_NAME, 0)
}
pub fn name(&self) -> String {
self.name.to_string()
}
pub fn id(&self) -> i32 {
self.id
}
}
impl fmt::Display for Symbol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.id == 0 {
write!(f, "{}", self.name)
} else {
write!(f, "{}__{}", self.name, self.id)
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Expr {
pub ty: Type,
pub kind: ExprKind,
pub annotations: Annotations,
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum IterKind {
ScalarIter,
SimdIter,
FringeIter,
NdIter,
RangeIter,
}
impl fmt::Display for IterKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use self::IterKind::*;
let text = &match *self {
ScalarIter => "",
SimdIter => "simd",
FringeIter => "fringe",
NdIter => "nditer",
RangeIter => "range",
};
f.write_str(text)?;
f.write_str("iter")
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Iter {
pub data: Box<Expr>,
pub start: Option<Box<Expr>>,
pub end: Option<Box<Expr>>,
pub stride: Option<Box<Expr>>,
pub kind: IterKind,
pub strides: Option<Box<Expr>>,
pub shape: Option<Box<Expr>>,
}
impl Iter {
pub fn is_simple(&self) -> bool {
self.start.is_none()
&& self.end.is_none()
&& self.stride.is_none()
&& self.kind == IterKind::ScalarIter
}
}
#[derive(Clone, Debug, PartialEq, Hash, Eq)]
pub enum ExprKind {
Literal(LiteralKind),
Ident(Symbol),
Not(Box<Expr>),
Assert(Box<Expr>),
Negate(Box<Expr>),
Broadcast(Box<Expr>),
BinOp {
kind: BinOpKind,
left: Box<Expr>,
right: Box<Expr>,
},
UnaryOp { kind: UnaryOpKind, value: Box<Expr> },
Cast {
kind: ScalarKind,
child_expr: Box<Expr>,
},
ToVec { child_expr: Box<Expr> },
MakeStruct { elems: Vec<Expr> },
MakeVector { elems: Vec<Expr> },
Zip { vectors: Vec<Expr> },
GetField { expr: Box<Expr>, index: u32 },
Length { data: Box<Expr> },
Lookup { data: Box<Expr>, index: Box<Expr> },
OptLookup { data: Box<Expr>, index: Box<Expr> },
KeyExists { data: Box<Expr>, key: Box<Expr> },
Slice {
data: Box<Expr>,
index: Box<Expr>,
size: Box<Expr>,
},
Sort { data: Box<Expr>, cmpfunc: Box<Expr> },
Let {
name: Symbol,
value: Box<Expr>,
body: Box<Expr>,
},
If {
cond: Box<Expr>,
on_true: Box<Expr>,
on_false: Box<Expr>,
},
Iterate {
initial: Box<Expr>,
update_func: Box<Expr>,
},
Select {
cond: Box<Expr>,
on_true: Box<Expr>,
on_false: Box<Expr>,
},
Lambda {
params: Vec<Parameter>,
body: Box<Expr>,
},
Apply { func: Box<Expr>, params: Vec<Expr> },
CUDF {
sym_name: String,
args: Vec<Expr>,
return_ty: Box<Type>,
},
Serialize(Box<Expr>),
Deserialize {
value: Box<Expr>,
value_ty: Box<Type>,
},
NewBuilder(Option<Box<Expr>>),
For {
iters: Vec<Iter>,
builder: Box<Expr>,
func: Box<Expr>,
},
Merge {
builder: Box<Expr>,
value: Box<Expr>,
},
Res { builder: Box<Expr> },
}
impl ExprKind {
pub fn name(&self) -> &str {
match *self {
Literal(_) => "Literal",
Ident(_) => "Ident",
Not(_) => "Not",
Assert(_) => "Assert",
Negate(_) => "Negate",
Broadcast(_) => "Broadcast",
BinOp { .. } => "BinOp",
UnaryOp { .. } => "UnaryOp",
Cast { .. } => "Cast",
ToVec { .. } => "ToVec",
MakeStruct { .. } => "MakeStruct",
MakeVector { .. } => "MakeVector",
Zip { .. } => "Zip",
GetField { .. } => "GetField",
Length { .. } => "Length",
Lookup { .. } => "Lookup",
OptLookup { .. } => "OptLookup",
KeyExists { .. } => "KeyExists",
Slice { .. } => "Slice",
Sort { .. } => "Sort",
Let { .. } => "Let",
If { .. } => "If",
Iterate { .. } => "Iterate",
Select { .. } => "Select",
Lambda { .. } => "Lambda",
Apply { .. } => "Apply",
CUDF { .. } => "CUDF",
Serialize(_) => "Serialize",
Deserialize { .. } => "Deserialize",
NewBuilder(_) => "NewBuilder",
For { .. } => "For",
Merge { .. } => "Merge",
Res { .. } => "Res",
}
}
pub fn is_builder_expr(&self) -> bool {
match *self {
Merge { .. } | Res { .. } | For { .. } | NewBuilder(_) => true,
_ => false,
}
}
}
#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub enum LiteralKind {
BoolLiteral(bool),
I8Literal(i8),
I16Literal(i16),
I32Literal(i32),
I64Literal(i64),
U8Literal(u8),
U16Literal(u16),
U32Literal(u32),
U64Literal(u64),
F32Literal(u32),
F64Literal(u64),
StringLiteral(String),
}
impl fmt::Display for LiteralKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use self::LiteralKind::*;
let text = &match *self {
BoolLiteral(v) => format!("{}", v),
I8Literal(v) => format!("{}c", v),
I16Literal(v) => format!("{}si", v),
I32Literal(v) => format!("{}", v),
I64Literal(v) => format!("{}L", v),
U8Literal(v) => format!("{}", v),
U16Literal(v) => format!("{}", v),
U32Literal(v) => format!("{}", v),
U64Literal(v) => format!("{}", v),
F32Literal(v) => {
let mut res = format!("{}", f32::from_bits(v));
if !res.contains('.') {
res.push_str(".0");
}
res.push('F');
res
}
F64Literal(v) => {
let mut res = format!("{}", f64::from_bits(v));
if !res.contains('.') {
res.push_str(".0");
}
res
}
StringLiteral(ref v) => format!("\"{}\"", v),
};
f.write_str(text)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum BinOpKind {
Add,
Subtract,
Multiply,
Divide,
Modulo,
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
LogicalAnd,
LogicalOr,
BitwiseAnd,
BitwiseOr,
Xor,
Max,
Min,
Pow,
}
impl BinOpKind {
pub fn is_comparison(self) -> bool {
match self {
Equal | NotEqual | LessThan | GreaterThan | LessThanOrEqual | GreaterThanOrEqual => {
true
}
_ => false,
}
}
}
impl fmt::Display for BinOpKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let text = match *self {
Add => "+",
Subtract => "-",
Multiply => "*",
Divide => "/",
Modulo => "%",
Equal => "==",
NotEqual => "!=",
LessThan => "<",
LessThanOrEqual => "<=",
GreaterThan => ">",
GreaterThanOrEqual => ">=",
LogicalAnd => "&&",
LogicalOr => "||",
BitwiseAnd => "&",
BitwiseOr => "|",
Xor => "^",
Max => "max",
Min => "min",
Pow => "pow",
};
f.write_str(text)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum UnaryOpKind {
Exp,
Log,
Sqrt,
Sin,
Cos,
Tan,
ASin,
ACos,
ATan,
Sinh,
Cosh,
Tanh,
Erf,
}
impl fmt::Display for UnaryOpKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let text = format!("{:?}", self);
f.write_str(text.to_lowercase().as_ref())
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Parameter {
pub name: Symbol,
pub ty: Type,
}
impl fmt::Display for Parameter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.name, self.ty)
}
}
impl Expr {
pub fn children(&self) -> vec::IntoIter<&Expr> {
use self::ExprKind::*;
match self.kind {
BinOp {
ref left,
ref right,
..
} => vec![left.as_ref(), right.as_ref()],
UnaryOp { ref value, .. } => vec![value.as_ref()],
Cast { ref child_expr, .. } => vec![child_expr.as_ref()],
ToVec { ref child_expr } => vec![child_expr.as_ref()],
Let {
ref value,
ref body,
..
} => vec![value.as_ref(), body.as_ref()],
Lambda { ref body, .. } => vec![body.as_ref()],
MakeStruct { ref elems } => elems.iter().collect(),
MakeVector { ref elems } => elems.iter().collect(),
Zip { ref vectors } => vectors.iter().collect(),
GetField { ref expr, .. } => vec![expr.as_ref()],
Length { ref data } => vec![data.as_ref()],
Lookup {
ref data,
ref index,
} => vec![data.as_ref(), index.as_ref()],
OptLookup {
ref data,
ref index,
} => vec![data.as_ref(), index.as_ref()],
KeyExists { ref data, ref key } => vec![data.as_ref(), key.as_ref()],
Slice {
ref data,
ref index,
ref size,
} => vec![data.as_ref(), index.as_ref(), size.as_ref()],
Sort {
ref data,
ref cmpfunc,
} => vec![data.as_ref(), cmpfunc.as_ref()],
Merge {
ref builder,
ref value,
} => vec![builder.as_ref(), value.as_ref()],
Res { ref builder } => vec![builder.as_ref()],
For {
ref iters,
ref builder,
ref func,
} => {
let mut res: Vec<&Expr> = vec![];
for iter in iters {
res.push(iter.data.as_ref());
if let Some(ref s) = iter.start {
res.push(s);
}
if let Some(ref e) = iter.end {
res.push(e);
}
if let Some(ref s) = iter.stride {
res.push(s);
}
}
res.push(builder.as_ref());
res.push(func.as_ref());
res
}
If {
ref cond,
ref on_true,
ref on_false,
} => vec![cond.as_ref(), on_true.as_ref(), on_false.as_ref()],
Iterate {
ref initial,
ref update_func,
} => vec![initial.as_ref(), update_func.as_ref()],
Select {
ref cond,
ref on_true,
ref on_false,
} => vec![cond.as_ref(), on_true.as_ref(), on_false.as_ref()],
Apply {
ref func,
ref params,
} => {
let mut res = vec![func.as_ref()];
res.extend(params.iter());
res
}
NewBuilder(ref opt) => {
if let Some(ref e) = *opt {
vec![e.as_ref()]
} else {
vec![]
}
}
Serialize(ref e) => vec![e.as_ref()],
Deserialize { ref value, .. } => vec![value.as_ref()],
CUDF { ref args, .. } => args.iter().collect(),
Negate(ref t) => vec![t.as_ref()],
Not(ref t) => vec![t.as_ref()],
Assert(ref t) => vec![t.as_ref()],
Broadcast(ref t) => vec![t.as_ref()],
Literal(_) | Ident(_) => vec![],
}
.into_iter()
}
pub fn children_mut(&mut self) -> vec::IntoIter<&mut Expr> {
use self::ExprKind::*;
match self.kind {
BinOp {
ref mut left,
ref mut right,
..
} => vec![left.as_mut(), right.as_mut()],
UnaryOp { ref mut value, .. } => vec![value.as_mut()],
Cast {
ref mut child_expr, ..
} => vec![child_expr.as_mut()],
ToVec { ref mut child_expr } => vec![child_expr.as_mut()],
Let {
ref mut value,
ref mut body,
..
} => vec![value.as_mut(), body.as_mut()],
Lambda { ref mut body, .. } => vec![body.as_mut()],
MakeStruct { ref mut elems } => elems.iter_mut().collect(),
MakeVector { ref mut elems } => elems.iter_mut().collect(),
Zip { ref mut vectors } => vectors.iter_mut().collect(),
GetField { ref mut expr, .. } => vec![expr.as_mut()],
Length { ref mut data } => vec![data.as_mut()],
Lookup {
ref mut data,
ref mut index,
} => vec![data.as_mut(), index.as_mut()],
OptLookup {
ref mut data,
ref mut index,
} => vec![data.as_mut(), index.as_mut()],
KeyExists {
ref mut data,
ref mut key,
} => vec![data.as_mut(), key.as_mut()],
Slice {
ref mut data,
ref mut index,
ref mut size,
} => vec![data.as_mut(), index.as_mut(), size.as_mut()],
Sort {
ref mut data,
ref mut cmpfunc,
} => vec![data.as_mut(), cmpfunc.as_mut()],
Merge {
ref mut builder,
ref mut value,
} => vec![builder.as_mut(), value.as_mut()],
Res { ref mut builder } => vec![builder.as_mut()],
For {
ref mut iters,
ref mut builder,
ref mut func,
} => {
let mut res: Vec<&mut Expr> = vec![];
for iter in iters {
res.push(iter.data.as_mut());
if let Some(ref mut s) = iter.start {
res.push(s);
}
if let Some(ref mut e) = iter.end {
res.push(e);
}
if let Some(ref mut s) = iter.stride {
res.push(s);
}
}
res.push(builder.as_mut());
res.push(func.as_mut());
res
}
If {
ref mut cond,
ref mut on_true,
ref mut on_false,
} => vec![cond.as_mut(), on_true.as_mut(), on_false.as_mut()],
Iterate {
ref mut initial,
ref mut update_func,
} => vec![initial.as_mut(), update_func.as_mut()],
Select {
ref mut cond,
ref mut on_true,
ref mut on_false,
} => vec![cond.as_mut(), on_true.as_mut(), on_false.as_mut()],
Apply {
ref mut func,
ref mut params,
} => {
let mut res = vec![func.as_mut()];
res.extend(params.iter_mut());
res
}
NewBuilder(ref mut opt) => {
if let Some(ref mut e) = *opt {
vec![e.as_mut()]
} else {
vec![]
}
}
Serialize(ref mut e) => vec![e.as_mut()],
Deserialize { ref mut value, .. } => vec![value.as_mut()],
CUDF { ref mut args, .. } => args.iter_mut().collect(),
Negate(ref mut t) => vec![t.as_mut()],
Not(ref mut t) => vec![t.as_mut()],
Assert(ref mut t) => vec![t.as_mut()],
Broadcast(ref mut t) => vec![t.as_mut()],
Literal(_) | Ident(_) => vec![],
}
.into_iter()
}
pub fn partially_typed(&self) -> bool {
use self::Type::Unknown;
match self.ty {
Unknown => true,
_ => self.children().any(|e| e.ty.partial_type()),
}
}
pub fn substitute(&mut self, symbol: &Symbol, replacement: &Expr) {
use self::ExprKind::*;
let mut self_matches = false;
match self.kind {
Ident(ref sym) if *sym == *symbol => self_matches = true,
_ => (),
}
if self_matches {
*self = (*replacement).clone();
return;
}
match self.kind {
Let {
ref name,
ref mut value,
ref mut body,
} => {
value.substitute(symbol, replacement);
if name != symbol {
body.substitute(symbol, replacement);
}
}
Lambda {
ref params,
ref mut body,
} => {
if params.iter().all(|p| p.name != *symbol) {
body.substitute(symbol, replacement);
}
}
_ => {
for c in self.children_mut() {
c.substitute(symbol, replacement);
}
}
}
}
pub fn traverse<F>(&self, func: &mut F)
where
F: FnMut(&Expr) -> (),
{
func(self);
for c in self.children() {
c.traverse(func);
}
}
pub fn contains_symbol(&self, sym: &Symbol) -> bool {
let mut found = false;
self.traverse(&mut |ref mut e| {
if let ExprKind::Ident(ref s) = e.kind {
if *sym == *s {
found = true;
}
}
});
found
}
pub fn transform_and_continue<F>(&mut self, func: &mut F)
where
F: FnMut(&mut Expr) -> (Option<Expr>, bool),
{
match func(self) {
(Some(e), true) => {
*self = e;
self.transform_and_continue(func)
}
(Some(e), false) => {
*self = e;
}
(None, true) => {
for c in self.children_mut() {
c.transform_and_continue(func);
}
}
(None, false) => {}
}
}
pub fn transform_and_continue_res<F>(&mut self, func: &mut F)
where
F: FnMut(&mut Expr) -> WeldResult<(Option<Expr>, bool)>,
{
if let Ok(result) = func(self) {
match result {
(Some(e), true) => {
*self = e;
return self.transform_and_continue_res(func);
}
(Some(e), false) => {
*self = e;
}
(None, true) => {
for c in self.children_mut() {
c.transform_and_continue_res(func);
}
}
(None, false) => {}
}
}
}
pub fn transform<F>(&mut self, func: &mut F)
where
F: FnMut(&mut Expr) -> Option<Expr>,
{
if let Some(e) = func(self) {
*self = e;
return self.transform(func);
}
for c in self.children_mut() {
c.transform(func);
}
}
pub fn transform_up<F>(&mut self, func: &mut F)
where
F: FnMut(&mut Expr) -> Option<Expr>,
{
for c in self.children_mut() {
c.transform_up(func);
}
if let Some(e) = func(self) {
*self = e;
}
}
pub fn transform_kind<F>(&mut self, func: &mut F)
where
F: FnMut(&mut Expr) -> Option<ExprKind>,
{
if let Some(k) = func(self) {
self.kind = k;
return self.transform_kind(func);
}
for c in self.children_mut() {
c.transform_kind(func);
}
}
pub fn contains(&self, other: &Expr) -> bool {
if *self == *other {
return true;
}
for c in other.children() {
if self.contains(c) {
return true;
}
}
false
}
}
pub fn expr_box(kind: ExprKind, annot: Annotations) -> Box<Expr> {
Box::new(Expr {
ty: Type::Unknown,
kind,
annotations: annot,
})
}
pub trait Placeholder {
fn is_placeholder(&self) -> bool;
fn new_placeholder() -> Self;
}
impl Placeholder for Expr {
fn is_placeholder(&self) -> bool {
if let Ident(ref name) = self.kind {
return name.name.as_ref() == PLACEHOLDER_NAME;
}
false
}
fn new_placeholder() -> Expr {
Expr {
ty: Type::Unknown,
kind: Ident(Symbol::placeholder()),
annotations: Annotations::new(),
}
}
}
impl Placeholder for Box<Expr> {
fn is_placeholder(&self) -> bool {
self.as_ref().is_placeholder()
}
fn new_placeholder() -> Box<Expr> {
Box::new(Expr {
ty: Type::Unknown,
kind: Ident(Symbol::placeholder()),
annotations: Annotations::new(),
})
}
}
pub trait Takeable: Placeholder {
fn take(&mut self) -> Self;
}
impl Takeable for Expr {
fn take(&mut self) -> Expr {
use std::mem;
let mut new = Self::new_placeholder();
new.ty = self.ty.clone();
mem::swap(self, &mut new);
new
}
}
impl Takeable for Box<Expr> {
fn take(&mut self) -> Box<Expr> {
use std::mem;
let mut new = Self::new_placeholder();
new.ty = self.ty.clone();
mem::swap(self.as_mut(), new.as_mut());
new
}
}