Skip to main content

rlx_umap/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Parametric UMAP on RLX — k-NN building blocks + full [`Umap::fit`](umap::Umap::fit) API matching [fast-umap](https://github.com/eugenehp/fast-umap).
17//!
18//! - [`Umap::fit`] / [`fit_with_progress`](umap::Umap::fit_with_progress) — sparse cross-entropy training
19//! - [`FittedUmap::transform`] — inference with training z-score stats
20//! - [`FittedUmap::save`] / [`load`] — safetensors or GGUF (`.ruama` load only for legacy files)
21//!
22//! Call [`register`] once per process before Session execution.
23//!
24//! ## Quick start
25//!
26//! ```ignore
27//! rlx_umap::register();
28//!
29//! use rlx_ir::{DType, Graph, Shape};
30//! use rlx_umap::knn_indices_and_distances;
31//!
32//! let mut g = Graph::new("knn");
33//! let pairwise = g.input("pairwise", Shape::new(&[64, 64], DType::F32));
34//! let (idx, dist) = knn_indices_and_distances(&mut g, pairwise, 5);
35//! g.set_outputs(vec![idx, dist]);
36//! ```
37
38pub mod knn;
39pub mod knn_attrs;
40pub mod pack;
41pub mod pairwise;
42pub mod parity;
43
44#[cfg(feature = "cpu")]
45pub mod ops;
46
47#[cfg(feature = "graph")]
48pub mod graph;
49
50#[cfg(feature = "bench")]
51pub mod session;
52
53#[cfg(feature = "full")]
54pub mod adam;
55#[cfg(feature = "full")]
56pub mod config;
57#[cfg(feature = "full")]
58pub mod data;
59#[cfg(feature = "full")]
60pub mod encoder;
61#[cfg(feature = "full")]
62pub mod fitted;
63#[cfg(feature = "full")]
64pub mod interrupt;
65#[cfg(feature = "full")]
66pub mod model;
67#[cfg(feature = "full")]
68pub mod model_io;
69#[cfg(feature = "nn-descent")]
70pub mod nn_descent;
71#[cfg(feature = "pca")]
72pub mod pca;
73#[cfg(feature = "full")]
74pub mod prelude;
75#[cfg(feature = "full")]
76pub mod serialize;
77#[cfg(feature = "full")]
78pub mod train;
79#[cfg(feature = "full")]
80pub mod training;
81#[cfg(feature = "full")]
82pub mod umap;
83#[cfg(feature = "full")]
84pub mod utils;
85#[cfg(feature = "full")]
86pub mod weights;
87
88#[cfg(all(feature = "optim", feature = "full"))]
89pub mod optim_adapter;
90
91#[cfg(all(feature = "metal", target_os = "macos"))]
92mod metal_kernels;
93
94#[cfg(all(feature = "mlx", target_os = "macos"))]
95mod mlx_kernels;
96
97pub use knn::{knn_backward_pairwise, knn_forward_packed};
98pub use knn_attrs::KnnAttrs;
99pub use pack::unpack_knn_packed;
100pub use pairwise::{cosine_pairwise_reference, euclidean_pairwise_reference};
101pub use parity::{KnnParityReport, compare_knn, max_pairwise_error};
102
103#[cfg(feature = "cpu")]
104pub use ops::{UMAP_KNN, UMAP_KNN_BWD, register_umap_ops};
105
106#[cfg(feature = "graph")]
107pub use graph::{
108    cosine_knn_graph, cosine_knn_packed_graph, knn_graph, knn_indices_and_distances,
109    pairwise_cosine_graph, pairwise_euclidean_graph, split_knn_packed,
110};
111
112/// Register UMAP custom ops (IR + CPU). Alias of [`register_umap_ops`].
113#[cfg(feature = "cpu")]
114pub fn register() {
115    register_umap_ops();
116}
117
118#[cfg(feature = "full")]
119pub use config::UmapConfig;
120#[cfg(feature = "full")]
121pub use data::{load_csv, load_f64_matrix, load_synthetic, write_embedding_csv};
122#[cfg(feature = "full")]
123pub use fitted::FittedUmap;
124#[cfg(feature = "full")]
125pub use model_io::{
126    EXT_GGUF, EXT_RUAMA, EXT_SAFETENSORS, MODEL_EXT, format_from_path, model_path,
127    model_path_with_ext, weight_shapes,
128};
129#[cfg(feature = "full")]
130pub use rlx_driver::Device;
131#[cfg(feature = "full")]
132pub use serialize::{
133    LoadedModel, ModelMetadata, SaveBundle, load_model, load_weights, save_model, save_weights,
134};
135#[cfg(feature = "full")]
136pub use train::EpochProgress;
137#[cfg(feature = "full")]
138pub use training::{FitOptions, TrainResult, fit, fit_with_progress, train_only};
139#[cfg(feature = "full")]
140pub use umap::Umap;
141#[cfg(feature = "full")]
142pub use utils::NormStats;
143#[cfg(feature = "full")]
144pub use weights::WeightStore;