#include "../node_context.h"
#include "../op_table.h"
#include "../utils.h"
#include "ggml-impl.h"
#include <cstddef>
#include <memory>
#include <openvino/core/shape.hpp>
#include <openvino/core/strides.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/extractimagepatches.hpp>
#include <openvino/op/pad.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/util/attr_types.hpp>
namespace ov {
namespace frontend {
namespace ggml {
namespace op {
OutputVector translate_im2col(const NodeContext & context) {
num_inputs_check(context, 2, 2);
const int32_t * params = context.get_output_op_params();
int32_t s0 = params[0];
int32_t s1 = params[1];
int32_t p0 = params[2];
int32_t p1 = params[3];
int32_t d0 = params[4];
int32_t d1 = params[5];
bool is_2D = params[6] == 1;
ov::Output<Node> res;
ov::Output<Node> image = context.get_input(1);
const ov::Shape kernel_shape = context.get_input(0).get_shape();
const size_t IC = is_2D ? kernel_shape[1] : kernel_shape[2];
const size_t KH = is_2D ? kernel_shape[2] : 1;
const size_t KW = kernel_shape[3];
int32_t stride_w = s0;
int32_t stride_h = is_2D ? s1 : 1;
int32_t pad_w = p0;
int32_t pad_h = is_2D ? p1 : 0;
int32_t dil_w = d0;
int32_t dil_h = is_2D ? d1 : 1;
if (!is_2D) {
const ov::Shape image_shape = image.get_shape();
const size_t N = image_shape[1];
const size_t IW = image_shape[3];
auto image_reshape_shape = ov::op::v0::Constant::create(
ov::element::i64, ov::Shape{4},
std::vector<int64_t>{static_cast<int64_t>(N), static_cast<int64_t>(IC), 1, static_cast<int64_t>(IW)});
image = std::make_shared<ov::op::v1::Reshape>(image, image_reshape_shape, false);
}
const ov::Shape patch_sizes = {KH, KW};
const ov::Strides strides = {static_cast<size_t>(stride_h), static_cast<size_t>(stride_w)};
const ov::Shape rates = {static_cast<size_t>(dil_h), static_cast<size_t>(dil_w)};
auto pads_begin =
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 0, pad_h, pad_w});
auto pads_end =
ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 0, pad_h, pad_w});
auto pad = std::make_shared<ov::op::v1::Pad>(image, pads_begin, pads_end, ov::op::PadMode::CONSTANT);
auto patches =
std::make_shared<ov::op::v3::ExtractImagePatches>(pad, patch_sizes, strides, rates, ov::op::PadType::VALID);
auto perm1 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 2, 3, 1});
auto t1 = std::make_shared<ov::op::v1::Transpose>(patches, perm1);
const ov::Shape out_shape = t1->get_output_shape(0);
const size_t N = out_shape[0];
const size_t OH = out_shape[1];
const size_t OW = out_shape[2];
auto reshape1_shape = ov::op::v0::Constant::create(
ov::element::i64, ov::Shape{5},
std::vector<int64_t>{static_cast<int64_t>(N), static_cast<int64_t>(OH), static_cast<int64_t>(OW),
static_cast<int64_t>(KH * KW), static_cast<int64_t>(IC)});
auto r1 = std::make_shared<ov::op::v1::Reshape>(t1, reshape1_shape, false);
auto perm2 = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{5}, std::vector<int64_t>{0, 1, 2, 4, 3});
auto t2 = std::make_shared<ov::op::v1::Transpose>(r1, perm2);
auto r2_shape = ov::op::v0::Constant::create(
ov::element::i64, ov::Shape{4},
std::vector<int64_t>{static_cast<int64_t>(N), static_cast<int64_t>(OH), static_cast<int64_t>(OW),
static_cast<int64_t>(IC * KH * KW)});
res = std::make_shared<ov::op::v1::Reshape>(t2, r2_shape, false);
if (!is_2D) {
auto final_reshape_shape = ov::op::v0::Constant::create(
ov::element::i64, ov::Shape{4},
std::vector<int64_t>{1, static_cast<int64_t>(N), static_cast<int64_t>(OW), static_cast<int64_t>(IC * KW)});
res = std::make_shared<ov::op::v1::Reshape>(res, final_reshape_shape, false);
}
auto output_type = context.get_output_type();
if (res.get_element_type() != output_type) {
res = std::make_shared<ov::op::v0::Convert>(res, output_type);
}
return rename_outputs_with_suffix({res}, context.get_name());
}
} } } }