Skip to main content

encoderfile/runtime/
state.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use ort::session::Session;
4use parking_lot::Mutex;
5
6use crate::{
7    common::{Config, ModelConfig, ModelType, model_type::ModelTypeSpec},
8    runtime::TokenizerService,
9    transforms::DEFAULT_LIBS,
10};
11
12pub type AppState<T> = Arc<EncoderfileState<T>>;
13
14#[derive(Debug)]
15pub struct EncoderfileState<T: ModelTypeSpec> {
16    pub config: Config,
17    pub session: Mutex<Session>,
18    pub tokenizer: TokenizerService,
19    pub model_config: ModelConfig,
20    pub lua_libs: Vec<mlua::StdLib>,
21    _marker: PhantomData<T>,
22}
23
24impl<T: ModelTypeSpec> EncoderfileState<T> {
25    pub fn new(
26        config: Config,
27        session: Mutex<Session>,
28        tokenizer: TokenizerService,
29        model_config: ModelConfig,
30    ) -> EncoderfileState<T> {
31        let lua_libs = match config.lua_libs {
32            Some(ref libs) => Vec::<mlua::StdLib>::from(libs),
33            None => DEFAULT_LIBS.to_vec(),
34        };
35        EncoderfileState {
36            config,
37            session,
38            tokenizer,
39            model_config,
40            lua_libs,
41            _marker: PhantomData,
42        }
43    }
44
45    pub fn transform_str(&self) -> Option<String> {
46        self.config.transform.clone()
47    }
48
49    pub fn lua_libs(&self) -> &Vec<mlua::StdLib> {
50        &self.lua_libs
51    }
52
53    pub fn model_type() -> ModelType {
54        T::enum_val()
55    }
56}