Skip to main content

cubecl_spirv/
lib.rs

1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::{
4    fmt::{Debug, Display},
5    sync::Arc,
6};
7
8use cubecl_core::prelude::Visibility;
9use cubecl_opt::Optimizer;
10use rspirv::{binary::Disassemble, dr::Module};
11
12mod arithmetic;
13mod atomic;
14mod bitwise;
15mod branch;
16mod cmma;
17mod compiler;
18mod debug;
19mod extensions;
20mod globals;
21mod instruction;
22mod item;
23mod lookups;
24mod metadata;
25mod subgroup;
26mod sync;
27mod target;
28mod transformers;
29mod variable;
30
31pub use compiler::*;
32use serde::{Deserialize, Serialize};
33pub use target::*;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct SpirvKernel {
37    #[serde(skip)]
38    pub module: Option<Arc<Module>>,
39    #[serde(skip)]
40    pub optimizer: Option<Arc<Optimizer>>,
41
42    pub assembled_module: Vec<u32>,
43    pub bindings: Vec<Visibility>,
44    pub shared_size: usize,
45}
46
47impl Eq for SpirvKernel {}
48impl PartialEq for SpirvKernel {
49    fn eq(&self, other: &Self) -> bool {
50        self.assembled_module == other.assembled_module
51    }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
55pub struct SpirvCacheEntry {
56    pub entrypoint_name: String,
57    pub kernel: SpirvKernel,
58}
59
60impl SpirvCacheEntry {
61    pub fn new(entrypoint_name: String, kernel: SpirvKernel) -> Self {
62        SpirvCacheEntry {
63            entrypoint_name,
64            kernel,
65        }
66    }
67}
68
69impl Display for SpirvKernel {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        if let Some(module) = &self.module {
72            write!(f, "{}", module.disassemble())
73        } else {
74            f.write_str("SPIR-V")
75        }
76    }
77}