runmat-accelerate 0.4.4

Pluggable GPU acceleration layer for RunMat (CUDA, ROCm, Metal, Vulkan/Spir-V)
Documentation
use anyhow::{anyhow, ensure, Result};
use runmat_accelerate_api::{SortComparison, SortOrder, SortRowsColumnSpec};
use std::cmp::Ordering;

pub struct SortRowsHostOutputs {
    pub values: Vec<f64>,
    pub indices: Vec<f64>,
    pub indices_shape: Vec<usize>,
}

pub fn sort_rows_host(
    data: &[f64],
    shape: &[usize],
    columns: &[SortRowsColumnSpec],
    comparison: SortComparison,
) -> Result<SortRowsHostOutputs> {
    ensure!(
        is_matrix_shape(shape),
        "sortrows: input must be a 2-D matrix on the provider path"
    );

    let (rows, cols) = rows_cols_for_shape(shape);
    let expected = rows
        .checked_mul(cols)
        .ok_or_else(|| anyhow!("sortrows: dimension product exceeds limits"))?;
    ensure!(
        expected == data.len(),
        "sortrows: tensor data length {} does not match shape {:?}",
        data.len(),
        shape
    );

    if rows <= 1 || cols == 0 || data.is_empty() || columns.is_empty() {
        return Ok(SortRowsHostOutputs {
            values: data.to_vec(),
            indices: identity_indices(rows),
            indices_shape: vec![rows, 1],
        });
    }

    let mut order: Vec<usize> = (0..rows).collect();
    order.sort_by(|&a, &b| compare_rows(data, rows, cols, columns, comparison, a, b));

    let mut sorted = vec![0.0f64; data.len()];
    for col in 0..cols {
        for (dest_row, &src_row) in order.iter().enumerate() {
            let src_idx = src_row + col * rows;
            let dst_idx = dest_row + col * rows;
            sorted[dst_idx] = data[src_idx];
        }
    }

    Ok(SortRowsHostOutputs {
        values: sorted,
        indices: order.into_iter().map(|idx| (idx + 1) as f64).collect(),
        indices_shape: vec![rows, 1],
    })
}

fn is_matrix_shape(shape: &[usize]) -> bool {
    if shape.len() <= 2 {
        return true;
    }
    shape.iter().skip(2).all(|&dim| dim == 1)
}

fn rows_cols_for_shape(shape: &[usize]) -> (usize, usize) {
    match shape.len() {
        0 => (1, 1),
        1 => (shape[0].max(1), 1),
        _ => (shape[0], shape[1]),
    }
}

fn compare_rows(
    data: &[f64],
    rows: usize,
    cols: usize,
    columns: &[SortRowsColumnSpec],
    comparison: SortComparison,
    a: usize,
    b: usize,
) -> Ordering {
    for spec in columns {
        if spec.index >= cols {
            continue;
        }
        let idx_a = a + spec.index * rows;
        let idx_b = b + spec.index * rows;
        let va = data[idx_a];
        let vb = data[idx_b];
        let ord = compare_scalar(va, vb, spec.order, comparison);
        if ord != Ordering::Equal {
            return ord;
        }
    }
    Ordering::Equal
}

fn compare_scalar(a: f64, b: f64, order: SortOrder, comparison: SortComparison) -> Ordering {
    match (a.is_nan(), b.is_nan()) {
        (true, true) => Ordering::Equal,
        (true, false) => match order {
            SortOrder::Ascend => Ordering::Greater,
            SortOrder::Descend => Ordering::Less,
        },
        (false, true) => match order {
            SortOrder::Ascend => Ordering::Less,
            SortOrder::Descend => Ordering::Greater,
        },
        (false, false) => compare_finite(a, b, order, comparison),
    }
}

fn compare_finite(a: f64, b: f64, order: SortOrder, comparison: SortComparison) -> Ordering {
    if matches!(comparison, SortComparison::Abs) {
        let abs_cmp = a.abs().partial_cmp(&b.abs()).unwrap_or(Ordering::Equal);
        if abs_cmp != Ordering::Equal {
            return match order {
                SortOrder::Ascend => abs_cmp,
                SortOrder::Descend => abs_cmp.reverse(),
            };
        }
    }
    match order {
        SortOrder::Ascend => a.partial_cmp(&b).unwrap_or(Ordering::Equal),
        SortOrder::Descend => b.partial_cmp(&a).unwrap_or(Ordering::Equal),
    }
}

fn identity_indices(rows: usize) -> Vec<f64> {
    (0..rows).map(|idx| (idx + 1) as f64).collect()
}