use std::sync::Arc;
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use crate::tuner::TrialResult;
use super::results::PyTrialResult;
pub struct PyProgressCallback {
callback: Arc<Py<PyAny>>,
}
impl PyProgressCallback {
pub fn new(callback: Py<PyAny>) -> Self {
Self {
callback: Arc::new(callback),
}
}
pub fn call(&self, trial: &TrialResult, current: usize, total: usize) {
Python::with_gil(|py| {
let py_trial = PyTrialResult::from(trial.clone());
let py_trial_obj = Py::new(py, py_trial).unwrap();
let args = PyTuple::new(
py,
&[
py_trial_obj.into_any(),
current.into_pyobject(py).unwrap().into_any().unbind(),
total.into_pyobject(py).unwrap().into_any().unbind(),
],
)
.unwrap();
if let Err(e) = self.callback.call1(py, args) {
eprintln!("Warning: Progress callback failed: {}", e);
}
});
}
}
unsafe impl Send for PyProgressCallback {}
unsafe impl Sync for PyProgressCallback {}
pub fn validate_callable(py: Python<'_>, obj: &Py<PyAny>) -> PyResult<()> {
if !obj.bind(py).is_callable() {
return Err(pyo3::exceptions::PyTypeError::new_err(
"callback must be callable",
));
}
Ok(())
}