cmake_init/
src_main_cpp.rs

1use std::fs::{self, File};
2use std::io::Write;
3use std::path::Path;
4
5pub fn src_main_cpp(mode: Option<&str>) {
6    let src_dir = Path::new("./src");
7    let (main_path, content);
8
9    match mode {
10        Some("CUDA") => {
11            main_path = src_dir.join("main.cu");
12            content = include_str!("../files/cuda/main.cu");
13            if !src_dir.exists() {
14                fs::create_dir_all(src_dir).expect("Failed to create src directory");
15            }
16
17            if main_path.exists() {
18                return;
19            }
20        }
21        Some("HIP") => {
22            main_path = src_dir.join("main.hip");
23            content = include_str!("../files/hip/main.hip");
24            if !src_dir.exists() {
25                fs::create_dir_all(src_dir).expect("Failed to create src directory");
26            }
27
28            if main_path.exists() {
29                return;
30            }
31        }
32        Some("MPI") => {
33            main_path = src_dir.join("main.cpp");
34            content = include_str!("../files/mpi/main.cpp");
35            if !src_dir.exists() {
36                fs::create_dir_all(src_dir).expect("Failed to create src directory");
37            }
38
39            if main_path.exists() {
40                return;
41            }
42        }
43        _ => {
44            main_path = src_dir.join("main.cpp");
45            content = include_str!("../files/main.cpp");
46            if !src_dir.exists() {
47                fs::create_dir_all(src_dir).expect("Failed to create src directory");
48            }
49
50            if main_path.exists() {
51                return;
52            }
53        }
54    }
55
56    let mut file = File::create(&main_path).expect("Failed to create main.cpp/.cu/.hip file");
57    file.write_all(content.as_bytes())
58        .expect("Failed to write to main.cpp");
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    #[test]
66    fn test_src_main_cpp() {
67        src_main_cpp(Some("C++"));
68        let content = std::fs::read_to_string("./src/main.cpp").unwrap();
69        assert!(content.contains("int main(int argc, char* argv[]) {"));
70    }
71
72    #[test]
73    fn test_src_main_cuda() {
74        src_main_cpp(Some("CUDA"));
75        let content = std::fs::read_to_string("./src/main.cu").unwrap();
76        assert!(content.contains("__global__"));
77    }
78
79    #[test]
80    fn test_src_main_hip() {
81        src_main_cpp(Some("HIP"));
82        let content = std::fs::read_to_string("./src/main.hip").unwrap();
83        assert!(content.contains("__global__"));
84    }
85}