use cubecl_ir::{ExpandElement, Variable};
use crate::prelude::*;
use crate::{self as cubecl};
#[cube]
pub fn hypot<F: Float>(lhs: Line<F>, rhs: Line<F>) -> Line<F> {
let one = Line::empty(lhs.size()).fill(F::from_int(1));
let a = lhs.abs();
let b = rhs.abs();
let max_val = max(a, b);
let max_val_is_zero = max_val.equal(Line::empty(lhs.size()).fill(F::from_int(0)));
let max_val_safe = select_many(max_val_is_zero, one, max_val);
let min_val = min(a, b);
let t = min_val / max_val_safe;
max_val * fma(t, t, one).sqrt()
}
#[allow(missing_docs)]
pub fn expand_hypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
scope.register_type::<FloatExpand<0>>(lhs.ty.storage_type());
let res = hypot::expand::<FloatExpand<0>>(
scope,
ExpandElement::Plain(lhs).into(),
ExpandElement::Plain(rhs).into(),
);
assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
}
#[cube]
pub fn rhypot<F: Float>(lhs: Line<F>, rhs: Line<F>) -> Line<F> {
let one = Line::empty(lhs.size()).fill(F::from_int(1));
let a = lhs.abs();
let b = rhs.abs();
let max_val = max(a, b);
let max_val_is_zero = max_val.equal(Line::empty(lhs.size()).fill(F::from_int(0)));
let max_val_safe = select_many(max_val_is_zero, one, max_val);
let min_val = min(a, b);
let t = min_val / max_val_safe;
fma(t, t, one).inverse_sqrt() / max_val
}
#[allow(missing_docs)]
pub fn expand_rhypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
scope.register_type::<FloatExpand<0>>(lhs.ty.storage_type());
let res = rhypot::expand::<FloatExpand<0>>(
scope,
ExpandElement::Plain(lhs).into(),
ExpandElement::Plain(rhs).into(),
);
assign::expand_no_check(scope, res, ExpandElement::Plain(out).into());
}