use pyo3::exceptions::{PyKeyError, PyRuntimeError, PyTypeError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use lerna::omegaconf::{
AnyNode, ConfigValue, DictConfig, ListConfig, Node, NodeContent, NodeType, NodeValue, OmegaConf,
};
use crate::omegaconf::listconfig::PyListConfig;
#[pyclass(name = "DictConfig")]
#[derive(Debug)]
pub struct PyDictConfig {
pub inner: Arc<RwLock<DictConfig>>,
}
#[pymethods]
impl PyDictConfig {
#[new]
#[pyo3(signature = (content=None))]
pub fn new(content: Option<&Bound<PyDict>>) -> PyResult<Self> {
let cfg = match content {
Some(dict) => {
let mut map = HashMap::new();
for (key, value) in dict.iter() {
let key_str: String = key.extract()?;
let config_value = py_to_config_value(&value)?;
map.insert(key_str, config_value);
}
OmegaConf::create_dict(map)
}
None => DictConfig::new(),
};
Ok(Self {
inner: Arc::new(RwLock::new(cfg)),
})
}
fn __getitem__(&self, key: &str) -> PyResult<Py<PyAny>> {
Python::attach(|py| {
let cfg = self.inner.read().map_err(|e| {
PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e))
})?;
match cfg.get(key) {
Some(node) => node_arc_to_py(node, py),
None => Err(PyKeyError::new_err(format!("Key not found: {}", key))),
}
})
}
fn __setitem__(&mut self, key: &str, value: &Bound<PyAny>) -> PyResult<()> {
let config_value = py_to_config_value(value)?;
let node = config_value_to_node(config_value);
let mut cfg = self
.inner
.write()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
cfg.set(key, node)
.map_err(|e| PyRuntimeError::new_err(format!("{}", e)))
}
fn __delitem__(&mut self, key: &str) -> PyResult<()> {
let mut cfg = self
.inner
.write()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
let _ = cfg.remove(key);
Ok(())
}
fn __contains__(&self, key: &str) -> PyResult<bool> {
let cfg = self
.inner
.read()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
Ok(cfg.contains_key(key))
}
fn __len__(&self) -> PyResult<usize> {
let cfg = self
.inner
.read()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
Ok(cfg.len_internal())
}
fn keys(&self) -> PyResult<Vec<String>> {
let cfg = self
.inner
.read()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
Ok(cfg.keys_iter().cloned().collect())
}
fn values(&self, py: Python) -> PyResult<Vec<Py<PyAny>>> {
let cfg = self
.inner
.read()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
let mut result = Vec::new();
for key in cfg.keys_iter() {
if let Some(node) = cfg.get(key) {
result.push(node_arc_to_py(node, py)?);
}
}
Ok(result)
}
fn items(&self, py: Python) -> PyResult<Vec<(String, Py<PyAny>)>> {
let cfg = self
.inner
.read()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
let mut result = Vec::new();
for key in cfg.keys_iter() {
if let Some(node) = cfg.get(key) {
let value = node_arc_to_py(node, py)?;
result.push((key.clone(), value));
}
}
Ok(result)
}
#[pyo3(signature = (key, default=None))]
fn get(&self, py: Python, key: &str, default: Option<&Bound<PyAny>>) -> PyResult<Py<PyAny>> {
let cfg = self
.inner
.read()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
match cfg.get(key) {
Some(node) => node_arc_to_py(node, py),
None => match default {
Some(d) => Ok(d.clone().unbind()),
None => Ok(py.None()),
},
}
}
fn __repr__(&self) -> PyResult<String> {
let cfg = self
.inner
.read()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
let keys: Vec<String> = cfg.keys_iter().cloned().collect();
Ok(format!("DictConfig({{{}}})", keys.join(", ")))
}
fn _get_flag(&self, flag: &str) -> PyResult<Option<bool>> {
let cfg = self
.inner
.read()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
Ok(cfg.get_flag(flag))
}
fn _set_flag(&mut self, flag: &str, value: Option<bool>) -> PyResult<()> {
let mut cfg = self
.inner
.write()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
cfg.set_flag(flag, value);
Ok(())
}
fn _is_struct(&self) -> PyResult<bool> {
let cfg = self
.inner
.read()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
Ok(cfg.is_struct())
}
fn _is_readonly(&self) -> PyResult<bool> {
let cfg = self
.inner
.read()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock DictConfig: {}", e)))?;
Ok(cfg.is_readonly())
}
}
pub fn py_to_config_value(obj: &Bound<PyAny>) -> PyResult<ConfigValue> {
if obj.is_none() {
return Ok(ConfigValue::None);
}
if let Ok(s) = obj.extract::<String>() {
if s == "???" {
return Ok(ConfigValue::Missing);
} else if s.starts_with("${") && s.ends_with("}") {
return Ok(ConfigValue::Interpolation(s));
}
return Ok(ConfigValue::String(s));
}
if let Ok(b) = obj.extract::<bool>() {
return Ok(ConfigValue::Bool(b));
}
if let Ok(i) = obj.extract::<i64>() {
return Ok(ConfigValue::Int(i));
}
if let Ok(f) = obj.extract::<f64>() {
return Ok(ConfigValue::Float(f));
}
if let Ok(list) = obj.cast::<PyList>() {
let mut items = Vec::new();
for item in list.iter() {
items.push(py_to_config_value(&item)?);
}
return Ok(ConfigValue::List(items));
}
if let Ok(dict) = obj.cast::<PyDict>() {
let mut map = HashMap::new();
for (key, value) in dict.iter() {
let key_str: String = key.extract()?;
map.insert(key_str, py_to_config_value(&value)?);
}
return Ok(ConfigValue::Dict(map));
}
Err(PyTypeError::new_err(format!(
"Unsupported type for config value: {}",
obj.get_type().name()?
)))
}
pub fn node_arc_to_py(node: Arc<RwLock<dyn Node>>, py: Python) -> PyResult<Py<PyAny>> {
let guard = node
.read()
.map_err(|e| PyRuntimeError::new_err(format!("Failed to lock node: {}", e)))?;
match guard.node_type() {
NodeType::Dict => {
if let Some(dict_config) = guard.as_any().downcast_ref::<DictConfig>() {
let cloned = dict_config.clone();
drop(guard); let py_dict = PyDictConfig {
inner: Arc::new(RwLock::new(cloned)),
};
Ok(py_dict.into_pyobject(py)?.into_any().unbind())
} else {
Err(PyRuntimeError::new_err("Failed to downcast to DictConfig"))
}
}
NodeType::List => {
if let Some(list_config) = guard.as_any().downcast_ref::<ListConfig>() {
let cloned = list_config.clone();
drop(guard); let py_list = PyListConfig {
inner: Arc::new(RwLock::new(cloned)),
};
Ok(py_list.into_pyobject(py)?.into_any().unbind())
} else {
Err(PyRuntimeError::new_err("Failed to downcast to ListConfig"))
}
}
NodeType::Value => {
let content = guard.content();
match content {
NodeContent::None => Ok(py.None()),
NodeContent::Missing => Ok("???".into_pyobject(py)?.into_any().unbind()),
NodeContent::Interpolation(s) => Ok(s.into_pyobject(py)?.into_any().unbind()),
NodeContent::Value(v) => match v {
NodeValue::Bool(b) => {
Ok((*b).into_pyobject(py)?.to_owned().into_any().unbind())
}
NodeValue::Int(i) => Ok((*i).into_pyobject(py)?.to_owned().into_any().unbind()),
NodeValue::Float(f) => {
Ok((*f).into_pyobject(py)?.to_owned().into_any().unbind())
}
NodeValue::String(s) => Ok(s.into_pyobject(py)?.into_any().unbind()),
NodeValue::Bytes(b) => Ok(b.clone().into_pyobject(py)?.into_any().unbind()),
},
}
}
}
}
fn config_value_to_node(value: ConfigValue) -> Arc<RwLock<dyn Node>> {
match value {
ConfigValue::None => Arc::new(RwLock::new(AnyNode::new(None))),
ConfigValue::Missing => Arc::new(RwLock::new(AnyNode::missing())),
ConfigValue::Bool(v) => Arc::new(RwLock::new(AnyNode::with_value(NodeValue::Bool(v)))),
ConfigValue::Int(v) => Arc::new(RwLock::new(AnyNode::with_value(NodeValue::Int(v)))),
ConfigValue::Float(v) => Arc::new(RwLock::new(AnyNode::with_value(NodeValue::Float(v)))),
ConfigValue::String(v) => Arc::new(RwLock::new(AnyNode::with_value(NodeValue::String(v)))),
ConfigValue::Bytes(v) => Arc::new(RwLock::new(AnyNode::with_value(NodeValue::Bytes(v)))),
ConfigValue::List(v) => {
let child = OmegaConf::create_list(v);
Arc::new(RwLock::new(child))
}
ConfigValue::Dict(v) => {
let child = OmegaConf::create_dict(v);
Arc::new(RwLock::new(child))
}
ConfigValue::Interpolation(v) => Arc::new(RwLock::new(AnyNode::interpolation(v))),
}
}