aimo_core/utils/
tokens.rs

1use serde_json::Value;
2
3/// Estimate total input token for a `maessages` list in a chat completion request.
4///
5/// Returns error if message is not a valid list.
6pub fn estimate_tokens(messages: &Value) -> Result<u64, String> {
7    match messages {
8        Value::Array(messages) => {
9            let (characters, words) = messages.iter().fold((0, 0), |(c_cnt, w_cnt), msg| {
10                // For each message, count characters and words in `role` and `content` field.
11                ["role", "content"]
12                    .iter()
13                    .fold((c_cnt, w_cnt), |(f_c_cnt, f_w_cnt), field| {
14                        msg.get(field)
15                            // If the field is String, count its length, otherwise count the serialized JSON string size
16                            .map(|value| value.to_string())
17                            .map(|s| {
18                                (
19                                    f_c_cnt + s.chars().count() as u64,
20                                    f_w_cnt + s.split_whitespace().count() as u64,
21                                )
22                            })
23                            // No field found, add 0
24                            .unwrap_or((f_c_cnt, f_w_cnt))
25                    })
26
27                // The value returned here is `previous count` + `role count` + `content count`
28            });
29
30            // - `n_chars / 4` — Roughly 1 token per 4 characters in English
31            // - `n_words * (4/3)` — Roughly 3/4 of a word per token
32            // - This covers varied text lengths/types. Using the maximum of these gives a safer upper bound
33            // - Add message-level overhead (~3 tokens) for role/content formatting
34            Ok(std::cmp::max(characters / 4, words * 4 / 3) + 3 * messages.len() as u64)
35        }
36
37        _ => Err("messages should be a list".to_string()),
38    }
39}