Skip to main content

ai_agent/services/
token_estimation.rs

1// Source: /data/home/swei/claudecode/openclaudecode/src/services/tokenEstimation.ts
2//! Token estimation for text.
3//!
4//! Provides token counting similar to claude code's token estimation.
5//! Includes both rough character-based estimation and API-accurate counting
6//! via `/v1/messages/count_tokens`.
7
8use crate::types::Message;
9use serde::{Deserialize, Serialize};
10
11/// Estimated token count with metadata
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct TokenEstimate {
14    pub tokens: usize,
15    pub characters: usize,
16    pub words: usize,
17    pub method: EstimationMethod,
18}
19
20/// Method used for estimation
21#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
22pub enum EstimationMethod {
23    /// Fast estimation using character ratio
24    CharacterRatio,
25    /// Word-based estimation
26    WordBased,
27    /// Exact TikToken estimation (if available)
28    TikToken,
29}
30
31// ============================================================================
32// Translation of claude code's tokenEstimation.ts - strictly line by line
33// ============================================================================
34
35/// Rough token count estimation - matches original TypeScript:
36/// `export function roughTokenCountEstimation(content: string, bytesPerToken: number = 4): number`
37pub fn rough_token_count_estimation(content: &str, bytes_per_token: f64) -> usize {
38    (content.len() as f64 / bytes_per_token).round() as usize
39}
40
41/// Returns bytes-per-token ratio for a given file extension
42/// Matches original TypeScript:
43/// `export function bytesPerTokenForFileType(fileExtension: string): number`
44/// Dense JSON has many single-character tokens which makes ratio closer to 2
45pub fn bytes_per_token_for_file_type(file_extension: &str) -> f64 {
46    match file_extension {
47        "json" | "jsonl" | "jsonc" => 2.0,
48        _ => 4.0,
49    }
50}
51
52/// Like roughTokenCountEstimation but uses more accurate bytes-per-token ratio
53/// when file type is known - matches original TypeScript:
54/// `export function roughTokenCountEstimationForFileType(content: string, fileExtension: string): number`
55pub fn rough_token_count_estimation_for_file_type(content: &str, file_extension: &str) -> usize {
56    rough_token_count_estimation(content, bytes_per_token_for_file_type(file_extension))
57}
58
59/// Estimate tokens for a single message - matches original TypeScript:
60/// `export function roughTokenCountEstimationForMessage(message: {...}): number`
61pub fn rough_token_count_estimation_for_message(message: &Message) -> usize {
62    rough_token_count_estimation_for_content(&message.content)
63}
64
65/// Estimate tokens for message content (string or array) - matches original TypeScript:
66/// `function roughTokenCountEstimationForContent(content: ...): number`
67pub fn rough_token_count_estimation_for_content(content: &str) -> usize {
68    if content.is_empty() {
69        return 0;
70    }
71    rough_token_count_estimation(content, 4.0)
72}
73
74/// Estimate tokens for an array of messages - matches original TypeScript:
75/// `export function roughTokenCountEstimationForMessages(messages: readonly {...}[]): number`
76pub fn rough_token_count_estimation_for_messages(messages: &[Message]) -> usize {
77    messages
78        .iter()
79        .map(|msg| rough_token_count_estimation_for_message(msg))
80        .sum()
81}
82
83// ============================================================================
84// Legacy estimation functions (kept for backward compatibility)
85// ============================================================================
86
87/// Estimate tokens using character ratio method (faster but less accurate)
88/// Average ratio is ~4 characters per token for English
89pub fn estimate_tokens_characters(text: &str) -> TokenEstimate {
90    let characters = text.len();
91    let words = text.split_whitespace().count();
92
93    // Use 4:1 character to token ratio as baseline
94    // Adjust based on text characteristics
95    let ratio = if text.contains("```") {
96        // Code blocks have more characters per token
97        5.5
98    } else if words > 0 {
99        let avg_word_len = characters as f64 / words as f64;
100        if avg_word_len > 8.0 {
101            // Long words = more characters per token
102            5.0
103        } else if avg_word_len < 3.0 {
104            // Short words = fewer characters per token
105            3.5
106        } else {
107            4.0
108        }
109    } else {
110        4.0
111    };
112
113    let tokens = (characters as f64 / ratio).ceil() as usize;
114
115    TokenEstimate {
116        tokens,
117        characters,
118        words,
119        method: EstimationMethod::CharacterRatio,
120    }
121}
122
123/// Estimate tokens using word-based method
124pub fn estimate_tokens_words(text: &str) -> TokenEstimate {
125    let words = text.split_whitespace().count();
126    let characters = text.len();
127
128    // Average ~1.3 words per token for English
129    let tokens = (words as f64 / 1.3).ceil() as usize;
130
131    TokenEstimate {
132        tokens,
133        characters,
134        words,
135        method: EstimationMethod::WordBased,
136    }
137}
138
139/// Estimate tokens using combined method (best balance of speed and accuracy)
140pub fn estimate_tokens(text: &str) -> TokenEstimate {
141    let char_estimate = estimate_tokens_characters(text);
142    let word_estimate = estimate_tokens_words(text);
143
144    // Use the average of both methods for better accuracy
145    let tokens = (char_estimate.tokens + word_estimate.tokens) / 2;
146
147    TokenEstimate {
148        tokens,
149        characters: char_estimate.characters,
150        words: char_estimate.words,
151        method: EstimationMethod::CharacterRatio,
152    }
153}
154
155/// Estimate tokens in messages (handles role/content format)
156pub fn estimate_message_tokens<T: MessageContent>(messages: &[T]) -> usize {
157    messages
158        .iter()
159        .map(|m| {
160            let content = m.content();
161            // Add overhead for role annotation
162            let role_overhead = 4;
163            estimate_tokens(content).tokens + role_overhead
164        })
165        .sum()
166}
167
168/// Estimate tokens in a conversation string
169pub fn estimate_conversation(conversation: &str) -> TokenEstimate {
170    // Count turns by looking for common patterns
171    let turns = conversation
172        .matches("User:")
173        .count()
174        .max(conversation.matches("Assistant:").count());
175
176    // Each turn has overhead for role prefix
177    let turn_overhead = turns * 10;
178
179    let base = estimate_tokens(conversation);
180    TokenEstimate {
181        tokens: base.tokens + turn_overhead,
182        characters: base.characters,
183        words: base.words,
184        method: base.method,
185    }
186}
187
188/// Estimate tokens for tool definitions
189pub fn estimate_tool_definitions(tools: &[ToolDefinition]) -> usize {
190    tools
191        .iter()
192        .map(|t| {
193            let name_tokens = estimate_tokens(&t.name).tokens;
194            let desc_tokens = t
195                .description
196                .as_ref()
197                .map(|d| estimate_tokens(d).tokens)
198                .unwrap_or(0);
199            let params_tokens = estimate_tokens(&t.input_schema).tokens;
200            name_tokens + desc_tokens + params_tokens + 20 // overhead
201        })
202        .sum()
203}
204
205/// Simple message content for estimation
206pub trait MessageContent {
207    fn content(&self) -> &str;
208}
209
210impl MessageContent for String {
211    fn content(&self) -> &str {
212        self.as_str()
213    }
214}
215
216impl MessageContent for &str {
217    fn content(&self) -> &str {
218        self
219    }
220}
221
222/// Message with role
223#[derive(Debug, Clone)]
224pub struct ChatMessage {
225    pub role: String,
226    pub content: String,
227}
228
229impl MessageContent for ChatMessage {
230    fn content(&self) -> &str {
231        &self.content
232    }
233}
234
235/// Tool definition for estimation
236#[derive(Debug, Clone)]
237pub struct ToolDefinition {
238    pub name: String,
239    pub description: Option<String>,
240    pub input_schema: String,
241}
242
243/// Calculate padding needed for context window
244/// Returns the amount of extra input tokens that could fit given the output token budget
245pub fn calculate_padding(input_tokens: usize, max_tokens: usize, context_limit: usize) -> usize {
246    // Calculate how much room is left for input given the output budget
247    let available_for_input = context_limit.saturating_sub(max_tokens);
248    if input_tokens < available_for_input {
249        available_for_input.saturating_sub(input_tokens)
250    } else {
251        0
252    }
253}
254
255/// Estimate if content fits in context
256pub fn fits_in_context(content_tokens: usize, max_tokens: usize, context_limit: usize) -> bool {
257    content_tokens + max_tokens <= context_limit
258}
259
260/// Token encoding utilities
261pub mod encoding {
262    /// Common tokenization patterns
263    pub const CHARS_PER_TOKEN_EN: f64 = 4.0;
264    pub const CHARS_PER_TOKEN_CODE: f64 = 5.5;
265    pub const CHARS_PER_TOKEN_CJK: f64 = 2.0; // Chinese, Japanese, Korean
266
267    /// Detect if text is primarily code
268    pub fn is_code(text: &str) -> bool {
269        let code_indicators = [
270            "```", "function", "class ", "def ", "const ", "let ", "var ", "import ",
271        ];
272        code_indicators.iter().any(|i| text.contains(i))
273    }
274
275    /// Detect if text is primarily CJK
276    pub fn is_cjk(text: &str) -> bool {
277        text.chars().any(|c| {
278            (c >= '\u{4E00}' && c <= '\u{9FFF}') ||  // CJK Unified Ideographs
279            (c >= '\u{3040}' && c <= '\u{309F}') ||  // Hiragana
280            (c >= '\u{30A0}' && c <= '\u{30FF}') ||  // Katakana
281            (c >= '\u{AC00}' && c <= '\u{D7AF}') // Korean
282        })
283    }
284
285    /// Get appropriate chars per token ratio
286    pub fn chars_per_token(text: &str) -> f64 {
287        if is_code(text) {
288            super::encoding::CHARS_PER_TOKEN_CODE
289        } else if is_cjk(text) {
290            super::encoding::CHARS_PER_TOKEN_CJK
291        } else {
292            super::encoding::CHARS_PER_TOKEN_EN
293        }
294    }
295}
296
297// ============================================================================
298// count_tokens API: /v1/messages/count_tokens
299// Translated from TypeScript countMessagesTokensWithAPI / countTokensWithAPI
300// ============================================================================
301
302/// Minimum thinking budget for token counting when messages contain thinking blocks
303/// API constraint: max_tokens must be greater than thinking.budget_tokens
304pub const TOKEN_COUNT_THINKING_BUDGET: u32 = 1024;
305
306/// Max tokens for token counting requests (used when thinking is enabled)
307pub const TOKEN_COUNT_MAX_TOKENS: u32 = 2048;
308
309/// Error type for count_tokens API operations
310#[derive(Debug, Clone)]
311pub struct CountTokensError(pub String);
312
313impl std::fmt::Display for CountTokensError {
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        write!(f, "count_tokens error: {}", self.0)
316    }
317}
318
319impl std::error::Error for CountTokensError {}
320
321/// Check if messages contain thinking or redacted_thinking blocks
322/// Matches TypeScript: hasThinkingBlocks()
323fn has_thinking_blocks(messages: &[serde_json::Value]) -> bool {
324    for msg in messages {
325        let role = msg.get("role").and_then(|r| r.as_str());
326        if role == Some("assistant") {
327            if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
328                for block in content {
329                    let block_type = block.get("type").and_then(|t| t.as_str());
330                    if block_type == Some("thinking") || block_type == Some("redacted_thinking") {
331                        return true;
332                    }
333                }
334            }
335        }
336    }
337    false
338}
339
340/// Get the base API URL from environment, defaulting to Anthropic API
341fn get_base_url() -> String {
342    std::env::var("AI_CODE_API_URL")
343        .or_else(|_| std::env::var("AI_CODE_BASE_URL"))
344        .unwrap_or_else(|_| "https://api.anthropic.com".to_string())
345}
346
347/// Get the API key from environment
348fn get_api_key() -> Option<String> {
349    std::env::var("AI_CODE_API_KEY")
350        .ok()
351        .or_else(|| std::env::var("AI_AUTH_TOKEN").ok())
352        .or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
353}
354
355/// Check if using Vertex provider
356fn is_using_vertex() -> bool {
357    let is_truthy = |v: Option<String>| {
358        v.map(|x| x == "1" || x.to_lowercase() == "true")
359            .unwrap_or(false)
360    };
361    is_truthy(std::env::var("AI_CODE_USE_VERTEX").ok())
362}
363
364/// Normalize model string for API (strip display wrappers)
365fn normalize_model_string_for_api(model: &str) -> String {
366    // Strip common display wrappers like "claude/" prefix if present
367    model.trim_start_matches("claude/").to_string()
368}
369
370/// Count tokens via the Anthropic `/v1/messages/count_tokens` API.
371///
372/// Matches TypeScript: `countMessagesTokensWithAPI(messages, tools)`
373///
374/// # Arguments
375/// * `api_key` - Anthropic API key (or None to read from env)
376/// * `base_url` - Base API URL (or None to read from env)
377/// * `model` - The model to use for counting
378/// * `messages` - Messages in API format (already serialized as JSON)
379/// * `tools` - Optional tool definitions in Anthropic API format
380/// * `betas` - Optional beta headers to include
381///
382/// # Returns
383/// `Some(input_tokens)` on success, `None` on any error (matching TS behavior)
384pub async fn count_messages_tokens_with_api(
385    api_key: Option<String>,
386    base_url: Option<String>,
387    model: &str,
388    messages: &[serde_json::Value],
389    tools: Option<&[serde_json::Value]>,
390    betas: Option<&[String]>,
391) -> Option<u64> {
392    let api_key = api_key.or_else(get_api_key)?;
393    let base_url = base_url.or_else(|| Some(get_base_url()))?;
394    let client = reqwest::Client::new();
395
396    // Build request body
397    let contains_thinking = has_thinking_blocks(messages);
398    let messages_to_send: Vec<serde_json::Value> = if messages.is_empty() {
399        // When we pass tools and no messages, we need a dummy message
400        vec![serde_json::json!({ "role": "user", "content": "foo" })]
401    } else {
402        messages.to_vec()
403    };
404    let mut body = serde_json::json!({
405        "model": normalize_model_string_for_api(model),
406        "messages": messages_to_send
407    });
408
409    // Add tools if provided
410    if let Some(tools_list) = tools {
411        if !tools_list.is_empty() {
412            body["tools"] = serde_json::json!(tools_list);
413        }
414    }
415
416    // Add betas (filter for Vertex if needed)
417    if let Some(betas_list) = betas {
418        let filtered = if is_using_vertex() {
419            let allowed = crate::constants::betas::get_vertex_count_tokens_allowed_betas();
420            betas_list
421                .iter()
422                .filter(|b| allowed.contains(b.as_str()))
423                .cloned()
424                .collect::<Vec<String>>()
425        } else {
426            betas_list.to_vec()
427        };
428        if !filtered.is_empty() {
429            body["betas"] = serde_json::json!(filtered);
430        }
431    }
432
433    // Enable thinking if messages contain thinking blocks
434    if contains_thinking {
435        body["thinking"] = serde_json::json!({
436            "type": "enabled",
437            "budget_tokens": TOKEN_COUNT_THINKING_BUDGET
438        });
439        body["max_tokens"] = serde_json::json!(TOKEN_COUNT_MAX_TOKENS);
440    }
441
442    let url = format!("{}/v1/messages/count_tokens", base_url.trim_end_matches('/'));
443
444    let resp = client
445        .post(&url)
446        .header("x-api-key", &api_key)
447        .header("anthropic-version", "2023-06-01")
448        .header("content-type", "application/json")
449        .json(&body)
450        .send()
451        .await;
452
453    let resp = match resp {
454        Ok(r) => r,
455        Err(e) => {
456            log::debug!("count_tokens API request failed: {}", e);
457            return None;
458        }
459    };
460
461    if !resp.status().is_success() {
462        let status = resp.status();
463        let body_text = resp.text().await.unwrap_or_default();
464        log::debug!("count_tokens API error {}: {}", status, body_text);
465        return None;
466    }
467
468    let json: serde_json::Value = match resp.json().await {
469        Ok(j) => j,
470        Err(e) => {
471            log::debug!("count_tokens failed to parse response: {}", e);
472            return None;
473        }
474    };
475
476    json.get("input_tokens")
477        .and_then(|v| v.as_u64())
478        .or_else(|| {
479            // Vertex / Bedrock may return different shapes
480            log::debug!("count_tokens response missing input_tokens field: {}", json);
481            None
482        })
483}
484
485/// Convenience wrapper: count tokens for a single text content string.
486///
487/// Matches TypeScript: `countTokensWithAPI(content)`
488///
489/// # Arguments
490/// * `content` - The text content to count
491/// * `api_key` - API key (or None to read from env)
492/// * `base_url` - Base API URL (or None to read from env)
493/// * `model` - The model to use for counting
494///
495/// # Returns
496/// `Some(tokens)` on success, `None` on error. Returns `Some(0)` for empty content.
497pub async fn count_tokens_with_api(
498    content: &str,
499    api_key: Option<String>,
500    base_url: Option<String>,
501    model: &str,
502) -> Option<u64> {
503    // API doesn't accept empty messages
504    if content.is_empty() {
505        return Some(0);
506    }
507
508    let message = serde_json::json!({
509        "role": "user",
510        "content": content
511    });
512
513    count_messages_tokens_with_api(api_key, base_url, model, &[message], None, None).await
514}
515
516/// Fallback token counting via a real `messages.create` call with a fast (Haiku) model.
517///
518/// Matches TypeScript: `countTokensViaHaikuFallback(messages, tools)`
519///
520/// Makes an actual API call with `max_tokens: 1` (or TOKEN_COUNT_MAX_TOKENS if thinking
521/// is needed) and reads the `usage.input_tokens` from the response.
522///
523/// # Returns
524/// `Some(input_tokens)` on success, `None` on error.
525pub async fn count_tokens_via_haiku_fallback(
526    api_key: Option<String>,
527    base_url: Option<String>,
528    messages: &[serde_json::Value],
529    tools: Option<&[serde_json::Value]>,
530) -> Option<u64> {
531    let api_key = api_key.or_else(get_api_key)?;
532    let base_url = base_url.or_else(|| Some(get_base_url()))?;
533    let client = reqwest::Client::new();
534
535    let contains_thinking = has_thinking_blocks(messages);
536
537    // Use Haiku for token counting by default (faster / cheaper).
538    // Use Sonnet if messages contain thinking blocks and on Vertex/Bedrock.
539    let model = if contains_thinking && is_using_vertex() {
540        crate::utils::model::get_default_sonnet_model()
541    } else {
542        crate::utils::model::get_small_fast_model()
543    };
544
545    let messages_to_send: Vec<serde_json::Value> = if messages.is_empty() {
546        vec![serde_json::json!({ "role": "user", "content": "count" })]
547    } else {
548        messages.to_vec()
549    };
550    let mut body = serde_json::json!({
551        "model": normalize_model_string_for_api(&model),
552        "max_tokens": if contains_thinking { TOKEN_COUNT_MAX_TOKENS } else { 1 },
553        "messages": messages_to_send
554    });
555
556    // Add tools if provided
557    if let Some(tools_list) = tools {
558        if !tools_list.is_empty() {
559            body["tools"] = serde_json::json!(tools_list);
560        }
561    }
562
563    // Enable thinking if messages contain thinking blocks
564    if contains_thinking {
565        body["thinking"] = serde_json::json!({
566            "type": "enabled",
567            "budget_tokens": TOKEN_COUNT_THINKING_BUDGET
568        });
569    }
570
571    let url = format!("{}/v1/messages", base_url.trim_end_matches('/'));
572
573    let resp = client
574        .post(&url)
575        .header("x-api-key", &api_key)
576        .header("anthropic-version", "2023-06-01")
577        .header("content-type", "application/json")
578        .json(&body)
579        .send()
580        .await;
581
582    let resp = match resp {
583        Ok(r) => r,
584        Err(e) => {
585            log::debug!("count_tokens Haiku fallback request failed: {}", e);
586            return None;
587        }
588    };
589
590    if !resp.status().is_success() {
591        let status = resp.status();
592        let body_text = resp.text().await.unwrap_or_default();
593        log::debug!("count_tokens Haiku fallback error {}: {}", status, body_text);
594        return None;
595    }
596
597    let json: serde_json::Value = match resp.json().await {
598        Ok(j) => j,
599        Err(e) => {
600            log::debug!("count_tokens Haiku fallback parse error: {}", e);
601            return None;
602        }
603    };
604
605    // Extract usage: input_tokens + cache_creation + cache_read
606    let usage = json.get("usage")?;
607    let input_tokens = usage.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0);
608    let cache_creation = usage
609        .get("cache_creation_input_tokens")
610        .and_then(|v| v.as_u64())
611        .unwrap_or(0);
612    let cache_read = usage
613        .get("cache_read_input_tokens")
614        .and_then(|v| v.as_u64())
615        .unwrap_or(0);
616
617    Some(input_tokens + cache_creation + cache_read)
618}
619
620/// Orchestrator: try API count_tokens first, fall back to Haiku if it fails.
621///
622/// Matches TypeScript: `countTokensWithFallback(messages, tools)` from analyzeContext.ts
623///
624/// # Arguments
625/// * `api_key` - API key (or None to read from env)
626/// * `base_url` - Base API URL (or None to read from env)
627/// * `model` - The model to use for counting (primary API call)
628/// * `messages` - Messages in API format
629/// * `tools` - Optional tool definitions in API format
630///
631/// # Returns
632/// `Some(input_tokens)` on success, `None` if both API and fallback fail.
633pub async fn count_tokens_with_fallback(
634    api_key: Option<String>,
635    base_url: Option<String>,
636    model: &str,
637    messages: &[serde_json::Value],
638    tools: Option<&[serde_json::Value]>,
639) -> Option<u64> {
640    // Try primary count_tokens API first
641    if let Some(count) = count_messages_tokens_with_api(api_key.clone(), base_url.clone(), model, messages, tools, None).await {
642        return Some(count);
643    }
644    log::debug!(
645        "count_tokens API returned null, trying Haiku fallback ({} tools)",
646        tools.map(|t| t.len()).unwrap_or(0)
647    );
648
649    // Haiku fallback
650    if let Some(count) = count_tokens_via_haiku_fallback(api_key, base_url, messages, tools).await {
651        return Some(count);
652    }
653    log::debug!("count_tokens Haiku fallback also returned null");
654    None
655}
656
657// ============================================================================
658// FileReadTool token budget validation
659// Translated from TypeScript validateContentTokens
660// ============================================================================
661
662/// Maximum token limit for file read tool output
663pub const DEFAULT_FILE_READ_MAX_TOKENS: u64 = 25_000;
664
665/// Error thrown when file content exceeds token budget
666#[derive(Debug, Clone)]
667pub struct MaxFileReadTokenExceededError {
668    pub token_count: u64,
669    pub max_tokens: u64,
670}
671
672impl std::fmt::Display for MaxFileReadTokenExceededError {
673    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
674        write!(
675            f,
676            "File content ({} tokens) exceeds maximum allowed tokens ({}). Use offset and limit parameters to read specific portions of the file, or search for specific content instead of reading the whole file.",
677            self.token_count, self.max_tokens
678        )
679    }
680}
681
682impl std::error::Error for MaxFileReadTokenExceededError {}
683
684/// Get the default file reading max tokens limit from environment or default.
685/// Matches TypeScript: `getDefaultFileReadingLimits().maxTokens`
686pub fn get_default_file_read_max_tokens() -> u64 {
687    std::env::var("AI_CODE_FILE_READ_MAX_OUTPUT_TOKENS")
688        .ok()
689        .and_then(|v| v.parse().ok())
690        .unwrap_or(DEFAULT_FILE_READ_MAX_TOKENS)
691}
692
693/// Validate that file content does not exceed the token budget.
694///
695/// Two-phase approach matching TypeScript:
696/// 1. Cheap rough estimate — if under `max_tokens / 4`, short-circuit and return OK
697/// 2. If rough estimate exceeds threshold, call count_tokens API for exact count
698/// 3. Throw if exact count exceeds limit
699///
700/// # Arguments
701/// * `content` - The file content to validate
702/// * `ext` - File extension (for bytes-per-token ratio)
703/// * `max_tokens` - Maximum allowed tokens (or None for default limit)
704/// * `api_key` - API key for exact counting (or None to read from env)
705/// * `base_url` - Base API URL (or None to read from env)
706/// * `model` - Model for count_tokens API call
707pub async fn validate_content_tokens(
708    content: &str,
709    ext: &str,
710    max_tokens: Option<u64>,
711    api_key: Option<String>,
712    base_url: Option<String>,
713    model: &str,
714) -> Result<(), MaxFileReadTokenExceededError> {
715    let effective_max = max_tokens.unwrap_or(get_default_file_read_max_tokens());
716
717    // Phase 1: cheap rough estimate
718    let rough_estimate = rough_token_count_estimation_for_file_type(content, ext) as u64;
719    if rough_estimate <= effective_max / 4 {
720        return Ok(());
721    }
722
723    // Phase 2: API-based exact count
724    let exact_count = count_tokens_with_api(content, api_key, base_url, model).await;
725    let effective_count = exact_count.unwrap_or(rough_estimate);
726
727    if effective_count > effective_max {
728        Err(MaxFileReadTokenExceededError {
729            token_count: effective_count,
730            max_tokens: effective_max,
731        })
732    } else {
733        Ok(())
734    }
735}
736
737#[cfg(test)]
738mod tests {
739    use super::*;
740    use crate::types::MessageRole;
741
742    // ============================================================================
743    // Tests for the translated TypeScript functions
744    // ============================================================================
745
746    #[test]
747    fn test_rough_token_count_estimation() {
748        // "Hello world" = 11 chars, 11/4 = 2.75 rounds to 3
749        assert_eq!(rough_token_count_estimation("Hello world", 4.0), 3);
750        // 100 chars / 4 = 25 tokens
751        assert_eq!(rough_token_count_estimation(&"a".repeat(100), 4.0), 25);
752    }
753
754    #[test]
755    fn test_bytes_per_token_for_file_type() {
756        assert_eq!(bytes_per_token_for_file_type("json"), 2.0);
757        assert_eq!(bytes_per_token_for_file_type("jsonl"), 2.0);
758        assert_eq!(bytes_per_token_for_file_type("rs"), 4.0);
759        assert_eq!(bytes_per_token_for_file_type("txt"), 4.0);
760    }
761
762    #[test]
763    fn test_rough_token_count_estimation_for_file_type() {
764        // JSON: 100 chars / 2 = 50 tokens
765        assert_eq!(
766            rough_token_count_estimation_for_file_type(&"a".repeat(100), "json"),
767            50
768        );
769        // Rust: 100 chars / 4 = 25 tokens
770        assert_eq!(
771            rough_token_count_estimation_for_file_type(&"a".repeat(100), "rs"),
772            25
773        );
774    }
775
776    #[test]
777    fn test_rough_token_count_estimation_for_content() {
778        assert_eq!(rough_token_count_estimation_for_content(""), 0);
779        // "Hello" = 5 chars, 5/4 = 1.25 rounds to 1
780        assert_eq!(rough_token_count_estimation_for_content("Hello"), 1);
781    }
782
783    #[test]
784    fn test_rough_token_count_estimation_for_message() {
785        let msg = crate::types::Message {
786            role: MessageRole::User,
787            content: "Hello world".to_string(),
788            ..Default::default()
789        };
790        // "Hello world" = 11 chars, 11/4 = 2.75 rounds to 3
791        assert_eq!(rough_token_count_estimation_for_message(&msg), 3);
792    }
793
794    #[test]
795    fn test_rough_token_count_estimation_for_messages() {
796        let messages = vec![
797            crate::types::Message {
798                role: MessageRole::User,
799                content: "Hello".to_string(),
800                ..Default::default()
801            },
802            crate::types::Message {
803                role: MessageRole::Assistant,
804                content: "Hi there".to_string(),
805                ..Default::default()
806            },
807        ];
808        // "Hello" = 5 chars / 4 = 1.25 -> 1 token
809        // "Hi there" = 8 chars / 4 = 2 tokens
810        // Total = 3 tokens
811        assert_eq!(rough_token_count_estimation_for_messages(&messages), 3);
812    }
813
814    // ============================================================================
815    // Tests for legacy estimation functions
816    // ============================================================================
817
818    #[test]
819    fn test_estimate_tokens_characters() {
820        let result = estimate_tokens_characters("Hello, world!");
821        assert!(result.tokens >= 3);
822        assert_eq!(result.characters, 13);
823    }
824
825    #[test]
826    fn test_estimate_tokens_words() {
827        let result = estimate_tokens_words("Hello world this is a test");
828        assert!(result.tokens > 0);
829        assert_eq!(result.words, 6);
830    }
831
832    #[test]
833    fn test_estimate_tokens() {
834        let result = estimate_tokens("The quick brown fox jumps over the lazy dog");
835        assert!(result.tokens > 0);
836    }
837
838    #[test]
839    fn test_estimate_conversation() {
840        let conv = "User: Hello\nAssistant: Hi there!\nUser: How are you?";
841        let result = estimate_conversation(conv);
842        assert!(result.tokens > 0);
843    }
844
845    #[test]
846    fn test_estimate_tool_definitions() {
847        let tools = vec![ToolDefinition {
848            name: "Read".to_string(),
849            description: Some("Read a file".to_string()),
850            input_schema: r#"{"type":"object","properties":{"path":{"type":"string"}}}"#
851                .to_string(),
852        }];
853        let tokens = estimate_tool_definitions(&tools);
854        assert!(tokens > 0);
855    }
856
857    #[test]
858    fn test_calculate_padding() {
859        assert_eq!(calculate_padding(1000, 500, 2000), 500);
860        assert_eq!(calculate_padding(1500, 500, 2000), 0);
861    }
862
863    #[test]
864    fn test_fits_in_context() {
865        assert!(fits_in_context(1000, 500, 2000));
866        assert!(!fits_in_context(1600, 500, 2000));
867    }
868
869    #[test]
870    fn test_encoding_chars_per_token() {
871        assert_eq!(
872            encoding::chars_per_token("Hello world"),
873            encoding::CHARS_PER_TOKEN_EN
874        );
875        assert_eq!(
876            encoding::chars_per_token("function test() {}"),
877            encoding::CHARS_PER_TOKEN_CODE
878        );
879    }
880
881    #[test]
882    fn test_is_code() {
883        assert!(encoding::is_code("function foo() { return 1; }"));
884        assert!(!encoding::is_code("Hello world"));
885    }
886
887    #[test]
888    fn test_is_cjk() {
889        assert!(encoding::is_cjk("你好世界"));
890        assert!(!encoding::is_cjk("Hello world"));
891    }
892
893    #[test]
894    fn test_message_content_trait() {
895        let msg = ChatMessage {
896            role: "user".to_string(),
897            content: "Hello".to_string(),
898        };
899        assert_eq!(msg.content(), "Hello");
900    }
901
902    // ============================================================================
903    // Tests for count_tokens API helpers
904    // ============================================================================
905
906    #[test]
907    fn test_has_thinking_blocks_detects_thinking() {
908        let messages = vec![serde_json::json!({
909            "role": "assistant",
910            "content": [
911                { "type": "thinking", "thinking": "let me think..." },
912                { "type": "text", "text": "I think the answer is 42" }
913            ]
914        })];
915        assert!(has_thinking_blocks(&messages));
916    }
917
918    #[test]
919    fn test_has_thinking_blocks_detects_redacted_thinking() {
920        let messages = vec![serde_json::json!({
921            "role": "assistant",
922            "content": [
923                { "type": "redacted_thinking", "data": "xxx" }
924            ]
925        })];
926        assert!(has_thinking_blocks(&messages));
927    }
928
929    #[test]
930    fn test_has_thinking_blocks_no_thinking() {
931        let messages = vec![
932            serde_json::json!({ "role": "user", "content": "Hello" }),
933            serde_json::json!({ "role": "assistant", "content": "Hi there" }),
934        ];
935        assert!(!has_thinking_blocks(&messages));
936    }
937
938    #[test]
939    fn test_has_thinking_blocks_empty() {
940        let messages: Vec<serde_json::Value> = vec![];
941        assert!(!has_thinking_blocks(&messages));
942    }
943
944    #[test]
945    fn test_has_thinking_blocks_tool_use_only() {
946        let messages = vec![serde_json::json!({
947            "role": "assistant",
948            "content": [
949                { "type": "tool_use", "id": "tool_1", "name": "Read", "input": {} }
950            ]
951        })];
952        assert!(!has_thinking_blocks(&messages));
953    }
954
955    #[test]
956    fn test_normalize_model_string_for_api() {
957        assert_eq!(normalize_model_string_for_api("claude/sonnet-4-6"), "sonnet-4-6");
958        assert_eq!(
959            normalize_model_string_for_api("claude-sonnet-4-6"),
960            "claude-sonnet-4-6"
961        );
962    }
963
964    #[test]
965    fn test_token_count_constants() {
966        // max_tokens must be greater than thinking budget
967        assert!(TOKEN_COUNT_MAX_TOKENS > TOKEN_COUNT_THINKING_BUDGET);
968        assert_eq!(TOKEN_COUNT_THINKING_BUDGET, 1024);
969        assert_eq!(TOKEN_COUNT_MAX_TOKENS, 2048);
970    }
971
972    #[test]
973    fn test_default_file_read_max_tokens() {
974        assert_eq!(get_default_file_read_max_tokens(), 25_000);
975    }
976
977    #[test]
978    fn test_max_file_read_error_display() {
979        let err = MaxFileReadTokenExceededError {
980            token_count: 30_000,
981            max_tokens: 25_000,
982        };
983        let msg = format!("{}", err);
984        assert!(msg.contains("30000"));
985        assert!(msg.contains("25000"));
986        assert!(msg.contains("tokens"));
987    }
988
989    #[tokio::test]
990    async fn test_validate_content_tokens_short_content() {
991        // Content under max_tokens / 4 → should pass without API call
992        let result = validate_content_tokens(
993            "short content",
994            "txt",
995            Some(25_000),
996            None, // no API key
997            None,
998            "claude-sonnet-4-6",
999        )
1000        .await;
1001        assert!(result.is_ok());
1002    }
1003}