tract_gpu/
session_handler.rs1use crate::memory::DeviceMemSchema;
2use crate::memory::DeviceMemoryPool;
3use std::borrow::Borrow;
4use tract_core::internal::*;
5
6#[derive(Debug, Clone)]
7pub struct DeviceSessionHandler {
8 pub mem_schema: DeviceMemSchema,
9}
10
11impl DeviceSessionHandler {
12 pub fn from_plan<M, P>(plan: P, memory_hint: &SymbolValues) -> TractResult<Self>
13 where
14 M: Borrow<Graph<TypedFact, Box<dyn TypedOp>>>,
15 P: Borrow<TypedSimplePlan<M>> + Clone,
16 {
17 let mem_schema = DeviceMemSchema::build(
18 plan.borrow().model(),
19 plan.borrow().order_without_consts(),
20 memory_hint,
21 )?;
22 Ok(Self { mem_schema })
23 }
24}
25
26impl SessionStateHandler for DeviceSessionHandler {
27 fn before_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()> {
28 let resolved_mem_schema = self.mem_schema.resolve(&session_state.resolved_symbols)?;
29 let memory_pool = DeviceMemoryPool::from_schema(resolved_mem_schema)?;
30
31 session_state.scratch_extensions.insert(memory_pool);
32 ensure!(session_state.scratch_extensions.get::<DeviceMemoryPool>().is_some());
33 Ok(())
34 }
35
36 fn after_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()> {
37 session_state.scratch_extensions.remove::<DeviceMemoryPool>();
38 Ok(())
39 }
40}
41
42pub fn get_device_mem_pool(session: &SessionState) -> Option<&DeviceMemoryPool> {
43 session.scratch_extensions.get::<DeviceMemoryPool>()
44}