tract-gpu 0.22.1

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::memory::DeviceMemSchema;
use crate::memory::DeviceMemoryPool;
use crate::tensor::DeviceTensor;
use std::borrow::Borrow;
use tract_core::internal::*;

#[derive(Debug, Clone)]
pub struct DeviceSessionHandler {
    pub mem_schema: DeviceMemSchema,
}

impl DeviceSessionHandler {
    pub fn from_plan<M, P>(plan: P, memory_hint: &SymbolValues) -> TractResult<Self>
    where
        M: Borrow<Graph<TypedFact, Box<dyn TypedOp>>>,
        P: Borrow<TypedSimplePlan<M>> + Clone,
    {
        let mem_schema = DeviceMemSchema::build(
            plan.borrow().model(),
            plan.borrow().order_without_consts(),
            memory_hint,
        )?;
        Ok(Self { mem_schema })
    }
}

impl SessionStateHandler for DeviceSessionHandler {
    fn before_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()> {
        let resolved_mem_schema = self.mem_schema.resolve(&session_state.resolved_symbols)?;
        let memory_pool = DeviceMemoryPool::from_schema(resolved_mem_schema)?;

        session_state.scratch_extensions.insert(memory_pool);
        ensure!(session_state.scratch_extensions.get::<DeviceMemoryPool>().is_some());
        Ok(())
    }

    fn after_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()> {
        session_state.scratch_extensions.remove::<DeviceMemoryPool>();
        Ok(())
    }
}

pub fn make_tensor_for_node(
    session: &SessionState,
    node_id: usize,
    dt: DatumType,
    shape: &[usize],
) -> TractResult<DeviceTensor> {
    session
        .scratch_extensions
        .get::<DeviceMemoryPool>()
        .map(|mem| mem.tensor_for_node(node_id, dt, shape))
        .unwrap_or_else(|| DeviceTensor::uninitialized_dt(dt, shape))
}