cmake_init/
src_main_cpp.rs1use std::fs::{self, File};
2use std::io::Write;
3use std::path::Path;
4
5pub fn src_main_cpp(mode: Option<&str>) {
6 let src_dir = Path::new("./src");
7 let (main_path, content);
8
9 match mode {
10 Some("CUDA") => {
11 main_path = src_dir.join("main.cu");
12 content = include_str!("../files/cuda/main.cu");
13 if !src_dir.exists() {
14 fs::create_dir_all(src_dir).expect("Failed to create src directory");
15 }
16
17 if main_path.exists() {
18 return;
19 }
20 }
21 Some("HIP") => {
22 main_path = src_dir.join("main.hip");
23 content = include_str!("../files/hip/main.hip");
24 if !src_dir.exists() {
25 fs::create_dir_all(src_dir).expect("Failed to create src directory");
26 }
27
28 if main_path.exists() {
29 return;
30 }
31 }
32 Some("MPI") => {
33 main_path = src_dir.join("main.cpp");
34 content = include_str!("../files/mpi/main.cpp");
35 if !src_dir.exists() {
36 fs::create_dir_all(src_dir).expect("Failed to create src directory");
37 }
38
39 if main_path.exists() {
40 return;
41 }
42 }
43 _ => {
44 main_path = src_dir.join("main.cpp");
45 content = include_str!("../files/main.cpp");
46 if !src_dir.exists() {
47 fs::create_dir_all(src_dir).expect("Failed to create src directory");
48 }
49
50 if main_path.exists() {
51 return;
52 }
53 }
54 }
55
56 let mut file = File::create(&main_path).expect("Failed to create main.cpp/.cu/.hip file");
57 file.write_all(content.as_bytes())
58 .expect("Failed to write to main.cpp");
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64
65 #[test]
66 fn test_src_main_cpp() {
67 src_main_cpp(Some("C++"));
68 let content = std::fs::read_to_string("./src/main.cpp").unwrap();
69 assert!(content.contains("int main(int argc, char* argv[]) {"));
70 }
71
72 #[test]
73 fn test_src_main_cuda() {
74 src_main_cpp(Some("CUDA"));
75 let content = std::fs::read_to_string("./src/main.cu").unwrap();
76 assert!(content.contains("__global__"));
77 }
78
79 #[test]
80 fn test_src_main_hip() {
81 src_main_cpp(Some("HIP"));
82 let content = std::fs::read_to_string("./src/main.hip").unwrap();
83 assert!(content.contains("__global__"));
84 }
85}