pub use nvidia_impl::{try_nvidia_kmeans, try_nvidia_search_batch};
pub use rocm_impl::{try_rocm_kmeans, try_rocm_search_batch};
mod nvidia_impl {
use std::ffi::c_void;
use ailake_core::{RowId, VectorMetric};
use libloading::{Library, Symbol};
use tracing::warn;
const H2D: i32 = 1; const D2H: i32 = 2;
const OP_T: i32 = 1; const OP_N: i32 = 0;
type CudaMallocFn = unsafe extern "C" fn(*mut *mut c_void, usize) -> i32;
type CudaFreeFn = unsafe extern "C" fn(*mut c_void) -> i32;
type CudaMemcpyFn = unsafe extern "C" fn(*mut c_void, *const c_void, usize, i32) -> i32;
type CudaSyncFn = unsafe extern "C" fn() -> i32;
#[cfg(target_os = "linux")]
const RT_LIBS: &[&str] = &["libcudart.so", "libcudart.so.12", "libcudart.so.11"];
#[cfg(windows)]
const RT_LIBS: &[&str] = &["cudart64_12.dll", "cudart64_11.dll"];
#[cfg(not(any(target_os = "linux", windows)))]
const RT_LIBS: &[&str] = &[];
#[cfg(target_os = "linux")]
const BLAS_LIBS: &[&str] = &["libcublas.so", "libcublas.so.12", "libcublas.so.11"];
#[cfg(windows)]
const BLAS_LIBS: &[&str] = &["cublas64_12.dll", "cublas64_11.dll"];
#[cfg(not(any(target_os = "linux", windows)))]
const BLAS_LIBS: &[&str] = &[];
type SgemmFn = unsafe extern "C" fn(
*mut c_void, i32, i32, i32, i32, i32, *const f32, *const c_void,
i32, *const c_void,
i32, *const f32, *mut c_void,
i32, ) -> i32;
struct DevBuf {
ptr: *mut c_void,
free_fn: CudaFreeFn,
}
impl Drop for DevBuf {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { (self.free_fn)(self.ptr) };
}
}
}
struct BlasHandle {
handle: *mut c_void,
destroy_fn: unsafe extern "C" fn(*mut c_void) -> i32,
}
impl Drop for BlasHandle {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { (self.destroy_fn)(self.handle) };
}
}
}
fn try_open(names: &[&str]) -> Option<Library> {
names
.iter()
.find_map(|name| unsafe { Library::new(name) }.ok())
}
pub fn try_nvidia_search_batch(
queries: &[&[f32]],
row_ids: &[u64],
flat_vecs: &[f32],
dim: usize,
metric: VectorMetric,
top_k: usize,
) -> Option<Vec<Vec<(RowId, f32)>>> {
if !crate::hardware::detect_cuda() {
return None;
}
if RT_LIBS.is_empty() || BLAS_LIBS.is_empty() {
return None;
}
let q = queries.len();
if row_ids.is_empty() || q == 0 {
return Some(vec![vec![]; q]);
}
let result = batch_inner(queries, row_ids, flat_vecs, dim, metric, top_k);
if result.is_none() {
warn!(
"ailake: NVIDIA GPU search failed at runtime (cuBLAS error or allocation failure); \
falling back to CPU SIMD — check CUDA runtime libraries and available GPU memory"
);
}
result
}
pub fn try_nvidia_kmeans(
vectors: &[Vec<f32>],
k: usize,
max_iter: usize,
) -> Option<Vec<Vec<f32>>> {
if !crate::hardware::detect_cuda() {
return None;
}
if RT_LIBS.is_empty() || BLAS_LIBS.is_empty() {
return None;
}
if vectors.is_empty() {
return Some(vec![]);
}
let n = vectors.len();
let dim = vectors[0].len();
let k = k.min(n);
let result = kmeans_inner(vectors, k, max_iter, n, dim);
if result.is_none() {
warn!(
"ailake: NVIDIA GPU k-means failed at runtime (cuBLAS error or allocation failure); \
falling back to CPU k-means — check CUDA runtime libraries and available GPU memory"
);
}
result
}
unsafe fn load_cuda_fns(
rt: &Library,
) -> Option<(CudaMallocFn, CudaFreeFn, CudaMemcpyFn, CudaSyncFn)> {
let malloc_sym: Symbol<CudaMallocFn> = rt.get(b"cudaMalloc\0").ok()?;
let free_sym: Symbol<CudaFreeFn> = rt.get(b"cudaFree\0").ok()?;
let memcpy_sym: Symbol<CudaMemcpyFn> = rt.get(b"cudaMemcpy\0").ok()?;
let sync_sym: Symbol<CudaSyncFn> = rt.get(b"cudaDeviceSynchronize\0").ok()?;
Some((*malloc_sym, *free_sym, *memcpy_sym, *sync_sym))
}
unsafe fn upload(
data: &[f32],
malloc_fn: CudaMallocFn,
free_fn: CudaFreeFn,
memcpy_fn: CudaMemcpyFn,
) -> Option<DevBuf> {
let bytes = std::mem::size_of_val(data);
let mut ptr: *mut c_void = std::ptr::null_mut();
if malloc_fn(&mut ptr, bytes) != 0 {
return None;
}
let buf = DevBuf { ptr, free_fn };
if memcpy_fn(ptr, data.as_ptr() as *const c_void, bytes, H2D) != 0 {
return None;
}
Some(buf)
}
unsafe fn alloc_dev(
len: usize,
malloc_fn: CudaMallocFn,
free_fn: CudaFreeFn,
) -> Option<DevBuf> {
let bytes = len * std::mem::size_of::<f32>();
let mut ptr: *mut c_void = std::ptr::null_mut();
if malloc_fn(&mut ptr, bytes) != 0 {
return None;
}
Some(DevBuf { ptr, free_fn })
}
fn normalize_rows(mut data: Vec<f32>, dim: usize) -> Vec<f32> {
for row in data.chunks_mut(dim) {
let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
row.iter_mut().for_each(|x| *x /= norm);
}
}
data
}
fn batch_inner(
queries: &[&[f32]],
row_ids: &[u64],
flat_vecs: &[f32],
dim: usize,
metric: VectorMetric,
top_k: usize,
) -> Option<Vec<Vec<(RowId, f32)>>> {
let n = row_ids.len();
let q = queries.len();
let rt = try_open(RT_LIBS)?;
let blas_lib = try_open(BLAS_LIBS)?;
let (cuda_malloc, cuda_free, cuda_memcpy, cuda_sync) = unsafe { load_cuda_fns(&rt) }?;
let blas_create: Symbol<unsafe extern "C" fn(*mut *mut c_void) -> i32> =
unsafe { blas_lib.get(b"cublasCreate_v2\0") }.ok()?;
let blas_destroy: unsafe extern "C" fn(*mut c_void) -> i32 = *unsafe {
blas_lib.get::<unsafe extern "C" fn(*mut c_void) -> i32>(b"cublasDestroy_v2\0")
}
.ok()?;
let sgemm: Symbol<SgemmFn> = unsafe { blas_lib.get(b"cublasSgemm_v2\0") }.ok()?;
let mut raw_handle: *mut c_void = std::ptr::null_mut();
if unsafe { blas_create(&mut raw_handle) } != 0 {
return None;
}
let _blas = BlasHandle {
handle: raw_handle,
destroy_fn: blas_destroy,
};
let q_flat: Vec<f32>;
let db_data: &[f32];
let q_data: &[f32];
let q_owned;
let db_owned;
match metric {
VectorMetric::Cosine => {
q_owned = normalize_rows(
queries.iter().flat_map(|q| q.iter().copied()).collect(),
dim,
);
db_owned = normalize_rows(flat_vecs.to_vec(), dim);
q_data = &q_owned;
db_data = &db_owned;
}
_ => {
q_flat = queries.iter().flat_map(|q| q.iter().copied()).collect();
q_data = &q_flat;
db_data = flat_vecs;
}
}
let db_dev = unsafe { upload(db_data, cuda_malloc, cuda_free, cuda_memcpy) }?;
let q_dev = unsafe { upload(q_data, cuda_malloc, cuda_free, cuda_memcpy) }?;
let c_dev = unsafe { alloc_dev(n * q, cuda_malloc, cuda_free) }?;
let (alpha, beta) = match metric {
VectorMetric::DotProduct | VectorMetric::NormalizedCosine => (-1.0f32, 0.0f32),
VectorMetric::Cosine => (-1.0f32, 0.0f32),
VectorMetric::Euclidean => (-2.0f32, 0.0f32),
};
let rc = unsafe {
sgemm(
raw_handle,
OP_T,
OP_N,
n as i32,
q as i32,
dim as i32,
&alpha,
db_dev.ptr as *const c_void,
dim as i32,
q_dev.ptr as *const c_void,
dim as i32,
&beta,
c_dev.ptr,
n as i32,
)
};
if rc != 0 {
return None;
}
if unsafe { cuda_sync() } != 0 {
return None;
}
let mut c_host = vec![0.0f32; n * q];
if unsafe {
cuda_memcpy(
c_host.as_mut_ptr() as *mut c_void,
c_dev.ptr as *const c_void,
n * q * std::mem::size_of::<f32>(),
D2H,
)
} != 0
{
return None;
}
let db_sq: Option<Vec<f32>> = if matches!(metric, VectorMetric::Euclidean) {
Some(
(0..n)
.map(|ni| {
flat_vecs[ni * dim..(ni + 1) * dim]
.iter()
.map(|x| x * x)
.sum()
})
.collect(),
)
} else {
None
};
let results = (0..q)
.map(|qi| {
let dists: Vec<f32> = (0..n)
.map(|ni| {
let raw = c_host[ni + qi * n];
match metric {
VectorMetric::DotProduct => raw,
VectorMetric::Cosine | VectorMetric::NormalizedCosine => 1.0 + raw,
VectorMetric::Euclidean => {
let q_sq: f32 = queries[qi].iter().map(|x| x * x).sum();
(q_sq + db_sq.as_ref().unwrap()[ni] + raw).max(0.0).sqrt()
}
}
})
.collect();
let mut indexed: Vec<(usize, f32)> = dists.into_iter().enumerate().collect();
indexed.sort_unstable_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
});
indexed.truncate(top_k);
indexed
.into_iter()
.map(|(i, d)| (RowId::new(row_ids[i]), d))
.collect()
})
.collect();
Some(results)
}
fn kmeans_inner(
vectors: &[Vec<f32>],
k: usize,
max_iter: usize,
n: usize,
dim: usize,
) -> Option<Vec<Vec<f32>>> {
let rt = try_open(RT_LIBS)?;
let blas_lib = try_open(BLAS_LIBS)?;
let (cuda_malloc, cuda_free, cuda_memcpy, cuda_sync) = unsafe { load_cuda_fns(&rt) }?;
let blas_create: Symbol<unsafe extern "C" fn(*mut *mut c_void) -> i32> =
unsafe { blas_lib.get(b"cublasCreate_v2\0") }.ok()?;
let blas_destroy: unsafe extern "C" fn(*mut c_void) -> i32 = *unsafe {
blas_lib.get::<unsafe extern "C" fn(*mut c_void) -> i32>(b"cublasDestroy_v2\0")
}
.ok()?;
let sgemm: Symbol<SgemmFn> = unsafe { blas_lib.get(b"cublasSgemm_v2\0") }.ok()?;
let mut raw_handle: *mut c_void = std::ptr::null_mut();
if unsafe { blas_create(&mut raw_handle) } != 0 {
return None;
}
let _blas = BlasHandle {
handle: raw_handle,
destroy_fn: blas_destroy,
};
let flat: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
let x_dev = unsafe { upload(&flat, cuda_malloc, cuda_free, cuda_memcpy) }?;
let x_sq: Vec<f32> = vectors
.iter()
.map(|v| v.iter().map(|x| x * x).sum())
.collect();
let step = n / k;
let mut centroids_flat: Vec<f32> = (0..k)
.flat_map(|i| vectors[(i * step) % n].iter().copied())
.collect();
let mut prev_asgn: Vec<u32> = vec![];
for _ in 0..max_iter {
let c_dev = unsafe { upload(¢roids_flat, cuda_malloc, cuda_free, cuda_memcpy) }?;
let cross_dev = unsafe { alloc_dev(k * n, cuda_malloc, cuda_free) }?;
let alpha = -2.0f32;
let beta = 0.0f32;
let rc = unsafe {
sgemm(
raw_handle,
OP_T,
OP_N,
k as i32,
n as i32,
dim as i32,
&alpha,
c_dev.ptr as *const c_void,
dim as i32,
x_dev.ptr as *const c_void,
dim as i32,
&beta,
cross_dev.ptr,
k as i32,
)
};
if rc != 0 {
return None;
}
if unsafe { cuda_sync() } != 0 {
return None;
}
let mut cross_host = vec![0.0f32; k * n];
if unsafe {
cuda_memcpy(
cross_host.as_mut_ptr() as *mut c_void,
cross_dev.ptr as *const c_void,
k * n * std::mem::size_of::<f32>(),
D2H,
)
} != 0
{
return None;
}
let c_sq: Vec<f32> = centroids_flat
.chunks(dim)
.map(|c| c.iter().map(|x| x * x).sum())
.collect();
let asgn: Vec<u32> = (0..n)
.map(|ni| {
let base = &cross_host[ni * k..(ni + 1) * k];
let best = (0..k)
.min_by(|&a, &b| {
let da = x_sq[ni] + c_sq[a] + base[a];
let db = x_sq[ni] + c_sq[b] + base[b];
da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(0);
best as u32
})
.collect();
if asgn == prev_asgn {
break;
}
let mut new_flat = vec![0.0f32; k * dim];
let mut counts = vec![0usize; k];
for (i, &ci) in asgn.iter().enumerate() {
let ci = ci as usize;
for (d, &v) in vectors[i].iter().enumerate() {
new_flat[ci * dim + d] += v;
}
counts[ci] += 1;
}
for j in 0..k {
if counts[j] > 0 {
let inv = 1.0 / counts[j] as f32;
new_flat[j * dim..(j + 1) * dim]
.iter_mut()
.for_each(|x| *x *= inv);
} else {
new_flat[j * dim..(j + 1) * dim]
.copy_from_slice(¢roids_flat[j * dim..(j + 1) * dim]);
}
}
centroids_flat = new_flat;
prev_asgn = asgn;
}
Some(centroids_flat.chunks(dim).map(|c| c.to_vec()).collect())
}
}
#[cfg(test)]
mod tests {
use ailake_core::{RowId, VectorMetric};
fn gpu_backend() -> String {
std::env::var("AILAKE_GPU_BACKEND").unwrap_or_else(|_| "none".into())
}
fn make_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
(0..n)
.map(|i| (0..dim).map(|d| ((i * dim + d + 1) as f32).sin()).collect())
.collect()
}
#[test]
fn gpu_search_batch_cosine_top1_exact() {
let backend = gpu_backend();
if backend == "none" {
println!("AILAKE_GPU_BACKEND=none — skipping gpu_search_batch_cosine_top1_exact");
return;
}
let dim = 16;
let vecs = make_vecs(64, dim);
let flat: Vec<f32> = vecs.iter().flat_map(|v| v.iter().copied()).collect();
let row_ids: Vec<u64> = (0..64).collect();
let q = vecs[0].clone();
let queries: &[&[f32]] = &[q.as_slice()];
let got = match backend.as_str() {
"cuda" => super::try_nvidia_search_batch(
queries,
&row_ids,
&flat,
dim,
VectorMetric::Cosine,
5,
),
"rocm" => {
super::try_rocm_search_batch(queries, &row_ids, &flat, dim, VectorMetric::Cosine, 5)
}
other => panic!("unknown AILAKE_GPU_BACKEND={other}"),
};
let got = got.expect("GPU cosine search returned None — check driver/library installation");
assert_eq!(got.len(), 1);
let (top_row, top_dist) = got[0][0];
assert_eq!(top_row, RowId::new(0), "top-1 must be the query itself");
assert!(
top_dist < 1e-3,
"cosine dist to self must be ≈0, got {top_dist}"
);
}
#[test]
fn gpu_search_batch_euclidean_top1_exact() {
let backend = gpu_backend();
if backend == "none" {
println!("AILAKE_GPU_BACKEND=none — skipping gpu_search_batch_euclidean_top1_exact");
return;
}
let dim = 8;
let vecs = make_vecs(32, dim);
let flat: Vec<f32> = vecs.iter().flat_map(|v| v.iter().copied()).collect();
let row_ids: Vec<u64> = (0..32).collect();
let q = vecs[7].clone();
let queries: &[&[f32]] = &[q.as_slice()];
let got = match backend.as_str() {
"cuda" => super::try_nvidia_search_batch(
queries,
&row_ids,
&flat,
dim,
VectorMetric::Euclidean,
3,
),
"rocm" => super::try_rocm_search_batch(
queries,
&row_ids,
&flat,
dim,
VectorMetric::Euclidean,
3,
),
other => panic!("unknown AILAKE_GPU_BACKEND={other}"),
};
let got = got.expect("GPU euclidean search returned None");
let (top_row, top_dist) = got[0][0];
assert_eq!(top_row, RowId::new(7), "top-1 must be the query itself");
assert!(
top_dist < 1e-4,
"euclidean dist to self must be 0, got {top_dist}"
);
}
#[test]
fn gpu_kmeans_returns_k_centroids() {
let backend = gpu_backend();
if backend == "none" {
println!("AILAKE_GPU_BACKEND=none — skipping gpu_kmeans_returns_k_centroids");
return;
}
let dim = 8;
let k = 4usize;
let vecs: Vec<Vec<f32>> = (0..k)
.flat_map(|c| {
(0..10).map(move |_| {
(0..dim)
.map(|d| c as f32 * 20.0 + d as f32 * 0.01)
.collect()
})
})
.collect();
let centroids = match backend.as_str() {
"cuda" => super::try_nvidia_kmeans(&vecs, k, 20),
"rocm" => super::try_rocm_kmeans(&vecs, k, 20),
other => panic!("unknown AILAKE_GPU_BACKEND={other}"),
};
let centroids =
centroids.expect("GPU k-means returned None — check driver/library installation");
assert_eq!(
centroids.len(),
k,
"expected {k} centroids, got {}",
centroids.len()
);
for c in ¢roids {
assert_eq!(
c.len(),
dim,
"centroid dim mismatch: expected {dim}, got {}",
c.len()
);
}
}
}
mod rocm_impl {
use std::ffi::c_void;
use ailake_core::{RowId, VectorMetric};
use libloading::{Library, Symbol};
use tracing::warn;
const H2D: i32 = 1; const D2H: i32 = 2;
const OP_T: i32 = 112; const OP_N: i32 = 111;
type HipMallocFn = unsafe extern "C" fn(*mut *mut c_void, usize) -> i32;
type HipFreeFn = unsafe extern "C" fn(*mut c_void) -> i32;
type HipMemcpyFn = unsafe extern "C" fn(*mut c_void, *const c_void, usize, i32) -> i32;
type HipSyncFn = unsafe extern "C" fn() -> i32;
#[cfg(target_os = "linux")]
const HIP_LIB: &str = "libamdhip64.so";
#[cfg(windows)]
const HIP_LIB: &str = "amdhip64.dll";
#[cfg(not(any(target_os = "linux", windows)))]
const HIP_LIB: &str = "";
#[cfg(target_os = "linux")]
const BLAS_LIB: &str = "libhipblas.so";
#[cfg(windows)]
const BLAS_LIB: &str = "hipblas.dll";
#[cfg(not(any(target_os = "linux", windows)))]
const BLAS_LIB: &str = "";
type SgemmFn = unsafe extern "C" fn(
*mut c_void, i32, i32, i32, i32, i32, *const f32, *const c_void,
i32, *const c_void,
i32, *const f32, *mut c_void,
i32, ) -> i32;
struct DevBuf {
ptr: *mut c_void,
free_fn: unsafe extern "C" fn(*mut c_void) -> i32,
}
impl Drop for DevBuf {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { (self.free_fn)(self.ptr) };
}
}
}
struct BlasHandle {
handle: *mut c_void,
destroy_fn: unsafe extern "C" fn(*mut c_void) -> i32,
}
impl Drop for BlasHandle {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { (self.destroy_fn)(self.handle) };
}
}
}
pub fn try_rocm_search_batch(
queries: &[&[f32]],
row_ids: &[u64],
flat_vecs: &[f32],
dim: usize,
metric: VectorMetric,
top_k: usize,
) -> Option<Vec<Vec<(RowId, f32)>>> {
if !crate::hardware::detect_rocm() {
return None;
}
if HIP_LIB.is_empty() || BLAS_LIB.is_empty() {
return None;
}
let n = row_ids.len();
let q = queries.len();
if n == 0 || q == 0 {
return Some(vec![vec![]; q]);
}
let result = batch_inner(queries, row_ids, flat_vecs, dim, metric, top_k);
if result.is_none() {
warn!(
"ailake: AMD ROCm GPU search failed at runtime (hipBLAS error or allocation failure); \
falling back to CPU SIMD — check ROCm runtime libraries and available GPU memory"
);
}
result
}
pub fn try_rocm_kmeans(
vectors: &[Vec<f32>],
k: usize,
max_iter: usize,
) -> Option<Vec<Vec<f32>>> {
if !crate::hardware::detect_rocm() {
return None;
}
if HIP_LIB.is_empty() || BLAS_LIB.is_empty() {
return None;
}
let n = vectors.len();
if n == 0 {
return Some(vec![]);
}
let dim = vectors[0].len();
let k = k.min(n);
let result = kmeans_inner(vectors, k, max_iter, n, dim);
if result.is_none() {
warn!(
"ailake: AMD ROCm GPU k-means failed at runtime (hipBLAS error or allocation failure); \
falling back to CPU k-means — check ROCm runtime libraries and available GPU memory"
);
}
result
}
unsafe fn load_hip_fns(
lib: &Library,
) -> Option<(HipMallocFn, HipFreeFn, HipMemcpyFn, HipSyncFn)> {
let malloc_sym: Symbol<HipMallocFn> = lib.get(b"hipMalloc\0").ok()?;
let free_sym: Symbol<HipFreeFn> = lib.get(b"hipFree\0").ok()?;
let memcpy_sym: Symbol<HipMemcpyFn> = lib.get(b"hipMemcpy\0").ok()?;
let sync_sym: Symbol<HipSyncFn> = lib.get(b"hipDeviceSynchronize\0").ok()?;
Some((*malloc_sym, *free_sym, *memcpy_sym, *sync_sym))
}
unsafe fn upload(
data: &[f32],
malloc_fn: HipMallocFn,
free_fn: HipFreeFn,
memcpy_fn: HipMemcpyFn,
) -> Option<DevBuf> {
let bytes = std::mem::size_of_val(data);
let mut ptr: *mut c_void = std::ptr::null_mut();
if malloc_fn(&mut ptr, bytes) != 0 {
return None;
}
let buf = DevBuf { ptr, free_fn };
if memcpy_fn(ptr, data.as_ptr() as *const c_void, bytes, H2D) != 0 {
return None;
}
Some(buf)
}
unsafe fn alloc_dev(len: usize, malloc_fn: HipMallocFn, free_fn: HipFreeFn) -> Option<DevBuf> {
let bytes = len * std::mem::size_of::<f32>();
let mut ptr: *mut c_void = std::ptr::null_mut();
if malloc_fn(&mut ptr, bytes) != 0 {
return None;
}
Some(DevBuf { ptr, free_fn })
}
fn normalize_rows(mut data: Vec<f32>, dim: usize) -> Vec<f32> {
for row in data.chunks_mut(dim) {
let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
row.iter_mut().for_each(|x| *x /= norm);
}
}
data
}
fn batch_inner(
queries: &[&[f32]],
row_ids: &[u64],
flat_vecs: &[f32],
dim: usize,
metric: VectorMetric,
top_k: usize,
) -> Option<Vec<Vec<(RowId, f32)>>> {
let n = row_ids.len();
let q = queries.len();
let hip = unsafe { Library::new(HIP_LIB) }.ok()?;
let blas_lib = unsafe { Library::new(BLAS_LIB) }.ok()?;
let (hip_malloc, hip_free, hip_memcpy, hip_sync) = unsafe { load_hip_fns(&hip) }?;
let blas_create: Symbol<unsafe extern "C" fn(*mut *mut c_void) -> i32> =
unsafe { blas_lib.get(b"hipblasCreate\0") }.ok()?;
let blas_destroy: unsafe extern "C" fn(*mut c_void) -> i32 = *unsafe {
blas_lib.get::<unsafe extern "C" fn(*mut c_void) -> i32>(b"hipblasDestroy\0")
}
.ok()?;
let sgemm: Symbol<SgemmFn> = unsafe { blas_lib.get(b"hipblasSgemm\0") }.ok()?;
let mut raw_handle: *mut c_void = std::ptr::null_mut();
if unsafe { blas_create(&mut raw_handle) } != 0 {
return None;
}
let _blas = BlasHandle {
handle: raw_handle,
destroy_fn: blas_destroy,
};
let q_flat: Vec<f32>;
let db_data: &[f32];
let q_data: &[f32];
let q_owned;
let db_owned;
match metric {
VectorMetric::Cosine => {
q_owned = normalize_rows(
queries.iter().flat_map(|q| q.iter().copied()).collect(),
dim,
);
db_owned = normalize_rows(flat_vecs.to_vec(), dim);
q_data = &q_owned;
db_data = &db_owned;
}
_ => {
q_flat = queries.iter().flat_map(|q| q.iter().copied()).collect();
q_data = &q_flat;
db_data = flat_vecs;
}
}
let db_dev = unsafe { upload(db_data, hip_malloc, hip_free, hip_memcpy) }?;
let q_dev = unsafe { upload(q_data, hip_malloc, hip_free, hip_memcpy) }?;
let c_dev = unsafe { alloc_dev(n * q, hip_malloc, hip_free) }?;
let (alpha, beta) = match metric {
VectorMetric::DotProduct | VectorMetric::NormalizedCosine => (-1.0f32, 0.0f32), VectorMetric::Cosine => (-1.0f32, 0.0f32), VectorMetric::Euclidean => (-2.0f32, 0.0f32), };
let rc = unsafe {
sgemm(
raw_handle,
OP_T,
OP_N,
n as i32,
q as i32,
dim as i32,
&alpha,
db_dev.ptr as *const c_void,
dim as i32,
q_dev.ptr as *const c_void,
dim as i32,
&beta,
c_dev.ptr,
n as i32,
)
};
if rc != 0 {
return None;
}
if unsafe { hip_sync() } != 0 {
return None;
}
let mut c_host = vec![0.0f32; n * q];
if unsafe {
hip_memcpy(
c_host.as_mut_ptr() as *mut c_void,
c_dev.ptr as *const c_void,
n * q * std::mem::size_of::<f32>(),
D2H,
)
} != 0
{
return None;
}
let db_sq: Option<Vec<f32>> = if matches!(metric, VectorMetric::Euclidean) {
Some(
(0..n)
.map(|ni| {
flat_vecs[ni * dim..(ni + 1) * dim]
.iter()
.map(|x| x * x)
.sum()
})
.collect(),
)
} else {
None
};
let results = (0..q)
.map(|qi| {
let dists: Vec<f32> = (0..n)
.map(|ni| {
let raw = c_host[ni + qi * n];
match metric {
VectorMetric::DotProduct => raw,
VectorMetric::Cosine | VectorMetric::NormalizedCosine => 1.0 + raw,
VectorMetric::Euclidean => {
let q_sq: f32 = queries[qi].iter().map(|x| x * x).sum();
(q_sq + db_sq.as_ref().unwrap()[ni] + raw).max(0.0).sqrt()
}
}
})
.collect();
let mut indexed: Vec<(usize, f32)> = dists.into_iter().enumerate().collect();
indexed.sort_unstable_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
});
indexed.truncate(top_k);
indexed
.into_iter()
.map(|(i, d)| (RowId::new(row_ids[i]), d))
.collect()
})
.collect();
Some(results)
}
fn kmeans_inner(
vectors: &[Vec<f32>],
k: usize,
max_iter: usize,
n: usize,
dim: usize,
) -> Option<Vec<Vec<f32>>> {
let hip = unsafe { Library::new(HIP_LIB) }.ok()?;
let blas_lib = unsafe { Library::new(BLAS_LIB) }.ok()?;
let (hip_malloc, hip_free, hip_memcpy, hip_sync) = unsafe { load_hip_fns(&hip) }?;
let blas_create: Symbol<unsafe extern "C" fn(*mut *mut c_void) -> i32> =
unsafe { blas_lib.get(b"hipblasCreate\0") }.ok()?;
let blas_destroy: unsafe extern "C" fn(*mut c_void) -> i32 = *unsafe {
blas_lib.get::<unsafe extern "C" fn(*mut c_void) -> i32>(b"hipblasDestroy\0")
}
.ok()?;
let sgemm: Symbol<SgemmFn> = unsafe { blas_lib.get(b"hipblasSgemm\0") }.ok()?;
let mut raw_handle: *mut c_void = std::ptr::null_mut();
if unsafe { blas_create(&mut raw_handle) } != 0 {
return None;
}
let _blas = BlasHandle {
handle: raw_handle,
destroy_fn: blas_destroy,
};
let flat: Vec<f32> = vectors.iter().flat_map(|v| v.iter().copied()).collect();
let x_dev = unsafe { upload(&flat, hip_malloc, hip_free, hip_memcpy) }?;
let x_sq: Vec<f32> = vectors
.iter()
.map(|v| v.iter().map(|x| x * x).sum())
.collect();
let step = n / k;
let mut centroids_flat: Vec<f32> = (0..k)
.flat_map(|i| vectors[(i * step) % n].iter().copied())
.collect();
let mut prev_asgn: Vec<u32> = vec![];
for _ in 0..max_iter {
let c_dev = unsafe { upload(¢roids_flat, hip_malloc, hip_free, hip_memcpy) }?;
let cross_dev = unsafe { alloc_dev(k * n, hip_malloc, hip_free) }?;
let alpha = -2.0f32;
let beta = 0.0f32;
let rc = unsafe {
sgemm(
raw_handle,
OP_T,
OP_N,
k as i32,
n as i32,
dim as i32,
&alpha,
c_dev.ptr as *const c_void,
dim as i32,
x_dev.ptr as *const c_void,
dim as i32,
&beta,
cross_dev.ptr,
k as i32,
)
};
if rc != 0 {
return None;
}
if unsafe { hip_sync() } != 0 {
return None;
}
let mut cross_host = vec![0.0f32; k * n];
if unsafe {
hip_memcpy(
cross_host.as_mut_ptr() as *mut c_void,
cross_dev.ptr as *const c_void,
k * n * std::mem::size_of::<f32>(),
D2H,
)
} != 0
{
return None;
}
let c_sq: Vec<f32> = centroids_flat
.chunks(dim)
.map(|c| c.iter().map(|x| x * x).sum())
.collect();
let asgn: Vec<u32> = (0..n)
.map(|ni| {
let base = &cross_host[ni * k..(ni + 1) * k];
let best = (0..k)
.min_by(|&a, &b| {
let da = x_sq[ni] + c_sq[a] + base[a];
let db = x_sq[ni] + c_sq[b] + base[b];
da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(0);
best as u32
})
.collect();
if asgn == prev_asgn {
break;
}
let mut new_flat = vec![0.0f32; k * dim];
let mut counts = vec![0usize; k];
for (i, &ci) in asgn.iter().enumerate() {
let ci = ci as usize;
for (d, &v) in vectors[i].iter().enumerate() {
new_flat[ci * dim + d] += v;
}
counts[ci] += 1;
}
for j in 0..k {
if counts[j] > 0 {
let inv = 1.0 / counts[j] as f32;
new_flat[j * dim..(j + 1) * dim]
.iter_mut()
.for_each(|x| *x *= inv);
} else {
new_flat[j * dim..(j + 1) * dim]
.copy_from_slice(¢roids_flat[j * dim..(j + 1) * dim]);
}
}
centroids_flat = new_flat;
prev_asgn = asgn;
}
Some(centroids_flat.chunks(dim).map(|c| c.to_vec()).collect())
}
}