use crate::task::AngrealGroup;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::wrap_pyfunction;
use version_compare::{compare, Cmp};
#[pyfunction]
pub fn required_version(specifier: &str) -> PyResult<()> {
let current_version = env!("CARGO_PKG_VERSION");
let is_compatible = if let Some(required) = specifier.strip_prefix(">=") {
compare(current_version, required) != Ok(Cmp::Lt)
} else if let Some(required) = specifier.strip_prefix("<=") {
compare(current_version, required) != Ok(Cmp::Gt)
} else if let Some(required) = specifier.strip_prefix("==") {
compare(current_version, required) == Ok(Cmp::Eq)
} else if let Some(required) = specifier.strip_prefix("!=") {
compare(current_version, required) != Ok(Cmp::Eq)
} else if let Some(required) = specifier.strip_prefix(">") {
compare(current_version, required) == Ok(Cmp::Gt)
} else if let Some(required) = specifier.strip_prefix("<") {
compare(current_version, required) == Ok(Cmp::Lt)
} else {
compare(current_version, specifier) == Ok(Cmp::Eq)
};
if !is_compatible {
return Err(PyErr::new::<pyo3::exceptions::PyEnvironmentError, _>(
format!(
"You require angreal {} but have {} installed.",
specifier, current_version
),
));
}
Ok(())
}
pub fn register_decorators(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(required_version, m)?)?;
m.add_function(wrap_pyfunction!(group, m)?)?;
m.add_function(wrap_pyfunction!(command_group, m)?)?;
m.add_function(wrap_pyfunction!(command, m)?)?;
m.add_function(wrap_pyfunction!(argument, m)?)?;
m.add_class::<GroupDecorator>()?;
m.add_class::<CommandDecorator>()?;
m.add_class::<ArgumentDecorator>()?;
Ok(())
}
#[pyclass]
#[derive(Clone)]
pub struct GroupDecorator {
name: String,
about: Option<String>,
}
#[pymethods]
impl GroupDecorator {
#[pyo3(signature = (func = None,))]
fn __call__(&self, func: Option<Py<PyAny>>) -> PyResult<Py<PyAny>> {
Python::attach(|py| {
match func {
Some(func) => {
let has_command = func.getattr(py, "__command").is_ok();
if !has_command {
return Err(PyErr::new::<pyo3::exceptions::PySyntaxError, _>(
"The group decorator must be applied before a command.",
));
}
let group_class = py.get_type::<AngrealGroup>();
let group = group_class.call1((&self.name, self.about.as_deref()))?;
let command = func.getattr(py, "__command")?;
command.call_method1(py, "add_group", (group,))?;
Ok(func)
}
None => {
Ok(Py::new(py, self.clone())?.into_any())
}
}
})
}
}
#[pyfunction]
#[pyo3(signature = (**kwargs))]
pub fn group(kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<GroupDecorator> {
let name = kwargs
.and_then(|d| d.get_item("name").ok().flatten())
.map(|v| v.extract::<String>())
.transpose()?
.unwrap_or_else(|| "default".to_string());
let about = kwargs
.and_then(|d| d.get_item("about").ok().flatten())
.map(|v| v.extract::<String>())
.transpose()?;
Ok(GroupDecorator { name, about })
}
#[pyfunction]
#[pyo3(signature = (name, about = None))]
pub fn command_group(name: &str, about: Option<&str>) -> PyResult<GroupDecorator> {
Ok(GroupDecorator {
name: name.to_string(),
about: about.map(|s| s.to_string()),
})
}
#[pyclass]
pub struct CommandDecorator {
name: Option<String>,
about: Option<String>,
long_about: Option<String>,
tool: Option<crate::task::ToolDescription>,
}
#[pymethods]
impl CommandDecorator {
#[pyo3(signature = (func,))]
fn __call__(&self, func: Py<PyAny>) -> PyResult<Py<PyAny>> {
Python::attach(|py| {
let name = match &self.name {
Some(name) => name.clone(),
None => {
func.getattr(py, "__name__")?
.extract::<String>(py)?
.to_lowercase()
.replace("_", "-")
}
};
if func.getattr(py, "__arguments").is_err() {
func.setattr(py, "__arguments", py.None())?;
}
let tool_py = match &self.tool {
Some(tool) => {
let tool_class = py.get_type::<crate::task::ToolDescription>();
let kwargs = pyo3::types::PyDict::new(py);
kwargs.set_item("risk_level", tool.risk_level.as_str())?;
tool_class
.call((&tool.description,), Some(&kwargs))?
.into_any()
.unbind()
}
None => py.None(),
};
let command_class = py.get_type::<crate::task::AngrealCommand>();
let command = command_class.call1((
&name,
func.clone_ref(py),
self.about.as_deref(),
self.long_about.as_deref(),
py.None(), tool_py,
))?;
func.setattr(py, "__command", command)?;
let arguments = func.getattr(py, "__arguments")?;
if !arguments.is_none(py) {
if let Ok(args_list) = arguments.extract::<Vec<Py<PyAny>>>(py) {
for arg_kwargs_obj in args_list {
let bound_arg = arg_kwargs_obj.bind(py);
if let Ok(kwargs_dict) = bound_arg.cast::<pyo3::types::PyDict>() {
let arg_class = py.get_type::<crate::task::AngrealArg>();
let arg_name = kwargs_dict
.get_item("name")
.ok()
.flatten()
.map(|v| v.extract::<String>())
.transpose()?
.unwrap_or_else(|| "default".to_string());
let arg_kwargs = pyo3::types::PyDict::new(py);
arg_kwargs.set_item("name", &arg_name)?;
arg_kwargs.set_item("command_name", &name)?;
for (key, value) in kwargs_dict.iter() {
let key_str = key.extract::<String>()?;
match key_str.as_str() {
"name" => arg_kwargs.set_item("name", value)?,
"short" => {
if let Ok(s) = value.extract::<String>() {
if let Some(c) = s.chars().next() {
arg_kwargs.set_item("short", c)?;
} else {
arg_kwargs.set_item("short", py.None())?;
}
} else {
arg_kwargs.set_item("short", py.None())?;
}
}
"long" => arg_kwargs.set_item("long", value)?,
"help" => arg_kwargs.set_item("help", value)?,
"long_help" => arg_kwargs.set_item("long_help", value)?,
"required" => arg_kwargs.set_item("required", value)?,
"takes_value" => arg_kwargs.set_item("takes_value", value)?,
"is_flag" => arg_kwargs.set_item("is_flag", value)?,
"default_value" => {
arg_kwargs.set_item("default_value", value)?
}
"multiple_values" => {
arg_kwargs.set_item("multiple_values", value)?
}
"number_of_values" => {
arg_kwargs.set_item("number_of_values", value)?
}
"max_values" => arg_kwargs.set_item("max_values", value)?,
"min_values" => arg_kwargs.set_item("min_values", value)?,
"require_equals" => {
arg_kwargs.set_item("require_equals", value)?
}
"python_type" => arg_kwargs.set_item("python_type", value)?,
_ => {} }
}
if !arg_kwargs.contains("default_value")? {
arg_kwargs.set_item("default_value", py.None())?;
}
if !arg_kwargs.contains("is_flag")? {
arg_kwargs.set_item("is_flag", py.None())?;
}
if !arg_kwargs.contains("require_equals")? {
arg_kwargs.set_item("require_equals", py.None())?;
}
if !arg_kwargs.contains("multiple_values")? {
arg_kwargs.set_item("multiple_values", py.None())?;
}
if !arg_kwargs.contains("number_of_values")? {
arg_kwargs.set_item("number_of_values", py.None())?;
}
if !arg_kwargs.contains("max_values")? {
arg_kwargs.set_item("max_values", py.None())?;
}
if !arg_kwargs.contains("min_values")? {
arg_kwargs.set_item("min_values", py.None())?;
}
if !arg_kwargs.contains("short")? {
arg_kwargs.set_item("short", py.None())?;
}
if !arg_kwargs.contains("long")? {
arg_kwargs.set_item("long", py.None())?;
}
if !arg_kwargs.contains("long_help")? {
arg_kwargs.set_item("long_help", py.None())?;
}
if !arg_kwargs.contains("help")? {
arg_kwargs.set_item("help", py.None())?;
}
if !arg_kwargs.contains("required")? {
arg_kwargs.set_item("required", py.None())?;
}
if !arg_kwargs.contains("takes_value")? {
arg_kwargs.set_item("takes_value", py.None())?;
}
if !arg_kwargs.contains("python_type")? {
arg_kwargs.set_item("python_type", py.None())?;
}
let _arg = arg_class.call((), Some(&arg_kwargs))?;
}
}
}
}
Ok(func)
})
}
}
#[pyfunction]
#[pyo3(signature = (**kwargs))]
pub fn command(kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<CommandDecorator> {
let name = kwargs
.and_then(|d| d.get_item("name").ok().flatten())
.map(|v| v.extract::<String>())
.transpose()?;
let about = kwargs
.and_then(|d| d.get_item("about").ok().flatten())
.map(|v| v.extract::<String>())
.transpose()?;
let long_about = kwargs
.and_then(|d| d.get_item("long_about").ok().flatten())
.map(|v| v.extract::<String>())
.transpose()?;
let tool = kwargs
.and_then(|d| d.get_item("tool").ok().flatten())
.map(|v| {
v.cast::<crate::task::ToolDescription>()
.map(|bound| bound.borrow().clone())
.map_err(pyo3::PyErr::from)
})
.transpose()?;
Ok(CommandDecorator {
name,
about,
long_about,
tool,
})
}
#[pyfunction]
#[pyo3(signature = (**kwargs))]
pub fn argument(kwargs: Option<&Bound<'_, PyDict>>) -> PyResult<ArgumentDecorator> {
let name = kwargs
.and_then(|d| d.get_item("name").ok().flatten())
.map(|v| v.extract::<String>())
.transpose()?
.unwrap_or_else(|| "default".to_string());
Ok(ArgumentDecorator {
name,
kwargs_dict: kwargs.map(|d| d.clone().into_any().unbind()),
})
}
#[pyclass]
pub struct ArgumentDecorator {
#[allow(dead_code)]
name: String,
kwargs_dict: Option<Py<PyAny>>,
}
impl Clone for ArgumentDecorator {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
kwargs_dict: self
.kwargs_dict
.as_ref()
.map(|py_obj| Python::attach(|py| py_obj.clone_ref(py))),
}
}
}
#[pymethods]
impl ArgumentDecorator {
#[pyo3(signature = (func,))]
fn __call__(&self, func: Py<PyAny>) -> PyResult<Py<PyAny>> {
Python::attach(|py| {
let mut arguments = if let Ok(args) = func.getattr(py, "__arguments") {
if args.is_none(py) {
Vec::new()
} else {
args.extract::<Vec<Py<PyAny>>>(py)
.unwrap_or_else(|_| Vec::new())
}
} else {
Vec::new()
};
if let Some(kwargs_obj) = &self.kwargs_dict {
arguments.push(kwargs_obj.clone_ref(py));
}
use pyo3::types::PyList;
func.setattr(py, "__arguments", PyList::new(py, &arguments)?)?;
Ok(func)
})
}
}