use std::rc::Rc;
use crate::autograd::Variable;
use crate::nn::Module;
use super::node::*;
use super::FlowBuilder;
pub struct MapBuilder {
fb: FlowBuilder,
body: Box<dyn Module>,
batched: bool,
}
impl MapBuilder {
pub(super) fn new(fb: FlowBuilder, body: Box<dyn Module>) -> Self {
MapBuilder {
fb,
body,
batched: false,
}
}
pub fn batched(mut self) -> Self {
self.batched = true;
self
}
pub fn each(self) -> FlowBuilder {
let mut fb = self.fb;
if fb.err.is_some() || fb.current.len() != 1 {
return Self::err_if_bad(fb);
}
let source = fb.current[0].clone();
wire_map(&mut fb, self.body, source, self.batched);
fb
}
pub fn over(self, tag: &str) -> FlowBuilder {
let mut fb = self.fb;
if fb.err.is_some() || fb.current.len() != 1 {
return Self::err_if_bad(fb);
}
let tap = fb.taps.get(tag).cloned();
match tap {
Some(source) => {
wire_map(&mut fb, self.body, source, self.batched);
fb
}
None => {
fb.err = Some(format!(
"Map.over({:?}) references unknown tag; Map requires a backward reference",
tag
));
fb
}
}
}
pub fn slices(self, n: i64) -> FlowBuilder {
let mut fb = self.fb;
if fb.err.is_some() || fb.current.len() != 1 {
return Self::err_if_bad(fb);
}
if n < 1 {
fb.err = Some(format!("Map.slices requires n >= 1 (got {})", n));
return fb;
}
let cur = fb.current[0].clone();
let body: Rc<dyn Module> = Rc::from(self.body);
let composite: Rc<dyn Module> = body.clone();
let batched = self.batched;
let run = make_slices_func(body, n, batched);
let id = fb.next_id("map");
fb.nodes.insert(
id.clone(),
Node {
id: id.clone(),
input_ports: vec![DEFAULT_INPUT.into()],
output_ports: vec![DEFAULT_OUTPUT.into()],
run,
module: Some(composite),
ref_forward: None,
trace_buf: None,
loop_ports: None,
},
);
fb.edges.push(Edge {
from_node: cur.node_id,
from_port: cur.port,
to_node: id.clone(),
to_port: DEFAULT_INPUT.into(),
});
let node_ref = NodeRef {
node_id: id,
port: DEFAULT_OUTPUT.into(),
};
fb.current = vec![node_ref.clone()];
fb.on_target = Some(node_ref);
fb
}
fn err_if_bad(mut fb: FlowBuilder) -> FlowBuilder {
if fb.err.is_none() {
fb.err = Some("map requires single stream".into());
}
fb
}
}
fn wire_map(fb: &mut FlowBuilder, body: Box<dyn Module>, source: NodeRef, batched: bool) {
let body: Rc<dyn Module> = Rc::from(body);
let composite: Rc<dyn Module> = body.clone();
let run = make_map_each_func(body, batched);
let id = fb.next_id("map");
fb.nodes.insert(
id.clone(),
Node {
id: id.clone(),
input_ports: vec![DEFAULT_INPUT.into()],
output_ports: vec![DEFAULT_OUTPUT.into()],
run,
module: Some(composite),
ref_forward: None,
trace_buf: None,
loop_ports: None,
},
);
fb.edges.push(Edge {
from_node: source.node_id,
from_port: source.port,
to_node: id.clone(),
to_port: DEFAULT_INPUT.into(),
});
let node_ref = NodeRef {
node_id: id,
port: DEFAULT_OUTPUT.into(),
};
fb.current = vec![node_ref.clone()];
fb.on_target = Some(node_ref);
}
fn make_map_each_func(body: Rc<dyn Module>, batched: bool) -> NodeFn {
Box::new(move |inputs: &[Variable]| {
let stream = &inputs[0];
if batched {
let output = body.forward(stream)?;
return Ok(vec![output]);
}
let n = stream.shape()[0];
if n == 0 {
return Ok(vec![stream.clone()]);
}
let first = body.forward(&stream.narrow(0, 0, 1)?)?;
let mut result = first;
for i in 1..n {
let element = stream.narrow(0, i, 1)?;
let out = body.forward(&element)?;
result = result.cat(&out, 0)?;
}
Ok(vec![result])
})
}
fn make_slices_func(body: Rc<dyn Module>, n: i64, batched: bool) -> NodeFn {
Box::new(move |inputs: &[Variable]| {
let source = &inputs[0];
let shape = source.shape();
if shape.len() < 2 {
return Err(crate::tensor::TensorError::new(&format!(
"map slices: input must be at least 2D (got {}D)",
shape.len()
)));
}
let last_dim = *shape.last().unwrap();
if last_dim % n != 0 {
return Err(crate::tensor::TensorError::new(&format!(
"map slices: last dim {} not divisible by {}",
last_dim, n
)));
}
let slice_dim = last_dim / n;
let orig_dim0 = shape[0];
let decomposed = source.reshape(&[orig_dim0 * n, slice_dim])?;
if batched {
let result = body.forward(&decomposed)?;
let result_shape = result.shape();
let out_features = result_shape[1] * n;
let recomposed = result.reshape(&[orig_dim0, out_features])?;
return Ok(vec![recomposed]);
}
let num_rows = orig_dim0 * n;
let first = body.forward(&decomposed.narrow(0, 0, 1)?)?;
let mut stacked = first;
for i in 1..num_rows {
let elem = decomposed.narrow(0, i, 1)?;
let out = body.forward(&elem)?;
stacked = stacked.cat(&out, 0)?;
}
let stacked_shape = stacked.shape();
let out_features = stacked_shape[1] * n;
let recomposed = stacked.reshape(&[orig_dim0, out_features])?;
Ok(vec![recomposed])
})
}