tract_gpu/
session_handler.rs1use crate::memory::DeviceMemSchema;
2use crate::memory::DeviceMemoryPool;
3use crate::tensor::DeviceTensor;
4use tract_core::internal::*;
5
6#[derive(Debug, Clone)]
7pub struct DeviceSessionHandler {
8 pub mem_schema: DeviceMemSchema,
9}
10
11impl DeviceSessionHandler {
12 pub fn from_plan(plan: &TypedSimplePlan, memory_hint: &SymbolValues) -> TractResult<Self> {
13 let mem_schema =
14 DeviceMemSchema::build(plan.model(), plan.order_without_consts(), memory_hint)?;
15 Ok(Self { mem_schema })
16 }
17}
18
19impl SessionStateHandler for DeviceSessionHandler {
20 fn before_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()> {
21 let resolved_mem_schema = self.mem_schema.resolve(&session_state.resolved_symbols)?;
22 let memory_pool = DeviceMemoryPool::from_schema(resolved_mem_schema)?;
23
24 session_state.scratch_extensions.insert(memory_pool);
25 ensure!(session_state.scratch_extensions.get::<DeviceMemoryPool>().is_some());
26 Ok(())
27 }
28
29 fn after_plan_eval(&self, session_state: &mut TurnState) -> TractResult<()> {
30 session_state.scratch_extensions.remove::<DeviceMemoryPool>();
31 Ok(())
32 }
33}
34
35pub fn make_tensor_for_node(
36 session: &TurnState,
37 node_id: usize,
38 dt: DatumType,
39 shape: &[usize],
40) -> TractResult<DeviceTensor> {
41 session
42 .scratch_extensions
43 .get::<DeviceMemoryPool>()
44 .map(|mem| mem.tensor_for_node(node_id, dt, shape))
45 .unwrap_or_else(|| DeviceTensor::uninitialized_dt(dt, shape))
46}
47
48pub fn make_scalar_opaque_tensor_for_node(
49 session: &TurnState,
50 node_id: usize,
51 opaque_fact: Box<dyn OpaqueFact>,
52) -> TractResult<DeviceTensor> {
53 match session.scratch_extensions.get::<DeviceMemoryPool>() {
54 Some(mem) => mem.scalar_opaque_tensor_for_node(node_id, opaque_fact),
55 None => DeviceTensor::uninitialized_opaque(opaque_fact),
56 }
57}