aimo_core/utils/tokens.rs
1use serde_json::Value;
2
3/// Estimate total input token for a `maessages` list in a chat completion request.
4///
5/// Returns error if message is not a valid list.
6pub fn estimate_tokens(messages: &Value) -> Result<u64, String> {
7 match messages {
8 Value::Array(messages) => {
9 let (characters, words) = messages.iter().fold((0, 0), |(c_cnt, w_cnt), msg| {
10 // For each message, count characters and words in `role` and `content` field.
11 ["role", "content"]
12 .iter()
13 .fold((c_cnt, w_cnt), |(f_c_cnt, f_w_cnt), field| {
14 msg.get(field)
15 // If the field is String, count its length, otherwise count the serialized JSON string size
16 .map(|value| value.to_string())
17 .map(|s| {
18 (
19 f_c_cnt + s.chars().count() as u64,
20 f_w_cnt + s.split_whitespace().count() as u64,
21 )
22 })
23 // No field found, add 0
24 .unwrap_or((f_c_cnt, f_w_cnt))
25 })
26
27 // The value returned here is `previous count` + `role count` + `content count`
28 });
29
30 // - `n_chars / 4` — Roughly 1 token per 4 characters in English
31 // - `n_words * (4/3)` — Roughly 3/4 of a word per token
32 // - This covers varied text lengths/types. Using the maximum of these gives a safer upper bound
33 // - Add message-level overhead (~3 tokens) for role/content formatting
34 Ok(std::cmp::max(characters / 4, words * 4 / 3) + 3 * messages.len() as u64)
35 }
36
37 _ => Err("messages should be a list".to_string()),
38 }
39}