use crate::tensor::Tensor;
use super::classify::should_use_f16;
pub fn round_to_f16_precision(tensor: &Tensor) -> Tensor {
let data: Vec<f32> = tensor
.data
.iter()
.map(|&v| half::f16::from_f32(v).to_f32())
.collect();
Tensor::new(data, tensor.shape.clone())
}
pub fn next_consumers_all_f16(
node_outputs: &[String],
all_nodes: &[crate::graph::Node],
current_node_idx: usize,
) -> bool {
for output_name in node_outputs {
if output_name.is_empty() {
continue;
}
for node in all_nodes.iter().skip(current_node_idx + 1) {
if node.inputs.contains(output_name) && !should_use_f16(node.op.as_str()) {
return false;
}
}
}
true
}