Skip to main content

tract_gpu/
session_handler.rs

1use crate::memory::DeviceMemSchema;
2use crate::memory::DeviceMemoryPool;
3use crate::tensor::DeviceTensor;
4use tract_core::internal::*;
5
6#[derive(Debug, Clone)]
7pub struct DeviceSessionHandler {
8    pub mem_schema: DeviceMemSchema,
9}
10
11impl DeviceSessionHandler {
12    pub fn from_plan(plan: &TypedSimplePlan, memory_hint: &SymbolValues) -> TractResult<Self> {
13        let mem_schema =
14            DeviceMemSchema::build(plan.model(), plan.order_without_consts(), memory_hint)?;
15        Ok(Self { mem_schema })
16    }
17}
18
19impl SessionStateHandler for DeviceSessionHandler {
20    fn before_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()> {
21        let resolved_mem_schema = self.mem_schema.resolve(&session_state.resolved_symbols)?;
22        let memory_pool = DeviceMemoryPool::from_schema(resolved_mem_schema)?;
23
24        session_state.scratch_extensions.insert(memory_pool);
25        ensure!(session_state.scratch_extensions.get::<DeviceMemoryPool>().is_some());
26        Ok(())
27    }
28
29    fn after_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()> {
30        session_state.scratch_extensions.remove::<DeviceMemoryPool>();
31        Ok(())
32    }
33}
34
35pub fn make_tensor_for_node(
36    session: &TurnState,
37    node_id: usize,
38    dt: DatumType,
39    shape: &[usize],
40) -> TractResult<DeviceTensor> {
41    session
42        .scratch_extensions
43        .get::<DeviceMemoryPool>()
44        .map(|mem| mem.tensor_for_node(node_id, dt, shape))
45        .unwrap_or_else(|| DeviceTensor::uninitialized_dt(dt, shape))
46}
47
48pub fn make_scalar_opaque_tensor_for_node(
49    session: &TurnState,
50    node_id: usize,
51    opaque_fact: Box<dyn OpaqueFact>,
52) -> TractResult<DeviceTensor> {
53    match session.scratch_extensions.get::<DeviceMemoryPool>() {
54        Some(mem) => mem.scalar_opaque_tensor_for_node(node_id, opaque_fact),
55        None => DeviceTensor::uninitialized_opaque(opaque_fact),
56    }
57}