#include "../node_context.h"
#include "../op_table.h"
#include "../utils.h"
#include "ggml-openvino/ggml-openvino-extra.h"
#include <cstdint>
#include <cstdlib>
#include <memory>
#include <openvino/op/add.hpp>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/matmul.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/scaled_dot_product_attention.hpp>
#include <openvino/op/softmax.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <string>
namespace ov {
namespace frontend {
namespace ggml {
namespace op {
OutputVector translate_flash_attn_ext(const NodeContext & context) {
num_inputs_check(context, 4, 4);
auto q_f32 = context.get_input(0);
auto k = context.get_input(1);
auto v = context.get_input(2);
auto mask = context.get_input(3);
float * params = reinterpret_cast<float *>(context.get_output_op_params());
float scale = params[0];
auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16);
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});
ov::Output<ov::Node> res;
std::string mask_name = "KQ_mask_sliced";
if (context.get_input_names()[3].find("swa") != std::string::npos) {
mask_name = "KQ_mask_swa_sliced";
}
if (context.has_input(mask_name)) {
mask = context.get_input(mask_name);
}
if (mask.get_element_type() != ov::element::f16) {
mask = std::make_shared<ov::op::v0::Convert>(mask, ov::element::f16);
}
auto q_shape = context.get_input_shape(0).to_shape();
auto k_shape = context.get_input_shape(1).to_shape();
const int64_t num_heads = q_shape[1];
const int64_t num_heads_kv = k_shape[1];
const int64_t head_size = q_shape[3];
const int64_t factor = num_heads / num_heads_kv;
static const bool manual_gqa_enabled = []() {
const char * env = ggml_openvino_getenv_str("GGML_OPENVINO_MANUAL_GQA_ATTN");
if (env != nullptr) {
return ggml_openvino_getenv_int("GGML_OPENVINO_MANUAL_GQA_ATTN") > 0;
}
const char * dev = ggml_openvino_getenv_str("GGML_OPENVINO_DEVICE");
return dev != nullptr && std::string(dev) == "GPU";
}();
const bool use_manual_gqa_attention =
manual_gqa_enabled && factor > 1 && num_heads_kv > 1 && !context.is_stateful();
if (use_manual_gqa_attention) {
auto k_5d_shape = ov::op::v0::Constant::create(ov::element::i64, {5},
std::vector<int64_t>{0, num_heads_kv, 1, -1, head_size});
auto v_5d_shape = ov::op::v0::Constant::create(ov::element::i64, {5},
std::vector<int64_t>{0, num_heads_kv, 1, -1, head_size});
auto q_5d_shape = ov::op::v0::Constant::create(ov::element::i64, {5},
std::vector<int64_t>{0, num_heads_kv, factor, -1, head_size});
auto k_r = std::make_shared<ov::op::v1::Reshape>(k, k_5d_shape, true);
auto v_r = std::make_shared<ov::op::v1::Reshape>(v, v_5d_shape, true);
auto q_r = std::make_shared<ov::op::v1::Reshape>(q, q_5d_shape, true);
auto qk = std::make_shared<ov::op::v0::MatMul>(q_r, k_r, false, true);
auto qk_scaled = std::make_shared<ov::op::v1::Multiply>(qk, scale_node);
auto mask_unsq1 =
std::make_shared<ov::op::v0::Unsqueeze>(mask, ov::op::v0::Constant::create(ov::element::i64, {1}, {2}));
ov::Output<ov::Node> qk_masked = std::make_shared<ov::op::v1::Add>(qk_scaled, mask_unsq1);
auto softmax = std::make_shared<ov::op::v8::Softmax>(qk_masked, -1);
auto attn = std::make_shared<ov::op::v0::MatMul>(softmax, v_r);
auto out_4d_shape =
ov::op::v0::Constant::create(ov::element::i64, {4}, std::vector<int64_t>{0, num_heads, -1, head_size});
auto out_4d = std::make_shared<ov::op::v1::Reshape>(attn, out_4d_shape, true);
res = std::make_shared<ov::op::v1::Transpose>(
out_4d, ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
res = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32);
return rename_outputs_with_suffix({res}, context.get_name());
}
auto tile_kv = [&](int64_t n_heads, int64_t n_heads_kv, int64_t hs, ov::Output<Node> kv) {
int64_t f = n_heads / n_heads_kv;
if (f > 1 && n_heads_kv > 1) {
ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape;
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
kv_broadcast_shape = ov::op::v0::Constant::create(ov::element::i64, {5},
{(int64_t) 1, (int64_t) 1, f, (int64_t) 1, (int64_t) 1});
new_kv_shape =
ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 0, n_heads, (int64_t) -1, hs});
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape,
ov::op::BroadcastType::BIDIRECTIONAL);
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, true);
}
return kv;
};
k = tile_kv(num_heads, num_heads_kv, head_size, k);
v = tile_kv(num_heads, num_heads_kv, head_size, v);
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask, scale_node, false);
res = std::make_shared<ov::op::v1::Transpose>(sdpa,
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
res = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32);
return rename_outputs_with_suffix({res}, context.get_name());
}
} } } }