use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
#[cfg(feature = "std")]
use std::{boxed::Box, collections::HashMap, vec::Vec};
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, vec::Vec};
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
use parking_lot::Mutex;
#[derive(Debug, Clone)]
pub enum GraphNode {
Module(String),
Conditional {
condition: String,
true_branch: Box<GraphNode>,
false_branch: Option<Box<GraphNode>>,
},
Sequence(Vec<GraphNode>),
Parallel {
nodes: Vec<GraphNode>,
combiner: String, },
Loop {
body: Box<GraphNode>,
condition: String,
max_iterations: usize,
},
Function(String),
}
pub struct DynamicGraph {
base: ModuleBase,
modules: HashMap<String, Box<dyn Module>>,
conditions: HashMap<String, Box<dyn Fn(&Tensor) -> bool + Send + Sync>>,
combiners: HashMap<String, Box<dyn Fn(Vec<Tensor>) -> Result<Tensor> + Send + Sync>>,
functions: HashMap<String, Box<dyn Fn(&Tensor) -> Result<Tensor> + Send + Sync>>,
graph: GraphNode,
execution_history: Mutex<Vec<String>>,
}
impl DynamicGraph {
pub fn new() -> Self {
let mut graph = Self {
base: ModuleBase::new(),
modules: HashMap::new(),
conditions: HashMap::new(),
combiners: HashMap::new(),
functions: HashMap::new(),
graph: GraphNode::Sequence(Vec::new()),
execution_history: Mutex::new(Vec::new()),
};
graph.add_combiner(
"concat".to_string(),
Box::new(|tensors: Vec<Tensor>| {
if tensors.is_empty() {
return Err(TorshError::InvalidArgument(
"No tensors to concatenate".to_string(),
));
}
let ndim = tensors[0].ndim();
if ndim == 0 {
return Err(TorshError::InvalidArgument(
"Cannot concatenate 0-dimensional tensors".to_string(),
));
}
let concat_dim = (ndim - 1) as i32;
let tensor_refs: Vec<&Tensor> = tensors.iter().collect();
Tensor::cat(&tensor_refs, concat_dim)
.map_err(|e| TorshError::Other(format!("Concatenation failed: {}", e)))
}),
);
graph.add_combiner(
"add".to_string(),
Box::new(|tensors: Vec<Tensor>| {
if tensors.is_empty() {
return Err(TorshError::InvalidArgument("No tensors to add".to_string()));
}
let mut result = tensors[0].clone();
for tensor in tensors.iter().skip(1) {
result = result.add_op(tensor)?;
}
Ok(result)
}),
);
graph.add_combiner(
"mean".to_string(),
Box::new(|tensors: Vec<Tensor>| {
if tensors.is_empty() {
return Err(TorshError::InvalidArgument(
"No tensors to average".to_string(),
));
}
let mut result = tensors[0].clone();
for tensor in tensors.iter().skip(1) {
result = result.add_op(tensor)?;
}
let count = tensors.len() as f32;
result = result.div_scalar(count)?;
Ok(result)
}),
);
graph
}
pub fn add_module<M: Module + 'static>(&mut self, name: String, module: M) {
self.modules.insert(name, Box::new(module));
}
pub fn add_condition<F>(&mut self, name: String, condition: F)
where
F: Fn(&Tensor) -> bool + Send + Sync + 'static,
{
self.conditions.insert(name, Box::new(condition));
}
pub fn add_combiner<F>(&mut self, name: String, combiner: F)
where
F: Fn(Vec<Tensor>) -> Result<Tensor> + Send + Sync + 'static,
{
self.combiners.insert(name, Box::new(combiner));
}
pub fn add_function<F>(&mut self, name: String, function: F)
where
F: Fn(&Tensor) -> Result<Tensor> + Send + Sync + 'static,
{
self.functions.insert(name, Box::new(function));
}
pub fn set_graph(&mut self, graph: GraphNode) {
self.graph = graph;
}
pub fn sequential(module_names: Vec<String>) -> GraphNode {
GraphNode::Sequence(
module_names
.into_iter()
.map(|name| GraphNode::Module(name))
.collect(),
)
}
pub fn conditional(
condition: String,
true_branch: GraphNode,
false_branch: Option<GraphNode>,
) -> GraphNode {
GraphNode::Conditional {
condition,
true_branch: Box::new(true_branch),
false_branch: false_branch.map(Box::new),
}
}
pub fn parallel(nodes: Vec<GraphNode>, combiner: String) -> GraphNode {
GraphNode::Parallel { nodes, combiner }
}
pub fn loop_graph(body: GraphNode, condition: String, max_iterations: usize) -> GraphNode {
GraphNode::Loop {
body: Box::new(body),
condition,
max_iterations,
}
}
fn execute_node(&self, node: &GraphNode, input: &Tensor) -> Result<Tensor> {
let mut history = self.execution_history.lock();
match node {
GraphNode::Module(name) => {
history.push(format!("Module: {}", name));
let module = self.modules.get(name).ok_or_else(|| {
TorshError::InvalidArgument(format!("Module '{}' not found", name))
})?;
module.forward(input)
}
GraphNode::Conditional {
condition,
true_branch,
false_branch,
} => {
history.push(format!("Conditional: {}", condition));
let cond_fn = self.conditions.get(condition).ok_or_else(|| {
TorshError::InvalidArgument(format!("Condition '{}' not found", condition))
})?;
if cond_fn(input) {
history.push("Taking true branch".to_string());
self.execute_node(true_branch, input)
} else if let Some(false_branch) = false_branch {
history.push("Taking false branch".to_string());
self.execute_node(false_branch, input)
} else {
history.push("No false branch, returning input".to_string());
Ok(input.clone())
}
}
GraphNode::Sequence(nodes) => {
history.push("Sequence execution".to_string());
let mut output = input.clone();
for node in nodes {
output = self.execute_node(node, &output)?;
}
Ok(output)
}
GraphNode::Parallel { nodes, combiner } => {
history.push(format!("Parallel execution with combiner: {}", combiner));
let mut results = Vec::new();
for node in nodes {
results.push(self.execute_node(node, input)?);
}
let combiner_fn = self.combiners.get(combiner).ok_or_else(|| {
TorshError::InvalidArgument(format!("Combiner '{}' not found", combiner))
})?;
combiner_fn(results)
}
GraphNode::Loop {
body,
condition,
max_iterations,
} => {
history.push(format!("Loop execution with condition: {}", condition));
let cond_fn = self.conditions.get(condition).ok_or_else(|| {
TorshError::InvalidArgument(format!("Condition '{}' not found", condition))
})?;
let mut output = input.clone();
let mut iterations = 0;
while cond_fn(&output) && iterations < *max_iterations {
output = self.execute_node(body, &output)?;
iterations += 1;
history.push(format!("Loop iteration: {}", iterations));
}
Ok(output)
}
GraphNode::Function(name) => {
history.push(format!("Function: {}", name));
let function = self.functions.get(name).ok_or_else(|| {
TorshError::InvalidArgument(format!("Function '{}' not found", name))
})?;
function(input)
}
}
}
pub fn get_execution_history(&self) -> Vec<String> {
self.execution_history.lock().clone()
}
pub fn clear_execution_history(&self) {
self.execution_history.lock().clear();
}
pub fn modify_graph<F>(&mut self, modifier: F)
where
F: FnOnce(&mut GraphNode),
{
modifier(&mut self.graph);
}
pub fn get_module(&self, name: &str) -> Option<&dyn Module> {
self.modules.get(name).map(|m| m.as_ref())
}
pub fn replace_module<M: Module + 'static>(&mut self, name: String, module: M) {
self.modules.insert(name, Box::new(module));
}
pub fn remove_module(&mut self, name: &str) -> Option<Box<dyn Module>> {
self.modules.remove(name)
}
pub fn module_count(&self) -> usize {
self.modules.len()
}
pub fn module_names(&self) -> Vec<&String> {
self.modules.keys().collect()
}
pub fn condition_names(&self) -> Vec<&String> {
self.conditions.keys().collect()
}
pub fn combiner_names(&self) -> Vec<&String> {
self.combiners.keys().collect()
}
pub fn function_names(&self) -> Vec<&String> {
self.functions.keys().collect()
}
}
impl Default for DynamicGraph {
fn default() -> Self {
Self::new()
}
}
impl Module for DynamicGraph {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
self.clear_execution_history();
self.execute_node(&self.graph, input)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (module_name, module) in &self.modules {
for (param_name, param) in module.parameters() {
params.insert(format!("{}.{}", module_name, param_name), param);
}
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (module_name, module) in &self.modules {
for (param_name, param) in module.named_parameters() {
params.insert(format!("{}.{}", module_name, param_name), param);
}
}
params
}
fn train(&mut self) {
self.base.set_training(true);
for module in self.modules.values_mut() {
module.train();
}
}
fn eval(&mut self) {
self.base.set_training(false);
for module in self.modules.values_mut() {
module.eval();
}
}
fn training(&self) -> bool {
self.base.training()
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
for module in self.modules.values_mut() {
module.set_training(training);
}
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)?;
for module in self.modules.values_mut() {
module.to_device(device)?;
}
Ok(())
}
fn children(&self) -> Vec<&dyn Module> {
self.modules.values().map(|m| m.as_ref()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MockModule {
base: ModuleBase,
_id: i32,
}
impl MockModule {
fn new(id: i32) -> Self {
Self {
base: ModuleBase::new(),
_id: id,
}
}
}
impl Module for MockModule {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
Ok(input.clone())
}
fn parameters(&self) -> HashMap<String, Parameter> {
HashMap::new()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
HashMap::new()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn training(&self) -> bool {
self.base.training()
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)
}
}
#[test]
fn test_dynamic_graph_creation() {
let graph = DynamicGraph::new();
assert_eq!(graph.module_count(), 0);
assert!(graph.module_names().is_empty());
assert!(graph.training());
}
#[test]
fn test_module_management() {
let mut graph = DynamicGraph::new();
graph.add_module("mock1".to_string(), MockModule::new(1));
graph.add_module("mock2".to_string(), MockModule::new(2));
assert_eq!(graph.module_count(), 2);
assert!(graph.get_module("mock1").is_some());
assert!(graph.get_module("nonexistent").is_none());
let removed = graph.remove_module("mock1");
assert!(removed.is_some());
assert_eq!(graph.module_count(), 1);
}
#[test]
fn test_graph_node_creation() {
let seq_graph =
DynamicGraph::sequential(vec!["module1".to_string(), "module2".to_string()]);
match seq_graph {
GraphNode::Sequence(nodes) => {
assert_eq!(nodes.len(), 2);
}
_ => panic!("Expected Sequence node"),
}
let cond_graph = DynamicGraph::conditional(
"test_condition".to_string(),
GraphNode::Module("true_module".to_string()),
Some(GraphNode::Module("false_module".to_string())),
);
match cond_graph {
GraphNode::Conditional { condition, .. } => {
assert_eq!(condition, "test_condition");
}
_ => panic!("Expected Conditional node"),
}
}
#[test]
fn test_default_combiners() {
let graph = DynamicGraph::new();
let combiners = graph.combiner_names();
assert!(combiners.iter().any(|&name| name == "add"));
assert!(combiners.iter().any(|&name| name == "mean"));
assert!(combiners.iter().any(|&name| name == "concat"));
}
#[test]
fn test_execution_history() {
let graph = DynamicGraph::new();
assert!(graph.get_execution_history().is_empty());
graph.clear_execution_history();
assert!(graph.get_execution_history().is_empty());
}
}