Struct multistochgrad::svrg::SVRGDescent
source · pub struct SVRGDescent { /* private fields */ }
Expand description
Provides stochastic Gradient Descent optimization described in :
“Accelerating Stochastic Gradient Descent using Predictive Variance Reduction”.
Advances in Neural Information Processing Systems, pages 315–323, 2013
The algorithm consists in alternating a full gradient computation of the sum and a sequence of mini batches of one term selected at random and computing just the gradient of this term.
During the mini batch sequence the gradient of the sum is approximated by updating only the randomly selected term component of the full gradient.
Precisely we have the following sequence:
- a batch_gradient as the full gradient at current position
- storing gradient and position before the mini batch sequence
- then for
nb_mini_batch
:- uniform sampling of a term of the summation
- computation of the gradient of the term at current position and the gradient at position before mini batch
- computation of direction of propagation as the batch gradient + gradient of term at current position - gradient of term at position before mini batch sequence
- update of position with adequate step size
The step size used in the algorithm is constant and according to the ref paper it should be of the order of L/4 where L is the lipschitz constant of the function to minimize
Implementations§
source§impl SVRGDescent
impl SVRGDescent
sourcepub fn new(nb_mini_batch: usize, step_size: f64) -> SVRGDescent
pub fn new(nb_mini_batch: usize, step_size: f64) -> SVRGDescent
nb_mini_batch : number of mini batch
step_size used in position update.
Examples found in repository?
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
fn main() {
let _ = env_logger::init();
log::set_max_level(log::LevelFilter::Trace);
// check for path image and labels
let image_path = PathBuf::from(String::from(IMAGE_FNAME_STR).clone());
let image_file_res = OpenOptions::new().read(true).open(&image_path);
if image_file_res.is_err() {
println!("could not open image file : {:?}", IMAGE_FNAME_STR);
return;
}
let label_path = PathBuf::from(LABEL_FNAME_STR);
let label_file_res = OpenOptions::new().read(true).open(&label_path);
if label_file_res.is_err() {
println!("could not open label file : {:?}", LABEL_FNAME_STR);
return;
}
//
// load mnist data
//
let mnist_data =
MnistData::new(String::from(IMAGE_FNAME_STR), String::from(LABEL_FNAME_STR)).unwrap();
let images = mnist_data.get_images();
let labels = mnist_data.get_labels();
// nb_images is length of third compoenent of array dimension
let (nb_row, nb_column, nb_images) = images.dim(); // get t-uple from dim method
assert_eq!(nb_images, labels.shape()[0]); // get slice from shape method...
// transform into logisitc regression
let mut observations = Vec::<(Array1<f64>, usize)>::with_capacity(nb_images);
//
for k in 0..nb_images {
let mut image = Array1::<f64>::zeros(1 + nb_row * nb_column);
let mut index = 0;
image[index] = 1.;
index += 1;
for i in 0..nb_row {
for j in 0..nb_column {
image[index] = images[[i, j, k]] as f64 / 256.;
index += 1;
}
} // end of for i
observations.push((image, labels[k] as usize));
} // end of for k
//
let regr_l = LogisticRegression::new(10, observations);
//
// minimize
//
//
let nb_iter = 50;
let svrg_pb = SVRGDescent::new(
2000, // mini batch size
0.05, // batch step
);
// allocate and set to 0 an array with 9 rows(each row corresponds to a class, columns are pixels values)
let mut initial_position = Array2::<f64>::zeros((9, 1 + nb_row * nb_column));
// do a bad initialization , fill with 0 is much better!!
initial_position.fill(0.0);
let solution = svrg_pb.minimize(®r_l, &initial_position, Some(nb_iter));
println!(" solution with minimized value = {:2.4E}", solution.value);
//
// get image of coefficients to see corresponding images.
//
let image_fname = String::from("classe_svrg.img");
for k in 0..9 {
let mut k_image_fname: String = image_fname.clone();
k_image_fname.push_str(&k.to_string());
let image_path = PathBuf::from(k_image_fname.clone());
let image_file_res = OpenOptions::new()
.write(true)
.create(true)
.open(&image_path);
if image_file_res.is_err() {
println!("could not open image file : {:?}", k_image_fname);
return;
}
//
let mut out = io::BufWriter::new(image_file_res.unwrap());
//
// get a f64 slice to write
let f64_array_to_write: &[f64] = solution.position.slice(s![k, ..]).to_slice().unwrap();
let u8_slice = unsafe {
std::slice::from_raw_parts(
f64_array_to_write.as_ptr() as *const u8,
std::mem::size_of::<f64>() * f64_array_to_write.len(),
)
};
out.write_all(u8_slice).unwrap();
out.flush().unwrap();
}
}