use serde_json::Value;
pub fn estimate_tokens(messages: &Value) -> Result<u64, String> {
match messages {
Value::Array(messages) => {
let (characters, words) = messages.iter().fold((0, 0), |(c_cnt, w_cnt), msg| {
["role", "content"]
.iter()
.fold((c_cnt, w_cnt), |(f_c_cnt, f_w_cnt), field| {
msg.get(field)
.map(|value| value.to_string())
.map(|s| {
(
f_c_cnt + s.chars().count() as u64,
f_w_cnt + s.split_whitespace().count() as u64,
)
})
.unwrap_or((f_c_cnt, f_w_cnt))
})
});
Ok(std::cmp::max(characters / 4, words * 4 / 3) + 3 * messages.len() as u64)
}
_ => Err("messages should be a list".to_string()),
}
}