use crate::core::ml::features::InstanceFeatures;
pub use crate::core::vrp::types::SolverHyperparams;
use anyhow::{Context, Result};
use candle_core::{Device, DType, Tensor};
use candle_nn::{linear, Linear, Module, VarBuilder};
use std::path::Path;
const NUM_FEATURES: usize = 28;
const HIDDEN1: usize = 64;
const NUM_OUTPUTS: usize = 5;
pub struct HyperparamPredictor {
lin1: Linear,
lin2: Linear,
device: Device,
}
impl HyperparamPredictor {
pub fn from_file(path: &Path) -> Result<Self> {
let device = crate::core::ml::best_device()?;
let tensors = candle_core::safetensors::load(path, &device)
.with_context(|| format!("Failed to load safetensors from {}", path.display()))?;
let vb = VarBuilder::from_tensors(tensors, DType::F32, &device);
let lin1 = linear(NUM_FEATURES, HIDDEN1, vb.pp("lin1"))?;
let lin2 = linear(HIDDEN1, NUM_OUTPUTS, vb.pp("lin2"))?;
Ok(Self { lin1, lin2, device })
}
pub fn predict(&self, features: &InstanceFeatures) -> Result<SolverHyperparams> {
let x = features.to_vector();
let input = Tensor::from_vec(x, (1, NUM_FEATURES), &self.device)?;
let h1 = self.lin1.forward(&input)?.relu()?;
let out = self.lin2.forward(&h1)?;
let vals: Vec<f32> = out.squeeze(0)?.to_vec1()?;
let v = vals.as_slice();
let max_iter = (100.0 + v[0].clamp(0.0, 1.0) as f64 * 9900.0) as u32;
let temperature = 1.0 + v[1].clamp(0.0, 1.0) as f64 * 999.0;
let tabu_tenure = (1.0 + v[2].clamp(0.0, 1.0) as f64 * 49.0).round() as usize;
let cooling_rate = 0.8 + v[3].clamp(0.0, 1.0) as f64 * 0.199;
let neighbourhood_radius = (1.0 + v[4].clamp(0.0, 1.0) as f64 * 19.0).round() as usize;
Ok(SolverHyperparams {
max_iterations: max_iter,
temperature,
tabu_tenure,
cooling_rate,
neighbourhood_radius,
model_used: true,
})
}
}
pub fn default_model_path() -> std::path::PathBuf {
if let Ok(exe_path) = std::env::current_exe() {
if let Some(exe_dir) = exe_path.parent() {
let p = exe_dir.join("models").join("automl.safetensors");
if p.exists() {
return p;
}
}
}
let p = std::path::PathBuf::from("models/automl.safetensors");
if p.exists() {
return p;
}
std::env::current_exe()
.unwrap_or_else(|_| std::path::PathBuf::from("."))
.parent()
.unwrap_or(std::path::Path::new("."))
.join("models")
.join("automl.safetensors")
}
pub fn predict_hyperparams(features: &InstanceFeatures) -> SolverHyperparams {
let path = default_model_path();
if path.exists() {
match HyperparamPredictor::from_file(&path) {
Ok(model) => {
match model.predict(features) {
Ok(params) => return params,
Err(e) => {
tracing::warn!("Hyperparam predictor inference failed: {}. Falling back to defaults.", e);
}
}
}
Err(e) => {
tracing::warn!("Failed to load hyperparam predictor: {}. Falling back to defaults.", e);
}
}
}
SolverHyperparams::default_fallback()
}
impl SolverHyperparams {
pub fn default_fallback() -> Self {
Self {
max_iterations: 1000,
temperature: 100.0,
tabu_tenure: 7,
cooling_rate: 0.995,
neighbourhood_radius: 3,
model_used: false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::vrp::test_utils::{make_input, make_stop};
#[test]
fn test_predict_hyperparams_fallback() {
let stops = vec![
make_stop(0.0, 0.0, "depot"),
make_stop(1.0, 0.0, "a"),
make_stop(0.0, 1.0, "b"),
];
let input = make_input(stops, 1);
let features = InstanceFeatures::from_input(&input);
let params = predict_hyperparams(&features);
assert!(params.max_iterations >= 100);
assert!(params.temperature >= 1.0);
assert!(params.tabu_tenure >= 1);
assert!(params.cooling_rate >= 0.8 && params.cooling_rate <= 1.0);
assert!(params.neighbourhood_radius >= 1);
}
}