use std::collections::HashMap;
use chrono::{DateTime, Local};
use either::Either;
use minijinja::{Error, ErrorKind, Value, value::Kwargs};
use serde::{Deserialize, Serialize};
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
pub struct AddedTokensDecoder {
__type: Option<String>,
pub content: String,
lstrip: bool,
normalized: bool,
rstrip: bool,
single_word: bool,
special: Option<bool>,
}
pub fn raise_exception(msg: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(ErrorKind::InvalidOperation, msg))
}
#[derive(Debug, Deserialize)]
pub struct BeginEndUnkTok(
#[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
);
#[derive(Debug, Deserialize)]
pub struct ChatTemplateValue(
#[serde(with = "either::serde_untagged")] pub Either<String, Vec<HashMap<String, String>>>,
);
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
pub struct PadTokenValue(
#[serde(with = "either::serde_untagged")] pub Either<String, AddedTokensDecoder>,
);
#[allow(dead_code)]
#[derive(Debug, Deserialize, Default)]
pub struct ChatTemplate {
pub bos_token: Option<BeginEndUnkTok>,
pub eos_token: Option<BeginEndUnkTok>,
pub unk_token: Option<BeginEndUnkTok>,
pub chat_template: Option<ChatTemplateValue>,
add_bos_token: Option<bool>,
add_eos_token: Option<bool>,
added_tokens_decoder: Option<HashMap<String, AddedTokensDecoder>>,
additional_special_tokens: Option<Vec<String>>,
clean_up_tokenization_spaces: Option<bool>,
device_map: Option<String>,
legacy: Option<bool>,
model_max_length: Option<f64>,
pad_token: Option<PadTokenValue>,
sp_model_kwargs: Option<HashMap<String, String>>,
spaces_between_special_tokens: Option<bool>,
tokenizer_class: Option<String>,
truncation_size: Option<String>,
use_default_system_prompt: Option<bool>,
}
impl ChatTemplate {
pub fn eos_tok(&self) -> Option<String> {
match self.eos_token.as_ref()?.0 {
Either::Left(ref lit) => Some(lit.clone()),
Either::Right(ref added) => Some(added.content.clone()),
}
}
pub fn bos_tok(&self) -> Option<String> {
match self.bos_token.as_ref()?.0 {
Either::Left(ref lit) => Some(lit.clone()),
Either::Right(ref added) => Some(added.content.clone()),
}
}
pub fn unk_tok(&self) -> Option<String> {
match self.unk_token.as_ref()?.0 {
Either::Left(ref lit) => Some(lit.clone()),
Either::Right(ref added) => Some(added.content.clone()),
}
}
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
pub struct GenerationConfig {
#[serde(with = "either::serde_untagged")]
bos_token_id: Either<u32, Vec<u32>>,
#[serde(with = "either::serde_untagged")]
eos_token_id: Either<u32, Vec<u32>>,
}
pub fn tojson(value: Value, kwargs: Kwargs) -> Result<Value, Error> {
if let Ok(indent) = kwargs.get("indent") {
let mut buf = Vec::new();
let repeat = b" ".repeat(indent);
let formatter = serde_json::ser::PrettyFormatter::with_indent(&repeat);
let mut serializer = serde_json::Serializer::with_formatter(&mut buf, formatter);
value.serialize(&mut serializer).unwrap();
String::from_utf8(buf).map_err(|err| {
Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
})
} else {
serde_json::to_string(&value).map_err(|err| {
Error::new(ErrorKind::BadSerialization, "cannot serialize to JSON").with_source(err)
})
}
.map_err(|err| {
Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
})
.map(|s| {
let mut rv = String::with_capacity(s.len());
for c in s.chars() {
match c {
'<' => rv.push_str("\\u003c"),
'>' => rv.push_str("\\u003e"),
'&' => rv.push_str("\\u0026"),
'\'' => rv.push_str("\\u0027"),
_ => rv.push(c),
}
}
Value::from_safe_string(rv)
})
}
pub fn strftime_now(format_str: &str) -> Result<Value, Error> {
let local: DateTime<Local> = Local::now();
Ok(Value::from_safe_string(
local.format(format_str).to_string(),
))
}