use cubecl_ir::{ManagedVariable, Variable};
use crate::prelude::*;
use crate::{self as cubecl};
define_scalar!(ElemA);
define_size!(SizeA);
#[cube]
pub fn hypot<F: Float, N: Size>(lhs: Vector<F, N>, rhs: Vector<F, N>) -> Vector<F, N> {
let one = Vector::new(F::from_int(1));
let a = lhs.abs();
let b = rhs.abs();
let max_val = max(a, b);
let max_val_is_zero = max_val.equal(Vector::new(F::from_int(0)));
let max_val_safe = select_many(max_val_is_zero, one, max_val);
let min_val = min(a, b);
let t = min_val / max_val_safe;
max_val * fma(t, t, one).sqrt()
}
#[allow(missing_docs)]
pub fn expand_hypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
scope.register_type::<ElemA>(lhs.ty.storage_type());
scope.register_size::<SizeA>(lhs.vector_size());
let res = hypot::expand::<ElemA, SizeA>(
scope,
ManagedVariable::Plain(lhs).into(),
ManagedVariable::Plain(rhs).into(),
);
assign::expand_no_check(scope, res, ManagedVariable::Plain(out).into());
}
#[cube]
pub fn rhypot<F: Float, N: Size>(lhs: Vector<F, N>, rhs: Vector<F, N>) -> Vector<F, N> {
let one = Vector::new(F::from_int(1));
let a = lhs.abs();
let b = rhs.abs();
let max_val = max(a, b);
let max_val_is_zero = max_val.equal(Vector::new(F::from_int(0)));
let max_val_safe = select_many(max_val_is_zero, one, max_val);
let min_val = min(a, b);
let t = min_val / max_val_safe;
fma(t, t, one).inverse_sqrt() / max_val
}
#[allow(missing_docs)]
pub fn expand_rhypot(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
scope.register_type::<ElemA>(lhs.ty.storage_type());
scope.register_size::<SizeA>(lhs.vector_size());
let res = rhypot::expand::<ElemA, SizeA>(
scope,
ManagedVariable::Plain(lhs).into(),
ManagedVariable::Plain(rhs).into(),
);
assign::expand_no_check(scope, res, ManagedVariable::Plain(out).into());
}