use crate::{error::PyResult, tensor::PyTensor};
use pyo3::prelude::*;
use pyo3::types::PyAny;
use std::collections::HashMap;
#[pyclass(name = "Optimizer", subclass)]
pub struct PyOptimizer {
}
#[pymethods]
impl PyOptimizer {
#[new]
fn new() -> Self {
Self {}
}
fn step(&mut self) -> PyResult<()> {
Err(PyErr::new::<pyo3::exceptions::PyNotImplementedError, _>(
"Subclasses must implement step method",
))
}
fn zero_grad(&mut self, set_to_none: Option<bool>) {
let _set_to_none = set_to_none.unwrap_or(false);
}
fn state_dict(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
Ok(HashMap::new())
}
fn load_state_dict(&mut self, state_dict: HashMap<String, Py<PyAny>>) -> PyResult<()> {
let _state_dict = state_dict;
Ok(())
}
fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
Ok(Vec::new())
}
fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
Ok(HashMap::new())
}
fn add_param_group(&mut self, param_group: HashMap<String, Py<PyAny>>) -> PyResult<()> {
let _param_group = param_group;
Err(PyErr::new::<pyo3::exceptions::PyNotImplementedError, _>(
"Subclasses must implement add_param_group method",
))
}
fn __repr__(&self) -> String {
"Optimizer()".to_string()
}
fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
Ok(HashMap::new())
}
}
pub fn extract_parameters(params: Vec<PyTensor>) -> PyResult<Vec<torsh_tensor::Tensor<f32>>> {
params.into_iter().map(|p| Ok(p.tensor)).collect()
}
pub fn create_param_group(
params: Vec<PyTensor>,
lr: f32,
extra_params: HashMap<String, Py<PyAny>>,
) -> PyResult<HashMap<String, Py<PyAny>>> {
let mut param_group = HashMap::new();
Python::attach(|py| {
let py_params: Vec<Py<PyAny>> = params
.into_iter()
.map(|p| {
p.into_pyobject(py)
.expect("Python object conversion should succeed")
.into()
})
.collect();
param_group.insert(
"params".to_string(),
py_params
.into_pyobject(py)
.expect("Python object conversion should succeed")
.into(),
);
param_group.insert(
"lr".to_string(),
lr.into_pyobject(py)
.expect("Python object conversion should succeed")
.into(),
);
for (key, value) in extra_params {
param_group.insert(key, value);
}
Ok(param_group)
})
}