use std::io::Read;
use std::path::PathBuf;
#[cfg(feature = "tch-model")]
pub fn build_pt_mlp_temp(
layer_specs: &[(usize, usize, Vec<f32>, Vec<f32>)],
) -> Result<(Vec<u8>, PathBuf), String> {
use tch::nn::Module;
use tch::{Device, Kind, Tensor, nn};
if layer_specs.is_empty() {
return Err("Empty layer specs".to_string());
}
let vs = nn::VarStore::new(Device::Cpu);
let root = vs.root();
let mut seq = nn::seq();
for (idx, (layer_in, layer_out, weights, biases)) in layer_specs.iter().enumerate() {
let layer_path = root.sub(&format!("layer_{}", idx));
let mut linear_config = nn::LinearConfig::default();
linear_config.bias = true;
let mut linear = nn::linear(
&layer_path,
*layer_in as i64,
*layer_out as i64,
linear_config,
);
let weight_tensor = Tensor::from_slice(weights)
.reshape([*layer_in as i64, *layer_out as i64])
.transpose(0, 1);
let bias_tensor = Tensor::from_slice(biases);
tch::no_grad(|| {
linear.ws.copy_(&weight_tensor);
if let Some(ref mut bs) = linear.bs {
bs.copy_(&bias_tensor);
}
});
seq = seq.add(linear);
if idx < layer_specs.len() - 1 {
seq = seq.add_fn(|x| x.relu());
}
}
let temp_file = tempfile::Builder::new()
.prefix("relayrl_pt_model_")
.suffix(".pt")
.tempfile()
.map_err(|e| format!("Failed to create temp file: {}", e))?;
let temp_path = temp_file.path().to_path_buf();
let in_dim = layer_specs[0].0 as i64;
let example_input = Tensor::zeros([1, in_dim], (Kind::Float, Device::Cpu));
let mut trace_closure = |inputs: &[Tensor]| -> Vec<Tensor> { vec![seq.forward(&inputs[0])] };
let module =
tch::CModule::create_by_tracing("mlp", "forward", &[example_input], &mut trace_closure)
.map_err(|e| format!("Failed to create traced module: {}", e))?;
module
.save(&temp_path)
.map_err(|e| format!("Failed to save model: {}", e))?;
let mut file = std::fs::File::open(&temp_path)
.map_err(|e| format!("Failed to open saved model: {}", e))?;
let mut bytes = Vec::new();
file.read_to_end(&mut bytes)
.map_err(|e| format!("Failed to read model bytes: {}", e))?;
std::mem::forget(temp_file);
Ok((bytes, temp_path))
}
#[cfg(not(feature = "tch-model"))]
pub fn build_pt_mlp_temp(
_layer_specs: &[(usize, usize, Vec<f32>, Vec<f32>)],
) -> Result<(Vec<u8>, PathBuf), String> {
Err("tch-model feature not enabled".to_string())
}
#[cfg(all(test, feature = "tch-model"))]
mod tests {
use super::*;
#[test]
fn test_build_pt_mlp_single_layer() {
let weights = vec![1.0f32, 0.0, 0.0, 1.0]; let biases = vec![0.0f32, 0.0];
let specs = vec![(2usize, 2usize, weights, biases)];
let result = build_pt_mlp_temp(&specs);
assert!(result.is_ok(), "Should successfully build PT model");
let (bytes, path) = result.unwrap();
assert!(!bytes.is_empty(), "PT bytes should not be empty");
assert!(path.exists(), "Temp file should exist");
let _ = std::fs::remove_file(path);
}
#[test]
fn test_build_pt_mlp_empty_layers() {
let result = build_pt_mlp_temp(&[]);
assert!(result.is_err(), "Should fail on empty layer specs");
}
#[test]
fn test_build_pt_mlp_two_layers() {
let w1 = vec![0.1f32; 4 * 8]; let b1 = vec![0.0f32; 8];
let w2 = vec![0.2f32; 8 * 2]; let b2 = vec![0.0f32; 2];
let specs = vec![(4, 8, w1, b1), (8, 2, w2, b2)];
let result = build_pt_mlp_temp(&specs);
assert!(result.is_ok(), "Should successfully build 2-layer PT model");
let (bytes, path) = result.unwrap();
assert!(bytes.len() > 100, "Expected non-trivial PT model bytes");
let _ = std::fs::remove_file(path);
}
}