use anyhow::{Context, Result};
use ort::session::Session;
use ort::value::Value;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NeuralRouteRequest {
pub locations: Vec<[f64; 2]>,
pub demands: Vec<f64>,
pub capacity: f64,
pub model_path: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NeuralRouteResponse {
pub visit_sequence: Vec<usize>,
pub total_cost: f64,
pub solve_time_ms: u64,
}
pub struct NeuralInferenceEngine {
session: Session,
}
impl NeuralInferenceEngine {
pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
let session = Session::builder()?
.commit_from_file(model_path)
.context("Failed to load ONNX model")?;
Ok(Self { session })
}
pub fn solve(&mut self, req: &NeuralRouteRequest) -> Result<NeuralRouteResponse> {
let start_time = std::time::Instant::now();
let n = req.locations.len();
let mut locs_data = Vec::with_capacity(n * 2);
for loc in &req.locations {
locs_data.push(loc[0] as f32);
locs_data.push(loc[1] as f32);
}
let mut demands_data = Vec::with_capacity(n);
for &demand in &req.demands {
demands_data.push(demand as f32);
}
let capacity_data = vec![req.capacity as f32];
let locs_value = Value::from_array(([1, n, 2], locs_data))?;
let demands_value = Value::from_array(([1, n, 1], demands_data))?;
let capacity_value = Value::from_array(([1, 1], capacity_data))?;
let inputs = ort::inputs![
"locs" => locs_value,
"demand" => demands_value,
"capacity" => capacity_value,
];
let outputs = self.session.run(inputs)?;
let output_tensor_value = outputs
.get("actions")
.context("Model output 'actions' not found")?;
let (_shape, data) = output_tensor_value.try_extract_tensor::<i64>()?;
let visit_sequence: Vec<usize> = data
.iter()
.map(|&id| id as usize)
.filter(|&id| id != 0) .collect();
Ok(NeuralRouteResponse {
visit_sequence,
total_cost: 0.0,
solve_time_ms: start_time.elapsed().as_millis() as u64,
})
}
}
pub fn solve_neural(req: &NeuralRouteRequest) -> Result<NeuralRouteResponse> {
let mut engine = NeuralInferenceEngine::new(&req.model_path)?;
engine.solve(req)
}