#pragma once
#include "llama-impl.h"
#include <map>
#include <regex>
#include <string>
#include <vector>
struct llama_vocab;
struct llama_sampling;
enum llama_gretype {
LLAMA_GRETYPE_END = 0,
LLAMA_GRETYPE_ALT = 1,
LLAMA_GRETYPE_RULE_REF = 2,
LLAMA_GRETYPE_CHAR = 3,
LLAMA_GRETYPE_CHAR_NOT = 4,
LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
LLAMA_GRETYPE_CHAR_ALT = 6,
LLAMA_GRETYPE_CHAR_ANY = 7,
LLAMA_GRETYPE_TOKEN = 8,
LLAMA_GRETYPE_TOKEN_NOT = 9,
};
typedef struct llama_grammar_element {
enum llama_gretype type;
uint32_t value; } llama_grammar_element;
struct llama_partial_utf8 {
uint32_t value; int n_remain; };
struct llama_grammar_candidate {
size_t index;
const uint32_t* code_points;
llama_partial_utf8 partial_utf8;
llama_token id;
};
using llama_grammar_rule = std::vector< llama_grammar_element>;
using llama_grammar_stack = std::vector<const llama_grammar_element*>;
using llama_grammar_rules = std::vector<llama_grammar_rule>;
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;
const llama_grammar_rules& llama_grammar_get_rules(const struct llama_grammar* grammar);
llama_grammar_stacks& llama_grammar_get_stacks(struct llama_grammar* grammar);
void llama_grammar_accept(struct llama_grammar* grammar, uint32_t chr);
std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules& rules,
const llama_grammar_stack& stack,
const llama_grammar_candidates& candidates);
struct llama_grammar_parser {
const llama_vocab * vocab;
std::map<std::string, uint32_t> symbol_ids;
llama_grammar_rules rules;
llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {}
llama_grammar_stack c_rules() const;
uint32_t get_symbol_id(const char* src, size_t len);
uint32_t generate_symbol_id(const std::string& base_name);
void add_rule(uint32_t rule_id, const llama_grammar_rule& rule);
const char* parse_alternates(
const char* src,
const std::string& rule_name,
uint32_t rule_id,
bool is_nested);
const char* parse_sequence(
const char* src,
const std::string& rule_name,
llama_grammar_rule& rule,
bool is_nested);
const char* parse_rule(const char* src);
bool parse(const char* src);
void print(FILE* file);
};
struct llama_grammar_trigger_pattern {
std::string pattern;
std::regex regex;
size_t find(const std::string & input) const;
};
struct llama_grammar {
using token_pos = std::pair<llama_token, std::pair<size_t, size_t>>;
const llama_vocab* vocab;
const llama_grammar_rules rules; llama_grammar_stacks stacks;
llama_partial_utf8 partial_utf8;
bool lazy = false;
bool awaiting_trigger = false; std::string trigger_buffer; std::vector<token_pos> trigger_buffer_positions; std::vector<llama_token> trigger_tokens; std::vector<llama_grammar_trigger_pattern> trigger_patterns;
};
struct llama_grammar* llama_grammar_init_impl(
const llama_grammar_element** rules,
size_t n_rules,
size_t start_rule_index);
struct llama_grammar* llama_grammar_init_impl(
const struct llama_vocab* vocab,
const char* grammar_str,
const char* grammar_root,
bool lazy,
const char** trigger_patterns,
size_t num_trigger_patterns,
const llama_token* trigger_tokens,
size_t num_trigger_tokens);
void llama_grammar_free_impl(struct llama_grammar * grammar);
struct llama_grammar* llama_grammar_clone_impl(const struct llama_grammar& grammar);
void llama_grammar_sample_impl(
const struct llama_grammar * grammar,
const struct llama_vocab * vocab,
const struct llama_sampling * smpl,
llama_token_data_array * candidates);
void llama_grammar_accept_impl(
struct llama_grammar & grammar,
const struct llama_vocab * vocab,
const struct llama_sampling * smpl,
llama_token token);
void llama_grammar_accept_str(
struct llama_grammar & grammar,
const std::string & piece);
void llama_grammar_accept_token(
struct llama_grammar & grammar,
llama_token token,
const std::string & piece);