1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::{
4 fmt::{Debug, Display},
5 sync::Arc,
6};
7
8use cubecl_core::prelude::Visibility;
9use cubecl_opt::Optimizer;
10use rspirv::{binary::Disassemble, dr::Module};
11
12mod arithmetic;
13mod atomic;
14mod bitwise;
15mod branch;
16mod cmma;
17mod compiler;
18mod debug;
19mod extensions;
20mod globals;
21mod instruction;
22mod item;
23mod lookups;
24mod metadata;
25mod subgroup;
26mod sync;
27mod target;
28mod transformers;
29mod variable;
30
31pub use compiler::*;
32use serde::{Deserialize, Serialize};
33pub use target::*;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct SpirvKernel {
37 #[serde(skip)]
38 pub module: Option<Arc<Module>>,
39 #[serde(skip)]
40 pub optimizer: Option<Arc<Optimizer>>,
41
42 pub assembled_module: Vec<u32>,
43 pub bindings: Vec<Visibility>,
44 pub shared_size: usize,
45 pub uniform_info: bool,
46}
47
48impl Eq for SpirvKernel {}
49impl PartialEq for SpirvKernel {
50 fn eq(&self, other: &Self) -> bool {
51 self.assembled_module == other.assembled_module
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
56pub struct SpirvCacheEntry {
57 pub entrypoint_name: String,
58 pub kernel: SpirvKernel,
59}
60
61impl SpirvCacheEntry {
62 pub fn new(entrypoint_name: String, kernel: SpirvKernel) -> Self {
63 SpirvCacheEntry {
64 entrypoint_name,
65 kernel,
66 }
67 }
68}
69
70impl Display for SpirvKernel {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 if let Some(module) = &self.module {
73 write!(f, "{}", module.disassemble())
74 } else {
75 f.write_str("SPIR-V")
76 }
77 }
78}