// SPDX-License-Identifier: Apache-2.0
#include <metal_stdlib>
using namespace metal;
struct J2kStoreParams {
uint input_width;
uint source_x;
uint source_y;
uint copy_width;
uint copy_height;
uint output_width;
uint output_x;
uint output_y;
float addend;
};
struct J2kRepeatedStoreParams {
uint input_width;
uint input_height;
uint input_instance_stride;
uint source_x;
uint source_y;
uint copy_width;
uint copy_height;
uint output_width;
uint output_height;
uint output_x;
uint output_y;
float addend;
uint batch_count;
};
struct J2kRepeatedGrayStoreParams {
uint input_width;
uint input_height;
uint source_x;
uint source_y;
uint copy_width;
uint copy_height;
uint output_width;
uint output_height;
uint output_x;
uint output_y;
float addend;
uint batch_count;
float max_value;
float u8_scale;
float u16_scale;
};
struct J2kGrayStoreParams {
uint input_width;
uint source_x;
uint source_y;
uint copy_width;
uint copy_height;
uint output_width;
uint output_x;
uint output_y;
float addend;
float max_value;
float u8_scale;
float u16_scale;
};
struct J2kStoreWindowIndices {
uint src_idx;
uint dst_idx;
};
inline J2kStoreWindowIndices j2k_store_window_indices(
uint input_width,
uint output_width,
uint source_x,
uint source_y,
uint output_x,
uint output_y,
uint2 gid,
uint input_offset,
uint output_offset
) {
const uint src_x = source_x + gid.x;
const uint src_y = source_y + gid.y;
const uint dst_x = output_x + gid.x;
const uint dst_y = output_y + gid.y;
return {
input_offset + src_y * input_width + src_x,
output_offset + dst_y * output_width + dst_x,
};
}
kernel void j2k_store_component(
device const float *input [[buffer(0)]],
device float *output [[buffer(1)]],
constant J2kStoreParams ¶ms [[buffer(2)]],
uint2 gid [[thread_position_in_grid]]
) {
if (gid.x >= params.copy_width || gid.y >= params.copy_height) {
return;
}
const J2kStoreWindowIndices indices = j2k_store_window_indices(
params.input_width,
params.output_width,
params.source_x,
params.source_y,
params.output_x,
params.output_y,
gid,
0u,
0u
);
output[indices.dst_idx] = input[indices.src_idx] + params.addend;
}
kernel void j2k_store_component_repeated(
device const float *input [[buffer(0)]],
device float *output [[buffer(1)]],
constant J2kRepeatedStoreParams ¶ms [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]
) {
if (gid.x >= params.copy_width || gid.y >= params.copy_height || gid.z >= params.batch_count) {
return;
}
const uint output_plane_len = params.output_width * params.output_height;
const J2kStoreWindowIndices indices = j2k_store_window_indices(
params.input_width,
params.output_width,
params.source_x,
params.source_y,
params.output_x,
params.output_y,
gid.xy,
gid.z * params.input_instance_stride,
gid.z * output_plane_len
);
output[indices.dst_idx] = input[indices.src_idx] + params.addend;
}
kernel void j2k_store_component_repeated_gray_u8(
device const float *input [[buffer(0)]],
device uchar *output [[buffer(1)]],
constant J2kRepeatedGrayStoreParams ¶ms [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]
) {
if (gid.x >= params.copy_width || gid.y >= params.copy_height || gid.z >= params.batch_count) {
return;
}
const uint input_plane_len = params.input_width * params.input_height;
const uint output_plane_len = params.output_width * params.output_height;
const J2kStoreWindowIndices indices = j2k_store_window_indices(
params.input_width,
params.output_width,
params.source_x,
params.source_y,
params.output_x,
params.output_y,
gid.xy,
gid.z * input_plane_len,
gid.z * output_plane_len
);
output[indices.dst_idx] = scale_to_u8(input[indices.src_idx] + params.addend, params.max_value, params.u8_scale);
}
kernel void j2k_store_component_repeated_gray_u16(
device const float *input [[buffer(0)]],
device ushort *output [[buffer(1)]],
constant J2kRepeatedGrayStoreParams ¶ms [[buffer(2)]],
uint3 gid [[thread_position_in_grid]]
) {
if (gid.x >= params.copy_width || gid.y >= params.copy_height || gid.z >= params.batch_count) {
return;
}
const uint input_plane_len = params.input_width * params.input_height;
const uint output_plane_len = params.output_width * params.output_height;
const J2kStoreWindowIndices indices = j2k_store_window_indices(
params.input_width,
params.output_width,
params.source_x,
params.source_y,
params.output_x,
params.output_y,
gid.xy,
gid.z * input_plane_len,
gid.z * output_plane_len
);
output[indices.dst_idx] = pack_to_u16(input[indices.src_idx] + params.addend, params.max_value, params.u16_scale);
}
kernel void j2k_store_component_repeated_gray_u8_contiguous(
device const float *input [[buffer(0)]],
device uchar *output [[buffer(1)]],
constant J2kRepeatedGrayStoreParams ¶ms [[buffer(2)]],
uint gid [[thread_position_in_grid]]
) {
const uint plane_len = params.input_width * params.input_height;
const uint total_len = plane_len * params.batch_count;
if (gid >= total_len) {
return;
}
output[gid] = scale_to_u8(input[gid] + params.addend, params.max_value, params.u8_scale);
}
kernel void j2k_store_component_repeated_gray_u16_contiguous(
device const float *input [[buffer(0)]],
device ushort *output [[buffer(1)]],
constant J2kRepeatedGrayStoreParams ¶ms [[buffer(2)]],
uint gid [[thread_position_in_grid]]
) {
const uint plane_len = params.input_width * params.input_height;
const uint total_len = plane_len * params.batch_count;
if (gid >= total_len) {
return;
}
output[gid] = pack_to_u16(input[gid] + params.addend, params.max_value, params.u16_scale);
}
kernel void j2k_store_component_gray_u8(
device const float *input [[buffer(0)]],
device uchar *output [[buffer(1)]],
constant J2kGrayStoreParams ¶ms [[buffer(2)]],
uint2 gid [[thread_position_in_grid]]
) {
if (gid.x >= params.copy_width || gid.y >= params.copy_height) {
return;
}
const J2kStoreWindowIndices indices = j2k_store_window_indices(
params.input_width,
params.output_width,
params.source_x,
params.source_y,
params.output_x,
params.output_y,
gid,
0u,
0u
);
output[indices.dst_idx] = scale_to_u8(input[indices.src_idx] + params.addend, params.max_value, params.u8_scale);
}
kernel void j2k_store_component_gray_u16(
device const float *input [[buffer(0)]],
device ushort *output [[buffer(1)]],
constant J2kGrayStoreParams ¶ms [[buffer(2)]],
uint2 gid [[thread_position_in_grid]]
) {
if (gid.x >= params.copy_width || gid.y >= params.copy_height) {
return;
}
const J2kStoreWindowIndices indices = j2k_store_window_indices(
params.input_width,
params.output_width,
params.source_x,
params.source_y,
params.output_x,
params.output_y,
gid,
0u,
0u
);
output[indices.dst_idx] = pack_to_u16(input[indices.src_idx] + params.addend, params.max_value, params.u16_scale);
}