use super::fixtures::*;
use crate::commands;
use crate::db::GpuDb;
use rusqlite::params;
#[test]
fn escape_regex_template_names() {
use crate::commands::escape_regex;
let name = "void cutlass::Kernel<cutlass::gemm::kernel::GemmUniversal<float>>";
let escaped = escape_regex(name);
assert!(escaped.contains('<'), "< should be preserved (not a regex metachar)");
let name2 = "void func(float*, int)";
let escaped2 = escape_regex(name2);
assert!(escaped2.contains(r"\("), "( must be escaped: {escaped2}");
assert!(escaped2.contains(r"\)"), ") must be escaped: {escaped2}");
assert!(escaped2.contains(r"\*"), "* must be escaped: {escaped2}");
let name3 = "a|b";
let escaped3 = escape_regex(name3);
assert!(escaped3.contains(r"\|"), "| must be escaped: {escaped3}");
let name4 = r"foo\bar";
let escaped4 = escape_regex(name4);
assert!(escaped4.starts_with(r"foo\\"), "backslash must be escaped: {escaped4}");
let name5 = "libcuda.so";
let escaped5 = escape_regex(name5);
assert!(escaped5.contains(r"\."), ". must be escaped: {escaped5}");
}
#[test]
fn ncu_csv_edge_cases() {
use crate::db::GpuDb;
use crate::parsers::ncu::import_ncu_csv;
let db = GpuDb::create(&tempfile::tempdir().unwrap().into_path().join("csv_edge.db")).unwrap();
let lid = db.add_layer("ncu", "test.csv", None, None, None).unwrap();
let mut tmp = tempfile::NamedTempFile::new().unwrap();
use std::io::Write;
writeln!(tmp, r#""ID","Kernel Name","Metric Name","Metric Unit","Metric Value""#).unwrap();
writeln!(tmp, r#""1","my_kernel","sm__warps_active.avg.pct_of_peak_sustained_active","%","75.0""#).unwrap();
writeln!(tmp, r#""1","my_kernel","gpu__time_duration.sum","nsecond","1,234,567""#).unwrap();
writeln!(tmp, r#""1","my_kernel","sm__throughput.avg.pct_of_peak_sustained_elapsed","%","""#).unwrap();
writeln!(tmp, "==PROF== Disconnected").unwrap();
writeln!(tmp).unwrap();
writeln!(tmp, r#""2","other_kernel","dram__throughput.avg.pct_of_peak_sustained_elapsed","%","65.2""#).unwrap();
writeln!(tmp, r#""2","other_kernel","sm__throughput.avg.pct_of_peak_sustained_elapsed","%","12.0""#).unwrap();
import_ncu_csv(&db.conn, tmp.path(), lid).unwrap();
let occ: f64 = db.conn.query_row(
"SELECT occupancy_pct FROM metrics WHERE kernel_name = 'my_kernel'",
[], |row| row.get(0),
).unwrap();
assert!((occ - 75.0).abs() < 0.1, "occupancy should be 75.0, got {occ}");
let has_launch: bool = db.conn.query_row(
"SELECT COUNT(*) > 0 FROM launches WHERE kernel_name = 'my_kernel'",
[], |row| row.get(0),
).unwrap();
assert!(has_launch, "ncu should insert launch for kernel with duration");
let bound: String = db.conn.query_row(
"SELECT boundedness FROM metrics WHERE kernel_name = 'other_kernel'",
[], |row| row.get(0),
).unwrap();
assert_eq!(bound, "memory", "65.2% mem vs 12.0% compute should be memory-bound");
}
#[test]
fn boundedness_edge_cases() {
use crate::parsers::ncu::classify_boundedness;
assert_eq!(classify_boundedness(Some(9.9), Some(9.9)).as_deref(), Some("latency"));
assert_eq!(classify_boundedness(Some(0.0), Some(0.0)).as_deref(), Some("latency"));
assert_eq!(classify_boundedness(Some(10.0), Some(10.0)).as_deref(), Some("memory"));
assert_eq!(classify_boundedness(None, Some(50.0)), None);
assert_eq!(classify_boundedness(Some(50.0), None), None);
assert_eq!(classify_boundedness(None, None), None);
assert_eq!(classify_boundedness(Some(90.0), Some(10.0)).as_deref(), Some("compute"));
assert_eq!(classify_boundedness(Some(40.0), Some(60.0)).as_deref(), Some("memory"));
}
#[test]
fn triple_filter_interaction() {
let mut db = build_session();
db.focus = Some("kernel".to_string());
db.ignore = Some("nccl".to_string());
db.region_filter = Some("Step#1".to_string());
let filter = db.kernel_filter();
let tl = db.timeline_filter();
assert!(filter.contains("LIKE '%kernel%'"), "focus clause missing");
assert!(filter.contains("NOT LIKE '%nccl%'"), "ignore clause missing");
assert!(filter.contains("regions"), "region clause missing");
let count: i64 = db.conn.query_row(
&format!("SELECT COUNT(*) FROM launches WHERE {filter} AND {tl}"),
[],
|row| row.get(0),
).unwrap();
assert!(count >= 0, "query should not error");
let nccl: i64 = db.conn.query_row(
&format!(
"SELECT COUNT(*) FROM launches
WHERE kernel_name LIKE '%nccl%' AND {filter} AND {tl}"
),
[],
|row| row.get(0),
).unwrap();
assert_eq!(nccl, 0, "nccl should be excluded by ignore filter");
commands::cmd_kernels(&db, &[]);
commands::cmd_small(&db, &[]);
commands::cmd_stats(&db);
commands::cmd_timeline(&db, &[]);
}
#[test]
fn diff_identical_sessions() {
let db = build_cuda_only_session();
let dir = tempfile::tempdir().unwrap();
let dest = dir.path().join("copy.gpu.db");
{
let mut dest_conn = rusqlite::Connection::open(&dest).unwrap();
let backup = rusqlite::backup::Backup::new(&db.conn, &mut dest_conn).unwrap();
backup.run_to_completion(100, std::time::Duration::from_millis(10), None).unwrap();
}
db.attach(dest.to_str().unwrap(), "other").unwrap();
let rows: Vec<(String, f64, f64)> = db.query_vec(
"SELECT
COALESCE(c.kernel_name, o.kernel_name) as name,
COALESCE(o.total, 0) as before,
COALESCE(c.total, 0) as after
FROM
(SELECT kernel_name, SUM(duration_us) as total FROM launches GROUP BY kernel_name) c
FULL OUTER JOIN
(SELECT kernel_name, SUM(duration_us) as total FROM other.launches GROUP BY kernel_name) o
ON c.kernel_name = o.kernel_name",
[],
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
);
for (name, before, after) in &rows {
assert!(
(before - after).abs() < 0.01,
"kernel '{name}': before ({before}) != after ({after}) in identical diff"
);
}
db.detach("other").unwrap();
}
#[test]
fn ncu_launches_excluded_from_timeline() {
let dir = tempfile::tempdir().unwrap();
let path = dir.into_path().join("ncu_tl.gpu.db");
let db = GpuDb::create(&path).unwrap();
let nsys_id = db.add_layer("nsys", "t", None, None, None).unwrap();
db.conn.execute(
"INSERT INTO launches (kernel_name, duration_us, start_us, stream_id, layer_id)
VALUES ('real_kernel', 100.0, 500.0, 7, ?1)",
params![nsys_id],
).unwrap();
let ncu_id = db.add_layer("ncu", "t", None, None, None).unwrap();
db.conn.execute(
"INSERT INTO launches (kernel_name, duration_us, layer_id)
VALUES ('real_kernel', 95.0, ?1)",
params![ncu_id],
).unwrap();
let total = db.total_gpu_time_us();
assert!(
(total - 100.0).abs() < 0.01,
"should see nsys launch only (100us), got {total}"
);
let timeline_count: i64 = db.conn.query_row(
&format!(
"SELECT COUNT(*) FROM launches
WHERE start_us IS NOT NULL AND {}",
db.timeline_filter()
),
[],
|row| row.get(0),
).unwrap();
assert_eq!(timeline_count, 1, "timeline should see only 1 nsys launch");
commands::cmd_gaps(&db, &[]);
commands::cmd_timeline(&db, &[]);
}
#[test]
fn recompute_with_no_ops() {
let db = build_cuda_only_session();
db.recompute_op_gpu_times();
}
#[test]
fn recompute_with_no_launches() {
let dir = tempfile::tempdir().unwrap();
let path = dir.into_path().join("empty_recompute.gpu.db");
let db = GpuDb::create(&path).unwrap();
db.recompute_op_gpu_times();
}
#[test]
fn negative_duration_does_not_corrupt() {
let dir = tempfile::tempdir().unwrap();
let path = dir.into_path().join("neg_dur.gpu.db");
let db = GpuDb::create(&path).unwrap();
let lid = db.add_layer("nsys", "t", None, None, None).unwrap();
db.conn.execute(
"INSERT INTO launches (kernel_name, duration_us, start_us, stream_id, layer_id)
VALUES ('good', 100.0, 0.0, 7, ?1)",
params![lid],
).unwrap();
db.conn.execute(
"INSERT INTO launches (kernel_name, duration_us, start_us, stream_id, layer_id)
VALUES ('bad', -50.0, 200.0, 7, ?1)",
params![lid],
).unwrap();
let total = db.total_gpu_time_us();
assert!(total > 0.0, "total should be positive even with negative duration: {total}");
commands::cmd_stats(&db);
commands::cmd_kernels(&db, &[]);
commands::cmd_gaps(&db, &[]);
commands::cmd_timeline(&db, &[]);
commands::cmd_warmup(&db);
}
#[test]
fn sql_injection_in_kernel_name() {
let dir = tempfile::tempdir().unwrap();
let path = dir.into_path().join("inject.gpu.db");
let mut db = GpuDb::create(&path).unwrap();
let lid = db.add_layer("nsys", "t", None, None, None).unwrap();
let evil_name = "'; DROP TABLE launches; --";
db.conn.execute(
"INSERT INTO launches (kernel_name, duration_us, start_us, layer_id)
VALUES (?1, 100.0, 0.0, ?2)",
params![evil_name, lid],
).unwrap();
commands::cmd_inspect(&db, &["DROP"]);
commands::cmd_variance(&db, &["DROP"]);
commands::cmd_kernels(&db, &[]);
commands::cmd_focus(&mut db, &["'; DROP"]);
commands::cmd_kernels(&db, &[]);
let count = db.total_launch_count();
assert_eq!(count, 1, "table should survive injection attempt");
}
#[test]
fn sql_like_wildcard_escaping() {
use crate::db::escape_sql_like;
assert_eq!(escape_sql_like("100%"), r"100\%");
assert_eq!(escape_sql_like("a_b"), "a_b");
assert_eq!(escape_sql_like("it's"), "it''s");
assert_eq!(escape_sql_like("50% of it's_done"), r"50\% of it''s_done");
}
#[test]
fn suggest_ncu_regex_is_valid() {
let db = build_session();
use crate::commands::escape_regex;
let tl = db.timeline_filter();
let names: Vec<String> = db.query_vec(
&format!("SELECT kernel_name FROM launches WHERE {tl}
GROUP BY kernel_name ORDER BY SUM(duration_us) DESC LIMIT 5"),
[],
|row| row.get(0),
);
let regex_str = names.iter().map(|n| escape_regex(n)).collect::<Vec<_>>().join("|");
let re = regex::Regex::new(®ex_str);
assert!(re.is_ok(), "suggest regex should be valid: {regex_str}\nerror: {:?}", re.err());
let re = re.unwrap();
for name in &names {
assert!(re.is_match(name), "regex should match original name '{name}'");
}
}
#[test]
fn chrome_trace_op_cpu_time_aggregation() {
use crate::db::GpuDb;
use crate::parsers::chrome_trace::import_chrome_trace;
let db = GpuDb::create(&tempfile::tempdir().unwrap().into_path().join("ct.db")).unwrap();
let lid = db.add_layer("torch", "trace.json", None, None, None).unwrap();
let trace = serde_json::json!({
"traceEvents": [
{"ph": "X", "cat": "kernel", "name": "gemm_kernel", "ts": 100.0, "dur": 50.0, "args": {}},
{"ph": "X", "cat": "cpu_op", "name": "aten::mm", "ts": 90.0, "dur": 80.0, "args": {}},
{"ph": "X", "cat": "cpu_op", "name": "aten::mm", "ts": 200.0, "dur": 30.0, "args": {}},
{"ph": "X", "cat": "cpu_op", "name": "aten::add", "ts": 300.0, "dur": 10.0, "args": {}},
]
});
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), serde_json::to_string(&trace).unwrap()).unwrap();
import_chrome_trace(&db.conn, tmp.path(), lid).unwrap();
let mm_cpu: f64 = db.conn.query_row(
"SELECT cpu_time_us FROM ops WHERE name = 'aten::mm'",
[], |row| row.get(0),
).unwrap();
assert!(
(mm_cpu - 110.0).abs() < 0.01,
"aten::mm CPU time should be 110 (80+30), got {mm_cpu}"
);
let add_cpu: f64 = db.conn.query_row(
"SELECT cpu_time_us FROM ops WHERE name = 'aten::add'",
[], |row| row.get(0),
).unwrap();
assert!((add_cpu - 10.0).abs() < 0.01, "aten::add CPU should be 10, got {add_cpu}");
let mapped_op: String = db.conn.query_row(
"SELECT o.name FROM op_kernel_map okm JOIN ops o ON o.id = okm.op_id
WHERE okm.kernel_name = 'gemm_kernel'",
[], |row| row.get(0),
).unwrap();
assert_eq!(mapped_op, "aten::mm", "kernel should map to containing op");
let mm_gpu: f64 = db.conn.query_row(
"SELECT gpu_time_us FROM ops WHERE name = 'aten::mm'",
[], |row| row.get(0),
).unwrap();
assert!((mm_gpu - 50.0).abs() < 0.01, "aten::mm GPU should be 50, got {mm_gpu}");
}
#[test]
fn chrome_trace_innermost_op_wins() {
use crate::db::GpuDb;
use crate::parsers::chrome_trace::import_chrome_trace;
let db = GpuDb::create(&tempfile::tempdir().unwrap().into_path().join("nested.db")).unwrap();
let lid = db.add_layer("torch", "trace.json", None, None, None).unwrap();
let trace = serde_json::json!({
"traceEvents": [
{"ph": "X", "cat": "cpu_op", "name": "aten::linear", "ts": 0.0, "dur": 1000.0, "args": {}},
{"ph": "X", "cat": "cpu_op", "name": "aten::mm", "ts": 100.0, "dur": 100.0, "args": {}},
{"ph": "X", "cat": "kernel", "name": "sgemm", "ts": 150.0, "dur": 30.0, "args": {}},
{"ph": "X", "cat": "kernel", "name": "elementwise", "ts": 500.0, "dur": 10.0, "args": {}},
]
});
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), serde_json::to_string(&trace).unwrap()).unwrap();
import_chrome_trace(&db.conn, tmp.path(), lid).unwrap();
let sgemm_op: String = db.conn.query_row(
"SELECT o.name FROM op_kernel_map okm JOIN ops o ON o.id = okm.op_id
WHERE okm.kernel_name = 'sgemm'",
[], |row| row.get(0),
).unwrap();
assert_eq!(sgemm_op, "aten::mm", "sgemm should map to innermost op aten::mm");
let ew_op: String = db.conn.query_row(
"SELECT o.name FROM op_kernel_map okm JOIN ops o ON o.id = okm.op_id
WHERE okm.kernel_name = 'elementwise'",
[], |row| row.get(0),
).unwrap();
assert_eq!(ew_op, "aten::linear", "elementwise should map to outer op aten::linear");
}