1use std::{
2 iter::Sum,
3 simd::{f32x32, f32x8, num::SimdFloat, StdFloat},
4 time::Instant,
5};
6
7use rand::Rng;
8use rayon::prelude::*;
9
10use crate::errors::PllmError;
11
12pub trait FloatVec {
13 fn rope_rotate_neox(&mut self, pos: u32, header_size: u32, kv_dim: u32);
14 fn get_chunk(&self, chunk_size: u32, chunk_index: u32) -> &[f32];
15
16 fn get_mut_chunk(&mut self, chunk_size: u32, chunk_index: u32) -> &mut [f32];
17
18 fn arg_max(&self) -> u32;
19
20 fn sample(&self) -> u32;
21
22 fn rms_norm(&mut self, src: &[f32], weights: &[f32], eps: f32);
23
24 fn soft_max(&mut self);
27
28 fn accum(&mut self, other: &[f32]);
29
30 fn scale(&mut self, rhs: f32);
31
32 fn rope_rotate(
33 &mut self,
34 other: &mut [f32],
35 pos: u32,
36 header_size: u32,
37 kv_dim: u32,
38 ) -> Result<(), PllmError>;
39}
40
41impl FloatVec for [f32] {
42 fn get_chunk(&self, chunk_size: u32, chunk_index: u32) -> &[f32] {
43 let range = (chunk_size * chunk_index) as usize..(chunk_size * (chunk_index + 1)) as usize;
44 &self[range]
45 }
46
47 fn get_mut_chunk(&mut self, chunk_size: u32, chunk_index: u32) -> &mut [f32] {
48 let range = (chunk_size * chunk_index) as usize..(chunk_size * (chunk_index + 1)) as usize;
49 &mut self[range]
50 }
51
52 fn arg_max(&self) -> u32 {
53 self.iter()
54 .enumerate()
55 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
56 .map(|(index, _)| index as u32)
57 .unwrap_or(0)
58 }
59
60 fn sample(&self) -> u32 {
61 let mut rng = rand::thread_rng();
62 let r = rng.gen_range(0.0..1.0);
63 let mut cdf = 0.0;
64 for (i, &p) in self.iter().enumerate() {
65 cdf = cdf + p;
66 if r < cdf {
67 return i as u32;
68 }
69 }
70
71 self.len() as u32 - 1
72 }
73 fn rms_norm(&mut self, src: &[f32], weights: &[f32], eps: f32) {
74 let mut ss: f32 = src.iter().map(|&i| i * i).sum();
75
76 ss /= src.len() as f32;
77 ss += eps;
78 ss = 1.0 / ss.sqrt();
79
80 src.iter()
81 .enumerate()
82 .for_each(|(i, &x)| self[i] = weights[i] * (ss * x));
83 }
84
85 fn soft_max(&mut self) {
98 let len = self.len();
99 if len == 1 {
100 self[0] = 1.0;
101 return;
102 }
103
104 let mut max = 0.0;
106 for n in self.iter() {
107 max = n.max(max);
108 }
109
110 self.iter_mut().for_each(|v| *v = (*v - max).exp());
112
113 let sum: f32 = self.iter().map(|&n| n).sum();
115 self.iter_mut().for_each(|v| *v = *v / sum);
116 }
117
118 fn accum(&mut self, other: &[f32]) {
119 self.iter_mut()
120 .zip(other)
121 .for_each(|(v1, v2)| *v1 = *v1 + *v2)
122 }
123
124 fn scale(&mut self, rhs: f32) {
125 self.iter_mut().for_each(|v| *v = *v * rhs)
126 }
127
128 fn rope_rotate(
129 &mut self,
130 other: &mut [f32],
131 pos: u32,
132 header_size: u32,
133 kv_dim: u32,
134 ) -> Result<(), PllmError> {
135 for i in (0..self.len() as usize).step_by(2) {
136 let head_dim = i as u32 % header_size;
137 let freq = 1.0 / 10000.0_f32.powf(head_dim as f32 / header_size as f32);
138 let val = pos as f32 * freq;
139 let fcr = val.cos();
140 let fci = val.sin();
141 let rotn = if (i as u32) < kv_dim { 2 } else { 1 };
142 for v in 0..rotn {
143 let dst = if v == 0 {
144 self.as_mut()
145 } else {
146 other.as_mut()
147 };
148 let v0 = dst[i];
149 let v1 = dst[i + 1];
150 dst[i] = v0 * fcr - v1 * fci;
151 dst[i + 1] = v0 * fci + v1 * fcr;
152 }
153 }
154
155 Ok(())
156 }
157
158 fn rope_rotate_neox(&mut self, pos: u32, header_size: u32, kv_dim: u32) {
160 let head_dim = header_size as usize;
161 let rope_dim = kv_dim as usize;
162 self.chunks_exact_mut(head_dim).for_each(|chunk| {
163 for i in 0..rope_dim / 2 {
164 let freq_exponents = 2.0 * i as f32 / head_dim as f32;
165 let timescale = 10000_f32.powf(freq_exponents);
166 let theta = pos as f32 / timescale;
167 let cos_theta = theta.cos();
168 let sin_theta = theta.sin();
169
170 let qp0 = chunk[i];
171 let qp1 = chunk[i + head_dim / 2];
172 chunk[i] = qp0 * cos_theta - qp1 * sin_theta;
173 chunk[i + head_dim / 2] = qp0 * sin_theta + qp1 * cos_theta;
174 }
175 });
176 }
177}
178
179