use pyo3::prelude::*;
use std::ffi::CString;
fn init_py() {
Python::initialize();
}
fn run_python_code<F, T>(code: &str, extract_fn: F) -> PyResult<T>
where
F: FnOnce(&Bound<'_, PyAny>) -> PyResult<T>,
{
init_py();
Python::attach(|py| {
let code_str = format!(
"import sys\nsys.path.insert(0, '{}')\nimport rstorch_python as rstorch\n\n{}",
env!("CARGO_MANIFEST_DIR"),
code
);
let code_cstr = CString::new(code_str).unwrap();
let filename = CString::new("").unwrap();
let module_name = CString::new("").unwrap();
let module = PyModule::from_code(py, &code_cstr, &filename, &module_name)?;
let result = module.getattr("result")?;
extract_fn(&result)
})
}
#[test]
fn test_torsh_error_creation() {
let result: String = run_python_code(
r#"
error = rstorch.TorshError("Test error message")
result = str(error)
"#,
|r| r.extract(),
)
.unwrap();
assert_eq!(result, "Test error message");
}
#[test]
fn test_torsh_error_repr() {
let result: String = run_python_code(
r#"
error = rstorch.TorshError("Test error message")
result = repr(error)
"#,
|r| r.extract(),
)
.unwrap();
assert_eq!(result, "TorshError('Test error message')");
}
#[test]
fn test_torsh_error_type_registered() {
init_py();
Python::attach(|py| {
let code = format!(
r#"
import sys
sys.path.insert(0, '{}')
import rstorch_python as rstorch
result = hasattr(rstorch, 'TorshError')
"#,
env!("CARGO_MANIFEST_DIR")
);
let code_cstr = CString::new(code).unwrap();
let filename = CString::new("").unwrap();
let module_name = CString::new("").unwrap();
let module = PyModule::from_code(py, &code_cstr, &filename, &module_name).unwrap();
let result: bool = module.getattr("result").unwrap().extract().unwrap();
assert!(result, "TorshError should be registered in the module");
});
}
#[test]
fn test_torsh_error_is_type() {
init_py();
Python::attach(|py| {
let code = format!(
r#"
import sys
sys.path.insert(0, '{}')
import rstorch_python as rstorch
error = rstorch.TorshError("Test")
result = type(error).__name__
"#,
env!("CARGO_MANIFEST_DIR")
);
let code_cstr = CString::new(code).unwrap();
let filename = CString::new("").unwrap();
let module_name = CString::new("").unwrap();
let module = PyModule::from_code(py, &code_cstr, &filename, &module_name).unwrap();
let result: String = module.getattr("result").unwrap().extract().unwrap();
assert_eq!(result, "TorshPyError");
});
}
#[test]
fn test_torsh_error_empty_message() {
let result: String = run_python_code(
r#"
error = rstorch.TorshError("")
result = str(error)
"#,
|r| r.extract(),
)
.unwrap();
assert_eq!(result, "");
}
#[test]
fn test_torsh_error_multiline_message() {
let result: String = run_python_code(
r#"
error = rstorch.TorshError("Line 1\nLine 2\nLine 3")
result = str(error)
"#,
|r| r.extract(),
)
.unwrap();
assert_eq!(result, "Line 1\nLine 2\nLine 3");
}
#[test]
fn test_torsh_error_with_special_characters() {
let result: String = run_python_code(
r#"
error = rstorch.TorshError("Error: 'quoted' and \"double quoted\" with symbols !@#$%^&*()")
result = str(error)
"#,
|r| r.extract(),
)
.unwrap();
assert!(result.contains("Error: 'quoted'"));
assert!(result.contains("symbols !@#$%^&*()"));
}