use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use lerna::omegaconf::Node;
use super::dictconfig::PyDictConfig;
#[pyclass(name = "open_dict")]
pub struct PyOpenDict {
config: Py<PyDictConfig>,
previous_struct: Option<bool>,
}
#[pymethods]
impl PyOpenDict {
#[new]
fn new(config: Py<PyDictConfig>) -> Self {
Self {
config,
previous_struct: None,
}
}
fn __enter__<'py>(mut slf: PyRefMut<'py, Self>, py: Python<'py>) -> PyResult<Py<PyDictConfig>> {
{
let config_ref = slf.config.bind(py);
let config = config_ref.borrow();
let mut inner = config.inner.write().map_err(|e| {
PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e))
})?;
slf.previous_struct = inner.get_flag("struct");
inner.set_flag("struct", Some(false));
}
Ok(slf.config.clone_ref(py))
}
#[pyo3(signature = (_exc_type=None, _exc_val=None, _exc_tb=None))]
fn __exit__(
&mut self,
py: Python,
_exc_type: Option<&Bound<PyAny>>,
_exc_val: Option<&Bound<PyAny>>,
_exc_tb: Option<&Bound<PyAny>>,
) -> PyResult<bool> {
let config_ref = self.config.bind(py);
let config = config_ref.borrow_mut();
let mut inner = config
.inner
.write()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
inner.set_flag("struct", self.previous_struct);
Ok(false) }
}
#[pyclass(name = "read_write")]
pub struct PyReadWrite {
config: Py<PyDictConfig>,
previous_readonly: Option<bool>,
}
#[pymethods]
impl PyReadWrite {
#[new]
fn new(config: Py<PyDictConfig>) -> Self {
Self {
config,
previous_readonly: None,
}
}
fn __enter__<'py>(mut slf: PyRefMut<'py, Self>, py: Python<'py>) -> PyResult<Py<PyDictConfig>> {
{
let config_ref = slf.config.bind(py);
let config = config_ref.borrow();
let mut inner = config.inner.write().map_err(|e| {
PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e))
})?;
slf.previous_readonly = inner.get_flag("readonly");
inner.set_flag("readonly", Some(false));
}
Ok(slf.config.clone_ref(py))
}
#[pyo3(signature = (_exc_type=None, _exc_val=None, _exc_tb=None))]
fn __exit__(
&mut self,
py: Python,
_exc_type: Option<&Bound<PyAny>>,
_exc_val: Option<&Bound<PyAny>>,
_exc_tb: Option<&Bound<PyAny>>,
) -> PyResult<bool> {
let config_ref = self.config.bind(py);
let config = config_ref.borrow_mut();
let mut inner = config
.inner
.write()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
inner.set_flag("readonly", self.previous_readonly);
Ok(false) }
}
#[pyclass(name = "flag_override")]
pub struct PyFlagOverride {
config: Py<PyDictConfig>,
flag_name: String,
previous_value: Option<bool>,
}
#[pymethods]
impl PyFlagOverride {
#[new]
fn new(
config: Py<PyDictConfig>,
flag_name: String,
new_value: Option<bool>,
py: Python,
) -> PyResult<Self> {
let config_ref = config.bind(py);
let config_borrow = config_ref.borrow();
let inner = config_borrow
.inner
.read()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
let previous_value = inner.get_flag(&flag_name);
drop(inner);
drop(config_borrow);
let config_mut = config_ref.borrow_mut();
let mut inner = config_mut
.inner
.write()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
inner.set_flag(&flag_name, new_value);
Ok(Self {
config,
flag_name,
previous_value,
})
}
fn __enter__<'py>(slf: PyRef<'py, Self>, py: Python<'py>) -> Py<PyDictConfig> {
slf.config.clone_ref(py)
}
#[pyo3(signature = (_exc_type=None, _exc_val=None, _exc_tb=None))]
fn __exit__(
&mut self,
py: Python,
_exc_type: Option<&Bound<PyAny>>,
_exc_val: Option<&Bound<PyAny>>,
_exc_tb: Option<&Bound<PyAny>>,
) -> PyResult<bool> {
let config_ref = self.config.bind(py);
let config = config_ref.borrow_mut();
let mut inner = config
.inner
.write()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
inner.set_flag(&self.flag_name, self.previous_value);
Ok(false) }
}