Skip to main content

akuna_infer/
model.rs

1use std::{collections::HashMap, fmt, fs, path::Path};
2
3use burn::tensor::{
4    Tensor, TensorData, activation::softmax, backend::Backend, module::conv1d,
5    ops::ConvOptions,
6};
7use onnx_ir::{ModelProto, TensorProto};
8use protobuf::Message;
9
10use crate::{
11    config::{ModelConfig, runtime_config},
12    detection::{Detection, RankedAlternative},
13    file::FileType,
14    preprocess::{PreparedInput, prepare_input},
15    vendor::{content::ContentType, model as vendor_model},
16};
17
18pub(crate) use crate::vendor::model::{CONFIG, Label, NUM_LABELS};
19const NUM_CLASSES: usize = 257;
20const SEQ_LEN: usize = 2048;
21const EMBED_DIM: usize = 64;
22const TOKENS_PER_BLOCK: usize = 512;
23const CHANNELS_PER_TOKEN: usize = 256;
24const CONV_OUT_CHANNELS: usize = 512;
25const CONV_KERNEL: usize = 5;
26const DENSE_OUT: usize = vendor_model::NUM_LABELS;
27const EMBEDDED_MODEL: &[u8] =
28    include_bytes!("vendor/assets/models/standard_v3_3/model.onnx");
29
30struct TensorSpec {
31    name: &'static str,
32    shape: &'static [usize; 4],
33    rank: usize,
34}
35
36const EMBEDDING_WEIGHT: TensorSpec = TensorSpec {
37    name: "jax2tf_get_logits_/Const:0",
38    shape: &[NUM_CLASSES, EMBED_DIM, 0, 0],
39    rank: 2,
40};
41
42const EMBEDDING_BIAS: TensorSpec = TensorSpec {
43    name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/Dense_0/Reshape:0",
44    shape: &[1, 1, EMBED_DIM, 0],
45    rank: 3,
46};
47
48const LAYER_NORM_0_WEIGHT: TensorSpec = TensorSpec {
49    name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/LayerNorm_0/Reshape_2:0",
50    shape: &[1, TOKENS_PER_BLOCK, 1, 0],
51    rank: 3,
52};
53
54const LAYER_NORM_0_BIAS: TensorSpec = TensorSpec {
55    name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/LayerNorm_0/Reshape_3:0",
56    shape: &[1, TOKENS_PER_BLOCK, 1, 0],
57    rank: 3,
58};
59
60const CONV_WEIGHT: TensorSpec = TensorSpec {
61    name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/Conv_0/transpose_3:0",
62    shape: &[CONV_OUT_CHANNELS, CHANNELS_PER_TOKEN, CONV_KERNEL, 1],
63    rank: 4,
64};
65
66const CONV_BIAS: TensorSpec = TensorSpec {
67    name: "const_fold_opt__209",
68    shape: &[1, CONV_OUT_CHANNELS, 1, 0],
69    rank: 3,
70};
71
72const LAYER_NORM_1_WEIGHT: TensorSpec = TensorSpec {
73    name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/LayerNorm_1/Reshape_2:0",
74    shape: &[1, CONV_OUT_CHANNELS, 0, 0],
75    rank: 2,
76};
77
78const LAYER_NORM_1_BIAS: TensorSpec = TensorSpec {
79    name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/LayerNorm_1/Reshape_3:0",
80    shape: &[1, CONV_OUT_CHANNELS, 0, 0],
81    rank: 2,
82};
83
84const DENSE_WEIGHT: TensorSpec = TensorSpec {
85    name: "jax2tf_get_logits_/Const_24:0",
86    shape: &[CONV_OUT_CHANNELS, DENSE_OUT, 0, 0],
87    rank: 2,
88};
89
90const DENSE_BIAS: TensorSpec = TensorSpec {
91    name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/Dense_1/Reshape:0",
92    shape: &[1, DENSE_OUT, 0, 0],
93    rank: 2,
94};
95
96#[derive(Debug)]
97pub enum MagikaInferenceError {
98    Io(std::io::Error),
99    InvalidConfig(String),
100    Runtime(String),
101}
102
103impl fmt::Display for MagikaInferenceError {
104    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105        match self {
106            Self::Io(e) => write!(f, "io error: {e}"),
107            Self::InvalidConfig(e) => write!(f, "invalid configuration: {e}"),
108            Self::Runtime(e) => write!(f, "inference runtime error: {e}"),
109        }
110    }
111}
112
113impl std::error::Error for MagikaInferenceError {}
114
115impl From<std::io::Error> for MagikaInferenceError {
116    fn from(value: std::io::Error) -> Self {
117        Self::Io(value)
118    }
119}
120
121pub struct MagikaModel<B: Backend> {
122    device: B::Device,
123    config: ModelConfig,
124    top_k: usize,
125    embedding_weight: Vec<f32>,
126    embedding_bias: Vec<f32>,
127    layer_norm_0_weight: Tensor<B, 3>,
128    layer_norm_0_bias: Tensor<B, 3>,
129    conv_weight: Tensor<B, 3>,
130    conv_bias: Tensor<B, 1>,
131    layer_norm_1_weight: Tensor<B, 2>,
132    layer_norm_1_bias: Tensor<B, 2>,
133    dense_weight: Tensor<B, 2>,
134    dense_bias: Tensor<B, 2>,
135}
136
137impl<B: Backend<FloatElem = f32>> MagikaModel<B> {
138    pub fn from_embedded(
139        device: &B::Device,
140    ) -> Result<Self, MagikaInferenceError> {
141        Self::from_bytes(device, EMBEDDED_MODEL)
142    }
143
144    pub fn from_file(
145        device: &B::Device,
146        path: impl AsRef<Path>,
147    ) -> Result<Self, MagikaInferenceError> {
148        let model_bytes = fs::read(path)?;
149        Self::from_bytes(device, &model_bytes)
150    }
151
152    pub fn from_bytes(
153        device: &B::Device,
154        model_bytes: &[u8],
155    ) -> Result<Self, MagikaInferenceError> {
156        let model =
157            ModelProto::parse_from_bytes(model_bytes).map_err(|err| {
158                MagikaInferenceError::Runtime(format!("parse model: {err}"))
159            })?;
160        let graph = model.graph.as_ref().ok_or_else(|| {
161            MagikaInferenceError::Runtime("model graph missing".to_string())
162        })?;
163
164        let initializers = graph
165            .initializer
166            .iter()
167            .map(|tensor| (tensor.name.as_str(), tensor))
168            .collect::<HashMap<_, _>>();
169
170        Ok(Self {
171            device: (*device).clone(),
172            config: runtime_config(),
173            top_k: 3,
174            embedding_weight: read_tensor_spec(
175                &initializers,
176                &EMBEDDING_WEIGHT,
177            )?,
178            embedding_bias: read_tensor_spec(&initializers, &EMBEDDING_BIAS)?,
179            layer_norm_0_weight: tensor_3d(
180                device,
181                &initializers,
182                &LAYER_NORM_0_WEIGHT,
183                [1, TOKENS_PER_BLOCK, 1],
184            )?,
185            layer_norm_0_bias: tensor_3d(
186                device,
187                &initializers,
188                &LAYER_NORM_0_BIAS,
189                [1, TOKENS_PER_BLOCK, 1],
190            )?,
191            conv_weight: tensor_3d_from_flat(
192                device,
193                read_conv_weight(&initializers)?,
194                [CONV_OUT_CHANNELS, CHANNELS_PER_TOKEN, CONV_KERNEL],
195            )?,
196            conv_bias: tensor_1d_from_flat(
197                device,
198                read_tensor_spec(&initializers, &CONV_BIAS)?,
199            )?,
200            layer_norm_1_weight: tensor_2d_from_flat(
201                device,
202                read_tensor_spec(&initializers, &LAYER_NORM_1_WEIGHT)?,
203                [1, CONV_OUT_CHANNELS],
204            )?,
205            layer_norm_1_bias: tensor_2d_from_flat(
206                device,
207                read_tensor_spec(&initializers, &LAYER_NORM_1_BIAS)?,
208                [1, CONV_OUT_CHANNELS],
209            )?,
210            dense_weight: tensor_2d_from_flat(
211                device,
212                read_tensor_spec(&initializers, &DENSE_WEIGHT)?,
213                [CONV_OUT_CHANNELS, DENSE_OUT],
214            )?,
215            dense_bias: tensor_2d_from_flat(
216                device,
217                read_tensor_spec(&initializers, &DENSE_BIAS)?,
218                [1, DENSE_OUT],
219            )?,
220        })
221    }
222
223    pub fn with_top_k(mut self, top_k: usize) -> Self {
224        self.top_k = top_k.max(1);
225        self
226    }
227
228    pub fn detect_path(
229        &self,
230        path: impl AsRef<Path>,
231    ) -> Result<Detection, MagikaInferenceError> {
232        let bytes = fs::read(path)?;
233        self.detect_bytes(&bytes)
234    }
235
236    pub fn identify_path(
237        &self,
238        path: impl AsRef<Path>,
239    ) -> Result<FileType, MagikaInferenceError> {
240        let path = path.as_ref();
241        let metadata = fs::symlink_metadata(path)?;
242        if metadata.is_dir() {
243            return Ok(FileType::Directory);
244        }
245        if metadata.file_type().is_symlink() {
246            return Ok(FileType::Symlink);
247        }
248
249        let bytes = fs::read(path)?;
250        self.identify_bytes(&bytes)
251    }
252
253    pub fn detect_bytes(
254        &self,
255        bytes: &[u8],
256    ) -> Result<Detection, MagikaInferenceError> {
257        let mut all = self.detect_batch(vec![bytes])?;
258        Ok(all.remove(0))
259    }
260
261    pub fn identify_bytes(
262        &self,
263        bytes: &[u8],
264    ) -> Result<FileType, MagikaInferenceError> {
265        let mut all = self.detect_content_type_batch(vec![bytes])?;
266        let content_type = all.remove(0);
267        Ok(FileType::Ruled(content_type))
268    }
269
270    pub fn detect_batch(
271        &self,
272        inputs: Vec<&[u8]>,
273    ) -> Result<Vec<Detection>, MagikaInferenceError> {
274        if inputs.is_empty() {
275            return Ok(Vec::new());
276        }
277
278        let mut detections = vec![None; inputs.len()];
279        let mut pending_positions = Vec::new();
280        let mut pending_features = Vec::new();
281
282        for (index, bytes) in inputs.into_iter().enumerate() {
283            match prepare_input(bytes, &self.config) {
284                PreparedInput::Ruled(content_type) => {
285                    detections[index] =
286                        Some(detection_for_content_type(content_type));
287                }
288                PreparedInput::Features(features) => {
289                    pending_positions.push(index);
290                    pending_features.push(features);
291                }
292            }
293        }
294
295        if !pending_features.is_empty() {
296            let rows = self.infer_batch(&pending_features)?;
297            if rows.len() != pending_positions.len() {
298                return Err(MagikaInferenceError::Runtime(
299                    "runtime returned mismatched batch size".to_string(),
300                ));
301            }
302
303            for (position, row) in pending_positions.into_iter().zip(rows) {
304                detections[position] = Some(self.row_to_detection(row)?);
305            }
306        }
307
308        detections
309            .into_iter()
310            .map(|detection| {
311                detection.ok_or_else(|| {
312                    MagikaInferenceError::Runtime(
313                        "missing detection result".to_string(),
314                    )
315                })
316            })
317            .collect()
318    }
319
320    fn detect_content_type_batch(
321        &self,
322        inputs: Vec<&[u8]>,
323    ) -> Result<Vec<ContentType>, MagikaInferenceError> {
324        if inputs.is_empty() {
325            return Ok(Vec::new());
326        }
327
328        let mut detections = vec![None; inputs.len()];
329        let mut pending_positions = Vec::new();
330        let mut pending_features = Vec::new();
331
332        for (index, bytes) in inputs.into_iter().enumerate() {
333            match prepare_input(bytes, &self.config) {
334                PreparedInput::Ruled(content_type) => {
335                    detections[index] = Some(content_type);
336                }
337                PreparedInput::Features(features) => {
338                    pending_positions.push(index);
339                    pending_features.push(features);
340                }
341            }
342        }
343
344        if !pending_features.is_empty() {
345            let rows = self.infer_batch(&pending_features)?;
346            if rows.len() != pending_positions.len() {
347                return Err(MagikaInferenceError::Runtime(
348                    "runtime returned mismatched batch size".to_string(),
349                ));
350            }
351
352            for (position, row) in pending_positions.into_iter().zip(rows) {
353                detections[position] = Some(self.row_to_content_type(row)?);
354            }
355        }
356
357        detections
358            .into_iter()
359            .map(|detection| {
360                detection.ok_or_else(|| {
361                    MagikaInferenceError::Runtime(
362                        "missing detection result".to_string(),
363                    )
364                })
365            })
366            .collect()
367    }
368
369    fn forward(
370        &self,
371        batch_features: &[Vec<i32>],
372    ) -> Result<Tensor<B, 2>, MagikaInferenceError> {
373        let logits = self.forward_logits(batch_features)?;
374        let logits = tensor_2d_from_flat(
375            &self.device,
376            logits,
377            [batch_features.len(), DENSE_OUT],
378        )?;
379        Ok(softmax(logits, 1))
380    }
381
382    fn forward_logits(
383        &self,
384        batch_features: &[Vec<i32>],
385    ) -> Result<Vec<f32>, MagikaInferenceError> {
386        let batch_size = batch_features.len();
387        let flat = batch_features
388            .iter()
389            .flat_map(|features| features.iter().map(|value| *value as f32))
390            .collect::<Vec<_>>();
391
392        if flat.len() != batch_size * SEQ_LEN {
393            return Err(MagikaInferenceError::Runtime(
394                "unexpected feature batch shape".to_string(),
395            ));
396        }
397
398        let mut embedded = Vec::with_capacity(batch_size * SEQ_LEN * EMBED_DIM);
399        for features in batch_features {
400            for &feature in features {
401                let index = usize::try_from(feature).map_err(|_| {
402                    MagikaInferenceError::Runtime(format!(
403                        "negative feature value: {feature}"
404                    ))
405                })?;
406                if index >= NUM_CLASSES {
407                    return Err(MagikaInferenceError::Runtime(format!(
408                        "feature value out of range: {feature}"
409                    )));
410                }
411
412                let start = index * EMBED_DIM;
413                for offset in 0..EMBED_DIM {
414                    embedded.push(
415                        self.embedding_weight[start + offset]
416                            + self.embedding_bias[offset],
417                    );
418                }
419            }
420        }
421
422        let x = Tensor::<B, 3>::from_data(
423            TensorData::new(embedded, [batch_size, SEQ_LEN, EMBED_DIM]),
424            &self.device,
425        );
426        let x = gelu(x);
427        let x: Tensor<B, 3> =
428            x.reshape([batch_size, TOKENS_PER_BLOCK, CHANNELS_PER_TOKEN]);
429        let x = layer_norm_axis_1_3d(
430            x,
431            TOKENS_PER_BLOCK as f32,
432            self.layer_norm_0_weight.clone(),
433            self.layer_norm_0_bias.clone(),
434        );
435        let x = x.permute([0, 2, 1]);
436        let x = conv1d(
437            x,
438            self.conv_weight.clone(),
439            Some(self.conv_bias.clone()),
440            ConvOptions::new([1], [0], [1], 1),
441        );
442        let x = gelu(x);
443        let pooled = x.max_dim(2).squeeze_dim(2);
444
445        let normalized = layer_norm_axis_1_2d(
446            pooled,
447            CONV_OUT_CHANNELS as f32,
448            self.layer_norm_1_weight.clone(),
449            self.layer_norm_1_bias.clone(),
450        );
451        (normalized.matmul(self.dense_weight.clone()) + self.dense_bias.clone())
452            .into_data()
453            .to_vec::<f32>()
454            .map_err(|err| {
455                MagikaInferenceError::Runtime(format!(
456                    "extract logits data: {err}"
457                ))
458            })
459    }
460
461    fn infer_batch(
462        &self,
463        batch_features: &[Vec<i32>],
464    ) -> Result<Vec<Vec<f32>>, MagikaInferenceError> {
465        let mut out = Vec::with_capacity(batch_features.len());
466
467        for features in batch_features {
468            let probs = self.forward(std::slice::from_ref(features))?;
469            let flat = probs.into_data().to_vec::<f32>().map_err(|err| {
470                MagikaInferenceError::Runtime(format!(
471                    "extract tensor data: {err}"
472                ))
473            })?;
474            out.push(flat);
475        }
476
477        Ok(out)
478    }
479
480    fn row_to_content_type(
481        &self,
482        row: Vec<f32>,
483    ) -> Result<ContentType, MagikaInferenceError> {
484        if row.len() != DENSE_OUT {
485            return Err(MagikaInferenceError::Runtime(format!(
486                "unexpected logits row size: {}",
487                row.len()
488            )));
489        }
490
491        let mut indexed: Vec<(usize, f32)> =
492            row.into_iter().enumerate().collect();
493        indexed.sort_by(|a, b| {
494            b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
495        });
496
497        let (label_idx, score) = indexed.first().copied().ok_or_else(|| {
498            MagikaInferenceError::Runtime("no alternatives created".to_string())
499        })?;
500
501        self.final_content_type(label_idx, score)
502    }
503
504    fn row_to_detection(
505        &self,
506        row: Vec<f32>,
507    ) -> Result<Detection, MagikaInferenceError> {
508        if row.len() != DENSE_OUT {
509            return Err(MagikaInferenceError::Runtime(format!(
510                "unexpected logits row size: {}",
511                row.len()
512            )));
513        }
514
515        let mut indexed: Vec<(usize, f32)> =
516            row.into_iter().enumerate().collect();
517        indexed.sort_by(|a, b| {
518            b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
519        });
520
521        let alternatives = indexed
522            .iter()
523            .take(self.top_k)
524            .map(|(label_idx, score)| {
525                let content_type =
526                    self.final_content_type(*label_idx, *score)?;
527                Ok(alternative_for_content_type(content_type, *score))
528            })
529            .collect::<Result<Vec<_>, MagikaInferenceError>>()?;
530
531        let best = alternatives
532            .first()
533            .ok_or_else(|| {
534                MagikaInferenceError::Runtime(
535                    "no alternatives created".to_string(),
536                )
537            })?
538            .clone();
539
540        Ok(Detection {
541            label: best.label.clone(),
542            mime_type: best.mime_type.clone(),
543            confidence: best.confidence,
544            alternatives,
545        })
546    }
547
548    fn final_content_type(
549        &self,
550        label_idx: usize,
551        score: f32,
552    ) -> Result<ContentType, MagikaInferenceError> {
553        let inferred_type = label_for_index(label_idx)?.content_type();
554        if score < self.config.thresholds[inferred_type as usize] {
555            return Ok(if inferred_type.info().is_text {
556                ContentType::Txt
557            } else {
558                ContentType::Unknown
559            });
560        }
561
562        Ok(self.config.overwrite_map[inferred_type as usize])
563    }
564}
565
566fn tensor_2d_from_flat<B: Backend<FloatElem = f32>>(
567    device: &B::Device,
568    values: Vec<f32>,
569    shape: [usize; 2],
570) -> Result<Tensor<B, 2>, MagikaInferenceError> {
571    Ok(Tensor::<B, 2>::from_data(
572        TensorData::new(values, shape),
573        device,
574    ))
575}
576
577fn tensor_1d_from_flat<B: Backend<FloatElem = f32>>(
578    device: &B::Device,
579    values: Vec<f32>,
580) -> Result<Tensor<B, 1>, MagikaInferenceError> {
581    let len = values.len();
582    Ok(Tensor::<B, 1>::from_data(
583        TensorData::new(values, [len]),
584        device,
585    ))
586}
587
588fn read_conv_weight(
589    initializers: &HashMap<&str, &TensorProto>,
590) -> Result<Vec<f32>, MagikaInferenceError> {
591    let raw = read_tensor_spec(initializers, &CONV_WEIGHT)?;
592    let mut flattened = Vec::with_capacity(
593        CONV_OUT_CHANNELS * CHANNELS_PER_TOKEN * CONV_KERNEL,
594    );
595
596    for out in 0..CONV_OUT_CHANNELS {
597        for channel in 0..CHANNELS_PER_TOKEN {
598            for kernel in 0..CONV_KERNEL {
599                let index =
600                    (out * CHANNELS_PER_TOKEN + channel) * CONV_KERNEL + kernel;
601                flattened.push(raw[index]);
602            }
603        }
604    }
605
606    Ok(flattened)
607}
608
609fn tensor_3d_from_flat<B: Backend<FloatElem = f32>>(
610    device: &B::Device,
611    values: Vec<f32>,
612    shape: [usize; 3],
613) -> Result<Tensor<B, 3>, MagikaInferenceError> {
614    Ok(Tensor::<B, 3>::from_data(
615        TensorData::new(values, shape),
616        device,
617    ))
618}
619
620fn tensor_3d<B: Backend<FloatElem = f32>>(
621    device: &B::Device,
622    initializers: &HashMap<&str, &TensorProto>,
623    spec: &TensorSpec,
624    shape: [usize; 3],
625) -> Result<Tensor<B, 3>, MagikaInferenceError> {
626    Ok(Tensor::<B, 3>::from_data(
627        TensorData::new(read_tensor_spec(initializers, spec)?, shape),
628        device,
629    ))
630}
631
632fn read_tensor_spec(
633    initializers: &HashMap<&str, &TensorProto>,
634    spec: &TensorSpec,
635) -> Result<Vec<f32>, MagikaInferenceError> {
636    read_f32_tensor(initializers, spec.name, &spec.shape[..spec.rank])
637}
638
639fn read_f32_tensor(
640    initializers: &HashMap<&str, &TensorProto>,
641    name: &str,
642    expected_shape: &[usize],
643) -> Result<Vec<f32>, MagikaInferenceError> {
644    let tensor = initializers.get(name).ok_or_else(|| {
645        MagikaInferenceError::Runtime(format!("missing initializer: {name}"))
646    })?;
647
648    if tensor.data_type != 1 {
649        return Err(MagikaInferenceError::Runtime(format!(
650            "initializer {name} has unexpected dtype {}",
651            tensor.data_type
652        )));
653    }
654
655    let actual_shape = tensor
656        .dims
657        .iter()
658        .map(|dim| *dim as usize)
659        .collect::<Vec<_>>();
660    if actual_shape.as_slice() != expected_shape {
661        return Err(MagikaInferenceError::Runtime(format!(
662            "initializer {name} has shape {:?}, expected {:?}",
663            actual_shape, expected_shape
664        )));
665    }
666
667    let values = tensor
668        .raw_data
669        .chunks_exact(4)
670        .map(|chunk| f32::from_le_bytes(chunk.try_into().expect("f32 chunk")))
671        .collect::<Vec<_>>();
672
673    if values.len() != expected_shape.iter().product::<usize>() {
674        return Err(MagikaInferenceError::Runtime(format!(
675            "initializer {name} has {} values, expected {}",
676            values.len(),
677            expected_shape.iter().product::<usize>()
678        )));
679    }
680
681    Ok(values)
682}
683
684fn gelu<B: Backend<FloatElem = f32>, const D: usize>(
685    x: Tensor<B, D>,
686) -> Tensor<B, D> {
687    let cubic = x.clone() * x.clone() * x.clone();
688    let inner = (x.clone() + cubic * 0.044_715) * 0.797_884_6;
689    x * ((inner.tanh() + 1.0) * 0.5)
690}
691
692fn layer_norm_axis_1_3d<B: Backend<FloatElem = f32>>(
693    x: Tensor<B, 3>,
694    axis_len: f32,
695    weight: Tensor<B, 3>,
696    bias: Tensor<B, 3>,
697) -> Tensor<B, 3> {
698    let mean = x.clone().sum_dim(1) * (1.0 / axis_len);
699    let variance = (x.clone() * x.clone()).sum_dim(1) * (1.0 / axis_len)
700        - mean.clone() * mean.clone();
701    let inv_std = (variance.clamp_min(0.0) + 1e-6).sqrt().recip();
702    ((x - mean) * inv_std) * weight + bias
703}
704
705#[allow(dead_code)]
706fn layer_norm_axis_1_2d<B: Backend<FloatElem = f32>>(
707    x: Tensor<B, 2>,
708    axis_len: f32,
709    weight: Tensor<B, 2>,
710    bias: Tensor<B, 2>,
711) -> Tensor<B, 2> {
712    let mean = x.clone().sum_dim(1) * (1.0 / axis_len);
713    let variance = (x.clone() * x.clone()).sum_dim(1) * (1.0 / axis_len)
714        - mean.clone() * mean.clone();
715    let inv_std = (variance.clamp_min(0.0) + 1e-6).sqrt().recip();
716    ((x - mean) * inv_std) * weight + bias
717}
718
719fn label_for_index(
720    index: usize,
721) -> Result<vendor_model::Label, MagikaInferenceError> {
722    if index >= vendor_model::NUM_LABELS {
723        return Err(MagikaInferenceError::Runtime(format!(
724            "label index out of range: {index}"
725        )));
726    }
727
728    Ok(
729        unsafe {
730            std::mem::transmute::<u32, vendor_model::Label>(index as u32)
731        },
732    )
733}
734
735fn detection_for_content_type(content_type: ContentType) -> Detection {
736    let alternative = alternative_for_content_type(content_type, 1.0);
737
738    Detection {
739        label: alternative.label.clone(),
740        mime_type: alternative.mime_type.clone(),
741        confidence: alternative.confidence,
742        alternatives: vec![alternative],
743    }
744}
745
746fn alternative_for_content_type(
747    content_type: ContentType,
748    confidence: f32,
749) -> RankedAlternative {
750    let info = content_type.info();
751
752    RankedAlternative {
753        label: info.label.to_string(),
754        mime_type: Some(info.mime_type.to_string()),
755        confidence,
756    }
757}
758
759#[cfg(test)]
760mod tests {
761    use burn_cpu::{Cpu, CpuDevice};
762
763    use super::MagikaModel;
764
765    #[test]
766    fn classifier_batch_is_deterministic() {
767        let classifier = MagikaModel::<Cpu>::from_embedded(&CpuDevice)
768            .expect("build classifier");
769
770        let a = classifier
771            .detect_bytes(b"abcdef")
772            .expect("first inference should succeed");
773        let b = classifier
774            .detect_bytes(b"abcdef")
775            .expect("second inference should succeed");
776        assert_eq!(a, b);
777
778        let batch = classifier
779            .detect_batch(vec![b"a", b"b", b"c"])
780            .expect("batch inference should succeed");
781        assert_eq!(batch.len(), 3);
782    }
783
784    #[test]
785    fn embedded_model_builds() {
786        MagikaModel::<Cpu>::from_embedded(&CpuDevice)
787            .expect("build embedded model");
788    }
789
790    #[test]
791    fn explicit_top_k_is_applied() {
792        let classifier = MagikaModel::<Cpu>::from_embedded(&CpuDevice)
793            .expect("build model")
794            .with_top_k(5);
795
796        let detection = classifier
797            .detect_bytes(b"function greet() { return 'hi'; }")
798            .expect("detect bytes");
799        assert_eq!(detection.alternatives.len(), 5);
800    }
801}