use anyhow::{anyhow, Result};
use rlx_ir::{DType, Graph, NodeId, Shape};
pub fn concat_descriptor_outputs(
g: &mut Graph,
outputs: &[NodeId],
nf: usize,
nloc: usize,
) -> Result<(NodeId, usize)> {
if outputs.is_empty() {
return Err(anyhow!("hybrid: at least one sub-descriptor is required"));
}
let mut total = 0usize;
for &n in outputs {
let s = g.shape(n);
if s.rank() != 3 {
return Err(anyhow!(
"hybrid: each sub-descriptor must have rank 3 [nf, nloc, d]"
));
}
let d = match s.dim(2) {
rlx_ir::Dim::Static(n) => n,
_ => return Err(anyhow!("hybrid: sub-descriptor last dim must be static")),
};
total += d;
}
let out_shape = Shape::new(&[nf, nloc, total], DType::F32);
let node = g.concat(outputs.to_vec(), 2, out_shape);
Ok((node, total))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hybrid_concat_dims_match() {
let mut g = Graph::new("hybrid");
let nf = 1;
let nloc = 4;
let a = g.input("a", Shape::new(&[nf, nloc, 16], DType::F32));
let b = g.input("b", Shape::new(&[nf, nloc, 24], DType::F32));
let (out, total) = concat_descriptor_outputs(&mut g, &[a, b], nf, nloc).unwrap();
assert_eq!(total, 40);
let s = g.shape(out);
assert_eq!(s.rank(), 3);
assert_eq!(s.dim(2), rlx_ir::Dim::Static(40));
}
}