use cutile;
use cutile_compiler::compiler::utils::CompileOptions;
use cutile_compiler::compiler::{CUDATileFunctionCompiler, CUDATileModules};
use cutile_compiler::cuda_tile_runtime_utils::get_gpu_name;
use cutile_compiler::error::JITError;
mod common;
fn unsupported_dsl_call() -> i32 {
0
}
#[cutile::module]
mod span_error_module {
use cutile::core::*;
#[cutile::entry()]
fn untyped_literal_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>) {
let _x = super::unsupported_dsl_call();
let tile = load_tile_mut(output);
output.store(tile);
}
}
#[test]
fn untyped_literal_error_has_correct_source_location() {
common::with_test_stack(|| {
let modules = CUDATileModules::from_kernel(span_error_module::__module_ast_self())
.expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"span_error_module",
"untyped_literal_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Compiler construction should succeed");
let compile_result = compiler.compile();
let err = match compile_result {
Err(e) => e,
Ok(_) => {
panic!(
"Expected compilation to fail for an unsupported DSL call, but it succeeded."
)
}
};
let err_string = format!("{err}");
println!("\n=== UNSUPPORTED DSL CALL ERROR ===\n{err_string}");
match &err {
JITError::Located(msg, loc) => {
assert!(
loc.is_known(),
"Expected a known source location on the error, got {loc:?}"
);
assert!(
loc.file.ends_with("span_source_location.rs"),
"Expected file to end with 'span_source_location.rs', got '{}'",
loc.file
);
let source = include_str!("span_source_location.rs");
let expected_line = source
.lines()
.enumerate()
.find(|(_, line)| {
let trimmed = line.trim_start();
trimmed.starts_with("let _x = super::unsupported_dsl_call();")
})
.map(|(idx, _)| idx + 1) .expect("Could not find unsupported helper call marker in test source");
assert_eq!(
loc.line, expected_line,
"Expected error on line {expected_line} ('let _x = super::unsupported_dsl_call();'), got line {}",
loc.line
);
assert!(
loc.column > 0,
"Expected a non-zero column for the unsupported call, got {}",
loc.column
);
assert!(
msg.contains("unsupported_dsl_call") || msg.contains("not supported"),
"Expected error message to mention the unsupported call, got: {msg}"
);
println!(
"\n✓ Unsupported DSL call error correctly located at {}:{}:{}\n message: {msg}",
loc.file, loc.line, loc.column
);
}
other => {
panic!("Expected JITError::Located, got a different variant: {other:?}");
}
}
});
}
#[cutile::module]
mod span_comments_module {
use cutile::core::*;
#[cutile::entry()]
fn commented_kernel<const S: [i32; 1]>(
output: &mut Tensor<f32, S>,
) {
let _y = super::unsupported_dsl_call();
let tile = load_tile_mut(output);
output.store(tile);
}
}
#[test]
fn comments_do_not_break_span_tracking() {
common::with_test_stack(|| {
let modules = CUDATileModules::from_kernel(span_comments_module::__module_ast_self())
.expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"span_comments_module",
"commented_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Compiler construction should succeed");
let compile_result = compiler.compile();
let err = match compile_result {
Err(e) => e,
Ok(_) => panic!(
"Expected compilation to fail for an unsupported DSL call inside \
commented module, but it succeeded."
),
};
let err_string = format!("{err}");
println!("\n=== COMMENTS MODULE ERROR ===\n{err_string}");
match &err {
JITError::Located(msg, loc) => {
assert!(
loc.is_known(),
"Expected a known source location even with comments, got {loc:?}"
);
assert!(
loc.file.ends_with("span_source_location.rs"),
"Expected file to end with 'span_source_location.rs', got '{}'",
loc.file
);
let source = include_str!("span_source_location.rs");
let expected_line = source
.lines()
.enumerate()
.find(|(_, line)| {
let trimmed = line.trim_start();
trimmed.starts_with("let _y = super::unsupported_dsl_call();")
})
.map(|(idx, _)| idx + 1)
.expect("Could not find unsupported helper call marker in test source");
assert_eq!(
loc.line, expected_line,
"Expected error on line {expected_line} ('let _y = super::unsupported_dsl_call();'), got line {}.\n\
Comments should NOT affect line numbers when source_text() is used.",
loc.line
);
assert!(
loc.column > 0,
"Expected a non-zero column, got {}",
loc.column
);
assert!(
msg.contains("unsupported_dsl_call") || msg.contains("not supported"),
"Expected error message to reference the unsupported call, got: {msg}"
);
println!(
"\n✓ Span tracking exact despite comments at {}:{}:{}\n message: {msg}",
loc.file, loc.line, loc.column
);
}
other => {
panic!("Expected JITError::Located even with comments, got: {other:?}");
}
}
});
}