hanzo-rocm-kernels 0.10.2

ROCm/HIP kernels for Hanzo
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
use crate::error::KernelError;
use crate::wrappers::SendSyncModule;
use rocm_rs::hip::Device;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::sync::{Arc, Mutex};

/// Single unified cache for compiled kernel modules.
///
/// Combines the functionality of the old CacheManager (disk cache)
/// and KernelManager (module cache) into one simpler struct.
pub struct KernelCache {
    cache_dir: PathBuf,
    arch: String,
    rocm_version: String,
    modules: Mutex<HashMap<&'static str, Arc<SendSyncModule>>>,
    /// Resolved kernel function handles (hipFunction_t as usize), keyed by func name.
    /// hipModuleGetFunction is a slow driver round-trip through the WSL bridge (~90us),
    /// so resolving once per function (not per op) is a major decode speedup. The handle
    /// stays valid because its module is held forever in `modules`.
    functions: Mutex<HashMap<String, usize>>,
}

impl KernelCache {
    /// Create a new KernelCache for the given device
    pub fn new(device: &Device) -> Result<Self, KernelError> {
        let arch = detect_gpu_arch(device)?;
        let rocm_version = detect_rocm_version()?;
        let cache_dir = get_cache_dir()?;

        // Create cache directory structure: ~/.cache/hanzo-ml-rocm/{arch}-{rocm_version}/
        let arch_version = format!("{}-{}", arch, rocm_version);
        let kernel_dir = cache_dir.join(&arch_version);
        fs::create_dir_all(&kernel_dir).map_err(|e| {
            KernelError::Io(format!(
                "Failed to create cache directory {}: {}",
                kernel_dir.display(),
                e
            ))
        })?;

        Ok(Self {
            cache_dir: kernel_dir,
            arch,
            rocm_version,
            modules: Mutex::new(HashMap::new()),
            functions: Mutex::new(HashMap::new()),
        })
    }

    /// Get a resolved kernel function handle (hipFunction_t as usize), caching it so
    /// hipModuleGetFunction runs once per function instead of once per op.
    pub fn get_func_raw(
        &self,
        module_name: &'static str,
        source: &'static str,
        func_name: &str,
    ) -> Result<usize, KernelError> {
        {
            let funcs = self
                .functions
                .lock()
                .map_err(|_| KernelError::Internal("Failed to lock functions cache".to_string()))?;
            if let Some(&ptr) = funcs.get(func_name) {
                return Ok(ptr);
            }
        }
        let module = self.get_or_load(module_name, source)?;
        let func = module.get_function(func_name).map_err(|e| {
            KernelError::Compilation(format!("Kernel function {} not found: {}", func_name, e))
        })?;
        let raw = func.as_raw() as usize;
        self.functions
            .lock()
            .map_err(|_| KernelError::Internal("Failed to lock functions cache".to_string()))?
            .insert(func_name.to_string(), raw);
        Ok(raw)
    }

    /// Get or compile a kernel module.
    ///
    /// This method checks the in-memory cache first, then the disk cache,
    /// and compiles from source if needed.
    pub fn get_or_load(
        &self,
        name: &'static str,
        source: &'static str,
    ) -> Result<Arc<SendSyncModule>, KernelError> {
        // Check in-memory cache first
        {
            let modules = self
                .modules
                .lock()
                .map_err(|_| KernelError::Internal("Failed to lock modules cache".to_string()))?;
            if let Some(module) = modules.get(name) {
                return Ok(module.clone());
            }
        }

        // Compute hash of source to version the cache
        let source_hash = compute_source_hash(source);
        let cache_file = self.cache_dir.join(format!("{}_{}.cso", name, source_hash));

        // Try to load from disk cache or compile
        let binary = if cache_file.exists() {
            fs::read(&cache_file).map_err(|e| {
                KernelError::Io(format!(
                    "Failed to read cached binary {}: {}",
                    cache_file.display(),
                    e
                ))
            })?
        } else {
            let binary = compile_kernel(name, source, &self.arch, &cache_file)?;
            fs::write(&cache_file, &binary).map_err(|e| {
                KernelError::Io(format!(
                    "Failed to write cache file {}: {}",
                    cache_file.display(),
                    e
                ))
            })?;
            binary
        };

        // Load module from binary
        let module = SendSyncModule::load_data(&binary).map_err(|e| {
            KernelError::Compilation(format!(
                "Failed to load module {} from compiled binary: {}",
                name, e
            ))
        })?;

        let module = Arc::new(module);

        // Store in memory cache
        {
            let mut modules = self
                .modules
                .lock()
                .map_err(|_| KernelError::Internal("Failed to lock modules cache".to_string()))?;
            modules.insert(name, module.clone());
        }

        Ok(module)
    }

    /// Get the cache directory path
    pub fn cache_dir(&self) -> &Path {
        &self.cache_dir
    }

    /// Get GPU architecture
    pub fn arch(&self) -> &str {
        &self.arch
    }

    /// Get ROCm version
    pub fn rocm_version(&self) -> &str {
        &self.rocm_version
    }
}

/// Detect the GPU architecture (e.g., "gfx908", "gfx90a", "gfx942")
fn detect_gpu_arch(_device: &Device) -> Result<String, KernelError> {
    // First try to get from environment variable (useful for testing/build machines)
    if let Ok(arch) = std::env::var("CANDLE_ROCM_ARCH") {
        return Ok(arch);
    }

    // Try to use rocminfo to detect the architecture
    match Command::new("rocminfo").arg("-a").output() {
        Ok(output) => {
            let stdout = String::from_utf8_lossy(&output.stdout);
            // Look for "Name:" line with gfxXXXX
            for line in stdout.lines() {
                if line.contains("Name:") && line.contains("gfx") {
                    if let Some(start) = line.find("gfx") {
                        let arch = &line[start..];
                        // Extract just the gfxXXXX part
                        let end = arch
                            .find(|c: char| !c.is_alphanumeric())
                            .unwrap_or(arch.len());
                        return Ok(arch[..end].to_string());
                    }
                }
            }
        }
        Err(e) => {
            eprintln!("Warning: Failed to run rocminfo: {}", e);
        }
    }

    // Try hipcc to get default arch
    match Command::new("hipcc").args(&["--version"]).output() {
        Ok(_) => {
            eprintln!("Warning: Could not detect GPU architecture, defaulting to gfx908");
            Ok("gfx908".to_string())
        }
        Err(e) => Err(KernelError::Compilation(format!(
            "hipcc not found: {}. Please install ROCm or set CANDLE_ROCM_ARCH environment variable",
            e
        ))),
    }
}

/// Detect ROCm version
fn detect_rocm_version() -> Result<String, KernelError> {
    // Try to get from environment variable first
    if let Ok(version) = std::env::var("CANDLE_ROCM_VERSION") {
        return Ok(version);
    }

    // Try to get from hipcc --version
    match Command::new("hipcc").args(&["--version"]).output() {
        Ok(output) => {
            let stdout = String::from_utf8_lossy(&output.stdout);
            // Parse version from output like "HIP version: 6.1.0"
            for line in stdout.lines() {
                if line.contains("HIP version:") || line.contains("HIP_VERSION:") {
                    if let Some(v) = line.split(':').nth(1) {
                        let version = v.trim().split('.').take(2).collect::<Vec<_>>().join(".");
                        return Ok(version);
                    }
                }
            }
            // If we can't parse, return a default
            Ok("6.0".to_string())
        }
        Err(e) => Err(KernelError::Compilation(format!(
            "hipcc not found: {}. Please install ROCm or set CANDLE_ROCM_VERSION environment variable",
            e
        ))),
    }
}

/// Get the base cache directory
fn get_cache_dir() -> Result<PathBuf, KernelError> {
    let home = dirs::cache_dir()
        .or_else(|| std::env::var("HOME").ok().map(PathBuf::from))
        .ok_or_else(|| KernelError::Internal("Could not determine cache directory".to_string()))?;

    Ok(home.join("hanzo-ml-rocm"))
}

/// Compute a hash of the source code
fn compute_source_hash(source: &str) -> String {
    let mut hasher = Sha256::new();
    hasher.update(source.as_bytes());
    let result = hasher.finalize();
    // Use first 16 characters of hex as hash
    format!("{:x}", result)[..16].to_string()
}

/// Compile a kernel using hipcc
fn compile_kernel(
    name: &str,
    source: &str,
    arch: &str,
    output_path: &Path,
) -> Result<Vec<u8>, KernelError> {
    let temp_dir = std::env::temp_dir();
    let source_hash = compute_source_hash(source);
    let source_file = temp_dir.join(format!("hanzo_{}_{}.hip", name, source_hash));
    let obj_file = temp_dir.join(format!("hanzo_{}_{}.o", name, source_hash));
    let fatbin_file = temp_dir.join(format!("hanzo_{}_{}.fatbin", name, source_hash));
    let hsaco_file = temp_dir.join(format!("hanzo_{}_{}.hsaco", name, source_hash));

    // Clean up temp files on any error
    let _cleanup = TempFileCleanup {
        files: vec![
            source_file.clone(),
            obj_file.clone(),
            fatbin_file.clone(),
            hsaco_file.clone(),
        ],
    };

    fs::write(&source_file, source).map_err(|e| {
        KernelError::Io(format!(
            "Failed to write source file {}: {}",
            source_file.display(),
            e
        ))
    })?;

    // Step 1: Compile HIP to object file
    let output = Command::new("hipcc")
        .args(&[
            &format!("--offload-arch={}", arch),
            "-O3",
            "-fPIC",
            "-c",
            "-o",
            obj_file.to_str().unwrap(),
            source_file.to_str().unwrap(),
        ])
        .output()
        .map_err(|e| {
            KernelError::Compilation(format!("Failed to execute hipcc: {}. Is hipcc in PATH?", e))
        })?;

    if !output.status.success() {
        let stderr = String::from_utf8_lossy(&output.stderr);
        return Err(KernelError::Compilation(format!(
            "hipcc compilation failed for {}:\n{}",
            name, stderr
        )));
    }

    // Step 2: Extract fat binary from object
    let extract_output = Command::new("objcopy")
        .args(&[
            "-O",
            "binary",
            "-j",
            ".hip_fatbin",
            obj_file.to_str().unwrap(),
            fatbin_file.to_str().unwrap(),
        ])
        .output()
        .map_err(|e| {
            KernelError::Compilation(format!(
                "Failed to execute objcopy: {}. Is binutils in PATH?",
                e
            ))
        })?;

    if !extract_output.status.success() {
        let stderr = String::from_utf8_lossy(&extract_output.stderr);
        return Err(KernelError::Compilation(format!(
            "objcopy extraction failed for {}:\n{}",
            name, stderr
        )));
    }

    // Step 3: Unbundle the code object for specific architecture
    let target = format!("hipv4-amdgcn-amd-amdhsa--{}", arch);
    let bundler_path = find_rocm_tool("clang-offload-bundler")?;
    let unbundle_output = Command::new(&bundler_path)
        .args(&[
            "--unbundle",
            "--type=o",
            "--input",
            fatbin_file.to_str().unwrap(),
            "--targets",
            &target,
            "--output",
            hsaco_file.to_str().unwrap(),
        ])
        .output()
        .map_err(|e| {
            KernelError::Compilation(format!(
                "Failed to execute clang-offload-bundler: {}. Is ROCm in PATH?",
                e
            ))
        })?;

    if !unbundle_output.status.success() {
        let stderr = String::from_utf8_lossy(&unbundle_output.stderr);
        return Err(KernelError::Compilation(format!(
            "clang-offload-bundler extraction failed for {}:\n{}",
            name, stderr
        )));
    }

    // Read the final code object
    let binary = fs::read(&hsaco_file).map_err(|e| {
        KernelError::Io(format!(
            "Failed to read code object {}: {}",
            hsaco_file.display(),
            e
        ))
    })?;

    // Write to cache location
    fs::write(output_path, &binary).map_err(|e| {
        KernelError::Io(format!(
            "Failed to write cache file {}: {}",
            output_path.display(),
            e
        ))
    })?;

    Ok(binary)
}

/// Find an ROCm tool using hipcc
fn find_rocm_tool(tool_name: &str) -> Result<String, KernelError> {
    let output = Command::new("hipcc")
        .args(&["--print-prog-name", tool_name])
        .output()
        .map_err(|e| KernelError::Compilation(format!("Failed to run hipcc: {}", e)))?;
    if output.status.success() {
        let path = String::from_utf8_lossy(&output.stdout).trim().to_string();
        if !path.is_empty() && PathBuf::from(&path).exists() {
            return Ok(path);
        }
    }
    Err(KernelError::Compilation(format!(
        "{} not found via hipcc. Is ROCm installed?",
        tool_name
    )))
}

/// Helper struct to clean up temporary files
struct TempFileCleanup {
    files: Vec<PathBuf>,
}

impl Drop for TempFileCleanup {
    fn drop(&mut self) {
        for file in &self.files {
            let _ = fs::remove_file(file);
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_source_hash() {
        let source1 = "__global__ void test() {}";
        let source2 = "__global__ void test() {}";
        let source3 = "__global__ void test2() {}";

        let hash1 = compute_source_hash(source1);
        let hash2 = compute_source_hash(source2);
        let hash3 = compute_source_hash(source3);

        assert_eq!(hash1, hash2);
        assert_ne!(hash1, hash3);
    }
}