use super::base::{create_param_group, extract_parameters, PyOptimizer};
use crate::{error::PyResult, tensor::PyTensor};
use parking_lot::RwLock;
use pyo3::prelude::*;
use pyo3::types::PyAny;
use std::collections::HashMap;
use std::sync::Arc;
use torsh_optim::{adagrad::AdaGrad, Optimizer};
#[pyclass(name = "Adagrad", extends = PyOptimizer)]
pub struct PyAdaGrad {
adagrad: AdaGrad,
param_groups: Vec<HashMap<String, Py<PyAny>>>,
lr: f32,
lr_decay: f32,
weight_decay: f32,
eps: f32,
}
#[pymethods]
impl PyAdaGrad {
#[new]
fn new(
params: Vec<PyTensor>,
lr: Option<f32>,
lr_decay: Option<f32>,
weight_decay: Option<f32>,
eps: Option<f32>,
) -> (Self, PyOptimizer) {
let lr = lr.unwrap_or(0.01);
let lr_decay = lr_decay.unwrap_or(0.0);
let weight_decay = weight_decay.unwrap_or(0.0);
let eps = eps.unwrap_or(1e-10);
let tensor_params =
extract_parameters(params.clone()).expect("parameter extraction should succeed");
let wrapped_params: Vec<Arc<RwLock<_>>> = tensor_params
.into_iter()
.map(|tensor| Arc::new(RwLock::new(tensor)))
.collect();
let adagrad = AdaGrad::new(
wrapped_params,
Some(lr),
Some(lr_decay),
Some(weight_decay),
Some(0.0),
Some(eps),
);
let mut param_group_data = HashMap::new();
Python::attach(|py| {
param_group_data.insert(
"lr_decay".to_string(),
lr_decay
.into_pyobject(py)
.expect("Python object conversion should succeed")
.into_any()
.unbind(),
);
param_group_data.insert(
"weight_decay".to_string(),
weight_decay
.into_pyobject(py)
.expect("Python object conversion should succeed")
.into_any()
.unbind(),
);
param_group_data.insert(
"eps".to_string(),
eps.into_pyobject(py)
.expect("Python object conversion should succeed")
.into_any()
.unbind(),
);
});
let param_groups = vec![create_param_group(params, lr, param_group_data)
.expect("param group creation should succeed")];
(
Self {
adagrad,
param_groups,
lr,
lr_decay,
weight_decay,
eps,
},
PyOptimizer {},
)
}
fn step(&mut self) -> PyResult<()> {
self.adagrad.step().map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
"Optimizer step failed: {}",
e
))
})?;
Ok(())
}
fn zero_grad(&mut self, set_to_none: Option<bool>) {
let _set_to_none = set_to_none.unwrap_or(false);
self.adagrad.zero_grad();
}
fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
Python::attach(|py| {
let cloned_groups = self
.param_groups
.iter()
.map(|group| {
group
.iter()
.map(|(k, v)| (k.clone(), v.clone_ref(py)))
.collect()
})
.collect();
Ok(cloned_groups)
})
}
fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
let mut state = HashMap::new();
Python::attach(|py| {
state.insert(
"step".to_string(),
0i64.into_pyobject(py)
.expect("Python object conversion should succeed")
.into_any()
.unbind(),
);
state.insert(
"sum".to_string(),
"{}".into_pyobject(py)
.expect("Python object conversion should succeed")
.into_any()
.unbind(),
);
});
Ok(state)
}
fn __repr__(&self) -> String {
format!(
"Adagrad(lr={}, lr_decay={}, eps={}, weight_decay={})",
self.lr, self.lr_decay, self.eps, self.weight_decay
)
}
fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
let mut defaults = HashMap::new();
Python::attach(|py| {
defaults.insert(
"lr".to_string(),
self.lr
.into_pyobject(py)
.expect("Python object conversion should succeed")
.into_any()
.unbind(),
);
defaults.insert(
"lr_decay".to_string(),
self.lr_decay
.into_pyobject(py)
.expect("Python object conversion should succeed")
.into_any()
.unbind(),
);
defaults.insert(
"weight_decay".to_string(),
self.weight_decay
.into_pyobject(py)
.expect("Python object conversion should succeed")
.into_any()
.unbind(),
);
defaults.insert(
"eps".to_string(),
self.eps
.into_pyobject(py)
.expect("Python object conversion should succeed")
.into_any()
.unbind(),
);
});
Ok(defaults)
}
#[getter]
fn lr(&self) -> f32 {
self.lr
}
#[setter]
fn set_lr(&mut self, lr: f32) {
self.lr = lr;
Python::attach(|py| {
for param_group in &mut self.param_groups {
param_group.insert(
"lr".to_string(),
lr.into_pyobject(py)
.expect("Python object conversion should succeed")
.into_any()
.unbind(),
);
}
});
}
#[getter]
fn lr_decay(&self) -> f32 {
self.lr_decay
}
#[getter]
fn weight_decay(&self) -> f32 {
self.weight_decay
}
#[getter]
fn eps(&self) -> f32 {
self.eps
}
}