burn_jit/
lib.rs

1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_auto_cfg))]
3
4//! Burn JIT Backend
5
6#[macro_use]
7extern crate derive_new;
8extern crate alloc;
9
10mod ops;
11
12/// Kernel module
13pub mod kernel;
14/// Tensor module.
15pub mod tensor;
16
17/// Elements for JIT backend
18pub mod element;
19
20use burn_tensor::backend::{DeviceId, DeviceOps};
21use cubecl::{compute::CubeTask, Feature, Runtime};
22pub use element::{BoolElement, FloatElement, IntElement, JitElement};
23
24mod backend;
25
26pub use backend::*;
27
28// Re-export cubecl.
29pub use cubecl;
30
31mod tune_key;
32pub use tune_key::JitAutotuneKey;
33
34#[cfg(any(feature = "fusion", test))]
35/// Module for interacting with fusion
36pub mod fusion;
37
38#[cfg(feature = "template")]
39/// Module for compiling custom non-jit kernels
40pub mod template;
41
42#[cfg(feature = "export_tests")]
43pub mod tests;
44
45/// Just-in-Time runtime extending the [cube runtime](Runtime).
46pub trait JitRuntime: Runtime<Device = Self::JitDevice, Server = Self::JitServer> {
47    /// The device that should also implement [burn_tensor::backend::DeviceOps].
48    type JitDevice: burn_tensor::backend::DeviceOps;
49    /// The cube server with the [JitAutotuneKey].
50    type JitServer: cubecl::server::ComputeServer<
51        Kernel = Box<dyn CubeTask<Self::Compiler>>,
52        Feature = Feature,
53    >;
54}
55
56/// ID used to identify a Just-in-Time environment.
57#[derive(Hash, PartialEq, Eq, Debug, Clone)]
58pub struct JitTuneId {
59    device: DeviceId,
60    name: &'static str,
61}
62
63impl JitTuneId {
64    /// Create a new ID.
65    pub fn new<R: JitRuntime>(device: &R::Device) -> Self {
66        Self {
67            device: DeviceOps::id(device),
68            name: R::name(),
69        }
70    }
71}
72
73impl core::fmt::Display for JitTuneId {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.write_fmt(format_args!(
76            "device-{}-{}-{}",
77            self.device.type_id, self.device.index_id, self.name
78        ))
79    }
80}