Skip to main content

akuna_infer/
session.rs

1use std::path::Path;
2
3use burn::tensor::backend::Backend;
4
5use crate::{Error, FileType, MagikaModel};
6
7pub struct Session<B: Backend> {
8    model: MagikaModel<B>,
9}
10
11impl<B: Backend<FloatElem = f32>> Session<B> {
12    pub fn new(device: &B::Device) -> Result<Self, Error> {
13        let model = MagikaModel::<B>::from_embedded(device)?;
14        Ok(Self { model })
15    }
16
17    pub fn from_file(
18        device: &B::Device,
19        path: impl AsRef<Path>,
20    ) -> Result<Self, Error> {
21        let model = MagikaModel::<B>::from_file(device, path)?;
22        Ok(Self { model })
23    }
24
25    pub fn from_bytes(device: &B::Device, bytes: &[u8]) -> Result<Self, Error> {
26        let model = MagikaModel::<B>::from_bytes(device, bytes)?;
27        Ok(Self { model })
28    }
29
30    pub fn identify_file_sync(
31        &mut self,
32        path: impl AsRef<Path>,
33    ) -> Result<FileType, Error> {
34        self.model.identify_path(path)
35    }
36
37    pub async fn identify_file_async(
38        &mut self,
39        path: impl AsRef<Path>,
40    ) -> Result<FileType, Error> {
41        self.model.identify_path(path)
42    }
43
44    pub fn identify_content_sync(
45        &mut self,
46        bytes: &[u8],
47    ) -> Result<FileType, Error> {
48        self.model.identify_bytes(bytes)
49    }
50
51    pub async fn identify_content_async(
52        &mut self,
53        bytes: &[u8],
54    ) -> Result<FileType, Error> {
55        self.model.identify_bytes(bytes)
56    }
57}