predict_broadcast_shape

Function predict_broadcast_shape 

Source
pub fn predict_broadcast_shape(
    a_shape: &[i64],
    b_shape: &[i64],
) -> Result<Shape, TensorError>
Expand description

Predicts the broadcasted shape resulting from broadcasting two arrays.

The predict_broadcast_shape function computes the resulting shape when two arrays with shapes a_shape and b_shape are broadcast together. Broadcasting is a technique that allows arrays of different shapes to be used together in arithmetic operations by “stretching” one or both arrays so that they have compatible shapes.

§Parameters

  • a_shape: A slice of i64 representing the shape of the first array.
  • b_shape: A slice of i64 representing the shape of the second array.

§Returns

  • Ok(Shape): The resulting broadcasted shape as a Shape object if broadcasting is possible.
  • Err(anyhow::Error): An error if the shapes cannot be broadcast together.

§Broadcasting Rules

The broadcasting rules determine how two arrays of different shapes can be broadcast together:

  1. Alignment: The shapes are right-aligned, meaning that the last dimensions are compared first. If one shape has fewer dimensions, it is left-padded with ones to match the other shape’s length.

  2. Dimension Compatibility: For each dimension from the last to the first:

    • If the dimensions are equal, they are compatible.
    • If one of the dimensions is 1, the array in that dimension can be broadcast to match the other dimension.
    • If the dimensions are not equal and neither is 1, broadcasting is not possible.

§Example

// Assuming Shape and the necessary imports are defined appropriately.

let a_shape = &[8, 1, 6, 1];
let b_shape = &[7, 1, 5];

match predict_broadcast_shape(a_shape, b_shape) {
    Ok(result_shape) => {
        assert_eq!(result_shape, Shape::from(vec![8, 7, 6, 5]));
        println!("Broadcasted shape: {:?}", result_shape);
    },
    Err(e) => {
        println!("Error: {}", e);
    },
}

In this example:

  • a_shape has shape [8, 1, 6, 1].
  • b_shape has shape [7, 1, 5].
  • After padding b_shape to [1, 7, 1, 5], the shapes are compared element-wise from the last dimension.
  • The resulting broadcasted shape is [8, 7, 6, 5].

§Notes

  • The function assumes that shapes are represented as slices of i64.
  • The function uses a helper function try_pad_shape to pad the shorter shape with ones on the left.
  • If broadcasting is not possible, the function returns an error indicating the dimension at which the incompatibility occurs.

§Errors

  • Returns an error if at any dimension the sizes differ and neither is 1, indicating that broadcasting cannot be performed.

§Implementation Details

  • The function first determines which of the two shapes is longer and which is shorter.
  • The shorter shape is padded on the left with ones to match the length of the longer shape.
  • It then iterates over the dimensions, comparing corresponding dimensions from each shape:
    • If the dimensions are equal or one of them is 1, the resulting dimension is set to the maximum of the two.
    • If neither condition is met, an error is returned.