use std::borrow::Cow;
pub(crate) const EMBED_MAX_CHARS: usize = 32_000;
const EMBED_HEAD_CHARS: usize = 24_000;
const EMBED_TAIL_CHARS: usize = 8_000;
const EMBED_TRUNCATION_MARKER: &str = "\n...[truncated]...\n";
#[inline]
pub(crate) fn owned_strs(texts: &[&str]) -> Vec<String> {
texts.iter().map(|t| (*t).to_owned()).collect()
}
pub(crate) fn truncate_for_embed(text: &str) -> Cow<'_, str> {
if text.len() <= EMBED_MAX_CHARS {
return Cow::Borrowed(text);
}
if text.len() <= EMBED_HEAD_CHARS + EMBED_TAIL_CHARS + EMBED_TRUNCATION_MARKER.len() {
return Cow::Borrowed(text);
}
let head_end = text.floor_char_boundary(EMBED_HEAD_CHARS);
let tail_byte_start = text.len().saturating_sub(EMBED_TAIL_CHARS);
let tail_start = text.ceil_char_boundary(tail_byte_start);
tracing::warn!(
original_chars = text.len(),
head_chars = head_end,
tail_chars = text.len() - tail_start,
"embed input truncated to fit provider limit"
);
Cow::Owned(format!(
"{}{}{}",
&text[..head_end],
EMBED_TRUNCATION_MARKER,
&text[tail_start..]
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_string_is_borrowed() {
let result = truncate_for_embed("");
assert!(matches!(result, Cow::Borrowed(_)));
assert_eq!(result, "");
}
#[test]
fn short_string_is_borrowed() {
let input = "hello world";
let result = truncate_for_embed(input);
assert!(matches!(result, Cow::Borrowed(_)));
assert_eq!(result, input);
}
#[test]
fn exactly_at_limit_is_borrowed() {
let input = "a".repeat(EMBED_MAX_CHARS);
let result = truncate_for_embed(&input);
assert!(matches!(result, Cow::Borrowed(_)));
assert_eq!(result.len(), EMBED_MAX_CHARS);
}
#[test]
fn over_limit_latin_is_owned_with_head_and_tail() {
let input = "a".repeat(EMBED_MAX_CHARS + 10_000);
let result = truncate_for_embed(&input);
assert!(matches!(result, Cow::Owned(_)));
let s = result.as_ref();
assert!(s.contains("...[truncated]..."), "marker must be present");
assert!(s.len() < input.len());
assert!(s.starts_with(&"a".repeat(100)));
assert!(s.ends_with(&"a".repeat(100)));
}
#[test]
fn over_limit_utf8_multibyte_is_valid_utf8() {
let input = "中".repeat(12_000);
let result = truncate_for_embed(&input);
let s: &str = result.as_ref();
assert!(std::str::from_utf8(s.as_bytes()).is_ok());
assert!(s.len() < input.len());
assert!(
s.contains("...[truncated]..."),
"truncation marker must appear in CJK output"
);
}
#[test]
fn overlap_guard_returns_borrowed() {
let input = "b".repeat(EMBED_MAX_CHARS + 1);
let result = truncate_for_embed(&input);
assert!(matches!(result, Cow::Borrowed(_)));
}
#[test]
fn overlap_guard_upper_boundary_is_borrowed() {
let boundary = EMBED_HEAD_CHARS + EMBED_TAIL_CHARS + EMBED_TRUNCATION_MARKER.len();
let input = "c".repeat(boundary);
let result = truncate_for_embed(&input);
assert!(
matches!(result, Cow::Borrowed(_)),
"text at the overlap guard boundary must be returned borrowed"
);
}
#[test]
fn one_byte_past_overlap_guard_produces_owned() {
let boundary = EMBED_HEAD_CHARS + EMBED_TAIL_CHARS + EMBED_TRUNCATION_MARKER.len();
let input = "d".repeat(boundary + 1);
let result = truncate_for_embed(&input);
assert!(
matches!(result, Cow::Owned(_)),
"text one byte past the overlap guard must be truncated (Cow::Owned)"
);
assert!(result.contains("...[truncated]..."));
}
#[test]
fn truncated_result_has_correct_structure() {
let head = "H".repeat(EMBED_HEAD_CHARS);
let middle = "M".repeat(5_000);
let tail = "T".repeat(EMBED_TAIL_CHARS);
let input = format!("{head}{middle}{tail}");
let result = truncate_for_embed(&input);
assert!(matches!(result, Cow::Owned(_)));
let s = result.as_ref();
assert!(s.starts_with('H'));
assert!(s.ends_with('T'));
assert!(!s.contains('M'));
}
}