use crate::memory::SizeClassPool;
use crate::tensor::Tensor;
use std::collections::HashMap;
use std::sync::Mutex;
pub(crate) struct SessionRunState {
tensors: HashMap<String, Tensor>,
}
impl SessionRunState {
pub(crate) fn with_capacity(capacity: usize) -> Self {
Self {
tensors: HashMap::with_capacity(capacity),
}
}
#[inline]
pub(crate) fn get(&self, name: &str) -> Option<&Tensor> {
self.tensors.get(name)
}
pub(crate) fn insert(
&mut self,
name: String,
tensor: Tensor,
pool: Option<&Mutex<SizeClassPool>>,
) {
if let Some(old) = self.tensors.remove(&name) {
release_to_pool(old, pool);
}
self.tensors.insert(name, tensor);
}
pub(crate) fn take(&mut self, name: &str) -> Option<Tensor> {
self.tensors.remove(name)
}
#[inline]
#[cfg_attr(
not(any(feature = "gpu", feature = "cuda", feature = "directml")),
allow(dead_code)
)]
pub(crate) fn as_map(&self) -> &HashMap<String, Tensor> {
&self.tensors
}
pub(crate) fn take_outputs(
mut self,
output_names: &[String],
pool: Option<&Mutex<SizeClassPool>>,
) -> HashMap<String, Tensor> {
let mut result: HashMap<String, Tensor> = HashMap::with_capacity(output_names.len());
for name in output_names {
if let Some(t) = self.tensors.remove(name) {
result.insert(name.clone(), t);
}
}
for (_name, tensor) in self.tensors.drain() {
release_to_pool(tensor, pool);
}
result
}
}
#[inline]
pub(super) fn release_to_pool(mut tensor: Tensor, pool: Option<&Mutex<SizeClassPool>>) {
if let Some(pool_mutex) = pool {
if let Ok(mut guard) = pool_mutex.lock() {
let buf = std::mem::take(&mut tensor.data);
if !buf.is_empty() {
guard.release(buf);
}
}
}
}
pub(super) struct TypedSessionRunState {
pub(super) slots: HashMap<String, oxionnx_core::TypedTensor>,
}
impl TypedSessionRunState {
pub(super) fn new() -> Self {
Self {
slots: HashMap::new(),
}
}
#[inline]
pub(super) fn get(&self, name: &str) -> Option<&oxionnx_core::TypedTensor> {
self.slots.get(name)
}
#[inline]
pub(super) fn insert(&mut self, name: String, tensor: oxionnx_core::TypedTensor) {
self.slots.insert(name, tensor);
}
pub(super) fn take_outputs(
&mut self,
output_names: &[String],
) -> HashMap<String, oxionnx_core::TypedTensor> {
output_names
.iter()
.filter_map(|n| self.slots.remove(n).map(|t| (n.clone(), t)))
.collect()
}
}