#include "../node_context.h"
#include "../op_table.h"
#include "../utils.h"
#include <memory>
#include <openvino/core/node_output.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/gelu.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/sigmoid.hpp>
#include <openvino/op/slice.hpp>
namespace ov {
namespace frontend {
namespace ggml {
namespace op {
OutputVector translate_glu_geglu(const NodeContext & context) {
num_inputs_check(context, 1, 2);
ov::Output<ov::Node> src0;
ov::Output<ov::Node> src1;
if (context.get_input_size() == 2) {
src0 = context.get_input(0);
src1 = context.get_input(1);
} else {
auto combined = context.get_input(0);
auto combined_shape = combined.get_partial_shape();
int64_t last_dim_val = combined_shape[combined_shape.rank().get_length() - 1].get_length();
int64_t nc = last_dim_val / 2;
auto axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
auto step = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto start0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto stop0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc});
auto start1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc});
auto stop1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {2 * nc});
src0 = std::make_shared<ov::op::v8::Slice>(combined, start0, stop0, step, axis);
src1 = std::make_shared<ov::op::v8::Slice>(combined, start1, stop1, step, axis);
}
int32_t * params = context.get_output_op_params();
const int32_t swapped = params[1];
if (swapped) {
std::swap(src0, src1);
}
auto gelu = std::make_shared<ov::op::v7::Gelu>(src0);
auto res = std::make_shared<ov::op::v1::Multiply>(gelu, src1);
return rename_outputs_with_suffix({res}, context.get_name());
}
} } } }