pub fn predict_broadcast_shape(
a_shape: &[i64],
b_shape: &[i64],
) -> Result<Shape, TensorError>Expand description
Predicts the broadcasted shape resulting from broadcasting two arrays.
The predict_broadcast_shape function computes the resulting shape when two arrays with shapes
a_shape and b_shape are broadcast together. Broadcasting is a technique that allows arrays of
different shapes to be used together in arithmetic operations by “stretching” one or both arrays
so that they have compatible shapes.
§Parameters
a_shape: A slice ofi64representing the shape of the first array.b_shape: A slice ofi64representing the shape of the second array.
§Returns
Ok(Shape): The resulting broadcasted shape as aShapeobject if broadcasting is possible.Err(anyhow::Error): An error if the shapes cannot be broadcast together.
§Broadcasting Rules
The broadcasting rules determine how two arrays of different shapes can be broadcast together:
-
Alignment: The shapes are right-aligned, meaning that the last dimensions are compared first. If one shape has fewer dimensions, it is left-padded with ones to match the other shape’s length.
-
Dimension Compatibility: For each dimension from the last to the first:
- If the dimensions are equal, they are compatible.
- If one of the dimensions is 1, the array in that dimension can be broadcast to match the other dimension.
- If the dimensions are not equal and neither is 1, broadcasting is not possible.
§Example
// Assuming Shape and the necessary imports are defined appropriately.
let a_shape = &[8, 1, 6, 1];
let b_shape = &[7, 1, 5];
match predict_broadcast_shape(a_shape, b_shape) {
Ok(result_shape) => {
assert_eq!(result_shape, Shape::from(vec![8, 7, 6, 5]));
println!("Broadcasted shape: {:?}", result_shape);
},
Err(e) => {
println!("Error: {}", e);
},
}In this example:
a_shapehas shape[8, 1, 6, 1].b_shapehas shape[7, 1, 5].- After padding
b_shapeto[1, 7, 1, 5], the shapes are compared element-wise from the last dimension. - The resulting broadcasted shape is
[8, 7, 6, 5].
§Notes
- The function assumes that shapes are represented as slices of
i64. - The function uses a helper function
try_pad_shapeto pad the shorter shape with ones on the left. - If broadcasting is not possible, the function returns an error indicating the dimension at which the incompatibility occurs.
§Errors
- Returns an error if at any dimension the sizes differ and neither is 1, indicating that broadcasting cannot be performed.
§Implementation Details
- The function first determines which of the two shapes is longer and which is shorter.
- The shorter shape is padded on the left with ones to match the length of the longer shape.
- It then iterates over the dimensions, comparing corresponding dimensions from each shape:
- If the dimensions are equal or one of them is 1, the resulting dimension is set to the maximum of the two.
- If neither condition is met, an error is returned.