#[cfg(all(feature = "mpsgraph", target_os = "macos"))]
use scirs2_core::gpu::backends::metal_mpsgraph::MPSGraphContext;
#[cfg(all(feature = "mpsgraph", target_os = "macos"))]
fn main() {
use objc2::msg_send;
use objc2::rc::Retained;
use objc2::runtime::ProtocolObject;
use objc2_metal::{MTLBuffer, MTLCommandQueue, MTLCreateSystemDefaultDevice};
println!("MPSGraph Operations Example");
println!("============================\n");
unsafe {
let device = match MTLCreateSystemDefaultDevice() {
Some(dev) => dev,
None => {
eprintln!("Error: No Metal device available");
return;
}
};
println!("✓ Metal device initialized");
let command_queue: Retained<ProtocolObject<dyn MTLCommandQueue>> =
msg_send![&device, newCommandQueue];
println!("✓ Command queue created\n");
let ctx = MPSGraphContext::new(device.clone(), command_queue);
println!("✓ MPSGraphContext initialized\n");
println!("Example 1: Matrix Multiplication");
println!("---------------------------------");
{
let m = 4;
let k = 4;
let n = 4;
let a_buffer: Retained<ProtocolObject<dyn MTLBuffer>> = msg_send![
&device,
newBufferWithLength: (m * k * 4) as u64,
options: objc2_metal::MTLResourceOptions::StorageModeShared
];
let b_buffer: Retained<ProtocolObject<dyn MTLBuffer>> = msg_send![
&device,
newBufferWithLength: (k * n * 4) as u64,
options: objc2_metal::MTLResourceOptions::StorageModeShared
];
match ctx.matmul(&a_buffer, &b_buffer, m, k, n) {
Ok(_) => println!(
"✓ Matrix multiplication [{}×{}] @ [{}×{}] succeeded",
m, k, k, n
),
Err(e) => println!("✗ Error: {}", e),
}
}
println!();
println!("Example 2: Softmax");
println!("------------------");
{
let shape = [2, 8, 16];
let size: usize = shape.iter().product();
let input_buffer: Retained<ProtocolObject<dyn MTLBuffer>> = msg_send![
&device,
newBufferWithLength: (size * 4) as u64,
options: objc2_metal::MTLResourceOptions::StorageModeShared
];
match ctx.softmax(&input_buffer, &shape, -1) {
Ok(_) => println!("✓ Softmax on shape {:?} along axis -1 succeeded", shape),
Err(e) => println!("✗ Error: {}", e),
}
}
println!();
println!("Example 3: GeLU Activation");
println!("--------------------------");
{
let shape = [2, 128, 512];
let size: usize = shape.iter().product();
let input_buffer: Retained<ProtocolObject<dyn MTLBuffer>> = msg_send![
&device,
newBufferWithLength: (size * 4) as u64,
options: objc2_metal::MTLResourceOptions::StorageModeShared
];
match ctx.gelu(&input_buffer, &shape) {
Ok(_) => println!("✓ GeLU activation on shape {:?} succeeded", shape),
Err(e) => println!("✗ Error: {}", e),
}
}
println!();
println!("Example 4: LayerNorm");
println!("--------------------");
{
let shape = [2, 128, 768];
let normalized_shape = [768];
let size: usize = shape.iter().product();
let norm_size: usize = normalized_shape.iter().product();
let input_buffer: Retained<ProtocolObject<dyn MTLBuffer>> = msg_send![
&device,
newBufferWithLength: (size * 4) as u64,
options: objc2_metal::MTLResourceOptions::StorageModeShared
];
let gamma_buffer: Retained<ProtocolObject<dyn MTLBuffer>> = msg_send![
&device,
newBufferWithLength: (norm_size * 4) as u64,
options: objc2_metal::MTLResourceOptions::StorageModeShared
];
let beta_buffer: Retained<ProtocolObject<dyn MTLBuffer>> = msg_send![
&device,
newBufferWithLength: (norm_size * 4) as u64,
options: objc2_metal::MTLResourceOptions::StorageModeShared
];
match ctx.layer_norm(
&input_buffer,
&gamma_buffer,
&beta_buffer,
&shape,
&normalized_shape,
1e-5,
) {
Ok(_) => println!("✓ LayerNorm on shape {:?} succeeded", shape),
Err(e) => println!("✗ Error: {}", e),
}
}
println!();
println!("Example 5: Scaled Dot-Product Attention (SDPA)");
println!("-----------------------------------------------");
{
let batch = 1;
let num_heads = 8;
let seq_len = 64;
let head_dim = 64;
let qkv_size = batch * num_heads * seq_len * head_dim;
let q_buffer: Retained<ProtocolObject<dyn MTLBuffer>> = msg_send![
&device,
newBufferWithLength: (qkv_size * 4) as u64,
options: objc2_metal::MTLResourceOptions::StorageModeShared
];
let k_buffer: Retained<ProtocolObject<dyn MTLBuffer>> = msg_send![
&device,
newBufferWithLength: (qkv_size * 4) as u64,
options: objc2_metal::MTLResourceOptions::StorageModeShared
];
let v_buffer: Retained<ProtocolObject<dyn MTLBuffer>> = msg_send![
&device,
newBufferWithLength: (qkv_size * 4) as u64,
options: objc2_metal::MTLResourceOptions::StorageModeShared
];
match ctx.scaled_dot_product_attention(
&q_buffer, &k_buffer, &v_buffer, batch, num_heads, seq_len, seq_len, head_dim,
true, ) {
Ok(_) => println!(
"✓ SDPA [batch={}, heads={}, seq={}, dim={}] with causal mask succeeded",
batch, num_heads, seq_len, head_dim
),
Err(e) => println!("✗ Error: {}", e),
}
}
println!();
println!("Summary");
println!("=======");
println!("✓ All MPSGraph operations are available and functional");
println!("✓ API structure is complete for TrustformeRS integration");
println!();
println!("Note: Graph execution is currently stubbed.");
println!(" Graphs are correctly built but not yet executed on GPU.");
println!(" Expected performance gains will be realized once execution is complete.");
println!();
println!("Target Performance (from MPSGRAPH.md):");
println!(" - SDPA: 10-30x speedup");
println!(" - Matmul: 5-10x speedup");
println!(" - GeLU: 10-20x speedup");
println!(" - End-to-end: 50-200x for rinna-1b GPT-2");
}
}
#[cfg(not(all(feature = "mpsgraph", target_os = "macos")))]
fn main() {
println!("This example requires macOS with MPSGraph support.");
println!("Please run on macOS with: cargo run --example mpsgraph_operations_example --features mpsgraph");
}