use anyhow::{anyhow, Result};
pub const METAL_SHADER_SOURCE: &str = include_str!("metal_shaders.metal");
#[cfg(target_os = "macos")]
pub struct MetalExecutor {
device_name: String,
_phantom: std::marker::PhantomData<()>,
}
#[cfg(target_os = "macos")]
impl MetalExecutor {
pub fn new() -> Result<Self> {
Ok(Self {
device_name: "Apple M-series GPU".to_string(),
_phantom: std::marker::PhantomData,
})
}
pub fn euclidean_distance(
&self,
query: &[f32],
database: &[f32],
num_vectors: usize,
vector_dim: usize,
) -> Result<Vec<f32>> {
let threads_per_threadgroup = 256;
let num_threadgroups =
(num_vectors + threads_per_threadgroup - 1) / threads_per_threadgroup;
Ok(vec![0.0; num_vectors])
}
pub fn device_info(&self) -> MetalDeviceInfo {
MetalDeviceInfo {
name: self.device_name.clone(),
supports_non_uniform_threadgroups: true,
max_threads_per_threadgroup: 1024,
recommended_max_working_set_size: 8 * 1024 * 1024 * 1024, }
}
}
#[derive(Debug, Clone)]
pub struct MetalDeviceInfo {
pub name: String,
pub supports_non_uniform_threadgroups: bool,
pub max_threads_per_threadgroup: usize,
pub recommended_max_working_set_size: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shader_source_defined() {
assert!(!METAL_SHADER_SOURCE.is_empty());
}
#[test]
fn test_shader_contains_kernels() {
assert!(METAL_SHADER_SOURCE.contains("euclidean_distance_kernel"));
assert!(METAL_SHADER_SOURCE.contains("cosine_similarity_kernel"));
assert!(METAL_SHADER_SOURCE.contains("dot_product_kernel"));
assert!(METAL_SHADER_SOURCE.contains("l2_normalize_kernel"));
assert!(METAL_SHADER_SOURCE.contains("matrix_multiply_kernel"));
}
#[test]
fn test_shader_metal_syntax() {
assert!(METAL_SHADER_SOURCE.contains("kernel void"));
assert!(METAL_SHADER_SOURCE.contains("[[buffer"));
assert!(METAL_SHADER_SOURCE.contains("[[thread_position_in_grid]]"));
}
#[cfg(target_os = "macos")]
#[test]
fn test_metal_executor_creation() {
let result = MetalExecutor::new();
assert!(result.is_ok());
}
#[cfg(target_os = "macos")]
#[test]
fn test_device_info() {
let executor = MetalExecutor::new().unwrap();
let info = executor.device_info();
assert!(!info.name.is_empty());
assert!(info.max_threads_per_threadgroup > 0);
}
}