use std::path::Path;
use ndarray::Array2;
use crate::future::{exec, AsyncEnv, Env, SyncEnv};
use crate::input::AsyncInputApi;
use crate::{AsyncInput, Builder, Features, FeaturesOrRuled, FileType, Result, SyncInput};
#[derive(Debug)]
pub struct Session {
pub(crate) session: ort::session::Session,
}
impl Session {
pub fn new() -> Result<Self> {
Session::builder().build()
}
pub fn builder() -> Builder {
Builder::default()
}
pub fn identify_file_sync(&mut self, file: impl AsRef<Path>) -> Result<FileType> {
exec(self.identify_file::<SyncEnv>(file.as_ref()))
}
pub async fn identify_file_async(&mut self, file: impl AsRef<Path>) -> Result<FileType> {
self.identify_file::<AsyncEnv>(file.as_ref()).await
}
async fn identify_file<E: Env>(&mut self, file: &Path) -> Result<FileType> {
let metadata = E::symlink_metadata(file).await?;
if metadata.is_dir() {
Ok(FileType::Directory)
} else if metadata.is_symlink() {
Ok(FileType::Symlink)
} else {
debug_assert!(metadata.is_file());
self.identify_content::<E>(E::open(file).await?).await
}
}
pub fn identify_content_sync(&mut self, file: impl SyncInput) -> Result<FileType> {
exec(self.identify_content::<SyncEnv>(file))
}
pub async fn identify_content_async(&mut self, file: impl AsyncInput) -> Result<FileType> {
self.identify_content::<AsyncEnv>(file).await
}
async fn identify_content<E: Env>(&mut self, file: impl AsyncInputApi) -> Result<FileType> {
match FeaturesOrRuled::extract(file).await? {
FeaturesOrRuled::Ruled(content_type) => Ok(FileType::Ruled(content_type)),
FeaturesOrRuled::Features(features) => self.identify_features::<E>(&features).await,
}
}
pub fn identify_features_sync(&mut self, features: &Features) -> Result<FileType> {
exec(self.identify_features::<SyncEnv>(features))
}
pub async fn identify_features_async(&mut self, features: &Features) -> Result<FileType> {
self.identify_features::<AsyncEnv>(features).await
}
async fn identify_features<E: Env>(&mut self, features: &Features) -> Result<FileType> {
let results = self.identify_features_batch::<E>(std::slice::from_ref(features)).await?;
let [result] = results.try_into().ok().unwrap();
Ok(result)
}
pub fn identify_features_batch_sync(&mut self, features: &[Features]) -> Result<Vec<FileType>> {
exec(self.identify_features_batch::<SyncEnv>(features))
}
pub async fn identify_features_batch_async(
&mut self, features: &[Features],
) -> Result<Vec<FileType>> {
self.identify_features_batch::<AsyncEnv>(features).await
}
async fn identify_features_batch<E: Env>(
&mut self, features: &[Features],
) -> Result<Vec<FileType>> {
if features.is_empty() {
return Ok(Vec::new());
}
let features_size = crate::model::CONFIG.features_size();
let input = Array2::from_shape_vec(
[features.len(), features_size],
features.iter().flat_map(|x| &x.0).cloned().collect(),
)?;
let mut output = E::ort_session_run(&mut self.session, input).await?;
let output = output.remove("target_label").unwrap();
let output = output.try_extract_array()?;
Ok(FileType::convert(output))
}
}