use anyhow::Result;
use std::os::raw::{c_double, c_int, c_uint};
#[repr(C)]
pub struct FlatGraph {
pub nodes: *mut c_double,
pub node_count: c_uint,
pub edges: *mut c_uint,
pub edge_count: c_uint,
pub start_node: c_uint,
}
unsafe impl Send for FlatGraph {}
unsafe impl Sync for FlatGraph {}
#[repr(C)]
pub struct VerifiedResult {
pub circuit: *mut c_uint,
pub circuit_length: c_uint,
pub total_distance: c_double,
pub success: c_int,
}
unsafe impl Send for VerifiedResult {}
unsafe impl Sync for VerifiedResult {}
pub struct Lean4Bridge {
lean4_available: bool,
}
impl Lean4Bridge {
pub fn new() -> Result<Self> {
#[cfg(feature = "lean4")]
{
tracing::info!("Lean 4 feature enabled, but runtime not yet linked");
Ok(Self {
lean4_available: false,
})
}
#[cfg(not(feature = "lean4"))]
Ok(Self {
lean4_available: false,
})
}
pub fn is_available(&self) -> bool {
self.lean4_available
}
#[cfg(feature = "lean4")]
pub unsafe fn optimize_lean4(
&self,
nodes: *const c_double,
node_count: c_uint,
edges: *const c_uint,
edge_count: c_uint,
start_node: c_uint,
) -> Result<VerifiedResult> {
if self.lean4_available {
let result = self.call_optimize_eulerian(
nodes,
node_count,
edges,
edge_count,
start_node,
);
if result.success == 0 {
return Err(anyhow::anyhow!("Lean 4 optimization failed"));
}
Ok(result)
} else {
tracing::warn!("Lean 4 runtime not available, using Rust fallback");
self.optimize_rust_fallback(nodes, node_count, edges, edge_count, start_node)
}
}
#[cfg(feature = "lean4")]
unsafe fn call_optimize_eulerian(
&self,
_nodes: *const c_double,
_node_count: c_uint,
_edges: *const c_uint,
_edge_count: c_uint,
_start_node: c_uint,
) -> VerifiedResult {
VerifiedResult {
circuit: std::ptr::null_mut(),
circuit_length: 0,
total_distance: 0.0,
success: 0,
}
}
unsafe fn optimize_rust_fallback(
&self,
nodes: *const c_double,
node_count: c_uint,
edges: *const c_uint,
edge_count: c_uint,
start_node: c_uint,
) -> Result<VerifiedResult> {
let nc = node_count as usize;
let ec = edge_count as usize;
if nc == 0 || ec == 0 {
return Err(anyhow::anyhow!("Empty graph"));
}
let node_coords = std::slice::from_raw_parts(nodes, nc * 2);
let edge_indices = std::slice::from_raw_parts(edges, ec * 2);
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); nc];
for i in 0..ec {
let from = edge_indices[i * 2] as usize;
let to = edge_indices[i * 2 + 1] as usize;
if from < nc && to < nc {
adj[from].push(to);
adj[to].push(from);
}
}
let mut used_edges: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
let mut circuit: Vec<usize> = Vec::new();
let mut stack: Vec<usize> = vec![start_node as usize];
while let Some(v) = stack.pop() {
let mut found = false;
for &neighbor in &adj[v] {
let key = if v < neighbor {
(v, neighbor)
} else {
(neighbor, v)
};
if !used_edges.contains(&key) {
used_edges.insert(key);
stack.push(v);
stack.push(neighbor);
found = true;
break;
}
}
if !found {
circuit.push(v);
}
}
circuit.reverse();
let mut total_distance = 0.0_f64;
for i in 0..circuit.len().saturating_sub(1) {
let a = circuit[i];
let b = circuit[i + 1];
if a < nc && b < nc {
let lat1 = node_coords[a * 2];
let lon1 = node_coords[a * 2 + 1];
let lat2 = node_coords[b * 2];
let lon2 = node_coords[b * 2 + 1];
total_distance += haversine_distance(lat1, lon1, lat2, lon2);
}
}
let circuit_bytes: Vec<c_uint> = circuit.iter().map(|&i| i as c_uint).collect();
let circuit_ptr = circuit_bytes.as_ptr() as *mut c_uint;
let circuit_len = circuit_bytes.len() as c_uint;
std::mem::forget(circuit_bytes);
Ok(VerifiedResult {
circuit: circuit_ptr,
circuit_length: circuit_len,
total_distance: total_distance,
success: 1,
})
}
}
impl Drop for Lean4Bridge {
fn drop(&mut self) {
}
}
pub trait FlattenForFFI {
fn flatten_for_ffi(&self) -> FlatGraph;
fn from_verified_result(&self, result: VerifiedResult) -> Result<super::types::OptimizationResult>;
}
impl FlattenForFFI for super::RouteOptimizer {
fn flatten_for_ffi(&self) -> FlatGraph {
let nc = self.nodes.len();
let mut node_coords: Vec<c_double> = Vec::with_capacity(nc * 2);
for node in &self.nodes {
node_coords.push(node.lat);
node_coords.push(node.lon);
}
let mut node_index_map: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
for (i, node) in self.nodes.iter().enumerate() {
node_index_map.insert(node.id.clone(), i);
}
let mut edge_indices: Vec<c_uint> = Vec::new();
for way in &self.ways {
for i in 0..way.nodes.len().saturating_sub(1) {
if let (Some(&from), Some(&to)) = (
node_index_map.get(&way.nodes[i]),
node_index_map.get(&way.nodes[i + 1]),
) {
edge_indices.push(from as c_uint);
edge_indices.push(to as c_uint);
}
}
}
let ec = edge_indices.len() / 2;
let start_node = if let (Some(lat), Some(lon)) = (self.depot_lat, self.depot_lon) {
self.nodes
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
let da = haversine_distance(lat, lon, a.lat, a.lon);
let db = haversine_distance(lat, lon, b.lat, b.lon);
da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i as c_uint)
.unwrap_or(0)
} else {
0
};
let node_coords_ptr = Box::into_raw(node_coords.into_boxed_slice()) as *mut c_double;
let edge_indices_ptr = Box::into_raw(edge_indices.into_boxed_slice()) as *mut c_uint;
FlatGraph {
nodes: node_coords_ptr,
node_count: nc as c_uint,
edges: edge_indices_ptr,
edge_count: ec as c_uint,
start_node,
}
}
fn from_verified_result(&self, result: VerifiedResult) -> Result<super::types::OptimizationResult> {
if result.success == 0 {
return Err(anyhow::anyhow!("Optimization failed (verified result indicates failure)"));
}
let nc = self.nodes.len();
let circuit_len = result.circuit_length as usize;
let mut route: Vec<super::types::RoutePoint> = Vec::with_capacity(circuit_len);
if !result.circuit.is_null() && circuit_len > 0 {
let circuit_indices = unsafe { std::slice::from_raw_parts(result.circuit, circuit_len) };
for &idx in circuit_indices {
let i = idx as usize;
if i < nc {
let node = &self.nodes[i];
route.push(super::types::RoutePoint::with_node_id(
node.lat,
node.lon,
&node.id,
));
}
}
}
let mut total_distance = 0.0_f64;
for i in 0..route.len().saturating_sub(1) {
total_distance += route[i].distance_to(&route[i + 1]);
}
total_distance /= 1000.0;
let mut opt_result = super::types::OptimizationResult::new(route, total_distance);
opt_result.message = if result.success == 1 && result.circuit_length > 0 {
"Verified optimization complete (Rust fallback)".to_string()
} else {
"Optimization failed or produced empty result".to_string()
};
opt_result.calculate_stats();
Ok(opt_result)
}
}
impl Drop for FlatGraph {
fn drop(&mut self) {
unsafe {
if !self.nodes.is_null() && self.node_count > 0 {
let _ = Vec::from_raw_parts(
self.nodes as *mut c_double,
(self.node_count * 2) as usize,
(self.node_count * 2) as usize,
);
}
if !self.edges.is_null() && self.edge_count > 0 {
let _ = Vec::from_raw_parts(
self.edges as *mut c_uint,
(self.edge_count * 2) as usize,
(self.edge_count * 2) as usize,
);
}
}
}
}
pub unsafe fn free_verified_result(result: VerifiedResult) {
if !result.circuit.is_null() && result.circuit_length > 0 {
let _ = Vec::from_raw_parts(
result.circuit as *mut c_uint,
result.circuit_length as usize,
result.circuit_length as usize,
);
}
}
fn haversine_distance(lat1: f64, lon1: f64, lat2: f64, lon2: f64) -> f64 {
const R: f64 = 6_371_000.0;
let lat1_r = lat1.to_radians();
let lat2_r = lat2.to_radians();
let dlat = (lat2 - lat1).to_radians();
let dlon = (lon2 - lon1).to_radians();
let a = (dlat / 2.0).sin().powi(2)
+ lat1_r.cos() * lat2_r.cos() * (dlon / 2.0).sin().powi(2);
let c = 2.0 * a.sqrt().atan2((1.0 - a).sqrt());
R * c
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flat_graph_from_optimizer() {
let mut optimizer = crate::optimizer::RouteOptimizer::new();
optimizer.nodes.push(crate::optimizer::Node::new("n0", 45.5, -73.6));
optimizer.nodes.push(crate::optimizer::Node::new("n1", 45.6, -73.7));
optimizer.ways.push(crate::optimizer::Way::new("w1", vec!["n0".into(), "n1".into()]));
let flat = optimizer.flatten_for_ffi();
assert_eq!(flat.node_count, 2);
assert_eq!(flat.edge_count, 1);
assert!(!flat.nodes.is_null());
assert!(!flat.edges.is_null());
unsafe {
assert_eq!(*flat.nodes, 45.5);
assert_eq!(*flat.nodes.add(1), -73.6);
assert_eq!(*flat.nodes.add(2), 45.6);
assert_eq!(*flat.nodes.add(3), -73.7);
assert_eq!(*flat.edges, 0);
assert_eq!(*flat.edges.add(1), 1);
}
}
#[test]
fn test_lean4_bridge_creation() {
let bridge = Lean4Bridge::new();
assert!(bridge.is_ok());
assert!(!bridge.unwrap().is_available());
}
#[test]
fn test_verified_result_structure() {
let result = VerifiedResult {
circuit: std::ptr::null_mut(),
circuit_length: 0,
total_distance: 1000.0,
success: 1,
};
assert_eq!(result.total_distance, 1000.0);
assert_eq!(result.success, 1);
}
#[test]
fn test_from_verified_result() {
let mut optimizer = crate::optimizer::RouteOptimizer::new();
optimizer.nodes.push(crate::optimizer::Node::new("n0", 45.5, -73.6));
optimizer.nodes.push(crate::optimizer::Node::new("n1", 45.6, -73.7));
let circuit: Vec<c_uint> = vec![0, 1, 0];
let circuit_ptr = circuit.as_ptr() as *mut c_uint;
let circuit_len = circuit.len() as c_uint;
std::mem::forget(circuit);
let result = VerifiedResult {
circuit: circuit_ptr,
circuit_length: circuit_len,
total_distance: 25000.0,
success: 1,
};
let opt_result = optimizer.from_verified_result(result).unwrap();
assert_eq!(opt_result.route.len(), 3);
assert_eq!(opt_result.route[0].node_id, Some("n0".to_string()));
unsafe {
let _ = Vec::from_raw_parts(
circuit_ptr,
circuit_len as usize,
circuit_len as usize,
);
}
}
#[test]
fn test_rust_fallback_optimization() {
let bridge = Lean4Bridge::new().unwrap();
let node_coords: Vec<c_double> = vec![45.5, -73.6, 45.6, -73.7, 45.55, -73.65];
let edge_indices: Vec<c_uint> = vec![0, 1, 1, 2, 2, 0];
let result = unsafe {
bridge.optimize_rust_fallback(
node_coords.as_ptr(),
3,
edge_indices.as_ptr(),
3,
0,
)
};
assert!(result.is_ok());
let verified = result.unwrap();
assert_eq!(verified.success, 1);
assert!(verified.circuit_length > 0);
unsafe { free_verified_result(verified) };
}
#[test]
fn test_haversine_distance() {
let dist = haversine_distance(45.5017, -73.5673, 45.5088, -73.5542);
assert!(dist > 800.0 && dist < 2000.0);
}
}