use serde::Serialize;
#[must_use]
pub fn estimate_tokens(char_count: usize) -> usize {
char_count / 4
}
#[must_use]
pub fn estimate_response_tokens<T: Serialize>(value: &T) -> usize {
serde_json::to_string(value)
.map(|s| estimate_tokens(s.len()))
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_returns_zero_for_zero_chars() {
assert_eq!(estimate_tokens(0), 0);
}
#[test]
fn it_rounds_down_three_chars_to_zero_tokens() {
assert_eq!(estimate_tokens(3), 0);
}
#[test]
fn it_counts_four_chars_as_one_token() {
assert_eq!(estimate_tokens(4), 1);
}
#[test]
fn it_rounds_down_seven_chars_to_one_token() {
assert_eq!(estimate_tokens(7), 1);
}
#[test]
fn it_counts_eight_chars_as_two_tokens() {
assert_eq!(estimate_tokens(8), 2);
}
#[test]
fn it_counts_one_hundred_chars_as_twenty_five_tokens() {
assert_eq!(estimate_tokens(100), 25);
}
#[test]
fn it_estimates_tokens_for_a_serializable_struct() {
#[derive(Serialize)]
struct Payload {
name: &'static str,
count: usize,
}
let payload = Payload {
name: "manifest",
count: 42,
};
let serialized = serde_json::to_string(&payload).unwrap();
let expected = estimate_tokens(serialized.len());
assert_eq!(estimate_response_tokens(&payload), expected);
}
#[test]
fn it_returns_a_positive_estimate_for_a_non_trivial_struct() {
#[derive(Serialize)]
struct Payload {
items: Vec<&'static str>,
}
let payload = Payload {
items: vec!["alpha", "beta", "gamma", "delta"],
};
let serialized = serde_json::to_string(&payload).unwrap();
assert!(serialized.len() >= 4);
assert!(estimate_response_tokens(&payload) > 0);
assert_eq!(
estimate_response_tokens(&payload),
estimate_tokens(serialized.len())
);
}
}