pllm/
util.rs

1use std::{
2    iter::Sum,
3    simd::{f32x32, f32x8, num::SimdFloat, StdFloat},
4    time::Instant,
5};
6
7use rand::Rng;
8use rayon::prelude::*;
9
10use crate::errors::PllmError;
11
12pub trait FloatVec {
13    fn rope_rotate_neox(&mut self, pos: u32, header_size: u32, kv_dim: u32);
14    fn get_chunk(&self, chunk_size: u32, chunk_index: u32) -> &[f32];
15
16    fn get_mut_chunk(&mut self, chunk_size: u32, chunk_index: u32) -> &mut [f32];
17
18    fn arg_max(&self) -> u32;
19
20    fn sample(&self) -> u32;
21
22    fn rms_norm(&mut self, src: &[f32], weights: &[f32], eps: f32);
23
24    // fn mat_mul(&mut self, src: &[f32], weights: &[f32]);
25
26    fn soft_max(&mut self);
27
28    fn accum(&mut self, other: &[f32]);
29
30    fn scale(&mut self, rhs: f32);
31
32    fn rope_rotate(
33        &mut self,
34        other: &mut [f32],
35        pos: u32,
36        header_size: u32,
37        kv_dim: u32,
38    ) -> Result<(), PllmError>;
39}
40
41impl FloatVec for [f32] {
42    fn get_chunk(&self, chunk_size: u32, chunk_index: u32) -> &[f32] {
43        let range = (chunk_size * chunk_index) as usize..(chunk_size * (chunk_index + 1)) as usize;
44        &self[range]
45    }
46
47    fn get_mut_chunk(&mut self, chunk_size: u32, chunk_index: u32) -> &mut [f32] {
48        let range = (chunk_size * chunk_index) as usize..(chunk_size * (chunk_index + 1)) as usize;
49        &mut self[range]
50    }
51
52    fn arg_max(&self) -> u32 {
53        self.iter()
54            .enumerate()
55            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
56            .map(|(index, _)| index as u32)
57            .unwrap_or(0)
58    }
59
60    fn sample(&self) -> u32 {
61        let mut rng = rand::thread_rng();
62        let r = rng.gen_range(0.0..1.0);
63        let mut cdf = 0.0;
64        for (i, &p) in self.iter().enumerate() {
65            cdf = cdf + p;
66            if r < cdf {
67                return i as u32;
68            }
69        }
70
71        self.len() as u32 - 1
72    }
73    fn rms_norm(&mut self, src: &[f32], weights: &[f32], eps: f32) {
74        let mut ss: f32 = src.iter().map(|&i| i * i).sum();
75
76        ss /= src.len() as f32;
77        ss += eps;
78        ss = 1.0 / ss.sqrt();
79
80        src.iter()
81            .enumerate()
82            .for_each(|(i, &x)| self[i] = weights[i] * (ss * x));
83    }
84
85    // fn mat_mul(&mut self, src: &[f32], weights: &[f32]) {
86    //     let n = src.len();
87    //     let d = weights.len() / n;
88    //     assert_eq!(self.len(), d);
89
90    //     // let before = Instant::now();
91    //     self.par_iter_mut().enumerate().for_each(|(i, value)| {
92    //         *value = f32_dot_product(src, &weights[i * n..(i + 1) * n]);
93    //     });
94    //     // println!("Mat mul time: n={}, d={}, {:.2?}", n, d, before.elapsed());
95    // }
96
97    fn soft_max(&mut self) {
98        let len = self.len();
99        if len == 1 {
100            self[0] = 1.0;
101            return;
102        }
103
104        // find max value (for numerical stability)
105        let mut max = 0.0;
106        for n in self.iter() {
107            max = n.max(max);
108        }
109
110        // e^x
111        self.iter_mut().for_each(|v| *v = (*v - max).exp());
112
113        // normalize
114        let sum: f32 = self.iter().map(|&n| n).sum();
115        self.iter_mut().for_each(|v| *v = *v / sum);
116    }
117
118    fn accum(&mut self, other: &[f32]) {
119        self.iter_mut()
120            .zip(other)
121            .for_each(|(v1, v2)| *v1 = *v1 + *v2)
122    }
123
124    fn scale(&mut self, rhs: f32) {
125        self.iter_mut().for_each(|v| *v = *v * rhs)
126    }
127
128    fn rope_rotate(
129        &mut self,
130        other: &mut [f32],
131        pos: u32,
132        header_size: u32,
133        kv_dim: u32,
134    ) -> Result<(), PllmError> {
135        for i in (0..self.len() as usize).step_by(2) {
136            let head_dim = i as u32 % header_size;
137            let freq = 1.0 / 10000.0_f32.powf(head_dim as f32 / header_size as f32);
138            let val = pos as f32 * freq;
139            let fcr = val.cos();
140            let fci = val.sin();
141            let rotn = if (i as u32) < kv_dim { 2 } else { 1 };
142            for v in 0..rotn {
143                let dst = if v == 0 {
144                    self.as_mut()
145                } else {
146                    other.as_mut()
147                };
148                let v0 = dst[i];
149                let v1 = dst[i + 1];
150                dst[i] = v0 * fcr - v1 * fci;
151                dst[i + 1] = v0 * fci + v1 * fcr;
152            }
153        }
154
155        Ok(())
156    }
157
158    // ref. https://github.com/crabml/crabml/blob/1e622975d56c7d15dc1c627ee3cb884de0dce953/crabml-core/src/backends/cpu/primitives/rope.rs#L34
159    fn rope_rotate_neox(&mut self, pos: u32, header_size: u32, kv_dim: u32) {
160        let head_dim = header_size as usize;
161        let rope_dim = kv_dim as usize;
162        self.chunks_exact_mut(head_dim).for_each(|chunk| {
163            for i in 0..rope_dim / 2 {
164                let freq_exponents = 2.0 * i as f32 / head_dim as f32;
165                let timescale = 10000_f32.powf(freq_exponents);
166                let theta = pos as f32 / timescale;
167                let cos_theta = theta.cos();
168                let sin_theta = theta.sin();
169
170                let qp0 = chunk[i];
171                let qp1 = chunk[i + head_dim / 2];
172                chunk[i] = qp0 * cos_theta - qp1 * sin_theta;
173                chunk[i + head_dim / 2] = qp0 * sin_theta + qp1 * cos_theta;
174            }
175        });
176    }
177}
178
179// #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
180// fn f32_dot_product(a: &[f32], b: &[f32]) -> f32 {
181//     let mut sum = 0.0;
182//     for i in (0..a.len()).step_by(8) {
183//         let s1 = f32x8::from_slice(&a[i..i + 8]);
184//         let s2 = f32x8::from_slice(&b[i..i + 8]);
185//         sum += (s1 * s2).reduce_sum();
186//     }
187//     sum
188// }
189
190// #[cfg(not(any(all(target_arch = "x86_64", target_feature = "avx2"))))]
191// fn f32_dot_product(a: &[f32], b: &[f32]) -> f32 {
192//     assert_eq!(a.len(), b.len());
193//     let mut product = 0.0;
194//     for i in 0..a.len() {
195//         product += a[i] * b[i];
196//     }
197//     product
198// }