use super::BaseMetadata;
use chat_prompts::PromptTemplateType;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
#[derive(Debug)]
pub struct GgmlMetadataBuilder {
metadata: GgmlMetadata,
}
impl GgmlMetadataBuilder {
pub fn new<S: Into<String>>(model_name: S, model_alias: S, pt: PromptTemplateType) -> Self {
let metadata = GgmlMetadata {
model_name: model_name.into(),
model_alias: model_alias.into(),
prompt_template: pt,
..Default::default()
};
Self { metadata }
}
pub fn with_prompt_template(mut self, template: PromptTemplateType) -> Self {
self.metadata.prompt_template = template;
self
}
pub fn enable_plugin_log(mut self, enable: bool) -> Self {
self.metadata.log_enable = enable;
self
}
pub fn enable_debug_log(mut self, enable: bool) -> Self {
self.metadata.debug_log = enable;
self
}
pub fn enable_prompts_log(mut self, enable: bool) -> Self {
self.metadata.log_prompts = enable;
self
}
pub fn enable_embeddings(mut self, enable: bool) -> Self {
self.metadata.embeddings = enable;
self
}
pub fn with_n_predict(mut self, n: i32) -> Self {
self.metadata.n_predict = n;
self
}
pub fn with_main_gpu(mut self, gpu: Option<u64>) -> Self {
self.metadata.main_gpu = gpu;
self
}
pub fn with_tensor_split(mut self, split: Option<String>) -> Self {
self.metadata.tensor_split = split;
self
}
pub fn with_threads(mut self, threads: u64) -> Self {
self.metadata.threads = threads;
self
}
pub fn with_reverse_prompt(mut self, prompt: Option<String>) -> Self {
self.metadata.reverse_prompt = prompt;
self
}
pub fn with_mmproj(mut self, path: Option<String>) -> Self {
self.metadata.mmproj = path;
self
}
pub fn with_image(mut self, path: impl Into<String>) -> Self {
self.metadata.image = Some(path.into());
self
}
pub fn with_n_gpu_layers(mut self, n: u64) -> Self {
self.metadata.n_gpu_layers = n;
self
}
pub fn disable_mmap(mut self, disable: Option<bool>) -> Self {
self.metadata.use_mmap = disable.map(|v| !v);
self
}
pub fn with_split_mode(mut self, mode: String) -> Self {
self.metadata.split_mode = mode;
self
}
pub fn with_ctx_size(mut self, size: u64) -> Self {
self.metadata.ctx_size = size;
self
}
pub fn with_batch_size(mut self, size: u64) -> Self {
self.metadata.batch_size = size;
self
}
pub fn with_ubatch_size(mut self, size: u64) -> Self {
self.metadata.ubatch_size = size;
self
}
pub fn with_temperature(mut self, temp: f64) -> Self {
self.metadata.temperature = temp;
self
}
pub fn with_top_p(mut self, top_p: f64) -> Self {
self.metadata.top_p = top_p;
self
}
pub fn with_repeat_penalty(mut self, penalty: f64) -> Self {
self.metadata.repeat_penalty = penalty;
self
}
pub fn with_presence_penalty(mut self, penalty: f64) -> Self {
self.metadata.presence_penalty = penalty;
self
}
pub fn with_frequency_penalty(mut self, penalty: f64) -> Self {
self.metadata.frequency_penalty = penalty;
self
}
pub fn with_grammar(mut self, grammar: impl Into<String>) -> Self {
self.metadata.grammar = grammar.into();
self
}
pub fn with_json_schema(mut self, schema: Option<String>) -> Self {
self.metadata.json_schema = schema;
self
}
pub fn include_usage(mut self, include: bool) -> Self {
self.metadata.include_usage = include;
self
}
pub fn build(self) -> GgmlMetadata {
self.metadata
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GgmlMetadata {
#[serde(skip_serializing)]
pub model_name: String,
#[serde(skip_serializing)]
pub model_alias: String,
#[serde(skip_serializing)]
pub log_prompts: bool,
#[serde(skip_serializing)]
pub prompt_template: PromptTemplateType,
#[serde(rename = "enable-log")]
pub log_enable: bool,
#[serde(rename = "enable-debug-log")]
pub debug_log: bool,
#[serde(rename = "embedding")]
pub embeddings: bool,
#[serde(rename = "n-predict")]
pub n_predict: i32,
#[serde(skip_serializing_if = "Option::is_none", rename = "reverse-prompt")]
pub reverse_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mmproj: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image: Option<String>,
#[serde(rename = "n-gpu-layers")]
pub n_gpu_layers: u64,
#[serde(rename = "main-gpu")]
#[serde(skip_serializing_if = "Option::is_none")]
pub main_gpu: Option<u64>,
#[serde(rename = "tensor-split")]
#[serde(skip_serializing_if = "Option::is_none")]
pub tensor_split: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "use-mmap")]
pub use_mmap: Option<bool>,
#[serde(rename = "split-mode")]
pub split_mode: String,
#[serde(rename = "ctx-size")]
pub ctx_size: u64,
#[serde(rename = "batch-size")]
pub batch_size: u64,
#[serde(rename = "ubatch-size")]
pub ubatch_size: u64,
#[serde(rename = "threads")]
pub threads: u64,
#[serde(rename = "temp")]
pub temperature: f64,
#[serde(rename = "top-p")]
pub top_p: f64,
#[serde(rename = "repeat-penalty")]
pub repeat_penalty: f64,
#[serde(rename = "presence-penalty")]
pub presence_penalty: f64,
#[serde(rename = "frequency-penalty")]
pub frequency_penalty: f64,
pub grammar: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<String>,
pub include_usage: bool,
}
impl Default for GgmlMetadata {
fn default() -> Self {
Self {
model_name: String::new(),
model_alias: String::new(),
log_prompts: false,
debug_log: false,
prompt_template: PromptTemplateType::Llama2Chat,
log_enable: false,
embeddings: false,
n_predict: -1,
reverse_prompt: None,
mmproj: None,
image: None,
n_gpu_layers: 100,
main_gpu: None,
tensor_split: None,
use_mmap: Some(true),
split_mode: "layer".to_string(),
ctx_size: 4096,
batch_size: 2048,
ubatch_size: 512,
threads: 2,
temperature: 0.8,
top_p: 0.9,
repeat_penalty: 1.0,
presence_penalty: 0.0,
frequency_penalty: 0.0,
grammar: String::new(),
json_schema: None,
include_usage: false,
}
}
}
impl BaseMetadata for GgmlMetadata {
fn model_name(&self) -> &str {
&self.model_name
}
fn model_alias(&self) -> &str {
&self.model_alias
}
}
impl GgmlMetadata {
pub fn prompt_template(&self) -> PromptTemplateType {
self.prompt_template
}
}
#[derive(Debug)]
pub struct GgmlTtsMetadataBuilder {
metadata: GgmlTtsMetadata,
}
impl GgmlTtsMetadataBuilder {
pub fn new<S: Into<String>, P: AsRef<Path>>(
model_name: S,
model_alias: S,
codec_model: P,
) -> Self {
let metadata = GgmlTtsMetadata {
model_name: model_name.into(),
model_alias: model_alias.into(),
codec_model: codec_model.as_ref().to_path_buf(),
..Default::default()
};
Self { metadata }
}
pub fn enable_tts(mut self, enable: bool) -> Self {
self.metadata.enable_tts = enable;
self
}
pub fn with_speaker_file(mut self, speaker_file: Option<PathBuf>) -> Self {
self.metadata.speaker_file = speaker_file;
self
}
pub fn with_ctx_size(mut self, size: u64) -> Self {
self.metadata.ctx_size = size;
self
}
pub fn with_batch_size(mut self, size: u64) -> Self {
self.metadata.batch_size = size;
self
}
pub fn with_ubatch_size(mut self, size: u64) -> Self {
self.metadata.ubatch_size = size;
self
}
pub fn with_n_predict(mut self, n: i32) -> Self {
self.metadata.n_predict = n;
self
}
pub fn with_n_gpu_layers(mut self, n: u64) -> Self {
self.metadata.n_gpu_layers = n;
self
}
pub fn with_temperature(mut self, temp: f64) -> Self {
self.metadata.temperature = temp;
self
}
pub fn enable_plugin_log(mut self, enable: bool) -> Self {
self.metadata.log_enable = enable;
self
}
pub fn enable_debug_log(mut self, enable: bool) -> Self {
self.metadata.debug_log = enable;
self
}
pub fn build(self) -> GgmlTtsMetadata {
self.metadata
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct GgmlTtsMetadata {
pub model_name: String,
pub model_alias: String,
#[serde(rename = "tts")]
pub enable_tts: bool,
#[serde(rename = "model-vocoder")]
pub codec_model: PathBuf,
#[serde(rename = "tts-speaker-file", skip_serializing_if = "Option::is_none")]
pub speaker_file: Option<PathBuf>,
#[serde(rename = "ctx-size")]
pub ctx_size: u64,
#[serde(rename = "batch-size")]
pub batch_size: u64,
#[serde(rename = "ubatch-size")]
pub ubatch_size: u64,
pub n_predict: i32,
pub n_gpu_layers: u64,
#[serde(rename = "temp")]
pub temperature: f64,
#[serde(rename = "enable-log")]
pub log_enable: bool,
#[serde(rename = "enable-debug-log")]
pub debug_log: bool,
}
impl Default for GgmlTtsMetadata {
fn default() -> Self {
Self {
model_name: "tts".to_string(),
model_alias: "tts".to_string(),
enable_tts: false,
codec_model: PathBuf::from(""),
speaker_file: None,
ctx_size: 8192,
batch_size: 8192,
ubatch_size: 8192,
n_predict: 4096,
n_gpu_layers: 100,
temperature: 0.8,
log_enable: false,
debug_log: false,
}
}
}
impl BaseMetadata for GgmlTtsMetadata {
fn model_name(&self) -> &str {
&self.model_name
}
fn model_alias(&self) -> &str {
&self.model_alias
}
}
impl GgmlTtsMetadata {
pub fn prompt_template(&self) -> PromptTemplateType {
PromptTemplateType::Tts
}
}