use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use crate::autograd::Variable;
use crate::nn::Module;
use crate::tensor::Result;
pub(crate) const DEFAULT_INPUT: &str = "input";
pub(crate) const DEFAULT_OUTPUT: &str = "output";
pub(crate) type NodeFn = Box<dyn Fn(&[Variable]) -> Result<Vec<Variable>>>;
pub(crate) type RefForwardFn =
Rc<dyn Fn(&Variable, &HashMap<String, Variable>) -> Result<Variable>>;
pub(crate) type NamedTraceStore = Rc<RefCell<HashMap<String, Vec<Variable>>>>;
pub(crate) struct Node {
pub id: String,
pub input_ports: Vec<String>,
pub output_ports: Vec<String>,
pub run: NodeFn,
pub module: Option<Rc<dyn Module>>,
pub ref_forward: Option<RefForwardFn>,
pub trace_buf: Option<Rc<RefCell<Vec<Variable>>>>,
pub named_trace_buf: Option<NamedTraceStore>,
pub loop_ports: Option<Rc<RefCell<Vec<String>>>>,
}
#[derive(Clone, Debug)]
pub(crate) struct NodeRef {
pub node_id: String,
pub port: String,
}
#[derive(Clone, Debug)]
pub(crate) struct Edge {
pub from_node: String,
pub from_port: String,
pub to_node: String,
pub to_port: String,
}
#[derive(Clone, Debug)]
pub(crate) struct ExposedPort {
#[allow(dead_code)]
pub name: String,
pub node_id: String,
pub port: String,
}
pub(crate) struct ForwardRefSpec {
#[allow(dead_code)]
pub name: String,
pub reader_id: String,
pub writer_id: String,
pub writer_port: String,
}
pub(crate) struct PendingUsing {
pub reader_id: String,
}
pub(crate) fn extract_refs(
ports: &[String],
inputs: &[Variable],
) -> HashMap<String, Variable> {
let mut refs = HashMap::new();
for (i, port) in ports.iter().enumerate() {
if let Some(name) = port.strip_prefix("ref_")
&& i < inputs.len()
{
refs.insert(name.to_string(), inputs[i].clone());
}
}
refs
}
pub(crate) fn wrap_module(module: Rc<dyn Module>) -> NodeFn {
Box::new(move |inputs: &[Variable]| {
let output = module.forward(&inputs[0])?;
Ok(vec![output])
})
}
pub(crate) fn wrap_ref_module(
module: Rc<dyn Module>,
ref_forward: RefForwardFn,
ports: Vec<String>,
) -> NodeFn {
Box::new(move |inputs: &[Variable]| {
let refs = extract_refs(&ports, inputs);
let output = if refs.is_empty() {
module.forward(&inputs[0])?
} else {
ref_forward(&inputs[0], &refs)?
};
Ok(vec![output])
})
}