tract_gpu/
session_handler.rs1use crate::memory::DeviceMemSchema;
2use crate::memory::DeviceMemoryPool;
3use crate::tensor::DeviceTensor;
4use std::borrow::Borrow;
5use tract_core::internal::*;
6
7#[derive(Debug, Clone)]
8pub struct DeviceSessionHandler {
9 pub mem_schema: DeviceMemSchema,
10}
11
12impl DeviceSessionHandler {
13 pub fn from_plan<M, P>(plan: P, memory_hint: &SymbolValues) -> TractResult<Self>
14 where
15 M: Borrow<Graph<TypedFact, Box<dyn TypedOp>>>,
16 P: Borrow<TypedSimplePlan<M>> + Clone,
17 {
18 let mem_schema = DeviceMemSchema::build(
19 plan.borrow().model(),
20 plan.borrow().order_without_consts(),
21 memory_hint,
22 )?;
23 Ok(Self { mem_schema })
24 }
25}
26
27impl SessionStateHandler for DeviceSessionHandler {
28 fn before_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()> {
29 let resolved_mem_schema = self.mem_schema.resolve(&session_state.resolved_symbols)?;
30 let memory_pool = DeviceMemoryPool::from_schema(resolved_mem_schema)?;
31
32 session_state.scratch_extensions.insert(memory_pool);
33 ensure!(session_state.scratch_extensions.get::<DeviceMemoryPool>().is_some());
34 Ok(())
35 }
36
37 fn after_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()> {
38 session_state.scratch_extensions.remove::<DeviceMemoryPool>();
39 Ok(())
40 }
41}
42
43pub fn make_tensor_for_node(
44 session: &SessionState,
45 node_id: usize,
46 dt: DatumType,
47 shape: &[usize],
48) -> TractResult<DeviceTensor> {
49 session
50 .scratch_extensions
51 .get::<DeviceMemoryPool>()
52 .map(|mem| mem.tensor_for_node(node_id, dt, shape))
53 .unwrap_or_else(|| DeviceTensor::uninitialized_dt(dt, shape))
54}