1use crate::config::UmapConfig;
2use crate::embedding::FittedUmap;
3use crate::layout::optimize_layout_euclidean::optimize_layout_euclidean_single_epoch_stateful;
4use crate::layout::optimize_layout_generic::optimize_layout_generic_single_epoch_stateful;
5use crate::manifold::LearnedManifold;
6use crate::metric::Metric;
7use crate::metric::MetricType;
8use crate::umap::make_epochs_per_sample::make_epochs_per_sample;
9use crate::utils::parallel_vec::ParallelVec;
10use ndarray::Array1;
11use ndarray::Array2;
12use ndarray::ArrayView2;
13use rayon::prelude::*;
14use serde::Deserialize;
15use serde::Serialize;
16use std::time::Instant;
17use tracing::info;
18
19#[derive(Debug, Serialize, Deserialize)]
28pub struct Optimizer {
29 manifold: LearnedManifold,
31
32 head: Array1<u32>,
34 tail: Array1<u32>,
35 epochs_per_sample: Array1<f64>,
36
37 embedding: Array2<f32>,
39
40 epoch_of_next_sample: Array1<f64>,
42 epoch_of_next_negative_sample: Array1<f64>,
43 epochs_per_negative_sample: Array1<f64>,
44
45 current_epoch: usize,
47 total_epochs: usize,
48
49 gamma: f32,
51 initial_alpha: f32,
52 negative_sample_rate: f64,
53
54 metric_type: MetricType,
56}
57
58impl Optimizer {
59 pub fn new(
75 manifold: LearnedManifold,
76 init: Array2<f32>,
77 total_epochs: usize,
78 opt_params: &UmapConfig,
79 metric_type: MetricType,
80 ) -> Self {
81 let gamma = opt_params.optimization.repulsion_strength;
82 let initial_alpha = opt_params.optimization.learning_rate;
83 let negative_sample_rate = opt_params.optimization.negative_sample_rate;
84
85 let graph = &manifold.graph;
86 let n_samples = graph.shape().0;
87
88 let started = Instant::now();
90 let max_val = graph
91 .data()
92 .par_iter()
93 .copied()
94 .reduce(|| 0.0f32, |a, b| a.max(b));
95
96 let default_epochs = if n_samples <= 10000 { 500 } else { 200 };
97 let threshold_epochs = if total_epochs > 10 {
98 total_epochs
99 } else {
100 default_epochs
101 };
102 let threshold = max_val / threshold_epochs as f32;
103 info!(
104 duration_ms = started.elapsed().as_millis(),
105 max_val, threshold, "optimizer threshold computed"
106 );
107
108 let started = Instant::now();
110 let row_counts: Vec<usize> = (0..n_samples)
111 .into_par_iter()
112 .map(|row| {
113 let row_start = graph.indptr().index(row);
114 let row_end = graph.indptr().index(row + 1);
115 let row_data = &graph.data()[row_start..row_end];
116 row_data.iter().filter(|&&v| v >= threshold).count()
117 })
118 .collect();
119
120 let mut edge_offsets: Vec<usize> = Vec::with_capacity(n_samples + 1);
122 edge_offsets.push(0);
123 let mut total_edges = 0usize;
124 for &count in &row_counts {
125 total_edges += count;
126 edge_offsets.push(total_edges);
127 }
128 info!(
129 duration_ms = started.elapsed().as_millis(),
130 total_edges, "optimizer edge filtering complete"
131 );
132
133 let started = Instant::now();
135 let head_vec = ParallelVec::new(vec![0u32; total_edges]);
136 let tail_vec = ParallelVec::new(vec![0u32; total_edges]);
137 let weights_vec = ParallelVec::new(vec![0.0f32; total_edges]);
138
139 (0..n_samples).into_par_iter().for_each(|row| {
140 let row_start = graph.indptr().index(row);
141 let row_end = graph.indptr().index(row + 1);
142 let row_indices = &graph.indices()[row_start..row_end];
143 let row_data = &graph.data()[row_start..row_end];
144
145 let out_start = edge_offsets[row];
146 let mut offset = 0;
147
148 for (&col, &val) in row_indices.iter().zip(row_data) {
149 if val >= threshold {
150 unsafe {
152 head_vec.write(out_start + offset, row as u32);
153 tail_vec.write(out_start + offset, col);
154 weights_vec.write(out_start + offset, val);
155 }
156 offset += 1;
157 }
158 }
159 });
160
161 let head = head_vec.into_inner();
162 let tail = tail_vec.into_inner();
163 let weights = weights_vec.into_inner();
164 info!(
165 duration_ms = started.elapsed().as_millis(),
166 "optimizer edge extraction complete"
167 );
168
169 let started = Instant::now();
171 let weights_array = Array1::from(weights);
172 let epochs_per_sample = make_epochs_per_sample(&weights_array.view(), total_epochs);
173
174 let head = Array1::from(head);
175 let tail = Array1::from(tail);
176 info!(
177 duration_ms = started.elapsed().as_millis(),
178 "optimizer epochs_per_sample complete"
179 );
180
181 let started = Instant::now();
184 let mut embedding = init;
185 let n_rows = embedding.shape()[0];
186 let n_dims = embedding.shape()[1];
187
188 let (mins, maxs) = (0..n_rows)
190 .into_par_iter()
191 .fold(
192 || (vec![f32::INFINITY; n_dims], vec![f32::NEG_INFINITY; n_dims]),
193 |(mut mins, mut maxs), i| {
194 let row = embedding.row(i);
195 for (d, &v) in row.iter().enumerate() {
196 mins[d] = mins[d].min(v);
197 maxs[d] = maxs[d].max(v);
198 }
199 (mins, maxs)
200 },
201 )
202 .reduce(
203 || (vec![f32::INFINITY; n_dims], vec![f32::NEG_INFINITY; n_dims]),
204 |(mut mins1, mut maxs1), (mins2, maxs2)| {
205 for d in 0..mins1.len() {
206 mins1[d] = mins1[d].min(mins2[d]);
207 maxs1[d] = maxs1[d].max(maxs2[d]);
208 }
209 (mins1, maxs1)
210 },
211 );
212
213 let scales: Vec<f32> = mins
215 .iter()
216 .zip(&maxs)
217 .map(|(&min, &max)| {
218 let range = max - min;
219 if range > 0.0 { 10.0 / range } else { 0.0 }
220 })
221 .collect();
222
223 let flat = embedding.as_slice_mut().unwrap();
225 flat.par_iter_mut().enumerate().for_each(|(idx, v)| {
226 let d = idx % n_dims;
227 if scales[d] > 0.0 {
228 *v = (*v - mins[d]) * scales[d];
229 }
230 });
231 info!(
232 duration_ms = started.elapsed().as_millis(),
233 "optimizer embedding normalization complete"
234 );
235
236 let started = Instant::now();
239 let neg_rate = negative_sample_rate as f64;
240 let eps_slice = epochs_per_sample.as_slice().unwrap();
241
242 let epoch_of_next_sample = Array1::from(eps_slice.par_iter().copied().collect::<Vec<_>>());
243 info!(
244 duration_ms = started.elapsed().as_millis(),
245 "optimizer epoch_of_next_sample complete"
246 );
247
248 let started = Instant::now();
249 let epochs_per_negative_sample = Array1::from(
250 eps_slice
251 .par_iter()
252 .map(|&eps| eps / neg_rate)
253 .collect::<Vec<_>>(),
254 );
255 info!(
256 duration_ms = started.elapsed().as_millis(),
257 "optimizer epochs_per_negative_sample complete"
258 );
259
260 let started = Instant::now();
261 let epoch_of_next_negative_sample = Array1::from(
262 epochs_per_negative_sample
263 .as_slice()
264 .unwrap()
265 .par_iter()
266 .copied()
267 .collect::<Vec<_>>(),
268 );
269 info!(
270 duration_ms = started.elapsed().as_millis(),
271 "optimizer epoch_of_next_negative_sample complete"
272 );
273
274 Self {
275 manifold,
276 head,
277 tail,
278 epochs_per_sample,
279 embedding,
280 epoch_of_next_sample,
281 epoch_of_next_negative_sample,
282 epochs_per_negative_sample,
283 current_epoch: 0,
284 total_epochs,
285 gamma,
286 initial_alpha,
287 negative_sample_rate: negative_sample_rate as f64,
288 metric_type,
289 }
290 }
291
292 pub fn step_epochs(&mut self, n: usize, output_metric: &dyn Metric) {
298 assert!(
299 self.current_epoch + n <= self.total_epochs,
300 "Cannot step {} epochs: would exceed total_epochs {} (current: {})",
301 n,
302 self.total_epochs,
303 self.current_epoch
304 );
305
306 let start_epoch = self.current_epoch;
307 let end_epoch = self.current_epoch + n;
308
309 let n_vertices = self.manifold.n_vertices;
310 let a = self.manifold.a;
311 let b = self.manifold.b;
312
313 let mut embedding_copy = self.embedding.clone();
315
316 for epoch in start_epoch..end_epoch {
317 let alpha = self.initial_alpha * (1.0 - (epoch as f32 / self.total_epochs as f32));
318
319 match self.metric_type {
320 MetricType::Euclidean => {
321 optimize_layout_euclidean_single_epoch_stateful(
323 &mut self.embedding.view_mut(),
324 &mut embedding_copy.view_mut(),
325 &self.head.view(),
326 &self.tail.view(),
327 n_vertices,
328 &self.epochs_per_sample.view(),
329 a,
330 b,
331 self.gamma,
332 alpha,
333 &mut self.epochs_per_negative_sample,
334 &mut self.epoch_of_next_sample,
335 &mut self.epoch_of_next_negative_sample,
336 epoch,
337 true, true, );
340 }
341 MetricType::Generic => {
342 optimize_layout_generic_single_epoch_stateful(
344 &mut self.embedding.view_mut(),
345 &mut embedding_copy.view_mut(),
346 &self.head.view(),
347 &self.tail.view(),
348 n_vertices,
349 &self.epochs_per_sample.view(),
350 a,
351 b,
352 self.gamma,
353 alpha,
354 &mut self.epochs_per_negative_sample,
355 &mut self.epoch_of_next_sample,
356 &mut self.epoch_of_next_negative_sample,
357 epoch,
358 true, output_metric,
360 );
361 }
362 }
363 }
364
365 self.current_epoch = end_epoch;
366 }
367
368 pub fn current_epoch(&self) -> usize {
370 self.current_epoch
371 }
372
373 pub fn total_epochs(&self) -> usize {
375 self.total_epochs
376 }
377
378 pub fn remaining_epochs(&self) -> usize {
380 self.total_epochs - self.current_epoch
381 }
382
383 pub fn embedding(&self) -> ArrayView2<'_, f32> {
385 self.embedding.view()
386 }
387
388 pub fn manifold(&self) -> &LearnedManifold {
390 &self.manifold
391 }
392
393 pub fn into_fitted(self, config: UmapConfig) -> FittedUmap {
398 FittedUmap {
399 manifold: self.manifold,
400 embedding: self.embedding,
401 config,
402 }
403 }
404}