1use std::path::Path;
2
3use burn::tensor::backend::Backend;
4
5use crate::{Error, FileType, MagikaModel};
6
7pub struct Session<B: Backend> {
8 model: MagikaModel<B>,
9}
10
11impl<B: Backend<FloatElem = f32>> Session<B> {
12 pub fn new(device: &B::Device) -> Result<Self, Error> {
13 let model = MagikaModel::<B>::from_embedded(device)?;
14 Ok(Self { model })
15 }
16
17 pub fn from_file(
18 device: &B::Device,
19 path: impl AsRef<Path>,
20 ) -> Result<Self, Error> {
21 let model = MagikaModel::<B>::from_file(device, path)?;
22 Ok(Self { model })
23 }
24
25 pub fn from_bytes(device: &B::Device, bytes: &[u8]) -> Result<Self, Error> {
26 let model = MagikaModel::<B>::from_bytes(device, bytes)?;
27 Ok(Self { model })
28 }
29
30 pub fn identify_file_sync(
31 &mut self,
32 path: impl AsRef<Path>,
33 ) -> Result<FileType, Error> {
34 self.model.identify_path(path)
35 }
36
37 pub async fn identify_file_async(
38 &mut self,
39 path: impl AsRef<Path>,
40 ) -> Result<FileType, Error> {
41 self.model.identify_path(path)
42 }
43
44 pub fn identify_content_sync(
45 &mut self,
46 bytes: &[u8],
47 ) -> Result<FileType, Error> {
48 self.model.identify_bytes(bytes)
49 }
50
51 pub async fn identify_content_async(
52 &mut self,
53 bytes: &[u8],
54 ) -> Result<FileType, Error> {
55 self.model.identify_bytes(bytes)
56 }
57}