import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class MultiHeadAttention(nn.Module):
def __init__(self, dim: int, num_heads: int):
super().__init__()
assert dim % num_heads == 0
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.Wq = nn.Linear(dim, dim, bias=False)
self.Wk = nn.Linear(dim, dim, bias=False)
self.Wv = nn.Linear(dim, dim, bias=False)
self.Wo = nn.Linear(dim, dim, bias=False)
def forward(self, q: torch.Tensor, kv: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, N_q, _ = q.shape
N_kv = kv.shape[1]
Q = self.Wq(q).view(B, N_q, self.num_heads, self.head_dim).transpose(1, 2)
K = self.Wk(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
V = self.Wv(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
attn = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
if mask is not None:
attn = attn.masked_fill(mask.unsqueeze(1), float('-inf'))
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, N_q, self.dim)
return self.Wo(out)
class TransformerEncoderLayer(nn.Module):
def __init__(self, dim: int, num_heads: int, ff_dim: int):
super().__init__()
self.mha = MultiHeadAttention(dim, num_heads)
self.norm1 = nn.LayerNorm(dim)
self.ff = nn.Sequential(
nn.Linear(dim, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, dim),
)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.norm1(x + self.mha(x, x, mask))
x = self.norm2(x + self.ff(x))
return x
class AMEncoder(nn.Module):
def __init__(self, dim: int, num_heads: int, num_layers: int, ff_dim: int):
super().__init__()
self.input_proj = nn.Linear(3, dim)
self.layers = nn.ModuleList([
TransformerEncoderLayer(dim, num_heads, ff_dim)
for _ in range(num_layers)
])
def forward(self, locs: torch.Tensor, demands: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x = torch.cat([locs, demands], dim=-1) x = self.input_proj(x)
for layer in self.layers:
x = layer(x)
graph_embed = x.mean(dim=1, keepdim=True) return x, graph_embed
class TraceableDecoder(nn.Module):
def __init__(self, dim: int, num_heads: int):
super().__init__()
self.dim = dim
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.tanh_clip = 10.0
self.context_proj = nn.Linear(dim * 3, dim)
self.Wq = nn.Linear(dim, dim, bias=False)
self.Wk = nn.Linear(dim, dim, bias=False)
def forward(self, node_embeddings: torch.Tensor,
graph_embed: torch.Tensor,
demands: torch.Tensor,
capacity: torch.Tensor,
max_steps: int) -> Tuple[torch.Tensor, torch.Tensor]:
B, Np1, _ = node_embeddings.shape
device = node_embeddings.device
keys = self.Wk(node_embeddings)
first_node_embed = node_embeddings[:, 0:1, :]
actions_list = []
log_p_list = []
visited = torch.zeros(B, Np1, dtype=torch.bool, device=device)
visited[:, 0] = True
cur_loc = torch.zeros(B, dtype=torch.long, device=device) rem_cap = capacity.clone()
for _ in range(max_steps):
cur_embed = node_embeddings[torch.arange(B), cur_loc].unsqueeze(1)
ctx = torch.cat([graph_embed, first_node_embed, cur_embed], dim=-1) context = self.context_proj(ctx).squeeze(1)
query = self.Wq(context)
compat = (query.unsqueeze(1) * keys).sum(dim=-1) * self.scale compat = self.tanh_clip * torch.tanh(compat / self.tanh_clip)
compat = compat.masked_fill(visited, float('-inf'))
demand_exceeds = demands > rem_cap
compat = compat.masked_fill(demand_exceeds, float('-inf'))
probs = F.softmax(compat, dim=-1)
action = torch.argmax(probs, dim=-1) log_prob = torch.log(probs[torch.arange(B), action] + 1e-20)
actions_list.append(action)
log_p_list.append(log_prob)
visited = visited | (torch.arange(Np1, device=device).unsqueeze(0) == action.unsqueeze(1))
cur_loc = action
at_depot = (action == 0).unsqueeze(1) rem_cap = torch.where(at_depot, capacity, rem_cap)
consumed = demands[torch.arange(B), action].unsqueeze(1) rem_cap = rem_cap - consumed rem_cap = torch.where(at_depot, capacity, rem_cap + consumed)
actions = torch.stack(actions_list, dim=-1) log_p = torch.stack(log_p_list, dim=-1)
return actions, log_p
class TraceableCVRPModel(nn.Module):
def __init__(self, embedding_dim: int = 128, num_heads: int = 8,
num_encoder_layers: int = 6, ff_hidden_dim: int = 512):
super().__init__()
self.encoder = AMEncoder(embedding_dim, num_heads, num_encoder_layers, ff_hidden_dim)
self.decoder = TraceableDecoder(embedding_dim, num_heads)
def forward(self, locs: torch.Tensor, demands: torch.Tensor,
capacity: torch.Tensor, max_steps: int) -> Tuple[torch.Tensor, torch.Tensor]:
demands_flat = demands.squeeze(-1) node_emb, graph_emb = self.encoder(locs, demands)
actions, log_p = self.decoder(node_emb, graph_emb, demands_flat, capacity, max_steps)
return actions, log_p
class ONNXExportWrapper(nn.Module):
def __init__(self, model: TraceableCVRPModel):
super().__init__()
self.model = model
def forward(self, locs: torch.Tensor, demand: torch.Tensor,
capacity: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
B = locs.shape[0]
N = locs.shape[1]
max_steps = N * 3
depot_loc = torch.full((B, 1, 2), 0.5, device=locs.device, dtype=locs.dtype)
full_locs = torch.cat([depot_loc, locs], dim=1)
depot_demand = torch.zeros(B, 1, 1, device=demand.device, dtype=demand.dtype)
full_demands = torch.cat([depot_demand, demand], dim=1)
actions, log_p = self.model(full_locs, full_demands, capacity, max_steps)
return actions, log_p
def export_trained_to_onnx(checkpoint_path: str, onnx_path: str,
onnx_problem_size: int = 50):
import os
from train_cvrp import CVPRouteModel, CVRPConfig
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
train_config = ckpt["config"]
model = TraceableCVRPModel(
embedding_dim=train_config.embedding_dim,
num_heads=train_config.num_heads,
num_encoder_layers=train_config.num_encoder_layers,
ff_hidden_dim=train_config.ff_hidden_dim,
)
trained_model = CVPRouteModel(train_config)
trained_model.load_state_dict(ckpt["model_state_dict"])
model.encoder.load_state_dict(trained_model.encoder.state_dict())
model.decoder.Wq.load_state_dict(trained_model.decoder.Wq.state_dict())
model.decoder.Wk.load_state_dict(trained_model.decoder.Wk.state_dict())
model.decoder.context_proj.load_state_dict(trained_model.decoder.context_proj.state_dict())
model.eval()
wrapper = ONNXExportWrapper(model)
wrapper.eval()
dummy_locs = torch.randn(1, onnx_problem_size, 2)
dummy_demand = torch.ones(1, onnx_problem_size, 1)
dummy_capacity = torch.tensor([[1.0]])
print(f"Exporting ONNX model to {onnx_path} (problem_size={onnx_problem_size})...")
torch.onnx.export(
wrapper,
(dummy_locs, dummy_demand, dummy_capacity),
onnx_path,
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=['locs', 'demand', 'capacity'],
output_names=['actions', 'log_p'],
)
import onnx
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
import onnxruntime as ort
import numpy as np
sess = ort.InferenceSession(onnx_path)
result = sess.run(None, {
'locs': dummy_locs.numpy().astype(np.float32),
'demand': dummy_demand.numpy().astype(np.float32),
'capacity': dummy_capacity.numpy().astype(np.float32),
})
print(f"✅ ONNX model exported and validated: {onnx_path}")
print(f" actions shape: {result[0].shape}, log_p shape: {result[1].shape}")
print(f" First 15 actions: {result[0][0, :15]}")
return onnx_path