use anyhow::{anyhow, Result};
use glob::glob;
use std::convert::TryInto;
use std::env;
use std::ops::Not;
use std::path::{Path, PathBuf};
use std::process::exit;
use std::vec::Vec;
use tera::Context;
use toml::{map::Map, Table, Value};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyModule};
use pyo3::PyResult;
use std::fs;
use std::fs::File;
use std::io::Write;
use tera::Tera;
use text_io::read;
use reqwest::{self};
use version_compare::Version;
use walkdir::WalkDir;
use log::{debug, error, info, warn};
use pythonize::pythonize;
use std::time::Duration;
pub fn context_to_map(ctx: Context) -> Map<String, Value> {
Map::try_from(ctx.into_json().as_object().unwrap().clone()).unwrap()
}
pub fn repl_context_from_toml(toml_path: PathBuf, take_input: bool) -> Context {
let defaults = extract_key_defaults(toml_path.clone()).unwrap();
let prompts = extract_prompts(toml_path.clone()).unwrap();
let validations = extract_validation_rules(toml_path).unwrap();
let mut context = Context::new();
for (k, v) in defaults.iter() {
let value = if v.is_str()
&& v.as_str().unwrap().starts_with("{{")
&& v.as_str().unwrap().contains("}}")
{
let temp_value = v.clone();
let rendered_value =
Tera::one_off(temp_value.as_str().unwrap(), &context, false).unwrap();
Value::from(rendered_value)
} else {
v.clone()
};
let input = if take_input {
let prompt_text = prompts
.get(k)
.and_then(|p| p.as_str())
.map(|p| format!("{} [{value}]", p))
.unwrap_or_else(|| format!("{k}? [{value}]"));
let mut valid_input = String::new();
let mut is_valid = false;
while !is_valid {
print!("{}: ", prompt_text);
valid_input = read!("{}\n");
if valid_input.trim().is_empty() {
break;
}
match crate::validation::validate_input(valid_input.trim(), k, &validations) {
Ok(_) => {
is_valid = true;
}
Err(err_msg) => {
println!("Invalid input: {}", err_msg);
is_valid = false;
}
}
}
valid_input
} else {
String::new()
};
if input.trim().is_empty() | take_input.not() {
if value.is_str() {
context.insert(k, &value.as_str().unwrap());
}
if value.is_integer() {
context.insert(k, &value.as_integer().unwrap());
}
if value.is_bool() {
context.insert(k, &value.as_bool().unwrap());
}
if value.is_float() {
context.insert(k, &value.as_float().unwrap());
}
} else {
if value.is_str() {
context.insert(k, &input.trim());
}
if value.is_integer() {
context.insert(
k,
&input.trim().parse::<i32>().unwrap_or_else(|_| {
debug!(
"Could not parse '{}' as integer for key '{}', using default.",
input.trim(),
k
);
let i64_val = value.as_integer().unwrap();
i64_val.try_into().unwrap_or_else(|_| {
debug!("Integer value too large for i32, truncating: {}", i64_val);
i64_val as i32
})
}),
);
}
if value.is_bool() {
context.insert(k, &input.trim());
}
if value.is_float() {
context.insert(
k,
&input.trim().parse::<f64>().unwrap_or_else(|_| {
debug!(
"Could not parse '{}' as float for key '{}', using default.",
input.trim(),
k
);
value.as_float().unwrap()
}),
);
}
}
}
context
}
fn is_templated_segment(segment: &str) -> bool {
segment.starts_with("{{") && segment.contains("}}")
}
fn in_place_root(src: &Path) -> String {
let mut roots: Vec<String> = Vec::new();
let entries = fs::read_dir(src).unwrap_or_else(|e| {
error!("Failed to read template directory {}: {}", src.display(), e);
exit(1);
});
for entry in entries.flatten() {
if entry.path().is_dir() {
let name = entry.file_name().to_string_lossy().to_string();
if is_templated_segment(&name) {
roots.push(name);
}
}
}
match roots.len() {
1 => roots.pop().unwrap(),
0 => {
error!(
"--in-place requires the template to have exactly one top-level templated directory (e.g. `{{{{ project_name }}}}`), but none was found in {}.",
src.display()
);
exit(1);
}
n => {
error!(
"--in-place requires exactly one top-level templated directory, but found {} in {}: {}. In-place rendering cannot determine which directory to strip.",
n,
src.display(),
roots.join(", ")
);
exit(1);
}
}
}
fn dest_rel_path(rendered: &str, in_place: bool) -> Option<PathBuf> {
if !in_place {
return Some(PathBuf::from(rendered));
}
let mut components = Path::new(rendered).components();
components.next(); let rest = components.as_path();
if rest.as_os_str().is_empty() {
None
} else {
Some(rest.to_path_buf())
}
}
pub fn render_dir(
src: &Path,
context: Context,
dst: &Path,
force: bool,
in_place: bool,
) -> Vec<String> {
let mut rendered_paths: Vec<String> = Vec::new();
let mut tmp_dir = env::temp_dir();
tmp_dir.push(Path::new("angreal_tmp"));
if tmp_dir.is_dir().not() {
debug!("Creating tmpdir at {:?}", tmp_dir);
fs::create_dir(&tmp_dir).unwrap();
}
tmp_dir.push(Path::new("*"));
let mut tera = Tera::new(tmp_dir.to_str().unwrap()).unwrap();
tmp_dir.pop();
if tmp_dir.is_dir() {
debug!("Destroying tmpdir at {:?}", tmp_dir);
fs::remove_dir_all(&tmp_dir).unwrap();
}
let mut template_src = <&std::path::Path>::clone(&src).to_path_buf();
template_src.push(Path::new("**/*"));
let _template_name = <&std::path::Path>::clone(&src).file_name().unwrap();
if in_place {
let _ = in_place_root(src);
}
for file in glob(template_src.to_str().unwrap()).expect("Failed to read glob pattern") {
let file_path = file.as_ref().unwrap();
let rel_path = file_path.strip_prefix(src).unwrap().to_str().unwrap();
if file.as_ref().unwrap().is_file() && rel_path.starts_with("{{") && rel_path.contains("}}")
{
debug!(
"Adding template with relative path {:?} to tera instance.",
rel_path
);
tera.add_template_file(file.as_ref().unwrap().to_str().unwrap(), Some(rel_path))
.unwrap();
}
}
if in_place && force.not() {
let mut collisions: Vec<String> = Vec::new();
for entry in WalkDir::new(src)
.into_iter()
.filter_entry(|e| e.file_type().is_dir())
{
let entry = entry.unwrap();
let rel = entry.path().strip_prefix(src).unwrap().to_str().unwrap();
if is_templated_segment(rel) {
let real_path = Tera::one_off(rel, &context, false).unwrap();
if let Some(dest_rel) = dest_rel_path(&real_path, true) {
if dst.join(&dest_rel).exists() {
collisions.push(dest_rel.to_string_lossy().to_string());
}
}
}
}
for template in tera.get_template_names() {
if template == "angreal.toml" || template.starts_with('.') {
continue;
}
let path = Tera::one_off(template, &context, false).unwrap();
if let Some(dest_rel) = dest_rel_path(&path, true) {
if dst.join(&dest_rel).exists() {
collisions.push(dest_rel.to_string_lossy().to_string());
}
}
}
if collisions.is_empty().not() {
error!(
"{} already exist(s) in {}. Will not proceed unless `--force`/force=True is used.",
collisions.join(", "),
dst.display()
);
exit(1);
}
}
let walker = WalkDir::new(src).into_iter();
for entry in walker.filter_entry(|e| e.file_type().is_dir()) {
let path_template = entry.unwrap().clone();
let path_postfix = path_template.path();
let path_template = path_postfix.strip_prefix(src).unwrap().to_str().unwrap();
if is_templated_segment(path_template) {
let real_path = Tera::one_off(path_template, &context, false).unwrap();
if real_path.starts_with('.') {
continue;
}
let dest_rel = match dest_rel_path(&real_path, in_place) {
Some(p) => p,
None => continue,
};
if dst.join(&dest_rel).is_dir() & force.not() {
error!(
"{} already exists. Will not proceed unless `--force`/force=True is used.",
dest_rel.display()
)
}
let destination = dst.join(&dest_rel);
debug!("Creating directory {:?}", destination);
fs::create_dir_all(&destination).unwrap();
rendered_paths.push(destination.to_string_lossy().to_string());
}
}
for template in tera.get_template_names() {
if template == "angreal.toml" {
continue;
}
if template.starts_with('.') {
continue;
}
let rendered = tera.render(template, &context).unwrap();
let path = Tera::one_off(template, &context, false).unwrap();
let dest_rel = match dest_rel_path(&path, in_place) {
Some(p) => p,
None => continue,
};
let destination = dst.join(&dest_rel);
debug!("Rendering file at {:?}", destination);
let mut output = File::create(&destination).unwrap();
write!(output, "{}", rendered.as_str()).unwrap();
rendered_paths.push(destination.to_string_lossy().to_string());
}
rendered_paths
}
pub fn check_up_to_date() -> Result<()> {
let client = reqwest::blocking::Client::new();
let response_result = client
.get("https://pypi.org/pypi/angreal/json")
.timeout(Duration::from_millis(400)) .send();
let json = match response_result {
Ok(response) => {
let json_result = response.json::<serde_json::Value>();
result_or_return_err!(json_result)
}
Err(e) => {
if e.is_timeout() {
warn!("Request timed out. Please check your network connection.");
return Ok(());
}
warn!("Error checking for updates: {}", e);
return Ok(());
}
};
let upstream = value_or_return_err!(json["info"]["version"].as_str());
let current = env!("CARGO_PKG_VERSION");
let current = value_or_return_err!(Version::from(current));
let upstream = value_or_return_err!(Version::from(upstream));
if upstream > current {
println!(
"A newer version of angreal is available, use pip install --upgrade angreal to upgrade."
)
};
Ok(())
}
pub fn get_task_files(path: PathBuf) -> Result<Vec<PathBuf>> {
let mut tasks = Vec::new();
let mut pattern = path;
pattern.push("task_*.py");
let mut have_tasks = false;
for entry in glob(pattern.to_str().unwrap()).expect("Failed to read glob pattern") {
match entry {
Ok(path) => {
info!("Found task {:?}", path.display());
tasks.push(path);
have_tasks = true;
}
Err(e) => error!("{:?}", e),
}
}
if have_tasks {
Ok(tasks)
} else {
error!("No tasks found for execution.");
Err(anyhow!("No tasks found for execution."))
}
}
pub fn register(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_root, m)?)?;
m.add_function(wrap_pyfunction!(render_template, m)?)?;
m.add_function(wrap_pyfunction!(generate_context, m)?)?;
m.add_function(wrap_pyfunction!(render_directory, m)?)?;
m.add_function(wrap_pyfunction!(get_context, m)?)?;
Ok(())
}
#[pyfunction]
pub fn render_directory(
src: &str,
dst: &str,
force: bool,
context: Option<&Bound<'_, PyDict>>,
) -> PyResult<Py<PyAny>> {
let mut ctx = Context::new();
let src = Path::new(src);
let dst = Path::new(dst);
if let Some(context) = context {
for key in context.keys() {
if let Ok(Some(value)) = context.get_item(&key) {
let v = value.to_string();
let k = key.to_string();
ctx.insert(&k, &v);
}
}
}
let x = render_dir(src, ctx, dst, force, false);
Ok(pythonize_this!(x))
}
#[pyfunction]
fn generate_context(path: &str, take_input: bool) -> PyResult<Py<PyAny>> {
let toml_path = Path::new(path).to_path_buf();
let ctx = repl_context_from_toml(toml_path, take_input);
let map = context_to_map(ctx);
Ok(pythonize_this!(map))
}
#[pyfunction]
fn get_root() -> PyResult<String> {
match is_angreal_project() {
Ok(angreal_root) => Ok(String::from(angreal_root.to_string_lossy())),
Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
e.to_string(),
)),
}
}
#[pyfunction]
fn render_template(template: &str, context: &Bound<'_, PyDict>) -> PyResult<String> {
let mut tera = Tera::default();
let mut ctx = tera::Context::new();
tera.add_raw_template("template", template).unwrap();
for (key, val) in context.iter() {
ctx.insert(key.to_string(), &val.to_string());
}
Ok(tera.render("template", &ctx).unwrap())
}
#[pyfunction]
fn get_context() -> PyResult<Py<PyAny>> {
let angreal_root = match is_angreal_project() {
Ok(root) => root,
Err(_) => {
let empty = toml::Table::new();
return Ok(pythonize_this!(empty));
}
};
let toml_path = angreal_root.join("angreal.toml");
let file_contents = match fs::read_to_string(&toml_path) {
Ok(contents) => contents,
Err(_) => {
let empty = toml::Table::new();
return Ok(pythonize_this!(empty));
}
};
let toml_value = match file_contents.parse::<Table>() {
Ok(value) => value,
Err(_) => {
let empty = toml::Table::new();
return Ok(pythonize_this!(empty));
}
};
Ok(pythonize_this!(toml_value))
}
pub fn is_angreal_project() -> Result<PathBuf> {
let angreal_path = Path::new(".angreal");
let mut check_dir = match env::current_dir() {
Ok(dir) => dir,
Err(_) => return Err(anyhow!("This doesn't appear to be an angreal project.")),
};
check_dir.push(angreal_path);
let found = loop {
if check_dir.is_dir() {
break true;
}
let mut next_dir = check_dir.clone();
next_dir.pop();
next_dir.pop();
next_dir.push(angreal_path);
if next_dir == check_dir {
break false;
}
check_dir = next_dir.clone();
};
if found {
Ok(check_dir)
} else {
Err(anyhow!("This doesn't appear to be an angreal project."))
}
}
pub fn load_python(file: PathBuf) -> Result<(), PyErr> {
let mut dir = file.clone();
dir.pop();
let dir = dir.to_str();
let contents = fs::read_to_string(file.clone()).unwrap();
let r_value = Python::attach(|py| -> PyResult<()> {
let sys = py.import("sys")?;
let path_attr = sys.getattr("path")?;
let syspath = path_attr.cast::<PyList>()?;
syspath.insert(0, dir)?;
use std::ffi::CString;
let contents_cstr = CString::new(contents.as_str()).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Invalid C string: {}", e))
})?;
let result = PyModule::from_code(py, contents_cstr.as_c_str(), c"", c"");
match result {
Ok(_result) => {
debug!("Successfully loaded {:?}", file);
Ok(())
}
Err(err) => {
error!("{:?} failed to load", file);
let formatter =
crate::error_formatter::PythonErrorFormatter::new(err.clone_ref(py));
println!("{}", formatter);
Err(err)
}
}
});
match r_value {
Ok(_ok) => Ok(()),
Err(err) => Err(err),
}
}
pub fn extract_key_defaults(toml_path: PathBuf) -> Result<Map<String, Value>> {
let file_contents = fs::read_to_string(&toml_path)
.unwrap_or_else(|_| panic!("Unable to open {:?}", &toml_path));
let extract = file_contents.parse::<Table>().unwrap();
let mut defaults = Map::new();
for (k, v) in extract
.iter()
.filter(|(key, _)| *key != "prompt" && *key != "validation")
{
defaults.insert(k.clone(), v.clone());
}
Ok(defaults)
}
pub fn extract_validation_rules(toml_path: PathBuf) -> Result<Map<String, Value>> {
let file_contents = fs::read_to_string(&toml_path)
.unwrap_or_else(|_| panic!("Unable to open {:?}", &toml_path));
let extract = file_contents.parse::<Table>().unwrap();
let binding_validation = Table::new();
let validations = extract
.get("validation")
.and_then(|v| v.as_table())
.unwrap_or(&binding_validation);
let mut flattened_validations = Map::new();
for (field, rules) in validations.iter() {
if let Some(rules_table) = rules.as_table() {
for (rule, value) in rules_table.iter() {
let key = format!("{}.{}", field, rule);
flattened_validations.insert(key, value.clone());
}
}
}
Ok(flattened_validations)
}
pub fn extract_prompts(toml_path: PathBuf) -> Result<Map<String, Value>> {
let file_contents = fs::read_to_string(&toml_path)
.unwrap_or_else(|_| panic!("Unable to open {:?}", &toml_path));
let extract = file_contents.parse::<Table>().unwrap();
let binding_prompt = Table::new();
let prompts = extract
.get("prompt")
.and_then(|v| v.as_table())
.unwrap_or(&binding_prompt);
Ok(prompts.clone())
}
#[cfg(test)]
#[path = "../tests"]
mod tests {
use super::*;
use pyo3::types::PyDict;
use std::env;
use std::fs;
use std::io::Write;
use std::path::Path;
use std::path::PathBuf;
mod common;
#[test]
fn test_repl_context_from_toml() {
let root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let test_toml = root.join("tests/common/test_assets/test_template/angreal.toml");
let ctx = crate::utils::repl_context_from_toml(test_toml, false);
assert_eq!(ctx.get("key_1").unwrap(), "value_1");
assert_eq!(ctx.get("key_2").unwrap(), 1);
assert_eq!(ctx.get("folder_variable").unwrap(), "folder_name");
assert_eq!(
ctx.get("variable_text").unwrap(),
"Just some text that we want to render"
);
assert_eq!(ctx.get("role").unwrap(), "user");
assert_eq!(ctx.get("age").unwrap(), 25);
assert_eq!(ctx.get("email").unwrap(), "test@example.com");
assert_eq!(ctx.get("score").unwrap(), 50);
assert_eq!(ctx.get("username").unwrap(), "user123");
assert_eq!(ctx.get("password").unwrap(), "securepass");
assert_eq!(ctx.get("required_field").unwrap(), "important");
assert!(ctx.get("prompt").is_none());
assert!(ctx.get("validation").is_none());
}
#[test]
fn test_load_python() {
let root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let should_pass = [
"tests/common/test_assets/good_init.py",
"tests/common/test_assets/good_task.py",
"tests/common/test_assets/no_func_init.py",
"tests/common/test_assets/no_func_task.py",
"tests/common/test_assets/exception_init.py",
"tests/common/test_assets/exception_task.py",
];
for f_name in &should_pass {
let file = PathBuf::from(String::from(*f_name));
let rv = crate::utils::load_python(root.join(file)).is_ok();
assert!(rv);
}
let shouldnt_pass = [
"tests/common/test_assets/bad_import_init.py",
"tests/common/test_assets/bad_import_task.py",
];
for f_name in &shouldnt_pass {
let file = PathBuf::from(String::from(*f_name));
let rv = crate::utils::load_python(root.join(file)).is_err();
assert!(rv);
}
}
#[test]
fn test_is_angreal_project() {
let starting_dir = std::env::current_dir().unwrap();
let tmp_dir = common::make_tmp_dir();
std::env::set_current_dir(&tmp_dir).unwrap_or(());
assert!(crate::utils::is_angreal_project().is_err());
std::env::set_current_dir(starting_dir).unwrap_or(());
fs::remove_dir_all(&tmp_dir).unwrap_or(());
}
#[test]
fn test_is_not_angreal_project() {
let starting_dir = std::env::current_dir().unwrap();
let tmp_dir = common::make_tmp_dir();
std::env::set_current_dir(&tmp_dir).unwrap_or(());
fs::create_dir(Path::new(".angreal")).unwrap_or(());
assert!(crate::utils::is_angreal_project().is_ok());
std::env::set_current_dir(starting_dir).unwrap_or(());
fs::remove_dir_all(&tmp_dir).unwrap_or(());
}
#[test]
fn test_get_task_files() {
let starting_dir = std::env::current_dir().unwrap();
let tmp_dir = common::make_tmp_dir();
std::env::set_current_dir(&tmp_dir).unwrap_or(());
fs::create_dir(Path::new(".angreal")).unwrap_or(());
let files_to_make = ["task_test_task.py", "not_this_file.py", "task_not_this.txt"];
for f_name in &files_to_make {
let mut f_path = tmp_dir.clone();
f_path.push(Path::new(".angreal"));
f_path.push(Path::new(f_name));
let _ = fs::File::create(&f_path);
}
let files_should_find = vec![tmp_dir.join(".angreal").join("task_test_task.py")];
let files_found = crate::utils::get_task_files(tmp_dir.join(".angreal")).unwrap();
assert_eq!(files_found, files_should_find);
std::env::set_current_dir(starting_dir).unwrap_or(());
fs::remove_dir_all(&tmp_dir).unwrap_or(());
}
#[test]
fn test_extract_key_defaults() {
let root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let test_toml = root.join("tests/common/test_assets/test_template/angreal.toml");
let defaults = extract_key_defaults(test_toml).unwrap();
assert_eq!(defaults.get("key_1").unwrap().as_str().unwrap(), "value_1");
assert_eq!(defaults.get("key_2").unwrap().as_integer().unwrap(), 1);
assert_eq!(
defaults.get("folder_variable").unwrap().as_str().unwrap(),
"folder_name"
);
assert_eq!(defaults.get("role").unwrap().as_str().unwrap(), "user");
assert_eq!(defaults.get("age").unwrap().as_integer().unwrap(), 25);
assert_eq!(
defaults.get("email").unwrap().as_str().unwrap(),
"test@example.com"
);
assert_eq!(defaults.get("score").unwrap().as_integer().unwrap(), 50);
assert_eq!(
defaults.get("username").unwrap().as_str().unwrap(),
"user123"
);
assert_eq!(
defaults.get("password").unwrap().as_str().unwrap(),
"securepass"
);
assert_eq!(
defaults.get("required_field").unwrap().as_str().unwrap(),
"important"
);
assert!(defaults.get("prompt").is_none());
assert!(defaults.get("validation").is_none());
}
#[test]
fn test_extract_validation_rules() {
let root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let test_toml = root.join("tests/common/test_assets/test_template/angreal.toml");
let validations = extract_validation_rules(test_toml).unwrap();
assert_eq!(
validations
.get("role.allowed_values")
.unwrap()
.as_array()
.unwrap()[0]
.as_str()
.unwrap(),
"admin"
);
assert_eq!(
validations
.get("role.allowed_values")
.unwrap()
.as_array()
.unwrap()[1]
.as_str()
.unwrap(),
"user"
);
assert_eq!(
validations
.get("role.allowed_values")
.unwrap()
.as_array()
.unwrap()[2]
.as_str()
.unwrap(),
"guest"
);
assert_eq!(
validations
.get("score.allowed_values")
.unwrap()
.as_array()
.unwrap()[0]
.as_integer()
.unwrap(),
0
);
assert_eq!(
validations
.get("score.allowed_values")
.unwrap()
.as_array()
.unwrap()[1]
.as_integer()
.unwrap(),
25
);
assert_eq!(
validations
.get("score.allowed_values")
.unwrap()
.as_array()
.unwrap()[2]
.as_integer()
.unwrap(),
50
);
assert_eq!(
validations
.get("score.allowed_values")
.unwrap()
.as_array()
.unwrap()[3]
.as_integer()
.unwrap(),
75
);
assert_eq!(
validations
.get("score.allowed_values")
.unwrap()
.as_array()
.unwrap()[4]
.as_integer()
.unwrap(),
100
);
assert_eq!(
validations.get("age.min").unwrap().as_integer().unwrap(),
18
);
assert_eq!(
validations.get("age.max").unwrap().as_integer().unwrap(),
65
);
assert_eq!(
validations.get("age.type").unwrap().as_str().unwrap(),
"integer"
);
assert_eq!(
validations.get("key_2.type").unwrap().as_str().unwrap(),
"integer"
);
assert_eq!(
validations
.get("email.regex_match")
.unwrap()
.as_str()
.unwrap(),
"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
);
assert!(validations
.get("email.not_empty")
.unwrap()
.as_bool()
.unwrap());
assert_eq!(
validations
.get("username.length_min")
.unwrap()
.as_integer()
.unwrap(),
3
);
assert_eq!(
validations
.get("username.length_max")
.unwrap()
.as_integer()
.unwrap(),
20
);
assert_eq!(
validations
.get("username.regex_match")
.unwrap()
.as_str()
.unwrap(),
"^[a-zA-Z0-9]+$"
);
assert_eq!(
validations
.get("password.length_min")
.unwrap()
.as_integer()
.unwrap(),
8
);
assert!(validations
.get("password.not_empty")
.unwrap()
.as_bool()
.unwrap());
assert!(validations
.get("required_field.not_empty")
.unwrap()
.as_bool()
.unwrap());
}
#[test]
fn test_extract_prompts() {
let root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let test_toml = root.join("tests/common/test_assets/test_template/angreal.toml");
let prompts = extract_prompts(test_toml).unwrap();
assert_eq!(
prompts.get("key_1").unwrap().as_str().unwrap(),
"Enter the first key value"
);
assert_eq!(
prompts.get("key_2").unwrap().as_str().unwrap(),
"Enter the second key value (must be a number)"
);
assert_eq!(
prompts.get("folder_variable").unwrap().as_str().unwrap(),
"What should we name the folder?"
);
assert_eq!(
prompts.get("variable_text").unwrap().as_str().unwrap(),
"Enter the text you would like to include"
);
assert_eq!(
prompts.get("role").unwrap().as_str().unwrap(),
"Select a role (admin, user, guest)"
);
assert_eq!(
prompts.get("age").unwrap().as_str().unwrap(),
"Enter your age (must be between 18 and 65)"
);
assert_eq!(
prompts.get("email").unwrap().as_str().unwrap(),
"Enter your email address"
);
assert_eq!(
prompts.get("score").unwrap().as_str().unwrap(),
"Enter your score (0, 25, 50, 75, or 100)"
);
assert_eq!(
prompts.get("username").unwrap().as_str().unwrap(),
"Choose a username (3-20 characters, alphanumeric)"
);
assert_eq!(
prompts.get("password").unwrap().as_str().unwrap(),
"Create a password (min 8 characters)"
);
assert_eq!(
prompts.get("required_field").unwrap().as_str().unwrap(),
"This field is required"
);
}
#[test]
fn test_read_config() {
let starting_dir = std::env::current_dir().unwrap();
let tmp_dir = common::make_tmp_dir();
std::env::set_current_dir(&tmp_dir).unwrap_or(());
fs::create_dir(Path::new(".angreal")).unwrap_or(());
let toml_content = r#"
key_1 = "value_1"
key_2 = 42
nested = { key = "value" }
array = [1, 2, 3]
"#;
let mut toml_file =
fs::File::create(tmp_dir.join(".angreal").join("angreal.toml")).unwrap();
write!(toml_file, "{}", toml_content).unwrap();
let config = get_context().unwrap();
Python::attach(|py| {
let dict = config.cast_bound::<PyDict>(py).unwrap();
assert_eq!(
dict.get_item("key_1")
.unwrap()
.unwrap()
.extract::<String>()
.unwrap(),
"value_1"
);
assert_eq!(
dict.get_item("key_2")
.unwrap()
.unwrap()
.extract::<i64>()
.unwrap(),
42
);
let nested_item = dict.get_item("nested").unwrap().unwrap();
let nested = nested_item.cast::<PyDict>().unwrap();
assert_eq!(
nested
.get_item("key")
.unwrap()
.unwrap()
.extract::<String>()
.unwrap(),
"value"
);
let array_item = dict.get_item("array").unwrap().unwrap();
let array = array_item.extract::<Vec<i64>>().unwrap();
assert_eq!(array, vec![1, 2, 3]);
});
std::env::set_current_dir(starting_dir).unwrap_or(());
fs::remove_dir_all(&tmp_dir).unwrap_or(());
}
#[test]
fn test_read_config_python_bindings() {
let starting_dir = std::env::current_dir().unwrap();
let tmp_dir = common::make_tmp_dir();
std::env::set_current_dir(&tmp_dir).unwrap_or(());
fs::create_dir(Path::new(".angreal")).unwrap_or(());
let toml_content = r#"
key_1 = "value_1"
key_2 = 42
"#;
let mut toml_file =
fs::File::create(tmp_dir.join(".angreal").join("angreal.toml")).unwrap();
write!(toml_file, "{}", toml_content).unwrap();
Python::attach(|py| {
let module = PyModule::new(py, "test_module").unwrap();
module
.add_function(wrap_pyfunction!(get_context, &module).unwrap())
.unwrap();
let attr = module.getattr("get_context").unwrap();
let call_result = attr.call0().unwrap();
let result = call_result.cast::<PyDict>().unwrap();
assert_eq!(
result
.get_item("key_1")
.unwrap()
.unwrap()
.extract::<String>()
.unwrap(),
"value_1"
);
assert_eq!(
result
.get_item("key_2")
.unwrap()
.unwrap()
.extract::<i64>()
.unwrap(),
42
);
});
std::env::set_current_dir(starting_dir).unwrap_or(());
fs::remove_dir_all(&tmp_dir).unwrap_or(());
}
}