1use std::mem::{align_of, size_of};
20
21use crate::format::ModelType;
22
23#[derive(Debug, Clone)]
52pub struct TruenoNativeModel {
53 pub model_type: ModelType,
55
56 pub n_params: u32,
58
59 pub n_features: u32,
61
62 pub n_outputs: u32,
64
65 pub params: Option<AlignedVec<f32>>,
67
68 pub bias: Option<AlignedVec<f32>>,
70
71 pub extra: Option<ModelExtra>,
73}
74
75impl TruenoNativeModel {
76 #[must_use]
78 pub const fn new(
79 model_type: ModelType,
80 n_params: u32,
81 n_features: u32,
82 n_outputs: u32,
83 ) -> Self {
84 Self {
85 model_type,
86 n_params,
87 n_features,
88 n_outputs,
89 params: None,
90 bias: None,
91 extra: None,
92 }
93 }
94
95 #[must_use]
97 pub fn with_params(mut self, params: AlignedVec<f32>) -> Self {
98 self.params = Some(params);
99 self
100 }
101
102 #[must_use]
104 pub fn with_bias(mut self, bias: AlignedVec<f32>) -> Self {
105 self.bias = Some(bias);
106 self
107 }
108
109 #[must_use]
111 pub fn with_extra(mut self, extra: ModelExtra) -> Self {
112 self.extra = Some(extra);
113 self
114 }
115
116 #[must_use]
118 pub fn is_aligned(&self) -> bool {
119 let params_aligned = self.params.as_ref().map_or(true, AlignedVec::is_aligned);
120 let bias_aligned = self.bias.as_ref().map_or(true, AlignedVec::is_aligned);
121 params_aligned && bias_aligned
122 }
123
124 #[must_use]
126 pub fn size_bytes(&self) -> usize {
127 let params_size = self.params.as_ref().map_or(0, AlignedVec::size_bytes);
128 let bias_size = self.bias.as_ref().map_or(0, AlignedVec::size_bytes);
129 let extra_size = self.extra.as_ref().map_or(0, ModelExtra::size_bytes);
130 params_size + bias_size + extra_size
131 }
132
133 pub fn validate(&self) -> Result<(), NativeModelError> {
135 if let Some(ref params) = self.params {
137 if params.len() != self.n_params as usize {
138 return Err(NativeModelError::ParamCountMismatch {
139 declared: self.n_params as usize,
140 actual: params.len(),
141 });
142 }
143 }
144
145 if let Some(ref params) = self.params {
147 for (i, &val) in params.as_slice().iter().enumerate() {
148 if !val.is_finite() {
149 return Err(NativeModelError::InvalidParameter {
150 index: i,
151 value: val,
152 });
153 }
154 }
155 }
156
157 if let Some(ref bias) = self.bias {
159 for (i, &val) in bias.as_slice().iter().enumerate() {
160 if !val.is_finite() {
161 return Err(NativeModelError::InvalidBias {
162 index: i,
163 value: val,
164 });
165 }
166 }
167 }
168
169 Ok(())
170 }
171
172 #[must_use]
177 pub fn params_ptr(&self) -> Option<*const f32> {
178 self.params.as_ref().map(AlignedVec::as_ptr)
179 }
180
181 #[must_use]
186 pub fn bias_ptr(&self) -> Option<*const f32> {
187 self.bias.as_ref().map(AlignedVec::as_ptr)
188 }
189
190 pub fn predict_linear(&self, features: &[f32]) -> Result<f32, NativeModelError> {
195 if features.len() != self.n_features as usize {
196 return Err(NativeModelError::FeatureMismatch {
197 expected: self.n_features as usize,
198 got: features.len(),
199 });
200 }
201
202 let params = self
203 .params
204 .as_ref()
205 .ok_or(NativeModelError::MissingParams)?;
206
207 let dot: f32 = params
208 .as_slice()
209 .iter()
210 .zip(features.iter())
211 .map(|(p, x)| p * x)
212 .sum();
213
214 let bias = self
215 .bias
216 .as_ref()
217 .and_then(|b| b.as_slice().first().copied())
218 .unwrap_or(0.0);
219
220 Ok(dot + bias)
221 }
222}
223
224impl Default for TruenoNativeModel {
225 fn default() -> Self {
226 Self::new(ModelType::LinearRegression, 0, 0, 1)
227 }
228}
229
230#[derive(Debug, Clone)]
252pub struct AlignedVec<T: Copy + Default> {
253 data: Vec<T>,
255 len: usize,
257 capacity: usize,
259}
260
261impl<T: Copy + Default> AlignedVec<T> {
262 #[must_use]
264 pub fn with_capacity(capacity: usize) -> Self {
265 let size_of_t = size_of::<T>();
266 let aligned_cap = if size_of_t > 0 {
267 (capacity * size_of_t + 63) / 64 * 64 / size_of_t
268 } else {
269 capacity
270 };
271 let aligned_cap = aligned_cap.max(capacity);
272 let data = vec![T::default(); aligned_cap];
273 Self {
274 data,
275 len: 0,
276 capacity: aligned_cap,
277 }
278 }
279
280 #[must_use]
282 pub fn from_slice(slice: &[T]) -> Self {
283 let mut vec = Self::with_capacity(slice.len());
284 vec.data[..slice.len()].copy_from_slice(slice);
285 vec.len = slice.len();
286 vec
287 }
288
289 #[must_use]
291 pub fn zeros(len: usize) -> Self {
292 let mut vec = Self::with_capacity(len);
293 vec.len = len;
294 vec
295 }
296
297 #[must_use]
299 pub const fn len(&self) -> usize {
300 self.len
301 }
302
303 #[must_use]
305 pub const fn is_empty(&self) -> bool {
306 self.len == 0
307 }
308
309 #[must_use]
311 pub const fn capacity(&self) -> usize {
312 self.capacity
313 }
314
315 #[must_use]
317 pub fn as_ptr(&self) -> *const T {
318 self.data.as_ptr()
319 }
320
321 #[must_use]
323 pub fn as_mut_ptr(&mut self) -> *mut T {
324 self.data.as_mut_ptr()
325 }
326
327 #[must_use]
329 pub fn as_slice(&self) -> &[T] {
330 &self.data[..self.len]
331 }
332
333 pub fn as_mut_slice(&mut self) -> &mut [T] {
335 &mut self.data[..self.len]
336 }
337
338 #[must_use]
344 pub fn is_aligned(&self) -> bool {
345 if self.data.is_empty() || size_of::<T>() == 0 {
349 return true;
350 }
351 self.data.as_ptr() as usize % align_of::<T>() == 0
353 }
354
355 #[must_use]
357 pub fn size_bytes(&self) -> usize {
358 self.len * size_of::<T>()
359 }
360
361 pub fn push(&mut self, value: T) {
363 if self.len >= self.data.len() {
364 let new_cap = (self.capacity * 2).max(16);
366 let mut new_data = vec![T::default(); new_cap];
367 new_data[..self.len].copy_from_slice(&self.data[..self.len]);
368 self.data = new_data;
369 self.capacity = new_cap;
370 }
371 self.data[self.len] = value;
372 self.len += 1;
373 }
374
375 pub fn clear(&mut self) {
377 self.len = 0;
378 }
379
380 #[must_use]
382 pub fn get(&self, index: usize) -> Option<&T> {
383 if index < self.len {
384 Some(&self.data[index])
385 } else {
386 None
387 }
388 }
389
390 pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
392 if index < self.len {
393 Some(&mut self.data[index])
394 } else {
395 None
396 }
397 }
398
399 pub fn set(&mut self, index: usize, value: T) -> bool {
401 if index < self.len {
402 self.data[index] = value;
403 true
404 } else {
405 false
406 }
407 }
408}
409
410impl<T: Copy + Default> Default for AlignedVec<T> {
411 fn default() -> Self {
412 Self::with_capacity(0)
413 }
414}
415
416impl<T: Copy + Default> std::ops::Index<usize> for AlignedVec<T> {
417 type Output = T;
418
419 fn index(&self, index: usize) -> &Self::Output {
420 &self.data[index]
421 }
422}
423
424impl<T: Copy + Default> std::ops::IndexMut<usize> for AlignedVec<T> {
425 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
426 &mut self.data[index]
427 }
428}
429
430impl<T: Copy + Default> FromIterator<T> for AlignedVec<T> {
431 fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
432 let vec: Vec<T> = iter.into_iter().collect();
433 Self::from_slice(&vec)
434 }
435}
436
437impl<T: Copy + Default + PartialEq> PartialEq for AlignedVec<T> {
438 fn eq(&self, other: &Self) -> bool {
439 self.as_slice() == other.as_slice()
440 }
441}
442
443mod model_extra;
444pub use model_extra::*;