aimo-core 0.1.6

AiMo Network core protocol Rust specs
Documentation
use serde_json::Value;

/// Estimate total input token for a `maessages` list in a chat completion request.
///
/// Returns error if message is not a valid list.
pub fn estimate_tokens(messages: &Value) -> Result<u64, String> {
    match messages {
        Value::Array(messages) => {
            let (characters, words) = messages.iter().fold((0, 0), |(c_cnt, w_cnt), msg| {
                // For each message, count characters and words in `role` and `content` field.
                ["role", "content"]
                    .iter()
                    .fold((c_cnt, w_cnt), |(f_c_cnt, f_w_cnt), field| {
                        msg.get(field)
                            // If the field is String, count its length, otherwise count the serialized JSON string size
                            .map(|value| value.to_string())
                            .map(|s| {
                                (
                                    f_c_cnt + s.chars().count() as u64,
                                    f_w_cnt + s.split_whitespace().count() as u64,
                                )
                            })
                            // No field found, add 0
                            .unwrap_or((f_c_cnt, f_w_cnt))
                    })

                // The value returned here is `previous count` + `role count` + `content count`
            });

            // - `n_chars / 4` — Roughly 1 token per 4 characters in English
            // - `n_words * (4/3)` — Roughly 3/4 of a word per token
            // - This covers varied text lengths/types. Using the maximum of these gives a safer upper bound
            // - Add message-level overhead (~3 tokens) for role/content formatting
            Ok(std::cmp::max(characters / 4, words * 4 / 3) + 3 * messages.len() as u64)
        }

        _ => Err("messages should be a list".to_string()),
    }
}