Skip to main content

oxihuman_export/
msl_export.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3#![allow(dead_code)]
4
5//! Metal MSL shader export stub.
6
7/// MSL function type.
8#[derive(Clone, Copy, PartialEq)]
9pub enum MslFunctionType {
10    Vertex,
11    Fragment,
12    Kernel,
13}
14
15impl MslFunctionType {
16    pub fn keyword(&self) -> &'static str {
17        match self {
18            MslFunctionType::Vertex => "vertex",
19            MslFunctionType::Fragment => "fragment",
20            MslFunctionType::Kernel => "kernel",
21        }
22    }
23}
24
25/// An MSL shader function.
26pub struct MslFunction {
27    pub function_type: MslFunctionType,
28    pub name: String,
29    pub source: String,
30}
31
32/// An MSL export document.
33pub struct MslExport {
34    pub functions: Vec<MslFunction>,
35    pub includes: Vec<String>,
36}
37
38/// Create a new MSL export.
39pub fn new_msl_export() -> MslExport {
40    MslExport {
41        functions: Vec::new(),
42        includes: vec!["<metal_stdlib>".to_string()],
43    }
44}
45
46/// Add an MSL function.
47pub fn add_msl_function(exp: &mut MslExport, fn_type: MslFunctionType, name: &str, src: &str) {
48    exp.functions.push(MslFunction {
49        function_type: fn_type,
50        name: name.to_string(),
51        source: src.to_string(),
52    });
53}
54
55/// Add an include.
56pub fn add_msl_include(exp: &mut MslExport, include: &str) {
57    exp.includes.push(include.to_string());
58}
59
60/// Function count.
61pub fn msl_function_count(exp: &MslExport) -> usize {
62    exp.functions.len()
63}
64
65/// Find a function by name.
66pub fn find_msl_function<'a>(exp: &'a MslExport, name: &str) -> Option<&'a MslFunction> {
67    exp.functions.iter().find(|f| f.name == name)
68}
69
70/// Render full MSL source.
71pub fn render_msl_source(exp: &MslExport) -> String {
72    let mut s = String::new();
73    for inc in &exp.includes {
74        s.push_str(&format!("#include {inc}\n"));
75    }
76    s.push_str("using namespace metal;\n");
77    for f in &exp.functions {
78        s.push_str(&format!(
79            "{} {} {{ {} }}\n",
80            f.function_type.keyword(),
81            f.name,
82            f.source
83        ));
84    }
85    s
86}
87
88/// Validate (at least one function).
89pub fn validate_msl_export(exp: &MslExport) -> bool {
90    !exp.functions.is_empty()
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    #[test]
98    fn new_export_has_include() {
99        let exp = new_msl_export();
100        assert!(exp.includes.iter().any(|i| i.contains("metal_stdlib")) /* has metal include */);
101    }
102
103    #[test]
104    fn add_function_increments() {
105        let mut exp = new_msl_export();
106        add_msl_function(
107            &mut exp,
108            MslFunctionType::Vertex,
109            "vertex_main",
110            "return pos;",
111        );
112        assert_eq!(msl_function_count(&exp), 1 /* one function */);
113    }
114
115    #[test]
116    fn keyword_vertex_correct() {
117        assert_eq!(
118            MslFunctionType::Vertex.keyword(),
119            "vertex" /* keyword */
120        );
121    }
122
123    #[test]
124    fn keyword_kernel_correct() {
125        assert_eq!(
126            MslFunctionType::Kernel.keyword(),
127            "kernel" /* keyword */
128        );
129    }
130
131    #[test]
132    fn find_function_by_name() {
133        let mut exp = new_msl_export();
134        add_msl_function(
135            &mut exp,
136            MslFunctionType::Fragment,
137            "frag_main",
138            "return color;",
139        );
140        assert!(find_msl_function(&exp, "frag_main").is_some() /* found */);
141    }
142
143    #[test]
144    fn find_missing_none() {
145        let exp = new_msl_export();
146        assert!(find_msl_function(&exp, "x").is_none() /* not found */);
147    }
148
149    #[test]
150    fn render_contains_function_name() {
151        let mut exp = new_msl_export();
152        add_msl_function(&mut exp, MslFunctionType::Vertex, "vs_main", "");
153        let src = render_msl_source(&exp);
154        assert!(src.contains("vs_main") /* function name */);
155    }
156
157    #[test]
158    fn render_contains_namespace() {
159        let exp = new_msl_export();
160        let src = render_msl_source(&exp);
161        assert!(src.contains("namespace metal") /* namespace */);
162    }
163
164    #[test]
165    fn validate_empty_fails() {
166        let exp = new_msl_export();
167        assert!(!validate_msl_export(&exp) /* no functions */);
168    }
169
170    #[test]
171    fn validate_with_function_passes() {
172        let mut exp = new_msl_export();
173        add_msl_function(&mut exp, MslFunctionType::Kernel, "compute", "");
174        assert!(validate_msl_export(&exp) /* valid */);
175    }
176}