use crate::execution_providers::OpPlacement;
use crate::memory::SizeClassPool;
use crate::tensor::Tensor;
use oxionnx_core::OperatorRegistry;
use std::collections::HashMap;
use std::sync::Mutex;
mod accessors;
mod builder;
#[cfg(feature = "gpu")]
mod gpu_dispatch;
mod loading;
pub(crate) mod mixed_precision;
mod run;
mod tests;
pub mod types;
pub use types::{ModelInfo, ModelMetadata, NodeProfile, OptLevel};
#[cfg(feature = "gpu")]
pub use gpu_dispatch::GpuExecutionProvider;
pub use builder::SessionBuilder;
pub struct Session {
pub(crate) sorted_nodes: Vec<crate::graph::Node>,
pub(crate) weights: HashMap<String, Tensor>,
pub(crate) input_names: Vec<String>,
pub(crate) output_names: Vec<String>,
pub(crate) input_infos: Vec<oxionnx_core::TensorInfo>,
pub(crate) output_infos: Vec<oxionnx_core::TensorInfo>,
pub(crate) metadata: ModelMetadata,
pub(crate) registry: OperatorRegistry,
pub(crate) profiling_data: Option<Mutex<Vec<NodeProfile>>>,
pub(crate) pool: Option<Mutex<SizeClassPool>>,
pub(crate) shape_cache: Option<HashMap<String, Vec<usize>>>,
pub(crate) parallel: bool,
pub(crate) mixed_precision: bool,
pub(crate) op_placement: OpPlacement,
pub(crate) dynamic_dims: Mutex<HashMap<String, usize>>,
pub(crate) resolved_shapes: Mutex<HashMap<String, Vec<usize>>>,
#[cfg(not(target_arch = "wasm32"))]
pub(crate) thread_pool: Option<rayon::ThreadPool>,
#[cfg(feature = "gpu")]
pub(crate) gpu: Option<crate::gpu::GpuContext>,
#[cfg(feature = "cuda")]
pub(crate) cuda: Option<oxionnx_cuda::CudaContext>,
#[cfg(feature = "directml")]
pub(crate) dml: Option<oxionnx_directml::DirectMLContext>,
}