rai_core/primitives/
others.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
use crate::{Primitive, Tensor};
use std::any::Any;
use tracing::Level;

#[derive(Clone, Debug, PartialEq)]
pub struct FlashAttention {
    pub softmax_scale: f32,
    pub window_size_left: Option<usize>,
    pub window_size_right: Option<usize>,
    pub alibi_slopes: Option<Tensor>,
}

impl FlashAttention {
    pub fn new(
        softmax_scale: f32,
        window_size_left: Option<usize>,
        window_size_right: Option<usize>,
        alibi_slopes: Option<Tensor>,
    ) -> Self {
        Self {
            softmax_scale,
            window_size_left,
            window_size_right,
            alibi_slopes,
        }
    }

    pub fn softmax_scale(&self) -> f32 {
        self.softmax_scale
    }

    pub fn window_size_left(&self) -> Option<usize> {
        self.window_size_left
    }

    pub fn window_size_right(&self) -> Option<usize> {
        self.window_size_right
    }

    pub fn alibi_slopes(&self) -> Option<&Tensor> {
        self.alibi_slopes.as_ref()
    }
}

impl Primitive for FlashAttention {
    fn clone_boxed(&self) -> Box<dyn Primitive> {
        Box::new(self.clone())
    }

    fn as_any(&self) -> &dyn Any {
        self
    }

    fn dot_label(&self) -> String {
        format!(
            "FlashAttention({}, {:?}, {:?}, {:?})",
            self.softmax_scale,
            self.window_size_left,
            self.window_size_right,
            self.alibi_slopes
                .as_ref()
                .map(|t| format!("tensor({})", t.id()))
        )
    }

    #[tracing::instrument(ret(level = Level::TRACE))]
    fn jvp(&self, _output: &Tensor, _primals: &[Tensor], tangents: &[Tensor]) -> Tensor {
        todo!("jvp for FlashAttention")
    }

    #[tracing::instrument(ret(level = Level::TRACE))]
    fn vjp(&self, _output: &Tensor, _primals: &[Tensor], cotangent: &Tensor) -> Vec<Tensor> {
        todo!("vjp for FlashAttention")
    }
}