Documentation
use std::sync::OnceLock;

use tiktoken_rs::{o200k_base, CoreBPE};

fn bpe() -> &'static CoreBPE {
    static BPE: OnceLock<CoreBPE> = OnceLock::new();
    BPE.get_or_init(|| o200k_base().expect("load o200k_base"))
}

pub fn count_tokens(text: &str) -> usize {
    bpe().count_ordinary(text)
}

pub fn truncate_to_tokens(text: &str, max: usize) -> (String, usize) {
    let toks = bpe().encode_ordinary(text);
    if toks.len() <= max {
        return (text.to_string(), 0);
    }
    let bytes = bpe().decode_bytes(&toks[..max]).unwrap_or_default();
    (
        String::from_utf8_lossy(&bytes).into_owned(),
        toks.len() - max,
    )
}

pub fn truncation_marker(omitted: usize) -> String {
    format!("\n\n⋯ truncated ({} tokens omitted)", omitted)
}

#[derive(Debug, Clone, Copy, Default)]
pub struct Budget {
    pub limit: Option<usize>,
    pub max_tokens: Option<usize>,
    pub max_document_tokens: Option<usize>,
}

impl Budget {
    pub fn is_active(&self) -> bool {
        self.limit.is_some_and(|l| l > 0)
            || self.max_tokens.is_some_and(|t| t > 0)
            || self.max_document_tokens.is_some_and(|t| t > 0)
    }
}

#[derive(Debug, Clone, Default)]
pub struct Truncation {
    pub emitted: usize,
    pub matched: usize,
    pub clipped: Vec<String>,
    pub tokens: usize,
    pub budget: Option<usize>,
}

impl Truncation {
    pub fn is_truncated(&self) -> bool {
        self.emitted < self.matched || !self.clipped.is_empty()
    }
}

pub fn apply_budget<T, K, C, Cap>(
    items: &mut Vec<T>,
    budget: &Budget,
    matched: usize,
    key_of: K,
    content_tokens: C,
    mut cap_content: Cap,
) -> Truncation
where
    K: Fn(&T) -> String,
    C: Fn(&T) -> usize,
    Cap: FnMut(&mut T, usize) -> Option<usize>,
{
    if let Some(limit) = budget.limit.filter(|&l| l > 0) {
        items.truncate(limit);
    }

    let mut clipped = Vec::new();
    if let Some(max_doc) = budget.max_document_tokens.filter(|&m| m > 0) {
        for item in items.iter_mut() {
            if cap_content(item, max_doc).is_some() {
                clipped.push(key_of(item));
            }
        }
    }

    let mut total = 0usize;
    if let Some(max_total) = budget.max_tokens.filter(|&m| m > 0) {
        let mut running = 0usize;
        let mut kept = items.len();
        for (index, item) in items.iter().enumerate() {
            let item_tokens = content_tokens(item);
            if running > 0 && running + item_tokens > max_total {
                kept = index;
                break;
            }
            running += item_tokens;
        }
        items.truncate(kept);
        total = running;
    } else {
        for item in items.iter() {
            total += content_tokens(item);
        }
    }

    Truncation {
        emitted: items.len(),
        matched,
        clipped,
        tokens: total,
        budget: budget.max_tokens.filter(|&m| m > 0),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn count_tokens_counts_known_strings() {
        assert_eq!(count_tokens(""), 0);
        assert_eq!(count_tokens("hello"), 1);
        assert_eq!(count_tokens("hello world"), 2);
    }

    #[test]
    fn truncate_to_tokens_exact_fit_returns_input_and_zero() {
        let text = "hello world";
        let (head, omitted) = truncate_to_tokens(text, 2);
        assert_eq!(head, "hello world");
        assert_eq!(omitted, 0);
    }

    #[test]
    fn truncate_to_tokens_under_limit_returns_input_and_zero() {
        let text = "hello world";
        let (head, omitted) = truncate_to_tokens(text, 10);
        assert_eq!(head, "hello world");
        assert_eq!(omitted, 0);
    }

    #[test]
    fn truncate_to_tokens_multibyte_never_drops_the_body() {
        let text = "日本語のテキストをここにたくさん書いています。";
        let full = count_tokens(text);
        assert!(full > 1);
        for max in 1..full {
            let (head, omitted) = truncate_to_tokens(text, max);
            assert!(
                !head.is_empty(),
                "truncating to {} tokens dropped the whole body",
                max
            );
            assert_eq!(omitted, full - max);
        }
    }

    #[test]
    fn truncate_to_tokens_over_limit_returns_head_and_omitted_count() {
        let text = "one two three four five";
        let full = count_tokens(text);
        let (head, omitted) = truncate_to_tokens(text, 2);
        assert_eq!(head, "one two");
        assert_eq!(omitted, full - 2);
    }

    #[test]
    fn apply_budget_limit_keeps_prefix() {
        let mut items = vec!["a".to_string(), "b".to_string(), "c".to_string()];
        let budget = Budget {
            limit: Some(2),
            max_tokens: None,
            max_document_tokens: None,
        };
        let report = apply_budget(
            &mut items,
            &budget,
            3,
            |s| s.clone(),
            |s| count_tokens(s),
            |_s, _max| None,
        );
        assert_eq!(items, vec!["a".to_string(), "b".to_string()]);
        assert_eq!(report.emitted, 2);
        assert_eq!(report.matched, 3);
        assert!(report.is_truncated());
    }

    #[test]
    fn apply_budget_max_tokens_drops_trailing_whole_documents() {
        let mut items = vec![
            "one two three".to_string(),
            "four five six".to_string(),
            "seven eight nine".to_string(),
        ];
        let budget = Budget {
            limit: None,
            max_tokens: Some(4),
            max_document_tokens: None,
        };
        let report = apply_budget(
            &mut items,
            &budget,
            3,
            |s| s.clone(),
            |s| count_tokens(s),
            |_s, _max| None,
        );
        assert_eq!(items, vec!["one two three".to_string()]);
        assert_eq!(report.emitted, 1);
        assert_eq!(report.tokens, 3);
        assert_eq!(report.budget, Some(4));
    }

    #[test]
    fn apply_budget_always_keeps_first_document() {
        let mut items = vec!["one two three four five".to_string(), "six".to_string()];
        let budget = Budget {
            limit: None,
            max_tokens: Some(1),
            max_document_tokens: None,
        };
        let report = apply_budget(
            &mut items,
            &budget,
            2,
            |s| s.clone(),
            |s| count_tokens(s),
            |_s, _max| None,
        );
        assert_eq!(items, vec!["one two three four five".to_string()]);
        assert_eq!(report.emitted, 1);
    }

    #[test]
    fn apply_budget_max_document_tokens_caps_and_records_clipped() {
        let mut items = vec!["one two three four five".to_string()];
        let budget = Budget {
            limit: None,
            max_tokens: None,
            max_document_tokens: Some(2),
        };
        let report = apply_budget(
            &mut items,
            &budget,
            1,
            |s| s.clone(),
            |s| count_tokens(s),
            |s, max| {
                let (head, omitted) = truncate_to_tokens(s, max);
                if omitted > 0 {
                    *s = head;
                    Some(omitted)
                } else {
                    None
                }
            },
        );
        assert_eq!(items, vec!["one two".to_string()]);
        assert_eq!(report.clipped, vec!["one two".to_string()]);
        assert!(report.is_truncated());
    }
}