zyx 0.15.6

Zyx machine learning library
Documentation
// Copyright (C) 2025 zk4x
// SPDX-License-Identifier: LGPL-3.0-only
//
// Auto-generated by regression.py. 15 features, Huber.
//! Cost prediction for kernel selection.
//!
//! This module provides a learned cost model for predicting kernel
//! execution time. The cost model is trained using regression analysis
//! on actual kernel execution times and considers features like:
//!
//! - Number of groups and work items per group
//! - Instruction counts (ops, compute ops)
//! - Memory access patterns (load/store bits)
//! - Register usage
//! - Branches and loops
//! - Stride characteristics
//!
//! The cost model is used during autotuning to select the best
//! kernel configuration based on predicted performance.

#![allow(unused)]
#![allow(clippy::decimal_literal_representation)]
use super::cost::Cost;
impl Cost {
    pub(crate) fn predict_time_us(
        num_groups: u32,
        wi_per_group: u32,
        wi_ops: u32,
        wi_compute_ops: u32,
        wi_barriers: u32,
        wi_global_load_bits: u32,
        wi_global_store_bits: u32,
        wi_local_load_bits: u32,
        wi_local_store_bits: u32,
        wi_peak_reg_bytes: u32,
        wi_branches: u32,
        wi_global_load_lidx_stride: u32,
        wi_global_store_lidx_stride: u32,
        wi_local_load_lidx_stride: u32,
        wi_local_store_lidx_stride: u32,
        warp_size: u32,
        max_local_threads: u32,
        max_register_bytes: u32,
        wi_register_load_bits: u32,
        wi_register_store_bits: u32,
        gws0: u32,
        gws1: u32,
        gws2: u32,
        lws0: u32,
        lws1: u32,
        lws2: u32,
        max_loop_depth: u32,
        preferred_vector_size: u32,
        local_mem_size: u32,
    ) -> f64 {
        let num_groups = num_groups as f64;
        let wi_per_group = wi_per_group as f64;
        let wi_ops = wi_ops as f64;
        let wi_compute_ops = wi_compute_ops as f64;
        let wi_barriers = wi_barriers as f64;
        let wi_global_load_bits = wi_global_load_bits as f64;
        let wi_global_store_bits = wi_global_store_bits as f64;
        let wi_local_load_bits = wi_local_load_bits as f64;
        let wi_local_store_bits = wi_local_store_bits as f64;
        let wi_peak_reg_bytes = wi_peak_reg_bytes as f64;
        let wi_branches = wi_branches as f64;
        let wi_global_load_lidx_stride = wi_global_load_lidx_stride as f64;
        let wi_global_store_lidx_stride = wi_global_store_lidx_stride as f64;
        let wi_local_load_lidx_stride = wi_local_load_lidx_stride as f64;
        let wi_local_store_lidx_stride = wi_local_store_lidx_stride as f64;
        let warp_size = warp_size as f64;
        let max_local_threads = max_local_threads as f64;
        let max_register_bytes = max_register_bytes as f64;
        let wi_register_load_bits = wi_register_load_bits as f64;
        let wi_register_store_bits = wi_register_store_bits as f64;
        let gws0 = gws0 as f64;
        let gws1 = gws1 as f64;
        let gws2 = gws2 as f64;
        let lws0 = lws0 as f64;
        let lws1 = lws1 as f64;
        let lws2 = lws2 as f64;
        let max_loop_depth = max_loop_depth as f64;
        let preferred_vector_size = preferred_vector_size as f64;
        let local_mem_size = local_mem_size as f64;
        let lng = num_groups.ln();
        let lwpg = (wi_per_group + 1.0).ln();
        let lops = wi_ops.ln();
        let lcop = wi_compute_ops.ln();
        let lgmem = (wi_global_load_bits + wi_global_store_bits + 1.0).ln();
        let ci = wi_compute_ops / (wi_global_load_bits + wi_global_store_bits).max(1.0);
        let barr = wi_barriers;
        let wr = wi_per_group / warp_size.max(1.0);
        let rr = wi_peak_reg_bytes / max_register_bytes.max(1.0);

        let features: [f64; 15] = [
            lwpg,
            ci,
            (wi_ops / num_groups.max(1.0) + 1.0).ln(),
            (wi_ops / (num_groups * wi_per_group).max(1.0)).ln_1p(),
            (1.0 / (num_groups * wi_per_group).max(1.0)).ln_1p(),
            (if wi_barriers == 7.0 { 1.0 } else { 0.0 }) * lops,
            lng * (warp_size / wi_per_group.max(1.0)).ln(),
            ((num_groups * wi_per_group) / 2048.0).min(1.0),
            (wi_local_load_bits + wi_local_store_bits) / 65536.0,
            (warp_size - wi_per_group) / warp_size,
            (num_groups / wi_ops.max(1.0) + 1.0).ln(),
            ((num_groups * wi_per_group) / wi_ops.max(1.0) + 1.0).ln(),
            wi_register_load_bits / (wi_global_load_bits + wi_global_store_bits).max(1.0),
            ci * lgmem,
            ci * lwpg,
        ];

        const RC: &[f64] = &[
            6.85046493e-02,
            -1.78014871e-02,
            2.42729311e-02,
            4.03479574e-02,
            9.17336507e-03,
            -2.55729374e-02,
            1.78794261e-02,
            9.98035625e-03,
            -1.17094612e-03,
            -7.73214011e-03,
            -5.15075983e-02,
            6.13613143e-02,
            -2.87523819e-02,
            -5.14433895e-02,
            -5.10241876e-02,
        ];
        let ri: f64 = -9.47892805e-03;

        let mut pred = ri;
        for i in 0..15 {
            pred += RC[i] * features[i];
        }
        pred * 1_000_000.0
    }
}