use std::mem::size_of;
use crate::{
Ops, Primitive, Shape, SigNode, Signature, Uiua, UiuaResult,
algorithm::{FixedRowsData, fixed_rows, get_ops, pervade::pervade_dim},
array::Array,
value::Value,
};
use super::{multi_output, validate_size_impl};
pub fn flip<A, B, C>(f: impl Fn(A, B) -> C + Copy) -> impl Fn(B, A) -> C + Copy {
move |b, a| f(a, b)
}
pub fn repeat(ops: Ops, with_inverse: bool, count_convergence: bool, env: &mut Uiua) -> UiuaResult {
crate::profile_function!();
let (f, inv) = if with_inverse {
let [f, inv] = get_ops(ops, env)?;
(f, Some(inv))
} else {
let [f] = get_ops(ops, env)?;
(f, None)
};
if count_convergence {
let count = repeat_impl(f, inv, f64::INFINITY, env)?;
env.push(count as f64);
return Ok(());
}
let n = env.pop("repetition count")?;
env.require_height(f.sig.args())?;
if n.rank() == 0 {
let n = rep_count(n, env)?;
repeat_impl(f, inv, n.data[0], env)?;
Ok(())
} else {
let sig = f.sig;
if sig.args() != sig.outputs() {
return Err(env.error(format!(
"{} with a non-scalar repetition count \
must use a function with the same number \
of arguments and outputs, but its signature \
is {sig}",
Primitive::Repeat.format()
)));
}
let mut args = Vec::with_capacity(sig.args() + 1);
let mut new_shape = n.shape.clone();
let mut true_shape = Shape::SCALAR;
let n_shape = n.shape.clone();
args.push(n);
for i in 0..sig.args() {
let arg = env.pop(i + 1)?;
if arg.rank() > 0
&& !n_shape.is_empty()
&& !(arg.shape.iter().skip(1))
.zip(n_shape.iter().skip(1))
.all(|(a, b)| *a == 1 || *b == 1 || a == b)
{
return Err(env.error(format!(
"Cannot {} with counts of shape {n_shape} \
when argument {} has shape {}",
Primitive::Repeat.format(),
i + 1,
arg.shape
)));
}
for (a, &b) in new_shape.iter_mut().zip(&arg.shape) {
true_shape.push(pervade_dim(*a, b));
*a = (*a).max(b);
}
args.push(arg);
}
args[1..].reverse();
rep_recur(f, inv, args, env)
}
}
fn rep_recur(f: SigNode, inv: Option<SigNode>, mut args: Vec<Value>, env: &mut Uiua) -> UiuaResult {
if args[0].rank() == 0 {
let n = rep_count(args.remove(0), env)?;
for arg in args {
env.push(arg);
}
repeat_impl(f, inv, n.data[0], env)?;
return Ok(());
}
let FixedRowsData {
mut rows,
row_count,
is_empty,
..
} = fixed_rows(Primitive::Repeat.format(), f.sig.outputs(), args, env)?;
let mut new_values = multi_output(f.sig.outputs(), Vec::new());
for _ in 0..row_count {
let args: Vec<_> = rows
.iter_mut()
.map(|arg| match arg {
Ok(rows) => rows.next().unwrap(),
Err(row) => row.clone(),
})
.collect();
rep_recur(f.clone(), inv.clone(), args, env)?;
for i in 0..f.sig.outputs() {
new_values[i].push(env.pop("rows's function result")?);
}
}
for new_values in new_values.into_iter().rev() {
let mut rowsed = Value::from_row_values(new_values, env)?;
if is_empty {
rowsed.pop_row();
}
rowsed.validate();
env.push(rowsed);
}
Ok(())
}
fn rep_count(value: Value, env: &Uiua) -> UiuaResult<Array<f64>> {
Ok(match value {
Value::Num(n) => n,
Value::Byte(n) => n.convert(),
val => {
return Err(env.error(format!(
"Repetitions must be a scalar or list of \
natural numbers or infinity, \
but it is {}",
val.type_name_plural()
)));
}
})
}
fn repeat_impl(f: SigNode, inv: Option<SigNode>, n: f64, env: &mut Uiua) -> UiuaResult<u64> {
let sig = f.sig;
let (f, n) = if n >= 0.0 {
(f, n)
} else {
let f = inv.ok_or_else(|| {
env.error(format!(
"Cannot use negative repetition count {n} \
because the function is not invertible"
))
})?;
(f, -n)
};
let excess_count = sig.outputs().saturating_sub(sig.args());
let preserve_count = sig.args().saturating_sub(sig.outputs());
let preserved = env.copy_n_down(preserve_count, f.sig.args())?;
let mut excess_rows = vec![Vec::new(); excess_count];
let mut convergence_count = 0;
env.without_fill(|env| -> UiuaResult {
if n.is_infinite() {
if sig.args() == 0 {
return Err(env.error(format!(
"Converging {}'s function must have at least 1 argument",
Primitive::Repeat.format()
)));
}
let mut prev = env.pop(1)?;
env.push(prev.clone());
loop {
if preserve_count > 0 {
env.insert_stack(sig.outputs(), preserved.iter().cloned())?;
}
env.exec(f.clone())?;
for (i, row) in env
.remove_n(excess_count, sig.args() + excess_count)?
.enumerate()
{
excess_rows[i].push(row);
}
let next = env.pop("converging function result")?;
let converged = next == prev;
if converged {
env.push(next);
break;
} else {
env.push(next.clone());
prev = next;
}
convergence_count += 1;
}
} else {
if n.fract() != 0.0 {
return Err(env.error("Repetitions must be an integer or infinity"));
}
let n = n as usize;
if sig.outputs() > sig.args() {
let delta = sig.outputs() - sig.args();
if validate_size_impl(size_of::<Value>(), [n, delta]).is_err() {
return Err(env.error(format!(
"{} would create too many values on the stack",
Primitive::Repeat.format()
)));
}
}
for _ in 0..n {
if preserve_count > 0 {
env.insert_stack(sig.outputs(), preserved.iter().cloned())?;
}
env.exec(f.clone())?;
for (i, row) in env
.remove_n(excess_count, sig.args() + excess_count)?
.enumerate()
{
excess_rows[i].push(row);
}
}
}
Ok(())
})?;
if excess_count > 0 {
_ = env.remove_n(sig.args(), sig.args())?;
} else if preserve_count > 0 {
_ = env.remove_n(preserve_count, sig.args())?;
}
for rows in excess_rows.into_iter().rev() {
env.respect_execution_limit()?;
let new_val = Value::from_row_values(rows, env)?;
env.push(new_val);
}
Ok(convergence_count)
}
pub fn do_(ops: Ops, env: &mut Uiua) -> UiuaResult {
crate::profile_function!();
let [body, cond] = get_ops(ops, env)?;
let cond_sig_err = if cond.sig.outputs() == 0 {
Some(env.error(format!(
"Do's condition function must return at least 1 value, \
but its signature is {}",
cond.sig
)))
} else {
None
};
let copy_count = (cond.sig.args()).saturating_sub(cond.sig.outputs().saturating_sub(1));
let cond_sub_sig = Signature::new(
cond.sig.args(),
(cond.sig.outputs() + copy_count).saturating_sub(1),
);
let comp_sig = body.sig.compose(cond_sub_sig);
let excess_count = comp_sig.outputs().saturating_sub(comp_sig.args());
let mut excess_rows = vec![Vec::new(); excess_count];
let preserve_count = comp_sig.args().saturating_sub(comp_sig.outputs());
let preserved = env.copy_n_down(preserve_count, comp_sig.args())?;
loop {
if env.stack().len() < copy_count {
for i in 0..copy_count {
env.pop(i + 1)?;
}
}
env.dup_values(copy_count, copy_count)?;
env.exec(cond.clone())?;
if let Some(err) = cond_sig_err {
return Err(err);
}
let cond = (env.pop("do condition")?).as_bool(env, "Do condition must be a boolean")?;
if !cond {
break;
}
if preserve_count > 0 {
env.insert_stack(comp_sig.outputs(), preserved.iter().cloned())?;
}
env.exec(body.clone())?;
for (i, row) in env
.remove_n(excess_count, comp_sig.args() + excess_count)?
.enumerate()
{
excess_rows[i].push(row);
}
}
if excess_count > 0 {
_ = env.remove_n(comp_sig.args(), comp_sig.args())?;
} else if preserve_count > 0 {
_ = env.remove_n(preserve_count, comp_sig.args())?;
}
for rows in excess_rows.into_iter().rev() {
let new_val = Value::from_row_values(rows, env)?;
env.push(new_val);
}
Ok(())
}