1#![allow(unknown_lints, unnecessary_transmutes)]
2
3use std::{
4 fmt::{Debug, Display},
5 sync::Arc,
6};
7
8use cubecl_core::prelude::Visibility;
9use cubecl_opt::Optimizer;
10use rspirv::{binary::Disassemble, dr::Module};
11
12mod arithmetic;
13mod atomic;
14mod bitwise;
15mod branch;
16mod cmma;
17mod compiler;
18mod debug;
19mod extensions;
20mod globals;
21mod instruction;
22mod item;
23mod lookups;
24mod metadata;
25mod subgroup;
26mod sync;
27mod target;
28mod transformers;
29mod variable;
30
31pub use compiler::*;
32use serde::{Deserialize, Serialize};
33pub use target::*;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct SpirvKernel {
37 #[serde(skip)]
38 pub module: Option<Arc<Module>>,
39 #[serde(skip)]
40 pub optimizer: Option<Arc<Optimizer>>,
41
42 pub assembled_module: Vec<u32>,
43 pub bindings: Vec<Visibility>,
44 pub shared_size: usize,
45}
46
47impl Eq for SpirvKernel {}
48impl PartialEq for SpirvKernel {
49 fn eq(&self, other: &Self) -> bool {
50 self.assembled_module == other.assembled_module
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
55pub struct SpirvCacheEntry {
56 pub entrypoint_name: String,
57 pub kernel: SpirvKernel,
58}
59
60impl SpirvCacheEntry {
61 pub fn new(entrypoint_name: String, kernel: SpirvKernel) -> Self {
62 SpirvCacheEntry {
63 entrypoint_name,
64 kernel,
65 }
66 }
67}
68
69impl Display for SpirvKernel {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 if let Some(module) = &self.module {
72 write!(f, "{}", module.disassemble())
73 } else {
74 f.write_str("SPIR-V")
75 }
76 }
77}