1use crate::ctypes::ShaderStage;
2use crate::error::GlslangError;
3use crate::{Compiler, Shader};
4use glslang_sys as sys;
5use glslang_sys::glslang_spv_options_s;
6use rustc_hash::FxHashMap;
7use std::ffi::{CStr, CString};
8use std::marker::PhantomData;
9use std::ptr::NonNull;
10
11pub struct Program<'a> {
13 handle: NonNull<sys::glslang_program_t>,
14 cache: FxHashMap<ShaderStage, bool>,
15 _compiler: PhantomData<&'a Compiler>,
16}
17
18impl<'a> Program<'a> {
19 pub fn new(_compiler: &'a Compiler) -> Self {
21 let program = Self {
22 handle: unsafe {
23 NonNull::new(sys::glslang_program_create()).expect("glslang created null shader")
24 },
25 cache: FxHashMap::default(),
26 _compiler: PhantomData,
27 };
28
29 program
30 }
31
32 pub fn add_shader<'shader>(&mut self, shader: &'shader Shader<'shader>)
34 where
35 'shader: 'a,
36 {
37 unsafe { sys::glslang_program_add_shader(self.handle.as_ptr(), shader.handle.as_ptr()) }
38 self.cache.insert(shader.stage, shader.is_spirv);
39 }
40
41 pub fn map_io(&mut self) -> Result<(), GlslangError> {
44 if unsafe { sys::glslang_program_map_io(self.handle.as_ptr()) } == 0 {
45 return Err(GlslangError::MapIoError(self.get_log()));
46 }
47
48 Ok(())
49 }
50
51 pub fn link(self) -> Result<(), GlslangError> {
55 let messages = glslang_sys::glslang_messages_t::DEFAULT
56 | glslang_sys::glslang_messages_t::VULKAN_RULES
57 | glslang_sys::glslang_messages_t::SPV_RULES;
58
59 if unsafe { sys::glslang_program_link(self.handle.as_ptr(), messages.0) } == 0 {
60 return Err(GlslangError::LinkError(self.get_log()));
61 }
62 Ok(())
63 }
64
65 pub fn compile(self, stage: ShaderStage) -> Result<Vec<u32>, GlslangError> {
69 if !self.cache.contains_key(&stage) {
71 return Err(GlslangError::ShaderStageNotFound(stage));
72 }
73
74 if let Some(false) = self.cache.get(&stage) {
75 return Err(GlslangError::NoLanguageTarget);
76 }
77
78 let messages = glslang_sys::glslang_messages_t::DEFAULT
79 | glslang_sys::glslang_messages_t::VULKAN_RULES
80 | glslang_sys::glslang_messages_t::SPV_RULES;
81
82 if unsafe { sys::glslang_program_link(self.handle.as_ptr(), messages.0) } == 0 {
83 return Err(GlslangError::LinkError(self.get_log()));
84 }
85
86 unsafe { sys::glslang_program_SPIRV_generate(self.handle.as_ptr(), stage) }
91
92 let size = unsafe { sys::glslang_program_SPIRV_get_size(self.handle.as_ptr()) };
93 let mut buffer = vec![0u32; size];
94
95 unsafe {
96 sys::glslang_program_SPIRV_get(self.handle.as_ptr(), buffer.as_mut_ptr());
97 }
98
99 Ok(buffer)
100 }
101
102 pub fn compile_size_optimized(self, stage: ShaderStage) -> Result<Vec<u32>, GlslangError> {
106 if !self.cache.contains_key(&stage) {
108 return Err(GlslangError::ShaderStageNotFound(stage));
109 }
110
111 if let Some(false) = self.cache.get(&stage) {
112 return Err(GlslangError::NoLanguageTarget);
113 }
114
115 let messages = glslang_sys::glslang_messages_t::DEFAULT
116 | glslang_sys::glslang_messages_t::VULKAN_RULES
117 | glslang_sys::glslang_messages_t::SPV_RULES;
118
119 if unsafe { sys::glslang_program_link(self.handle.as_ptr(), messages.0) } == 0 {
120 return Err(GlslangError::LinkError(self.get_log()));
121 }
122
123 let mut options = glslang_spv_options_s {
124 generate_debug_info: false,
125 strip_debug_info: false,
126 disable_optimizer: false,
127 optimize_size: true,
128 disassemble: false,
129 validate: false,
130 emit_nonsemantic_shader_debug_info: false,
131 emit_nonsemantic_shader_debug_source: false,
132 compile_only: false,
133 optimize_allow_expanded_id_bound: false,
134 };
135
136 unsafe {
141 sys::glslang_program_SPIRV_generate_with_options(
142 self.handle.as_ptr(),
143 stage,
144 &mut options,
145 )
146 }
147
148 let size = unsafe { sys::glslang_program_SPIRV_get_size(self.handle.as_ptr()) };
149 let mut buffer = vec![0u32; size];
150
151 unsafe {
152 sys::glslang_program_SPIRV_get(self.handle.as_ptr(), buffer.as_mut_ptr());
153 }
154
155 Ok(buffer)
156 }
157
158 fn get_log(&self) -> String {
159 let c_str =
160 unsafe { CStr::from_ptr(sys::glslang_program_get_info_log(self.handle.as_ptr())) };
161
162 let string = CString::from(c_str)
163 .into_string()
164 .expect("Expected glslang info log to be valid UTF-8");
165
166 string
167 }
168}
169
170impl<'a> Drop for Program<'a> {
171 fn drop(&mut self) {
172 unsafe { sys::glslang_program_delete(self.handle.as_ptr()) }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::ctypes::ShaderStage;
180 use crate::include::{IncludeHandler, IncludeResult};
181 use crate::shader::{CompilerOptions, OpenGlVersion, ShaderInput, ShaderSource, Target};
182 use crate::{GlslProfile, ShaderMessage, SourceLanguage};
183 use rspirv::binary::Disassemble;
184
185 #[test]
186 pub fn test_link() {
187 let compiler = Compiler::acquire().unwrap();
188
189 let source = ShaderSource::try_from(String::from(
190 r#"
191#version 450
192
193layout(location = 0) out vec4 color;
194layout(binding = 1) uniform sampler2D tex;
195
196void main() {
197 color = texture(tex, vec2(0.0));
198}
199 "#,
200 ))
201 .expect("source");
202
203 let input = ShaderInput::new(
204 &source,
205 ShaderStage::Fragment,
206 &CompilerOptions::default(),
207 None,
208 None,
209 )
210 .expect("target");
211 let _shader = Shader::new(&compiler, input).expect("shader init");
212
213 let program = Program::new(&compiler);
214 program.link().expect("shader");
217 }
218
219 #[test]
220 pub fn test_compile() {
221 let compiler = Compiler::acquire().unwrap();
222
223 let source = ShaderSource::try_from(String::from(
224 r#"
225#version 450
226
227layout(location = 0) out vec4 color;
228layout(binding = 1) uniform sampler2D tex;
229
230void main() {
231 color = texture(tex, vec2(0.0));
232}
233 "#,
234 ))
235 .expect("source");
236
237 let input = ShaderInput::new(
238 &source,
239 ShaderStage::Fragment,
240 &CompilerOptions::default(),
241 None,
242 None,
243 )
244 .expect("target");
245 let shader = Shader::new(&compiler, input).expect("shader init");
246 let code = shader.compile().expect("compile");
247 let mut loader = rspirv::dr::Loader::new();
248 rspirv::binary::parse_words(&code, &mut loader).unwrap();
249 let module = loader.module();
250
251 println!("{}", module.disassemble())
252 }
253
254 #[test]
255 pub fn test_compile_thread() {
256 let mut handles = Vec::new();
257 for _ in 0..8 {
258 handles.push(std::thread::spawn(|| test_compile()));
259 }
260
261 for handle in handles {
262 handle.join().unwrap()
263 }
264 }
265
266 #[test]
267 pub fn test_verify_old_gl() {
268 let compiler = Compiler::acquire().unwrap();
269
270 let source = ShaderSource::from(String::from(
271 r#"#version 120
272
273varying vec2 texcoord;
274
275void main() {
276 gl_Position = ftransform();
277 texcoord = gl_MultiTexCoord0.st;
278}
279 "#,
280 ));
281
282 let input = ShaderInput::new(
283 &source,
284 ShaderStage::Vertex,
285 &CompilerOptions {
286 source_language: SourceLanguage::GLSL,
287 target: Target::OpenGL {
288 version: OpenGlVersion::OpenGL4_5,
289 spirv_version: None,
290 },
291 messages: ShaderMessage::DEBUG_INFO | ShaderMessage::DEFAULT,
292 version_profile: Some((120, GlslProfile::None)),
293 },
294 None,
295 None,
296 )
297 .expect("target");
298 let _shader = Shader::new(&compiler, input).expect("shader init");
299 }
300
301 #[test]
302 pub fn test_no_language_target_does_not_segfault() {
303 let compiler = Compiler::acquire().unwrap();
304
305 let source = ShaderSource::try_from(String::from(
306 r#"
307#version 460
308
309
310layout(location = 0) out vec4 color;
311layout(binding = 1) uniform sampler2D tex;
312
313void main() {
314 color = texture(tex, vec2(0.0));
315}
316 "#,
317 ))
318 .expect("source");
319
320 let input = ShaderInput::new(
321 &source,
322 ShaderStage::Vertex,
323 &CompilerOptions {
324 source_language: SourceLanguage::GLSL,
325 target: Target::None(None),
326 messages: ShaderMessage::DEBUG_INFO | ShaderMessage::DEFAULT,
327 version_profile: None,
328 },
329 None,
330 None,
331 )
332 .expect("target");
333 let shader = Shader::new(&compiler, input).expect("shader init");
334 assert!(matches!(
335 shader.compile(),
336 Err(GlslangError::NoLanguageTarget)
337 ));
338 }
339
340 #[test]
341 pub fn test_compile_program() {
342 let compiler = Compiler::acquire().unwrap();
343
344 let fragment = ShaderSource::from(
345 r#"
346#version 450
347
348layout(location = 0) out vec4 color;
349layout(binding = 1) uniform sampler2D tex;
350
351void main() {
352 color = texture(tex, vec2(0.0));
353}
354 "#,
355 );
356
357 let vertex = ShaderSource::from(
358 r#"
359#version 450
360layout(set = 0, binding = 0, std140) uniform UBO
361{
362 mat4 MVP;
363};
364
365layout(location = 0) in vec4 Position;
366layout(location = 1) in vec2 TexCoord;
367layout(location = 0) out vec2 vTexCoord;
368void main()
369{
370 gl_Position = MVP * Position;
371 vTexCoord = TexCoord;
372}
373"#,
374 );
375
376 let fragment = ShaderInput::new(
377 &fragment,
378 ShaderStage::Fragment,
379 &CompilerOptions::default(),
380 None,
381 None,
382 )
383 .expect("target");
384 let fragment = Shader::new(&compiler, fragment).expect("shader init");
385
386 let vertex = ShaderInput::new(
387 &vertex,
388 ShaderStage::Vertex,
389 &CompilerOptions::default(),
390 None,
391 None,
392 )
393 .expect("target");
394 let vertex = Shader::new(&compiler, vertex).expect("shader init");
395
396 let mut program = Program::new(&compiler);
397
398 program.add_shader(&fragment);
399 program.add_shader(&vertex);
400
401 let _code = program.compile(ShaderStage::Fragment).expect("shader");
402
403 let mut program = compiler.create_program();
404 program.add_shader(&vertex);
405 let code2 = program.compile(ShaderStage::Vertex).expect("shader");
406
407 let mut loader = rspirv::dr::Loader::new();
408 rspirv::binary::parse_words(&code2, &mut loader).unwrap();
409 let module = loader.module();
410
411 println!("{}", module.disassemble());
412 }
413
414 #[test]
415 pub fn test_add_macros() {
416 let compiler = Compiler::acquire().unwrap();
417
418 let source = ShaderSource::try_from(String::from(
419 r#"
420#version 460
421
422layout(location = 0) out vec4 color;
423
424void main() {
425 color = vec4(CUSTOM_MACRO);
426}
427 "#,
428 ))
429 .expect("source");
430
431 let input = ShaderInput::new(
432 &source,
433 ShaderStage::Vertex,
434 &CompilerOptions {
435 source_language: SourceLanguage::GLSL,
436 target: Target::None(None),
437 messages: ShaderMessage::DEBUG_INFO | ShaderMessage::DEFAULT,
438 version_profile: None,
439 },
440 Some(&[("CUSTOM_MACRO", Some("1.0"))]),
441 None,
442 )
443 .expect("target");
444 let _shader = Shader::new(&compiler, input).expect("shader init");
445 }
446
447 #[test]
448 pub fn test_include_handler() {
449 let compiler = Compiler::acquire().unwrap();
450
451 struct MyIncludeHandler {
452 header_included: Vec<String>,
453 }
454 impl IncludeHandler for MyIncludeHandler {
455 fn include(
456 &mut self,
457 _ty: crate::include::IncludeType,
458 header_name: &str,
459 _includer_name: &str,
460 _include_depth: usize,
461 ) -> Option<IncludeResult> {
462 self.header_included.push(header_name.into());
463 Some(IncludeResult {
464 name: "included_macro".into(),
465 data: "#define INCLUDED_MACRO 0.0".into(),
466 })
467 }
468 }
469
470 let source = ShaderSource::try_from(String::from(
471 r#"
472#version 460
473#extension GL_GOOGLE_include_directive : require
474#include "custom_include.glsl"
475
476layout(location = 0) out vec4 color;
477
478void main() {
479 color = vec4(INCLUDED_MACRO);
480}
481 "#,
482 ))
483 .expect("source");
484 let mut include_handler = MyIncludeHandler {
485 header_included: vec![],
486 };
487 let input = ShaderInput::new(
488 &source,
489 ShaderStage::Vertex,
490 &CompilerOptions {
491 source_language: SourceLanguage::GLSL,
492 target: Target::OpenGL {
493 version: OpenGlVersion::OpenGL4_5,
494 spirv_version: None,
495 },
496 messages: ShaderMessage::DEBUG_INFO | ShaderMessage::DEFAULT,
497 version_profile: None,
498 },
499 Some(&[("CUSTOM_MACRO", Some("1.0"))]),
500 Some(&mut include_handler),
501 )
502 .expect("target");
503 let _shader = Shader::new(&compiler, input).expect("shader init");
504 assert!(include_handler.header_included.len() == 1);
505 assert_eq!(
506 include_handler.header_included[0], "custom_include.glsl",
507 ""
508 );
509 }
510}