use crate::backend::AutodiffBackend;
#[cfg(feature = "plotters")]
use crate::chart;
use crate::{serialize, train, utils, UMAP};
#[cfg(feature = "plotters")]
pub use chart::{chart_tensor, chart_vector};
use crossbeam_channel::unbounded;
use num::Float;
pub use crate::{
GraphParams, ManifoldParams, Metric, OptimizationParams, UmapConfig,
};
pub use crate::FittedUmap as FittedUmapExport;
pub use train::{TrainingConfig, TrainingConfigBuilder};
pub use utils::generate_test_data;
pub fn umap<B: AutodiffBackend, F: Float>(data: Vec<Vec<F>>) -> UMAP<B>
where
F: num::FromPrimitive + burn::tensor::Element,
{
let (exit_tx, exit_rx) = unbounded();
ctrlc::set_handler(move || exit_tx.send(()).expect("Could not send signal on channel."))
.expect("Error setting Ctrl-C handler");
let output_size = 2;
let device = Default::default();
let model = UMAP::<B>::fit(data, device, output_size, exit_rx);
model
}
pub fn umap_size<B: AutodiffBackend, F: Float>(data: Vec<Vec<F>>, output_size: usize) -> UMAP<B>
where
F: num::FromPrimitive + burn::tensor::Element,
{
let (exit_tx, exit_rx) = unbounded();
ctrlc::set_handler(move || exit_tx.send(()).expect("Could not send signal on channel."))
.expect("Error setting Ctrl-C handler");
let device = Default::default();
let model = UMAP::<B>::fit(data, device, output_size, exit_rx);
model
}