encoderfile/runtime/
state.rs1use std::{marker::PhantomData, sync::Arc};
2
3use ort::session::Session;
4use parking_lot::Mutex;
5
6use crate::{
7 common::{Config, ModelConfig, ModelType, model_type::ModelTypeSpec},
8 runtime::TokenizerService,
9 transforms::DEFAULT_LIBS,
10};
11
12pub type AppState<T> = Arc<EncoderfileState<T>>;
13
14#[derive(Debug)]
15pub struct EncoderfileState<T: ModelTypeSpec> {
16 pub config: Config,
17 pub session: Mutex<Session>,
18 pub tokenizer: TokenizerService,
19 pub model_config: ModelConfig,
20 pub lua_libs: Vec<mlua::StdLib>,
21 _marker: PhantomData<T>,
22}
23
24impl<T: ModelTypeSpec> EncoderfileState<T> {
25 pub fn new(
26 config: Config,
27 session: Mutex<Session>,
28 tokenizer: TokenizerService,
29 model_config: ModelConfig,
30 ) -> EncoderfileState<T> {
31 let lua_libs = match config.lua_libs {
32 Some(ref libs) => Vec::<mlua::StdLib>::from(libs),
33 None => DEFAULT_LIBS.to_vec(),
34 };
35 EncoderfileState {
36 config,
37 session,
38 tokenizer,
39 model_config,
40 lua_libs,
41 _marker: PhantomData,
42 }
43 }
44
45 pub fn transform_str(&self) -> Option<String> {
46 self.config.transform.clone()
47 }
48
49 pub fn lua_libs(&self) -> &Vec<mlua::StdLib> {
50 &self.lua_libs
51 }
52
53 pub fn model_type() -> ModelType {
54 T::enum_val()
55 }
56}