Skip to main content

cubecl_spirv/
lib.rs

1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::{
4    fmt::{Debug, Display},
5    sync::Arc,
6};
7
8use cubecl_core::prelude::Visibility;
9use cubecl_opt::Optimizer;
10use rspirv::{binary::Disassemble, dr::Module};
11
12mod arithmetic;
13mod atomic;
14mod bitwise;
15mod branch;
16mod cmma;
17mod compiler;
18mod debug;
19mod extensions;
20mod globals;
21mod instruction;
22mod item;
23mod lookups;
24mod metadata;
25mod subgroup;
26mod sync;
27mod target;
28mod transformers;
29mod variable;
30
31pub use compiler::*;
32use serde::{Deserialize, Serialize};
33pub use target::*;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct SpirvKernel {
37    #[serde(skip)]
38    pub module: Option<Arc<Module>>,
39    #[serde(skip)]
40    pub optimizer: Option<Arc<Optimizer>>,
41
42    pub assembled_module: Vec<u32>,
43    pub bindings: Vec<Visibility>,
44    pub shared_size: usize,
45    pub uniform_info: bool,
46}
47
48impl Eq for SpirvKernel {}
49impl PartialEq for SpirvKernel {
50    fn eq(&self, other: &Self) -> bool {
51        self.assembled_module == other.assembled_module
52    }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
56pub struct SpirvCacheEntry {
57    pub entrypoint_name: String,
58    pub kernel: SpirvKernel,
59}
60
61impl SpirvCacheEntry {
62    pub fn new(entrypoint_name: String, kernel: SpirvKernel) -> Self {
63        SpirvCacheEntry {
64            entrypoint_name,
65            kernel,
66        }
67    }
68}
69
70impl Display for SpirvKernel {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        if let Some(module) = &self.module {
73            write!(f, "{}", module.disassemble())
74        } else {
75            f.write_str("SPIR-V")
76        }
77    }
78}