use crate::device::context::{DeviceContext, DEVICE_MANAGER};
use crate::{Device, Result};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
pub struct Context {
default_device: Device,
device_contexts: RwLock<HashMap<Device, Arc<dyn DeviceContext>>>,
attributes: RwLock<HashMap<String, String>>,
eager_mode: bool,
profiling_enabled: bool,
}
impl Context {
pub fn new() -> Result<Self> {
Ok(Self {
default_device: Device::Cpu,
device_contexts: RwLock::new(HashMap::new()),
attributes: RwLock::new(HashMap::new()),
eager_mode: true,
profiling_enabled: false,
})
}
pub fn with_device(device: Device) -> Result<Self> {
let mut ctx = Self::new()?;
ctx.default_device = device;
Ok(ctx)
}
pub fn default_device(&self) -> Device {
self.default_device
}
pub fn set_default_device(&mut self, device: Device) {
self.default_device = device;
}
pub fn is_eager(&self) -> bool {
self.eager_mode
}
pub fn set_eager_mode(&mut self, eager: bool) {
self.eager_mode = eager;
}
pub fn set_profiling(&mut self, enabled: bool) {
self.profiling_enabled = enabled;
}
pub fn get_device_context(&self, device: &Device) -> Result<Arc<dyn DeviceContext>> {
{
let contexts = self
.device_contexts
.read()
.expect("read lock should not be poisoned");
if let Some(ctx) = contexts.get(device) {
return Ok(Arc::clone(ctx));
}
}
let ctx = DEVICE_MANAGER.get_context(device)?;
{
let mut contexts = self
.device_contexts
.write()
.expect("write lock should not be poisoned");
contexts.insert(*device, Arc::clone(&ctx));
}
Ok(ctx)
}
pub fn set_attribute(&self, key: String, value: String) {
let mut attrs = self
.attributes
.write()
.expect("write lock should not be poisoned");
attrs.insert(key, value);
}
pub fn get_attribute(&self, key: &str) -> Option<String> {
let attrs = self
.attributes
.read()
.expect("read lock should not be poisoned");
attrs.get(key).cloned()
}
}
lazy_static::lazy_static! {
static ref GLOBAL_CONTEXT: RwLock<Option<Arc<Context>>> = RwLock::new(None);
}
pub fn get_context() -> Result<Arc<Context>> {
let ctx_opt = GLOBAL_CONTEXT
.read()
.expect("read lock should not be poisoned");
if let Some(ctx) = ctx_opt.as_ref() {
Ok(Arc::clone(ctx))
} else {
drop(ctx_opt);
let ctx = Arc::new(Context::new()?);
let mut ctx_opt = GLOBAL_CONTEXT
.write()
.expect("write lock should not be poisoned");
*ctx_opt = Some(Arc::clone(&ctx));
Ok(ctx)
}
}
pub fn set_context(ctx: Arc<Context>) {
let mut ctx_opt = GLOBAL_CONTEXT
.write()
.expect("write lock should not be poisoned");
*ctx_opt = Some(ctx);
}
pub struct DeviceScope {
previous_device: Device,
context: Arc<Context>,
}
impl DeviceScope {
pub fn new(device: Device) -> Result<Self> {
let ctx = get_context()?;
let previous = ctx.default_device();
let mut new_ctx = (*ctx).clone();
new_ctx.set_default_device(device);
set_context(Arc::new(new_ctx));
Ok(Self {
previous_device: previous,
context: ctx,
})
}
}
impl Drop for DeviceScope {
fn drop(&mut self) {
let mut restored_ctx = (*self.context).clone();
restored_ctx.set_default_device(self.previous_device);
set_context(Arc::new(restored_ctx));
}
}
impl Clone for Context {
fn clone(&self) -> Self {
Self {
default_device: self.default_device,
device_contexts: RwLock::new(HashMap::new()), attributes: RwLock::new(
self.attributes
.read()
.expect("read lock should not be poisoned")
.clone(),
),
eager_mode: self.eager_mode,
profiling_enabled: self.profiling_enabled,
}
}
}
#[macro_export]
macro_rules! with_device {
($device:expr, $body:block) => {{
let _scope = $crate::context::DeviceScope::new($device)?;
$body
}};
}