oxihuman_export/
msl_export.rs1#![allow(dead_code)]
4
5#[derive(Clone, Copy, PartialEq)]
9pub enum MslFunctionType {
10 Vertex,
11 Fragment,
12 Kernel,
13}
14
15impl MslFunctionType {
16 pub fn keyword(&self) -> &'static str {
17 match self {
18 MslFunctionType::Vertex => "vertex",
19 MslFunctionType::Fragment => "fragment",
20 MslFunctionType::Kernel => "kernel",
21 }
22 }
23}
24
25pub struct MslFunction {
27 pub function_type: MslFunctionType,
28 pub name: String,
29 pub source: String,
30}
31
32pub struct MslExport {
34 pub functions: Vec<MslFunction>,
35 pub includes: Vec<String>,
36}
37
38pub fn new_msl_export() -> MslExport {
40 MslExport {
41 functions: Vec::new(),
42 includes: vec!["<metal_stdlib>".to_string()],
43 }
44}
45
46pub fn add_msl_function(exp: &mut MslExport, fn_type: MslFunctionType, name: &str, src: &str) {
48 exp.functions.push(MslFunction {
49 function_type: fn_type,
50 name: name.to_string(),
51 source: src.to_string(),
52 });
53}
54
55pub fn add_msl_include(exp: &mut MslExport, include: &str) {
57 exp.includes.push(include.to_string());
58}
59
60pub fn msl_function_count(exp: &MslExport) -> usize {
62 exp.functions.len()
63}
64
65pub fn find_msl_function<'a>(exp: &'a MslExport, name: &str) -> Option<&'a MslFunction> {
67 exp.functions.iter().find(|f| f.name == name)
68}
69
70pub fn render_msl_source(exp: &MslExport) -> String {
72 let mut s = String::new();
73 for inc in &exp.includes {
74 s.push_str(&format!("#include {inc}\n"));
75 }
76 s.push_str("using namespace metal;\n");
77 for f in &exp.functions {
78 s.push_str(&format!(
79 "{} {} {{ {} }}\n",
80 f.function_type.keyword(),
81 f.name,
82 f.source
83 ));
84 }
85 s
86}
87
88pub fn validate_msl_export(exp: &MslExport) -> bool {
90 !exp.functions.is_empty()
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[test]
98 fn new_export_has_include() {
99 let exp = new_msl_export();
100 assert!(exp.includes.iter().any(|i| i.contains("metal_stdlib")) );
101 }
102
103 #[test]
104 fn add_function_increments() {
105 let mut exp = new_msl_export();
106 add_msl_function(
107 &mut exp,
108 MslFunctionType::Vertex,
109 "vertex_main",
110 "return pos;",
111 );
112 assert_eq!(msl_function_count(&exp), 1 );
113 }
114
115 #[test]
116 fn keyword_vertex_correct() {
117 assert_eq!(
118 MslFunctionType::Vertex.keyword(),
119 "vertex" );
121 }
122
123 #[test]
124 fn keyword_kernel_correct() {
125 assert_eq!(
126 MslFunctionType::Kernel.keyword(),
127 "kernel" );
129 }
130
131 #[test]
132 fn find_function_by_name() {
133 let mut exp = new_msl_export();
134 add_msl_function(
135 &mut exp,
136 MslFunctionType::Fragment,
137 "frag_main",
138 "return color;",
139 );
140 assert!(find_msl_function(&exp, "frag_main").is_some() );
141 }
142
143 #[test]
144 fn find_missing_none() {
145 let exp = new_msl_export();
146 assert!(find_msl_function(&exp, "x").is_none() );
147 }
148
149 #[test]
150 fn render_contains_function_name() {
151 let mut exp = new_msl_export();
152 add_msl_function(&mut exp, MslFunctionType::Vertex, "vs_main", "");
153 let src = render_msl_source(&exp);
154 assert!(src.contains("vs_main") );
155 }
156
157 #[test]
158 fn render_contains_namespace() {
159 let exp = new_msl_export();
160 let src = render_msl_source(&exp);
161 assert!(src.contains("namespace metal") );
162 }
163
164 #[test]
165 fn validate_empty_fails() {
166 let exp = new_msl_export();
167 assert!(!validate_msl_export(&exp) );
168 }
169
170 #[test]
171 fn validate_with_function_passes() {
172 let mut exp = new_msl_export();
173 add_msl_function(&mut exp, MslFunctionType::Kernel, "compute", "");
174 assert!(validate_msl_export(&exp) );
175 }
176}