1use crate::error::RuntimeError;
11use wave_compiler::{CompilerConfig, Language, OptLevel};
12
13pub fn compile_kernel(source: &str, language: Language) -> Result<Vec<u8>, RuntimeError> {
19 let config = CompilerConfig {
20 language,
21 opt_level: OptLevel::O2,
22 ..CompilerConfig::default()
23 };
24 let wbin = wave_compiler::compile_source(source, &config)?;
25 Ok(wbin)
26}
27
28pub fn compile_kernel_with_config(
34 source: &str,
35 config: &CompilerConfig,
36) -> Result<Vec<u8>, RuntimeError> {
37 let wbin = wave_compiler::compile_source(source, config)?;
38 Ok(wbin)
39}
40
41#[cfg(test)]
42mod tests {
43 use super::*;
44
45 #[test]
46 fn test_compile_python_kernel() {
47 let source = r#"
48@kernel
49def vector_add(a: Buffer[f32], b: Buffer[f32], out: Buffer[f32], n: u32):
50 gid = thread_id()
51 if gid < n:
52 out[gid] = a[gid] + b[gid]
53"#;
54 let result = compile_kernel(source, Language::Python);
55 assert!(result.is_ok());
56 let wbin = result.unwrap();
57 assert_eq!(&wbin[0..4], b"WAVE");
58 }
59}