Skip to main content

tract_gpu/
session_handler.rs

1use crate::memory::DeviceMemSchema;
2use crate::memory::DeviceMemoryPool;
3use std::borrow::Borrow;
4use tract_core::internal::*;
5
6#[derive(Debug, Clone)]
7pub struct DeviceSessionHandler {
8    pub mem_schema: DeviceMemSchema,
9}
10
11impl DeviceSessionHandler {
12    pub fn from_plan<M, P>(plan: P, memory_hint: &SymbolValues) -> TractResult<Self>
13    where
14        M: Borrow<Graph<TypedFact, Box<dyn TypedOp>>>,
15        P: Borrow<TypedSimplePlan<M>> + Clone,
16    {
17        let mem_schema = DeviceMemSchema::build(
18            plan.borrow().model(),
19            plan.borrow().order_without_consts(),
20            memory_hint,
21        )?;
22        Ok(Self { mem_schema })
23    }
24}
25
26impl SessionStateHandler for DeviceSessionHandler {
27    fn before_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()> {
28        let resolved_mem_schema = self.mem_schema.resolve(&session_state.resolved_symbols)?;
29        let memory_pool = DeviceMemoryPool::from_schema(resolved_mem_schema)?;
30
31        session_state.scratch_extensions.insert(memory_pool);
32        ensure!(session_state.scratch_extensions.get::<DeviceMemoryPool>().is_some());
33        Ok(())
34    }
35
36    fn after_plan_eval(&self, session_state: &mut SessionState) -> TractResult<()> {
37        session_state.scratch_extensions.remove::<DeviceMemoryPool>();
38        Ok(())
39    }
40}
41
42pub fn get_device_mem_pool(session: &SessionState) -> Option<&DeviceMemoryPool> {
43    session.scratch_extensions.get::<DeviceMemoryPool>()
44}