use crate::arena::{ContextStack, DataValue, IterGuard};
use crate::node::ReduceHint;
use crate::opcode::OpCode;
use crate::{CompiledNode, Engine, Result};
use bumpalo::Bump;
use super::helpers::{IterArgKind, IterSrc, ResolvedInput, resolve_iter_input};
#[inline]
pub(crate) fn evaluate_reduce<'a>(
args: &'a [CompiledNode],
iter_arg_kind: IterArgKind,
ctx: &mut ContextStack<'a>,
engine: &Engine,
arena: &'a Bump,
) -> Result<&'a DataValue<'a>> {
if args.len() < 2 || args.len() > 3 {
return Err(crate::Error::invalid_args());
}
let body = &args[1];
let initial: &'a DataValue<'a> = if args.len() == 3 {
engine.dispatch_node(&args[2], ctx, arena)?
} else {
crate::arena::singletons::singleton_null()
};
let src = match resolve_iter_input(&args[0], iter_arg_kind, ctx, engine, arena)? {
ResolvedInput::Iterable(s) => s,
ResolvedInput::Empty => return Ok(initial),
ResolvedInput::Bridge(av) => {
return reduce_arena_bridge(av, body, initial, ctx, engine, arena);
}
};
if src.is_empty() {
return Ok(initial);
}
if !ctx.is_tracing() {
if let CompiledNode::BuiltinOperator {
opcode,
args: body_args,
..
} = body
{
if body_args.len() == 2
&& matches!(opcode, OpCode::Add | OpCode::Multiply | OpCode::Subtract)
{
if let Some(result) = try_reduce_fast_path(&src, initial, body_args, *opcode, arena)
{
return Ok(result);
}
}
}
}
reduce_general(&src, body, initial, ctx, engine, arena)
}
#[inline]
fn reduce_general<'a>(
src: &IterSrc<'a>,
body: &'a CompiledNode,
initial: &'a DataValue<'a>,
ctx: &mut ContextStack<'a>,
engine: &Engine,
arena: &'a Bump,
) -> Result<&'a DataValue<'a>> {
let len = src.len();
let total = len as u32;
let mut acc_av: &'a DataValue<'a> = initial;
let mut guard = IterGuard::new(ctx);
for i in 0..len {
let item = src.get(i);
guard.step_reduce(item, acc_av);
acc_av = engine.run_iter_body(body, guard.stack(), arena, i as u32, total)?;
}
drop(guard);
Ok(acc_av)
}
#[inline]
fn reduce_arena_bridge<'a>(
input: &'a DataValue<'a>,
body: &'a CompiledNode,
initial: &'a DataValue<'a>,
ctx: &mut ContextStack<'a>,
engine: &Engine,
arena: &'a Bump,
) -> Result<&'a DataValue<'a>> {
match input {
DataValue::Object(pairs) => {
let total = pairs.len() as u32;
let mut acc_av: &'a DataValue<'a> = initial;
let mut guard = IterGuard::new(ctx);
for (i, (_k, v)) in pairs.iter().enumerate() {
guard.step_reduce(v, acc_av);
acc_av = engine.run_iter_body(body, guard.stack(), arena, i as u32, total)?;
}
drop(guard);
Ok(acc_av)
}
DataValue::Array(items) => {
let total = items.len() as u32;
let mut acc_av: &'a DataValue<'a> = initial;
let mut guard = IterGuard::new(ctx);
for (i, item_av) in items.iter().enumerate() {
guard.step_reduce(item_av, acc_av);
acc_av = engine.run_iter_body(body, guard.stack(), arena, i as u32, total)?;
}
drop(guard);
Ok(acc_av)
}
_ => Ok(initial),
}
}
fn try_reduce_fast_path<'a>(
src: &IterSrc<'a>,
initial: &'a DataValue<'a>,
body_args: &[CompiledNode],
opcode: OpCode,
arena: &'a Bump,
) -> Option<&'a DataValue<'a>> {
let (current_arg, _acc_arg) = match (&body_args[0], &body_args[1]) {
(
CompiledNode::Var {
reduce_hint: hint0, ..
},
CompiledNode::Var {
reduce_hint: hint1, ..
},
) => match (hint0, hint1) {
(
ReduceHint::Current | ReduceHint::CurrentPath,
ReduceHint::Accumulator | ReduceHint::AccumulatorPath,
) => (&body_args[0], &body_args[1]),
(
ReduceHint::Accumulator | ReduceHint::AccumulatorPath,
ReduceHint::Current | ReduceHint::CurrentPath,
) => (&body_args[1], &body_args[0]),
_ => return None,
},
_ => return None,
};
let current_segments = if let CompiledNode::Var {
segments,
reduce_hint,
..
} = current_arg
{
match reduce_hint {
ReduceHint::Current => &[][..],
ReduceHint::CurrentPath if segments.len() >= 2 => &segments[1..],
_ => return None,
}
} else {
return None;
};
let len = src.len();
let mut acc_i = initial.as_i64();
if acc_i.is_some() {
let mut all_int = true;
for i in 0..len {
let item = src.get(i);
let current_val = if current_segments.is_empty() {
item
} else {
crate::arena::value::traverse_segments(item, current_segments)?
};
if let Some(cur_i) = current_val.as_i64() {
let a = acc_i.unwrap();
acc_i = Some(match opcode {
OpCode::Add => a.wrapping_add(cur_i),
OpCode::Multiply => a.wrapping_mul(cur_i),
OpCode::Subtract => a.wrapping_sub(cur_i),
_ => return None,
});
} else {
all_int = false;
break;
}
}
if all_int {
return acc_i.map(|v| {
crate::arena::singletons::singleton_small_int(v).unwrap_or_else(|| {
&*arena.alloc(DataValue::Number(datavalue::NumberValue::from_i64(v)))
})
});
}
}
let mut acc_f = initial.as_f64()?;
for i in 0..len {
let item = src.get(i);
let current_val = if current_segments.is_empty() {
item
} else {
crate::arena::value::traverse_segments(item, current_segments)?
};
let cur_f = current_val.as_f64()?;
acc_f = match opcode {
OpCode::Add => acc_f + cur_f,
OpCode::Multiply => acc_f * cur_f,
OpCode::Subtract => acc_f - cur_f,
_ => return None,
};
}
Some(arena.alloc(DataValue::Number(datavalue::NumberValue::from_f64(acc_f))))
}