1use std::{collections::HashMap, fmt, fs, path::Path};
2
3use burn::tensor::{
4 Tensor, TensorData, activation::softmax, backend::Backend, module::conv1d,
5 ops::ConvOptions,
6};
7use onnx_ir::{ModelProto, TensorProto};
8use protobuf::Message;
9
10use crate::{
11 config::{ModelConfig, runtime_config},
12 detection::{Detection, RankedAlternative},
13 file::FileType,
14 preprocess::{PreparedInput, prepare_input},
15 vendor::{content::ContentType, model as vendor_model},
16};
17
18pub(crate) use crate::vendor::model::{CONFIG, Label, NUM_LABELS};
19const NUM_CLASSES: usize = 257;
20const SEQ_LEN: usize = 2048;
21const EMBED_DIM: usize = 64;
22const TOKENS_PER_BLOCK: usize = 512;
23const CHANNELS_PER_TOKEN: usize = 256;
24const CONV_OUT_CHANNELS: usize = 512;
25const CONV_KERNEL: usize = 5;
26const DENSE_OUT: usize = vendor_model::NUM_LABELS;
27const EMBEDDED_MODEL: &[u8] =
28 include_bytes!("vendor/assets/models/standard_v3_3/model.onnx");
29
30struct TensorSpec {
31 name: &'static str,
32 shape: &'static [usize; 4],
33 rank: usize,
34}
35
36const EMBEDDING_WEIGHT: TensorSpec = TensorSpec {
37 name: "jax2tf_get_logits_/Const:0",
38 shape: &[NUM_CLASSES, EMBED_DIM, 0, 0],
39 rank: 2,
40};
41
42const EMBEDDING_BIAS: TensorSpec = TensorSpec {
43 name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/Dense_0/Reshape:0",
44 shape: &[1, 1, EMBED_DIM, 0],
45 rank: 3,
46};
47
48const LAYER_NORM_0_WEIGHT: TensorSpec = TensorSpec {
49 name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/LayerNorm_0/Reshape_2:0",
50 shape: &[1, TOKENS_PER_BLOCK, 1, 0],
51 rank: 3,
52};
53
54const LAYER_NORM_0_BIAS: TensorSpec = TensorSpec {
55 name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/LayerNorm_0/Reshape_3:0",
56 shape: &[1, TOKENS_PER_BLOCK, 1, 0],
57 rank: 3,
58};
59
60const CONV_WEIGHT: TensorSpec = TensorSpec {
61 name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/Conv_0/transpose_3:0",
62 shape: &[CONV_OUT_CHANNELS, CHANNELS_PER_TOKEN, CONV_KERNEL, 1],
63 rank: 4,
64};
65
66const CONV_BIAS: TensorSpec = TensorSpec {
67 name: "const_fold_opt__209",
68 shape: &[1, CONV_OUT_CHANNELS, 1, 0],
69 rank: 3,
70};
71
72const LAYER_NORM_1_WEIGHT: TensorSpec = TensorSpec {
73 name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/LayerNorm_1/Reshape_2:0",
74 shape: &[1, CONV_OUT_CHANNELS, 0, 0],
75 rank: 2,
76};
77
78const LAYER_NORM_1_BIAS: TensorSpec = TensorSpec {
79 name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/LayerNorm_1/Reshape_3:0",
80 shape: &[1, CONV_OUT_CHANNELS, 0, 0],
81 rank: 2,
82};
83
84const DENSE_WEIGHT: TensorSpec = TensorSpec {
85 name: "jax2tf_get_logits_/Const_24:0",
86 shape: &[CONV_OUT_CHANNELS, DENSE_OUT, 0, 0],
87 rank: 2,
88};
89
90const DENSE_BIAS: TensorSpec = TensorSpec {
91 name: "jax2tf_get_logits_/pjit_get_logits_/MagikaV2/Dense_1/Reshape:0",
92 shape: &[1, DENSE_OUT, 0, 0],
93 rank: 2,
94};
95
96#[derive(Debug)]
97pub enum MagikaInferenceError {
98 Io(std::io::Error),
99 InvalidConfig(String),
100 Runtime(String),
101}
102
103impl fmt::Display for MagikaInferenceError {
104 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
105 match self {
106 Self::Io(e) => write!(f, "io error: {e}"),
107 Self::InvalidConfig(e) => write!(f, "invalid configuration: {e}"),
108 Self::Runtime(e) => write!(f, "inference runtime error: {e}"),
109 }
110 }
111}
112
113impl std::error::Error for MagikaInferenceError {}
114
115impl From<std::io::Error> for MagikaInferenceError {
116 fn from(value: std::io::Error) -> Self {
117 Self::Io(value)
118 }
119}
120
121pub struct MagikaModel<B: Backend> {
122 device: B::Device,
123 config: ModelConfig,
124 top_k: usize,
125 embedding_weight: Vec<f32>,
126 embedding_bias: Vec<f32>,
127 layer_norm_0_weight: Tensor<B, 3>,
128 layer_norm_0_bias: Tensor<B, 3>,
129 conv_weight: Tensor<B, 3>,
130 conv_bias: Tensor<B, 1>,
131 layer_norm_1_weight: Tensor<B, 2>,
132 layer_norm_1_bias: Tensor<B, 2>,
133 dense_weight: Tensor<B, 2>,
134 dense_bias: Tensor<B, 2>,
135}
136
137impl<B: Backend<FloatElem = f32>> MagikaModel<B> {
138 pub fn from_embedded(
139 device: &B::Device,
140 ) -> Result<Self, MagikaInferenceError> {
141 Self::from_bytes(device, EMBEDDED_MODEL)
142 }
143
144 pub fn from_file(
145 device: &B::Device,
146 path: impl AsRef<Path>,
147 ) -> Result<Self, MagikaInferenceError> {
148 let model_bytes = fs::read(path)?;
149 Self::from_bytes(device, &model_bytes)
150 }
151
152 pub fn from_bytes(
153 device: &B::Device,
154 model_bytes: &[u8],
155 ) -> Result<Self, MagikaInferenceError> {
156 let model =
157 ModelProto::parse_from_bytes(model_bytes).map_err(|err| {
158 MagikaInferenceError::Runtime(format!("parse model: {err}"))
159 })?;
160 let graph = model.graph.as_ref().ok_or_else(|| {
161 MagikaInferenceError::Runtime("model graph missing".to_string())
162 })?;
163
164 let initializers = graph
165 .initializer
166 .iter()
167 .map(|tensor| (tensor.name.as_str(), tensor))
168 .collect::<HashMap<_, _>>();
169
170 Ok(Self {
171 device: (*device).clone(),
172 config: runtime_config(),
173 top_k: 3,
174 embedding_weight: read_tensor_spec(
175 &initializers,
176 &EMBEDDING_WEIGHT,
177 )?,
178 embedding_bias: read_tensor_spec(&initializers, &EMBEDDING_BIAS)?,
179 layer_norm_0_weight: tensor_3d(
180 device,
181 &initializers,
182 &LAYER_NORM_0_WEIGHT,
183 [1, TOKENS_PER_BLOCK, 1],
184 )?,
185 layer_norm_0_bias: tensor_3d(
186 device,
187 &initializers,
188 &LAYER_NORM_0_BIAS,
189 [1, TOKENS_PER_BLOCK, 1],
190 )?,
191 conv_weight: tensor_3d_from_flat(
192 device,
193 read_conv_weight(&initializers)?,
194 [CONV_OUT_CHANNELS, CHANNELS_PER_TOKEN, CONV_KERNEL],
195 )?,
196 conv_bias: tensor_1d_from_flat(
197 device,
198 read_tensor_spec(&initializers, &CONV_BIAS)?,
199 )?,
200 layer_norm_1_weight: tensor_2d_from_flat(
201 device,
202 read_tensor_spec(&initializers, &LAYER_NORM_1_WEIGHT)?,
203 [1, CONV_OUT_CHANNELS],
204 )?,
205 layer_norm_1_bias: tensor_2d_from_flat(
206 device,
207 read_tensor_spec(&initializers, &LAYER_NORM_1_BIAS)?,
208 [1, CONV_OUT_CHANNELS],
209 )?,
210 dense_weight: tensor_2d_from_flat(
211 device,
212 read_tensor_spec(&initializers, &DENSE_WEIGHT)?,
213 [CONV_OUT_CHANNELS, DENSE_OUT],
214 )?,
215 dense_bias: tensor_2d_from_flat(
216 device,
217 read_tensor_spec(&initializers, &DENSE_BIAS)?,
218 [1, DENSE_OUT],
219 )?,
220 })
221 }
222
223 pub fn with_top_k(mut self, top_k: usize) -> Self {
224 self.top_k = top_k.max(1);
225 self
226 }
227
228 pub fn detect_path(
229 &self,
230 path: impl AsRef<Path>,
231 ) -> Result<Detection, MagikaInferenceError> {
232 let bytes = fs::read(path)?;
233 self.detect_bytes(&bytes)
234 }
235
236 pub fn identify_path(
237 &self,
238 path: impl AsRef<Path>,
239 ) -> Result<FileType, MagikaInferenceError> {
240 let path = path.as_ref();
241 let metadata = fs::symlink_metadata(path)?;
242 if metadata.is_dir() {
243 return Ok(FileType::Directory);
244 }
245 if metadata.file_type().is_symlink() {
246 return Ok(FileType::Symlink);
247 }
248
249 let bytes = fs::read(path)?;
250 self.identify_bytes(&bytes)
251 }
252
253 pub fn detect_bytes(
254 &self,
255 bytes: &[u8],
256 ) -> Result<Detection, MagikaInferenceError> {
257 let mut all = self.detect_batch(vec![bytes])?;
258 Ok(all.remove(0))
259 }
260
261 pub fn identify_bytes(
262 &self,
263 bytes: &[u8],
264 ) -> Result<FileType, MagikaInferenceError> {
265 let mut all = self.detect_content_type_batch(vec![bytes])?;
266 let content_type = all.remove(0);
267 Ok(FileType::Ruled(content_type))
268 }
269
270 pub fn detect_batch(
271 &self,
272 inputs: Vec<&[u8]>,
273 ) -> Result<Vec<Detection>, MagikaInferenceError> {
274 if inputs.is_empty() {
275 return Ok(Vec::new());
276 }
277
278 let mut detections = vec![None; inputs.len()];
279 let mut pending_positions = Vec::new();
280 let mut pending_features = Vec::new();
281
282 for (index, bytes) in inputs.into_iter().enumerate() {
283 match prepare_input(bytes, &self.config) {
284 PreparedInput::Ruled(content_type) => {
285 detections[index] =
286 Some(detection_for_content_type(content_type));
287 }
288 PreparedInput::Features(features) => {
289 pending_positions.push(index);
290 pending_features.push(features);
291 }
292 }
293 }
294
295 if !pending_features.is_empty() {
296 let rows = self.infer_batch(&pending_features)?;
297 if rows.len() != pending_positions.len() {
298 return Err(MagikaInferenceError::Runtime(
299 "runtime returned mismatched batch size".to_string(),
300 ));
301 }
302
303 for (position, row) in pending_positions.into_iter().zip(rows) {
304 detections[position] = Some(self.row_to_detection(row)?);
305 }
306 }
307
308 detections
309 .into_iter()
310 .map(|detection| {
311 detection.ok_or_else(|| {
312 MagikaInferenceError::Runtime(
313 "missing detection result".to_string(),
314 )
315 })
316 })
317 .collect()
318 }
319
320 fn detect_content_type_batch(
321 &self,
322 inputs: Vec<&[u8]>,
323 ) -> Result<Vec<ContentType>, MagikaInferenceError> {
324 if inputs.is_empty() {
325 return Ok(Vec::new());
326 }
327
328 let mut detections = vec![None; inputs.len()];
329 let mut pending_positions = Vec::new();
330 let mut pending_features = Vec::new();
331
332 for (index, bytes) in inputs.into_iter().enumerate() {
333 match prepare_input(bytes, &self.config) {
334 PreparedInput::Ruled(content_type) => {
335 detections[index] = Some(content_type);
336 }
337 PreparedInput::Features(features) => {
338 pending_positions.push(index);
339 pending_features.push(features);
340 }
341 }
342 }
343
344 if !pending_features.is_empty() {
345 let rows = self.infer_batch(&pending_features)?;
346 if rows.len() != pending_positions.len() {
347 return Err(MagikaInferenceError::Runtime(
348 "runtime returned mismatched batch size".to_string(),
349 ));
350 }
351
352 for (position, row) in pending_positions.into_iter().zip(rows) {
353 detections[position] = Some(self.row_to_content_type(row)?);
354 }
355 }
356
357 detections
358 .into_iter()
359 .map(|detection| {
360 detection.ok_or_else(|| {
361 MagikaInferenceError::Runtime(
362 "missing detection result".to_string(),
363 )
364 })
365 })
366 .collect()
367 }
368
369 fn forward(
370 &self,
371 batch_features: &[Vec<i32>],
372 ) -> Result<Tensor<B, 2>, MagikaInferenceError> {
373 let logits = self.forward_logits(batch_features)?;
374 let logits = tensor_2d_from_flat(
375 &self.device,
376 logits,
377 [batch_features.len(), DENSE_OUT],
378 )?;
379 Ok(softmax(logits, 1))
380 }
381
382 fn forward_logits(
383 &self,
384 batch_features: &[Vec<i32>],
385 ) -> Result<Vec<f32>, MagikaInferenceError> {
386 let batch_size = batch_features.len();
387 let flat = batch_features
388 .iter()
389 .flat_map(|features| features.iter().map(|value| *value as f32))
390 .collect::<Vec<_>>();
391
392 if flat.len() != batch_size * SEQ_LEN {
393 return Err(MagikaInferenceError::Runtime(
394 "unexpected feature batch shape".to_string(),
395 ));
396 }
397
398 let mut embedded = Vec::with_capacity(batch_size * SEQ_LEN * EMBED_DIM);
399 for features in batch_features {
400 for &feature in features {
401 let index = usize::try_from(feature).map_err(|_| {
402 MagikaInferenceError::Runtime(format!(
403 "negative feature value: {feature}"
404 ))
405 })?;
406 if index >= NUM_CLASSES {
407 return Err(MagikaInferenceError::Runtime(format!(
408 "feature value out of range: {feature}"
409 )));
410 }
411
412 let start = index * EMBED_DIM;
413 for offset in 0..EMBED_DIM {
414 embedded.push(
415 self.embedding_weight[start + offset]
416 + self.embedding_bias[offset],
417 );
418 }
419 }
420 }
421
422 let x = Tensor::<B, 3>::from_data(
423 TensorData::new(embedded, [batch_size, SEQ_LEN, EMBED_DIM]),
424 &self.device,
425 );
426 let x = gelu(x);
427 let x: Tensor<B, 3> =
428 x.reshape([batch_size, TOKENS_PER_BLOCK, CHANNELS_PER_TOKEN]);
429 let x = layer_norm_axis_1_3d(
430 x,
431 TOKENS_PER_BLOCK as f32,
432 self.layer_norm_0_weight.clone(),
433 self.layer_norm_0_bias.clone(),
434 );
435 let x = x.permute([0, 2, 1]);
436 let x = conv1d(
437 x,
438 self.conv_weight.clone(),
439 Some(self.conv_bias.clone()),
440 ConvOptions::new([1], [0], [1], 1),
441 );
442 let x = gelu(x);
443 let pooled = x.max_dim(2).squeeze_dim(2);
444
445 let normalized = layer_norm_axis_1_2d(
446 pooled,
447 CONV_OUT_CHANNELS as f32,
448 self.layer_norm_1_weight.clone(),
449 self.layer_norm_1_bias.clone(),
450 );
451 (normalized.matmul(self.dense_weight.clone()) + self.dense_bias.clone())
452 .into_data()
453 .to_vec::<f32>()
454 .map_err(|err| {
455 MagikaInferenceError::Runtime(format!(
456 "extract logits data: {err}"
457 ))
458 })
459 }
460
461 fn infer_batch(
462 &self,
463 batch_features: &[Vec<i32>],
464 ) -> Result<Vec<Vec<f32>>, MagikaInferenceError> {
465 let mut out = Vec::with_capacity(batch_features.len());
466
467 for features in batch_features {
468 let probs = self.forward(std::slice::from_ref(features))?;
469 let flat = probs.into_data().to_vec::<f32>().map_err(|err| {
470 MagikaInferenceError::Runtime(format!(
471 "extract tensor data: {err}"
472 ))
473 })?;
474 out.push(flat);
475 }
476
477 Ok(out)
478 }
479
480 fn row_to_content_type(
481 &self,
482 row: Vec<f32>,
483 ) -> Result<ContentType, MagikaInferenceError> {
484 if row.len() != DENSE_OUT {
485 return Err(MagikaInferenceError::Runtime(format!(
486 "unexpected logits row size: {}",
487 row.len()
488 )));
489 }
490
491 let mut indexed: Vec<(usize, f32)> =
492 row.into_iter().enumerate().collect();
493 indexed.sort_by(|a, b| {
494 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
495 });
496
497 let (label_idx, score) = indexed.first().copied().ok_or_else(|| {
498 MagikaInferenceError::Runtime("no alternatives created".to_string())
499 })?;
500
501 self.final_content_type(label_idx, score)
502 }
503
504 fn row_to_detection(
505 &self,
506 row: Vec<f32>,
507 ) -> Result<Detection, MagikaInferenceError> {
508 if row.len() != DENSE_OUT {
509 return Err(MagikaInferenceError::Runtime(format!(
510 "unexpected logits row size: {}",
511 row.len()
512 )));
513 }
514
515 let mut indexed: Vec<(usize, f32)> =
516 row.into_iter().enumerate().collect();
517 indexed.sort_by(|a, b| {
518 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
519 });
520
521 let alternatives = indexed
522 .iter()
523 .take(self.top_k)
524 .map(|(label_idx, score)| {
525 let content_type =
526 self.final_content_type(*label_idx, *score)?;
527 Ok(alternative_for_content_type(content_type, *score))
528 })
529 .collect::<Result<Vec<_>, MagikaInferenceError>>()?;
530
531 let best = alternatives
532 .first()
533 .ok_or_else(|| {
534 MagikaInferenceError::Runtime(
535 "no alternatives created".to_string(),
536 )
537 })?
538 .clone();
539
540 Ok(Detection {
541 label: best.label.clone(),
542 mime_type: best.mime_type.clone(),
543 confidence: best.confidence,
544 alternatives,
545 })
546 }
547
548 fn final_content_type(
549 &self,
550 label_idx: usize,
551 score: f32,
552 ) -> Result<ContentType, MagikaInferenceError> {
553 let inferred_type = label_for_index(label_idx)?.content_type();
554 if score < self.config.thresholds[inferred_type as usize] {
555 return Ok(if inferred_type.info().is_text {
556 ContentType::Txt
557 } else {
558 ContentType::Unknown
559 });
560 }
561
562 Ok(self.config.overwrite_map[inferred_type as usize])
563 }
564}
565
566fn tensor_2d_from_flat<B: Backend<FloatElem = f32>>(
567 device: &B::Device,
568 values: Vec<f32>,
569 shape: [usize; 2],
570) -> Result<Tensor<B, 2>, MagikaInferenceError> {
571 Ok(Tensor::<B, 2>::from_data(
572 TensorData::new(values, shape),
573 device,
574 ))
575}
576
577fn tensor_1d_from_flat<B: Backend<FloatElem = f32>>(
578 device: &B::Device,
579 values: Vec<f32>,
580) -> Result<Tensor<B, 1>, MagikaInferenceError> {
581 let len = values.len();
582 Ok(Tensor::<B, 1>::from_data(
583 TensorData::new(values, [len]),
584 device,
585 ))
586}
587
588fn read_conv_weight(
589 initializers: &HashMap<&str, &TensorProto>,
590) -> Result<Vec<f32>, MagikaInferenceError> {
591 let raw = read_tensor_spec(initializers, &CONV_WEIGHT)?;
592 let mut flattened = Vec::with_capacity(
593 CONV_OUT_CHANNELS * CHANNELS_PER_TOKEN * CONV_KERNEL,
594 );
595
596 for out in 0..CONV_OUT_CHANNELS {
597 for channel in 0..CHANNELS_PER_TOKEN {
598 for kernel in 0..CONV_KERNEL {
599 let index =
600 (out * CHANNELS_PER_TOKEN + channel) * CONV_KERNEL + kernel;
601 flattened.push(raw[index]);
602 }
603 }
604 }
605
606 Ok(flattened)
607}
608
609fn tensor_3d_from_flat<B: Backend<FloatElem = f32>>(
610 device: &B::Device,
611 values: Vec<f32>,
612 shape: [usize; 3],
613) -> Result<Tensor<B, 3>, MagikaInferenceError> {
614 Ok(Tensor::<B, 3>::from_data(
615 TensorData::new(values, shape),
616 device,
617 ))
618}
619
620fn tensor_3d<B: Backend<FloatElem = f32>>(
621 device: &B::Device,
622 initializers: &HashMap<&str, &TensorProto>,
623 spec: &TensorSpec,
624 shape: [usize; 3],
625) -> Result<Tensor<B, 3>, MagikaInferenceError> {
626 Ok(Tensor::<B, 3>::from_data(
627 TensorData::new(read_tensor_spec(initializers, spec)?, shape),
628 device,
629 ))
630}
631
632fn read_tensor_spec(
633 initializers: &HashMap<&str, &TensorProto>,
634 spec: &TensorSpec,
635) -> Result<Vec<f32>, MagikaInferenceError> {
636 read_f32_tensor(initializers, spec.name, &spec.shape[..spec.rank])
637}
638
639fn read_f32_tensor(
640 initializers: &HashMap<&str, &TensorProto>,
641 name: &str,
642 expected_shape: &[usize],
643) -> Result<Vec<f32>, MagikaInferenceError> {
644 let tensor = initializers.get(name).ok_or_else(|| {
645 MagikaInferenceError::Runtime(format!("missing initializer: {name}"))
646 })?;
647
648 if tensor.data_type != 1 {
649 return Err(MagikaInferenceError::Runtime(format!(
650 "initializer {name} has unexpected dtype {}",
651 tensor.data_type
652 )));
653 }
654
655 let actual_shape = tensor
656 .dims
657 .iter()
658 .map(|dim| *dim as usize)
659 .collect::<Vec<_>>();
660 if actual_shape.as_slice() != expected_shape {
661 return Err(MagikaInferenceError::Runtime(format!(
662 "initializer {name} has shape {:?}, expected {:?}",
663 actual_shape, expected_shape
664 )));
665 }
666
667 let values = tensor
668 .raw_data
669 .chunks_exact(4)
670 .map(|chunk| f32::from_le_bytes(chunk.try_into().expect("f32 chunk")))
671 .collect::<Vec<_>>();
672
673 if values.len() != expected_shape.iter().product::<usize>() {
674 return Err(MagikaInferenceError::Runtime(format!(
675 "initializer {name} has {} values, expected {}",
676 values.len(),
677 expected_shape.iter().product::<usize>()
678 )));
679 }
680
681 Ok(values)
682}
683
684fn gelu<B: Backend<FloatElem = f32>, const D: usize>(
685 x: Tensor<B, D>,
686) -> Tensor<B, D> {
687 let cubic = x.clone() * x.clone() * x.clone();
688 let inner = (x.clone() + cubic * 0.044_715) * 0.797_884_6;
689 x * ((inner.tanh() + 1.0) * 0.5)
690}
691
692fn layer_norm_axis_1_3d<B: Backend<FloatElem = f32>>(
693 x: Tensor<B, 3>,
694 axis_len: f32,
695 weight: Tensor<B, 3>,
696 bias: Tensor<B, 3>,
697) -> Tensor<B, 3> {
698 let mean = x.clone().sum_dim(1) * (1.0 / axis_len);
699 let variance = (x.clone() * x.clone()).sum_dim(1) * (1.0 / axis_len)
700 - mean.clone() * mean.clone();
701 let inv_std = (variance.clamp_min(0.0) + 1e-6).sqrt().recip();
702 ((x - mean) * inv_std) * weight + bias
703}
704
705#[allow(dead_code)]
706fn layer_norm_axis_1_2d<B: Backend<FloatElem = f32>>(
707 x: Tensor<B, 2>,
708 axis_len: f32,
709 weight: Tensor<B, 2>,
710 bias: Tensor<B, 2>,
711) -> Tensor<B, 2> {
712 let mean = x.clone().sum_dim(1) * (1.0 / axis_len);
713 let variance = (x.clone() * x.clone()).sum_dim(1) * (1.0 / axis_len)
714 - mean.clone() * mean.clone();
715 let inv_std = (variance.clamp_min(0.0) + 1e-6).sqrt().recip();
716 ((x - mean) * inv_std) * weight + bias
717}
718
719fn label_for_index(
720 index: usize,
721) -> Result<vendor_model::Label, MagikaInferenceError> {
722 if index >= vendor_model::NUM_LABELS {
723 return Err(MagikaInferenceError::Runtime(format!(
724 "label index out of range: {index}"
725 )));
726 }
727
728 Ok(
729 unsafe {
730 std::mem::transmute::<u32, vendor_model::Label>(index as u32)
731 },
732 )
733}
734
735fn detection_for_content_type(content_type: ContentType) -> Detection {
736 let alternative = alternative_for_content_type(content_type, 1.0);
737
738 Detection {
739 label: alternative.label.clone(),
740 mime_type: alternative.mime_type.clone(),
741 confidence: alternative.confidence,
742 alternatives: vec![alternative],
743 }
744}
745
746fn alternative_for_content_type(
747 content_type: ContentType,
748 confidence: f32,
749) -> RankedAlternative {
750 let info = content_type.info();
751
752 RankedAlternative {
753 label: info.label.to_string(),
754 mime_type: Some(info.mime_type.to_string()),
755 confidence,
756 }
757}
758
759#[cfg(test)]
760mod tests {
761 use burn_cpu::{Cpu, CpuDevice};
762
763 use super::MagikaModel;
764
765 #[test]
766 fn classifier_batch_is_deterministic() {
767 let classifier = MagikaModel::<Cpu>::from_embedded(&CpuDevice)
768 .expect("build classifier");
769
770 let a = classifier
771 .detect_bytes(b"abcdef")
772 .expect("first inference should succeed");
773 let b = classifier
774 .detect_bytes(b"abcdef")
775 .expect("second inference should succeed");
776 assert_eq!(a, b);
777
778 let batch = classifier
779 .detect_batch(vec![b"a", b"b", b"c"])
780 .expect("batch inference should succeed");
781 assert_eq!(batch.len(), 3);
782 }
783
784 #[test]
785 fn embedded_model_builds() {
786 MagikaModel::<Cpu>::from_embedded(&CpuDevice)
787 .expect("build embedded model");
788 }
789
790 #[test]
791 fn explicit_top_k_is_applied() {
792 let classifier = MagikaModel::<Cpu>::from_embedded(&CpuDevice)
793 .expect("build model")
794 .with_top_k(5);
795
796 let detection = classifier
797 .detect_bytes(b"function greet() { return 'hi'; }")
798 .expect("detect bytes");
799 assert_eq!(detection.alternatives.len(), 5);
800 }
801}