use tracing::warn;
pub const CHARS_PER_TOKEN: usize = 4;
pub fn estimate_tokens(char_count: usize) -> usize {
char_count / CHARS_PER_TOKEN
}
pub const MAX_OUTPUT_SIZE: usize = 10 * 1024 * 1024;
pub fn truncate_output(data: &[u8], context: &str) -> String {
if data.len() > MAX_OUTPUT_SIZE {
warn!(
bytes = data.len(),
max = MAX_OUTPUT_SIZE,
"{context} output truncated to limit"
);
}
let truncated = &data[..data.len().min(MAX_OUTPUT_SIZE)];
match std::str::from_utf8(truncated) {
Ok(s) => s.trim_end().to_string(),
Err(_) => {
let cow = String::from_utf8_lossy(truncated);
cow.trim_end().to_string()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn estimate_tokens_standard() {
assert_eq!(estimate_tokens(400), 100);
assert_eq!(estimate_tokens(800_000), 200_000);
}
#[test]
fn estimate_tokens_zero() {
assert_eq!(estimate_tokens(0), 0);
}
#[test]
fn estimate_tokens_rounds_down() {
assert_eq!(estimate_tokens(3), 0);
assert_eq!(estimate_tokens(5), 1);
assert_eq!(estimate_tokens(7), 1);
}
#[test]
fn small_input_returned_as_is() {
let data = b"hello world";
let result = truncate_output(data, "test");
assert_eq!(result, "hello world");
}
#[test]
fn input_exactly_at_max_output_size_returned_as_is() {
let data = vec![b'a'; MAX_OUTPUT_SIZE];
let result = truncate_output(&data, "test");
assert_eq!(result.len(), MAX_OUTPUT_SIZE);
assert!(result.chars().all(|c| c == 'a'));
}
#[test]
fn input_over_limit_gets_truncated() {
let data = vec![b'x'; MAX_OUTPUT_SIZE + 100];
let result = truncate_output(&data, "test");
assert_eq!(result.len(), MAX_OUTPUT_SIZE);
}
#[test]
fn way_over_limit_gets_truncated() {
let data = vec![b'z'; 20 * 1024 * 1024];
let result = truncate_output(&data, "test");
assert_eq!(result.len(), MAX_OUTPUT_SIZE);
}
#[test]
fn empty_input_returns_empty_string() {
let result = truncate_output(b"", "test");
assert_eq!(result, "");
}
#[test]
fn trailing_whitespace_trimmed() {
let data = b"hello \n\n ";
let result = truncate_output(data, "test");
assert_eq!(result, "hello");
}
#[test]
fn only_whitespace_returns_empty() {
let data = b" \n\t \r\n ";
let result = truncate_output(data, "test");
assert_eq!(result, "");
}
#[test]
fn invalid_utf8_uses_lossy_conversion() {
let data: &[u8] = &[0xFF, 0xFE, b'h', b'i'];
let result = truncate_output(data, "test");
assert!(result.contains('\u{FFFD}'));
assert!(result.contains("hi"));
}
#[test]
fn valid_multibyte_utf8_emoji_handled() {
let data = "\u{1F680} rocket".as_bytes();
let result = truncate_output(data, "test");
assert_eq!(result, "\u{1F680} rocket");
}
#[test]
fn truncation_splitting_multibyte_char_handled_via_lossy() {
let mut data = vec![b'a'; MAX_OUTPUT_SIZE - 1];
data.push(0xF0);
data.push(0x9F);
data.push(0x9A);
data.push(0x80);
let result = truncate_output(&data, "test");
assert!(result.contains('\u{FFFD}'));
}
#[test]
fn null_bytes_in_data() {
let data: &[u8] = &[b'a', 0x00, b'b', 0x00, b'c'];
let result = truncate_output(data, "test");
assert_eq!(result, "a\0b\0c");
}
}