1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_auto_cfg))]
3
4#[macro_use]
7extern crate derive_new;
8extern crate alloc;
9
10mod ops;
11
12pub mod kernel;
14pub mod tensor;
16
17pub mod element;
19
20use burn_tensor::backend::{DeviceId, DeviceOps};
21use cubecl::{compute::CubeTask, Feature, Runtime};
22pub use element::{BoolElement, FloatElement, IntElement, JitElement};
23
24mod backend;
25
26pub use backend::*;
27
28pub use cubecl;
30
31mod tune_key;
32pub use tune_key::JitAutotuneKey;
33
34#[cfg(any(feature = "fusion", test))]
35pub mod fusion;
37
38#[cfg(feature = "template")]
39pub mod template;
41
42#[cfg(feature = "export_tests")]
43pub mod tests;
44
45pub trait JitRuntime: Runtime<Device = Self::JitDevice, Server = Self::JitServer> {
47 type JitDevice: burn_tensor::backend::DeviceOps;
49 type JitServer: cubecl::server::ComputeServer<
51 Kernel = Box<dyn CubeTask<Self::Compiler>>,
52 Feature = Feature,
53 >;
54}
55
56#[derive(Hash, PartialEq, Eq, Debug, Clone)]
58pub struct JitTuneId {
59 device: DeviceId,
60 name: &'static str,
61}
62
63impl JitTuneId {
64 pub fn new<R: JitRuntime>(device: &R::Device) -> Self {
66 Self {
67 device: DeviceOps::id(device),
68 name: R::name(),
69 }
70 }
71}
72
73impl core::fmt::Display for JitTuneId {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 f.write_fmt(format_args!(
76 "device-{}-{}-{}",
77 self.device.type_id, self.device.index_id, self.name
78 ))
79 }
80}