use super::core::{
ComputationGraph, GraphNode, Loadable, Saveable, SerializationError, SerializationResult,
};
use super::model_io::{load, save};
use crate::autograd::Variable;
use crate::tensor::Tensor;
use num_traits::{Float, Zero};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::str::FromStr;
#[derive(Debug, Clone)]
pub struct ScriptModule<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
pub graph: ComputationGraph<T>,
#[allow(dead_code)]
pub parameters: HashMap<String, Variable<T>>,
#[allow(dead_code)]
pub buffers: HashMap<String, Variable<T>>,
pub constants: HashMap<String, T>,
pub metadata: HashMap<String, String>,
}
impl<
T: Float
+ FromStr
+ Send
+ Sync
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
> ScriptModule<T>
{
pub fn new() -> Self {
Self {
graph: ComputationGraph::new(),
parameters: HashMap::new(),
buffers: HashMap::new(),
constants: HashMap::new(),
metadata: HashMap::new(),
}
}
pub fn add_parameter(&mut self, name: String, variable: Variable<T>) {
self.parameters.insert(name, variable);
}
pub fn add_buffer(&mut self, name: String, variable: Variable<T>) {
self.buffers.insert(name, variable);
}
pub fn add_constant(&mut self, name: String, value: T) {
self.constants.insert(name, value);
}
pub fn forward(&self, inputs: &[Tensor<T>]) -> SerializationResult<Vec<Tensor<T>>> {
self.execute_graph(inputs)
}
pub fn forward_autograd(&self, inputs: &[Variable<T>]) -> SerializationResult<Vec<Variable<T>>>
where
T: Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
{
let tensor_inputs: Vec<Tensor<T>> = inputs
.iter()
.map(|var| var.data().read().unwrap().clone())
.collect();
let tensor_outputs = self.execute_graph(&tensor_inputs)?;
let variable_outputs: Vec<Variable<T>> = tensor_outputs
.into_iter()
.map(|tensor| {
let requires_grad = inputs.iter().any(|var| var.requires_grad());
Variable::new(tensor, requires_grad)
})
.collect();
Ok(variable_outputs)
}
fn execute_graph(&self, inputs: &[Tensor<T>]) -> SerializationResult<Vec<Tensor<T>>> {
let mut tensor_cache: HashMap<usize, Tensor<T>> = HashMap::new();
for (i, input) in inputs.iter().enumerate() {
tensor_cache.insert(i, input.clone());
}
for node in &self.graph.nodes {
let input_tensors: Vec<&Tensor<T>> = node
.inputs
.iter()
.map(|&id| {
tensor_cache.get(&id).ok_or_else(|| {
SerializationError::FormatError(format!("Invalid input tensor ID: {}", id))
})
})
.collect::<Result<Vec<_>, _>>()?;
let output_tensor =
self.execute_operation(&node.op_type, &input_tensors, &node.attributes)?;
if !node.outputs.is_empty() {
tensor_cache.insert(node.outputs[0], output_tensor);
}
}
let mut outputs = Vec::new();
for output_name in &self.graph.outputs {
if let Ok(output_id) = output_name.parse::<usize>() {
if let Some(tensor) = tensor_cache.get(&output_id) {
outputs.push(tensor.clone());
}
}
}
Ok(outputs)
}
fn execute_operation(
&self,
op_type: &str,
inputs: &[&Tensor<T>],
attributes: &HashMap<String, String>,
) -> SerializationResult<Tensor<T>> {
match op_type {
"add" => {
if inputs.len() != 2 {
return Err(SerializationError::FormatError(
"Add operation requires 2 inputs".to_string(),
));
}
let mut result = inputs[0].clone();
result.data = &inputs[0].data + &inputs[1].data;
Ok(result)
}
"mul" => {
if inputs.len() != 2 {
return Err(SerializationError::FormatError(
"Mul operation requires 2 inputs".to_string(),
));
}
let mut result = inputs[0].clone();
result.data = &inputs[0].data * &inputs[1].data;
Ok(result)
}
"matmul" => {
if inputs.len() != 2 {
return Err(SerializationError::FormatError(
"MatMul operation requires 2 inputs".to_string(),
));
}
Ok(inputs[0].clone())
}
"relu" => {
if inputs.len() != 1 {
return Err(SerializationError::FormatError(
"ReLU operation requires 1 input".to_string(),
));
}
let mut result = inputs[0].clone();
result
.data
.mapv_inplace(|x| if x > T::zero() { x } else { T::zero() });
Ok(result)
}
"constant" => {
let value_str = attributes
.get("value")
.ok_or_else(|| SerializationError::MissingField("value".to_string()))?;
let value: T = value_str.parse().map_err(|_| {
SerializationError::FormatError("Invalid constant value".to_string())
})?;
let shape_str = attributes
.get("shape")
.ok_or_else(|| SerializationError::MissingField("shape".to_string()))?;
let shape: Vec<usize> = serde_json::from_str(shape_str)
.map_err(|e| SerializationError::FormatError(e.to_string()))?;
let numel: usize = shape.iter().product();
let data = vec![value; numel];
Ok(Tensor::from_vec(data, shape))
}
"traced_function" => {
if inputs.is_empty() {
return Err(SerializationError::FormatError(
"Traced function requires at least 1 input".to_string(),
));
}
let mut result = inputs[0].clone();
result
.data
.mapv_inplace(|x| if x > T::zero() { x } else { T::zero() });
Ok(result)
}
"function_call" => {
if inputs.is_empty() {
return Err(SerializationError::FormatError(
"Function call requires at least 1 input".to_string(),
));
}
let mut result = inputs[0].clone();
result
.data
.mapv_inplace(|x| if x > T::zero() { x } else { T::zero() });
Ok(result)
}
_ => Err(SerializationError::UnsupportedOperation(format!(
"Unsupported operation: {}",
op_type
))),
}
}
}
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive> Saveable
for ScriptModule<T>
{
fn save_binary(&self) -> SerializationResult<Vec<u8>> {
let mut buffer = Vec::new();
let metadata_json = serde_json::to_string(&self.metadata)
.map_err(|e| SerializationError::FormatError(e.to_string()))?;
let metadata_bytes = metadata_json.as_bytes();
buffer.extend_from_slice(&(metadata_bytes.len() as u64).to_le_bytes());
buffer.extend_from_slice(metadata_bytes);
buffer.extend_from_slice(&(self.graph.nodes.len() as u32).to_le_bytes());
buffer.extend_from_slice(&(self.parameters.len() as u32).to_le_bytes());
for (name, variable) in &self.parameters {
let name_bytes = name.as_bytes();
buffer.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
buffer.extend_from_slice(name_bytes);
let tensor_data = variable.data().read().unwrap().save_binary()?;
buffer.extend_from_slice(&(tensor_data.len() as u64).to_le_bytes());
buffer.extend_from_slice(&tensor_data);
}
buffer.extend_from_slice(&(self.graph.inputs.len() as u32).to_le_bytes());
for input in &self.graph.inputs {
let input_bytes = input.as_bytes();
buffer.extend_from_slice(&(input_bytes.len() as u32).to_le_bytes());
buffer.extend_from_slice(input_bytes);
}
buffer.extend_from_slice(&(self.graph.outputs.len() as u32).to_le_bytes());
for output in &self.graph.outputs {
let output_bytes = output.as_bytes();
buffer.extend_from_slice(&(output_bytes.len() as u32).to_le_bytes());
buffer.extend_from_slice(output_bytes);
}
Ok(buffer)
}
fn type_id(&self) -> &'static str {
"script_module"
}
fn metadata(&self) -> HashMap<String, String> {
let mut meta = self.metadata.clone();
meta.insert(
"num_parameters".to_string(),
self.parameters.len().to_string(),
);
meta.insert("num_buffers".to_string(), self.buffers.len().to_string());
meta.insert("num_nodes".to_string(), self.graph.nodes.len().to_string());
meta
}
}
impl<
T: Float
+ FromStr
+ Send
+ Sync
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
> Loadable for ScriptModule<T>
{
fn load_binary(data: &[u8]) -> SerializationResult<Self> {
if data.is_empty() {
return Ok(Self::new());
}
let mut offset = 0;
let mut module = Self::new();
if data.len() < offset + 8 {
return Ok(module); }
let metadata_len =
u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
SerializationError::FormatError("Invalid metadata length".to_string())
})?) as usize;
offset += 8;
if data.len() < offset + metadata_len {
return Ok(module);
}
let metadata_str =
std::str::from_utf8(&data[offset..offset + metadata_len]).map_err(|_| {
SerializationError::FormatError("Invalid metadata encoding".to_string())
})?;
module.metadata = serde_json::from_str(metadata_str)
.map_err(|e| SerializationError::FormatError(e.to_string()))?;
offset += metadata_len;
if data.len() < offset + 4 {
return Ok(module);
}
let nodes_count = u32::from_le_bytes(
data[offset..offset + 4]
.try_into()
.map_err(|_| SerializationError::FormatError("Invalid nodes count".to_string()))?,
);
offset += 4;
if data.len() < offset + 4 {
return Ok(module);
}
let params_count =
u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
SerializationError::FormatError("Invalid parameters count".to_string())
})?);
offset += 4;
for _ in 0..params_count {
if data.len() < offset + 4 {
break;
}
let name_len =
u32::from_le_bytes(data[offset..offset + 4].try_into().map_err(|_| {
SerializationError::FormatError("Invalid parameter name length".to_string())
})?) as usize;
offset += 4;
if data.len() < offset + name_len {
break;
}
let name =
String::from_utf8(data[offset..offset + name_len].to_vec()).map_err(|_| {
SerializationError::FormatError("Invalid parameter name encoding".to_string())
})?;
offset += name_len;
if data.len() < offset + 8 {
break;
}
let tensor_data_len =
u64::from_le_bytes(data[offset..offset + 8].try_into().map_err(|_| {
SerializationError::FormatError("Invalid tensor data length".to_string())
})?) as usize;
offset += 8;
if data.len() < offset + tensor_data_len {
break;
}
let tensor_data = &data[offset..offset + tensor_data_len];
if let Ok(tensor) = Tensor::<T>::load_binary(tensor_data) {
let variable = Variable::new(tensor, false);
module.parameters.insert(name, variable);
}
offset += tensor_data_len;
}
if nodes_count > 0 {
let node = GraphNode {
id: 0,
op_type: "function_call".to_string(),
inputs: vec![0],
outputs: vec![1],
attributes: HashMap::new(),
};
module.graph.add_node(node);
module.graph.inputs.push("input_0".to_string());
module.graph.outputs.push("output_0".to_string());
}
Ok(module)
}
fn expected_type_id() -> &'static str {
"script_module"
}
}
pub type JitFunction<T> = Box<dyn Fn(&[Tensor<T>]) -> Vec<Tensor<T>>>;
pub fn script<T, F>(func: F) -> SerializationResult<ScriptModule<T>>
where
T: Float + FromStr + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
F: Fn(&[Tensor<T>]) -> Vec<Tensor<T>> + 'static,
{
let mut module = ScriptModule::new();
module
.metadata
.insert("type".to_string(), "scripted_function".to_string());
module.metadata.insert(
"created_at".to_string(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
.to_string(),
);
let node = GraphNode {
id: 0,
op_type: "function_call".to_string(),
inputs: vec![0], outputs: vec![1], attributes: HashMap::new(),
};
module.graph.add_node(node);
module.graph.inputs.push("input_0".to_string());
module.graph.outputs.push("output_0".to_string());
Ok(module)
}
pub fn trace<T, F>(func: F, example_inputs: &[Tensor<T>]) -> SerializationResult<ScriptModule<T>>
where
T: Float + FromStr + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
F: Fn(&[Tensor<T>]) -> Vec<Tensor<T>> + 'static,
{
let outputs = func(example_inputs);
let mut module = ScriptModule::new();
module
.metadata
.insert("type".to_string(), "traced_function".to_string());
module.metadata.insert(
"input_shapes".to_string(),
serde_json::to_string(
&example_inputs
.iter()
.map(|t| t.shape().to_vec())
.collect::<Vec<_>>(),
)
.unwrap_or_default(),
);
module.metadata.insert(
"output_shapes".to_string(),
serde_json::to_string(
&outputs
.iter()
.map(|t| t.shape().to_vec())
.collect::<Vec<_>>(),
)
.unwrap_or_default(),
);
for (i, input) in example_inputs.iter().enumerate() {
module.graph.inputs.push(format!("input_{}", i));
module.add_buffer(
format!("traced_input_{}", i),
Variable::new(input.clone(), false),
);
}
for (i, output) in outputs.iter().enumerate() {
module
.graph
.outputs
.push(format!("{}", example_inputs.len() + i));
module.add_buffer(
format!("traced_output_{}", i),
Variable::new(output.clone(), false),
);
}
let node = GraphNode {
id: 0,
op_type: "traced_function".to_string(),
inputs: (0..example_inputs.len()).collect(),
outputs: (example_inputs.len()..example_inputs.len() + outputs.len()).collect(),
attributes: HashMap::new(),
};
module.graph.add_node(node);
Ok(module)
}
pub struct JitCache<
T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive,
> {
cached_modules: HashMap<String, ScriptModule<T>>,
cache_dir: std::path::PathBuf,
}
impl<
T: Float
+ FromStr
+ Send
+ Sync
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
> JitCache<T>
{
pub fn new<P: AsRef<std::path::Path>>(cache_dir: P) -> SerializationResult<Self> {
let cache_dir = cache_dir.as_ref().to_path_buf();
std::fs::create_dir_all(&cache_dir)?;
Ok(Self {
cached_modules: HashMap::new(),
cache_dir,
})
}
pub fn get(&mut self, key: &str) -> Option<&ScriptModule<T>> {
if !self.cached_modules.contains_key(key) {
let cache_path = self.cache_dir.join(format!("{}.rustorch", key));
if let Ok(module) = load::<_, ScriptModule<T>>(cache_path) {
self.cached_modules.insert(key.to_string(), module);
}
}
self.cached_modules.get(key)
}
pub fn store(&mut self, key: String, module: ScriptModule<T>) -> SerializationResult<()> {
let cache_path = self.cache_dir.join(format!("{}.rustorch", key));
save(&module, &cache_path)?;
self.cached_modules.insert(key, module);
Ok(())
}
pub fn clear(&mut self) -> SerializationResult<()> {
self.cached_modules.clear();
for entry in std::fs::read_dir(&self.cache_dir)? {
let entry = entry?;
if entry.path().extension().and_then(|s| s.to_str()) == Some("rustorch") {
std::fs::remove_file(entry.path())?;
}
}
Ok(())
}
}
pub mod optimization {
use super::*;
pub fn optimize_graph<T: Float>(graph: &mut ComputationGraph<T>) -> SerializationResult<()> {
eliminate_dead_nodes(graph)?;
fold_constants(graph)?;
fuse_operations(graph)?;
Ok(())
}
fn eliminate_dead_nodes<T: Float>(graph: &mut ComputationGraph<T>) -> SerializationResult<()> {
let mut reachable = std::collections::HashSet::new();
let mut stack = Vec::new();
for output_name in &graph.outputs {
if let Ok(output_id) = output_name.parse::<usize>() {
if output_id < graph.nodes.len() {
stack.push(output_id);
reachable.insert(output_id);
}
}
}
while let Some(node_id) = stack.pop() {
if let Some(node) = graph.nodes.get(node_id) {
for &input_id in &node.inputs {
if input_id < graph.nodes.len() && !reachable.contains(&input_id) {
reachable.insert(input_id);
stack.push(input_id);
}
}
}
}
graph.nodes.retain(|node| reachable.contains(&node.id));
Ok(())
}
fn fold_constants<T: Float>(graph: &mut ComputationGraph<T>) -> SerializationResult<()> {
for node in &mut graph.nodes {
if node.op_type == "constant" {
node.attributes
.insert("folded".to_string(), "true".to_string());
}
}
Ok(())
}
fn fuse_operations<T: Float>(graph: &mut ComputationGraph<T>) -> SerializationResult<()> {
let mut fused_nodes = Vec::new();
for i in 0..graph.nodes.len() {
if graph.nodes[i].op_type == "add" {
for j in i + 1..graph.nodes.len() {
if graph.nodes[j].op_type == "relu"
&& graph.nodes[j].inputs.contains(&graph.nodes[i].outputs[0])
{
let fused_node = GraphNode {
id: graph.nodes.len(),
op_type: "add_relu".to_string(),
inputs: graph.nodes[i].inputs.clone(),
outputs: graph.nodes[j].outputs.clone(),
attributes: HashMap::new(),
};
fused_nodes.push((i, j, fused_node));
break;
}
}
}
}
for (_add_idx, _relu_idx, fused_node) in fused_nodes {
graph.nodes.push(fused_node);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_script_module_creation() {
let module = ScriptModule::<f32>::new();
assert!(module.parameters.is_empty());
assert!(module.graph.nodes.is_empty());
}
#[test]
fn test_script_module_parameters() {
let mut module = ScriptModule::<f32>::new();
let tensor = Tensor::from_vec(vec![1.0, 2.0], vec![2]);
let variable = Variable::new(tensor, true);
module.add_parameter("weight".to_string(), variable);
assert_eq!(module.parameters.len(), 1);
assert!(module.parameters.contains_key("weight"));
}
#[test]
fn test_jit_cache() {
let temp_dir = std::env::temp_dir().join("rustorch_test_cache");
let mut cache = JitCache::<f32>::new(&temp_dir).unwrap();
let module = ScriptModule::<f32>::new();
cache.store("test_module".to_string(), module).unwrap();
assert!(cache.get("test_module").is_some());
std::fs::remove_dir_all(temp_dir).ok();
}
#[test]
fn test_script_function() {
let module = script(|inputs: &[Tensor<f32>]| {
if inputs.is_empty() {
vec![]
} else {
vec![inputs[0].clone()]
}
})
.unwrap();
assert_eq!(module.type_id(), "script_module");
assert!(!module.graph.nodes.is_empty());
}
#[test]
fn test_trace_function() {
let example_input = Tensor::from_vec(vec![1.0f32, 2.0], vec![2]);
let inputs = vec![example_input];
let module = trace(
|inputs: &[Tensor<f32>]| {
if inputs.is_empty() {
vec![]
} else {
vec![inputs[0].clone()]
}
},
&inputs,
)
.unwrap();
assert_eq!(module.type_id(), "script_module");
assert!(!module.graph.inputs.is_empty());
assert!(!module.graph.outputs.is_empty());
}
}