use anyhow::{anyhow, Context, Result};
use clap::{Parser, Subcommand};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::io::Write;
use std::fs::OpenOptions;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::fs;
#[derive(Parser)]
#[command(version, about = "Model tokenizer registry and token counting CLI")]
struct Cli {
#[arg(long, default_value = "https://aihubmix.com/api/v1/models?type=llm")]
models_url: String,
#[arg(long, default_value = "./token_registry.json")]
local_store: PathBuf,
#[arg(long, default_value = "./token_count_log.jsonl")]
local_log: PathBuf,
#[command(subcommand)]
cmd: CommandKind,
}
#[derive(Subcommand)]
enum CommandKind {
Sync {
#[arg(long)]
force_reclassify: bool,
#[arg(long, default_value_t = true)]
only_text: bool,
#[arg(long)]
no_probe: bool,
},
Count {
#[arg(long)]
model: String,
#[arg(long)]
text: Option<String>,
#[arg(long)]
text_file: Option<PathBuf>,
#[arg(long)]
messages_file: Option<PathBuf>,
#[arg(long, default_value_t = true)]
log_audit: bool,
},
List {
#[arg(long)]
strategy: Option<String>,
#[arg(long, default_value_t = 100)]
limit: u64,
},
Validate {
#[arg(long, default_value_t = 70)]
min_confidence: u8,
#[arg(long, default_value_t = true)]
only_text: bool,
},
Verify {
#[arg(long, default_value_t = true)]
only_text: bool,
#[arg(long)]
all: bool,
#[arg(long, default_value_t = 70)]
min_confidence: u8,
#[arg(long, default_value_t = false)]
run_count: bool,
#[arg(long, default_value = "test")]
sample_text: String,
#[arg(long, default_value = "./validation_report.json")]
report_json: PathBuf,
#[arg(long, default_value = "./validation_report.csv")]
report_csv: PathBuf,
#[arg(long)]
max_models: Option<usize>,
#[arg(long)]
offset: Option<usize>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ApiModel {
#[serde(rename = "model_id")]
model_key: String,
#[serde(default)]
model_name: Option<String>,
#[serde(default)]
desc: Option<String>,
#[serde(default)]
developer_id: Option<i32>,
#[serde(default)]
types: Option<String>,
#[serde(default)]
features: Option<String>,
#[serde(default)]
input_modalities: Option<String>,
#[serde(default)]
max_output: Option<i64>,
#[serde(default)]
context_length: Option<i64>,
#[serde(default)]
endpoints: Option<String>,
}
#[derive(Clone, Debug)]
struct ClassifyResult {
provider_code: String,
strategy_code: String,
fallback_strategy_code: String,
confidence: u8,
family_hint: Option<String>,
notes: String,
}
#[derive(Debug)]
struct CountInput {
kind: String,
text: String,
preview: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct LocalModelRow {
model_key: String,
model_name: String,
provider_code: String,
strategy_code: String,
fallback_strategy_code: String,
tokenizer_family_hint: Option<String>,
tokenizer_notes: String,
tokenizer_confidence: u8,
#[serde(default)]
developer_id: Option<i32>,
is_text_model: bool,
raw_payload: Value,
input_modalities: String,
input_modalities_raw: Option<String>,
features: String,
}
#[derive(Debug, Serialize, Clone)]
struct VerifyRecord {
model_key: String,
model_name: String,
provider_code: String,
strategy_code: String,
strategy_type: String,
tokenizer_confidence: u8,
verified: bool,
verify_source: String,
estimate_tokens: Option<i64>,
error: Option<String>,
note: String,
}
#[derive(Debug, Serialize)]
struct CountResult {
model: String,
strategy: String,
provider: String,
estimate_tokens: i64,
estimate_source: String,
billed_input_tokens: Option<i64>,
billed_output_tokens: Option<i64>,
extra: Option<Value>,
error: Option<String>,
}
#[tokio::main]
async fn main() -> Result<()> {
dotenvy::dotenv().ok();
let cli = Cli::parse();
let store_path = normalize_store_path(cli.local_store);
let log_path = normalize_log_path(cli.local_log);
match cli.cmd {
CommandKind::Sync {
force_reclassify,
only_text,
no_probe,
} => sync_models_local(&cli.models_url, &store_path, force_reclassify, only_text, no_probe).await?,
CommandKind::Count {
model,
text,
text_file,
messages_file,
log_audit,
} => {
let input = collect_count_input(text, text_file, messages_file).await?;
let res = count_for_model_local(&store_path, &log_path, &model, &input, log_audit).await?;
println!("{}", serde_json::to_string_pretty(&res)?);
}
CommandKind::List { strategy, limit } => list_models_local(&store_path, strategy, limit)?,
CommandKind::Validate {
min_confidence,
only_text,
} => validate_mapping_local(&store_path, min_confidence, only_text)?,
CommandKind::Verify {
only_text,
all,
min_confidence,
run_count,
sample_text,
report_json,
report_csv,
max_models,
offset,
} => {
verify_models_local(
&store_path,
only_text,
all,
min_confidence,
run_count,
&sample_text,
&report_json,
&report_csv,
max_models,
offset,
)
.await?
}
}
Ok(())
}
fn normalize_store_path(path: PathBuf) -> PathBuf {
let default = Path::new("./token_registry.json");
let legacy = Path::new("./aihubmix_token_registry.json");
if path == default && !default.exists() && legacy.exists() {
legacy.to_path_buf()
} else {
path
}
}
fn normalize_log_path(path: PathBuf) -> PathBuf {
let default = Path::new("./token_count_log.jsonl");
let legacy = Path::new("./aihubmix_token_count_log.jsonl");
if path == default && !default.exists() && legacy.exists() {
legacy.to_path_buf()
} else {
path
}
}
fn load_local_store(store_path: &Path) -> Result<Vec<LocalModelRow>> {
if !store_path.exists() {
return Ok(Vec::new());
}
let raw = std::fs::read_to_string(store_path).context("read local store")?;
let mut rows = if raw.trim().is_empty() {
Vec::new()
} else {
serde_json::from_str(&raw)?
};
rows.sort_by(|a: &LocalModelRow, b: &LocalModelRow| a.model_key.cmp(&b.model_key));
Ok(rows)
}
fn save_local_store(store_path: &Path, mut rows: Vec<LocalModelRow>) -> Result<()> {
rows.sort_by(|a, b| a.model_key.cmp(&b.model_key));
rows.dedup_by(|a, b| a.model_key == b.model_key);
let raw = serde_json::to_string_pretty(&rows)?;
if let Some(parent) = store_path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)
.with_context(|| format!("create store dir {:?}", parent))?;
}
}
std::fs::write(store_path, raw).context("write local store")?;
Ok(())
}
async fn sync_models_local(
url: &str,
store_path: &Path,
force_reclassify: bool,
only_text: bool,
no_probe: bool,
) -> Result<()> {
let client = reqwest::Client::new();
let mut req = client
.get(url)
.header("Accept", "application/json");
if let Some(key) = first_env(&["MODEL_INDEX_API_KEY", "AIHUBMIX_API_KEY"]) {
req = req.header("x-api-key", key);
}
let raw = req
.send()
.await?
.error_for_status()?
.text()
.await?;
let models = parse_models_payload(&raw)?;
let mut rows = load_local_store(store_path)?;
let mut upserted = 0usize;
let mut skipped = 0usize;
for m in models {
if only_text && !is_text_model(&m) {
skipped += 1;
continue;
}
let mut c = classify_model(&m);
if should_probe_source_route(&m, &c, force_reclassify) && !no_probe {
if let Some((prov, strat, conf, notes, family)) =
probe_open_source_hint(&m.model_key, m.developer_id).await
{
c.provider_code = prov;
c.strategy_code = strat;
c.confidence = conf;
c.notes = notes;
c.family_hint = family;
}
}
let row = LocalModelRow {
model_key: m.model_key.clone(),
model_name: m.model_name.clone().unwrap_or_default(),
provider_code: c.provider_code.clone(),
strategy_code: c.strategy_code.clone(),
fallback_strategy_code: c.fallback_strategy_code.clone(),
tokenizer_family_hint: c.family_hint.clone(),
tokenizer_notes: c.notes.clone(),
tokenizer_confidence: c.confidence,
is_text_model: is_text_model(&m),
raw_payload: serde_json::to_value(&m)?,
input_modalities: normalize_csv(&m.input_modalities.clone().unwrap_or_default()),
input_modalities_raw: m.input_modalities.clone(),
features: m.features.clone().unwrap_or_default(),
developer_id: m.developer_id,
};
if let Some(existing) = rows.iter_mut().find(|r| r.model_key == m.model_key) {
*existing = row;
} else {
rows.push(row);
}
upserted += 1;
}
save_local_store(store_path, rows)?;
println!("sync done. upserted: {upserted}, skipped: {skipped}");
Ok(())
}
fn local_strategy_type(strategy_code: &str) -> &'static str {
match strategy_code {
"local_tiktoken_openai" | "local_hf_transformers" => "local_python",
"provider_anthropic_count_tokens"
| "provider_gemini_count_tokens"
| "provider_cohere_tokenize"
| "provider_xai_tokenize_text" => {
"provider_api"
}
"provider_private_api" => "private_api",
"usage_only_private" => "usage_only",
_ => "unknown",
}
}
fn append_count_log_local(
log_path: &Path,
row: &LocalModelRow,
input: &CountInput,
res: &CountResult,
) -> Result<()> {
if let Some(parent) = log_path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)?;
}
}
let ts = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
let entry = json!({
"model_key": row.model_key,
"provider_code": row.provider_code,
"strategy_code": row.strategy_code,
"input_kind": input.kind,
"input_preview": input.preview,
"estimate_tokens": res.estimate_tokens,
"estimate_source": res.estimate_source,
"billed_input_tokens": res.billed_input_tokens,
"billed_output_tokens": res.billed_output_tokens,
"extra": res.extra,
"error_message": res.error,
"counted_at": ts,
});
let mut f = OpenOptions::new()
.create(true)
.append(true)
.open(log_path)
.context("open local log file")?;
writeln!(f, "{}", serde_json::to_string(&entry)?)?;
Ok(())
}
async fn count_for_model_local(
store_path: &Path,
log_path: &Path,
model: &str,
input: &CountInput,
log_audit: bool,
) -> Result<CountResult> {
let rows = load_local_store(store_path)?;
let row = rows
.into_iter()
.find(|r| r.model_key == model)
.ok_or_else(|| anyhow!("model not found in local registry, run sync first: {}", model))?;
let (tokens, billed_in, billed_out, details, api_url, err, source) = match row.strategy_code.as_str() {
"local_tiktoken_openai" => (
count_with_python("tiktoken", &row.model_key, &input.text).await?,
None,
None,
None,
None,
None,
"local_tiktoken".to_string(),
),
"local_hf_transformers" => {
let repo = row
.tokenizer_family_hint
.as_deref()
.and_then(|x| x.strip_prefix("hf:"))
.unwrap_or(&row.model_key);
(
count_with_python("hf", repo, &input.text).await?,
None,
None,
None,
None,
None,
"local_hf".to_string(),
)
}
"provider_anthropic_count_tokens" => match count_with_anthropic(&row.model_key, &input.text, input.kind.as_str()).await
{
Ok((n, b_in, b_out, d, endpoint)) => (
n,
Some(b_in),
Some(b_out),
Some(d),
Some(endpoint),
None,
"anthropic_api".to_string(),
),
Err(e) => (
0,
None,
None,
None,
None,
Some(e.to_string()),
"anthropic_api_failed".to_string(),
),
},
"provider_gemini_count_tokens" => match count_with_gemini(&row.model_key, &input.text).await {
Ok((n, b_in, b_out, d)) => (
n,
Some(b_in),
Some(b_out),
Some(d),
Some("gemini_countTokens".to_string()),
None,
"gemini_api".to_string(),
),
Err(e) => (
0,
None,
None,
None,
Some("gemini_countTokens".to_string()),
Some(e.to_string()),
"gemini_api_failed".to_string(),
),
},
"provider_cohere_tokenize" => match count_with_cohere(&row.model_key, &input.text).await {
Ok((n, b_in, b_out, d)) => (
n,
Some(b_in),
Some(b_out),
Some(d),
Some("https://api.cohere.ai/v1/tokenize".to_string()),
None,
"cohere_api".to_string(),
),
Err(e) => (
0,
None,
None,
None,
Some("https://api.cohere.ai/v1/tokenize".to_string()),
Some(e.to_string()),
"cohere_api_failed".to_string(),
),
},
"provider_xai_tokenize_text" => match count_with_xai(&row.model_key, &input.text).await {
Ok((n, b_in, b_out, d, endpoint)) => (
n,
Some(b_in),
Some(b_out),
Some(d),
Some(endpoint),
None,
"xai_tokenize_api".to_string(),
),
Err(e) => (
0,
None,
None,
None,
Some("https://api.x.ai/v1/tokenize-text".to_string()),
Some(e.to_string()),
"xai_tokenize_api_failed".to_string(),
),
},
_ => {
let err = if row.tokenizer_confidence < 60 {
Some(format!(
"no reproducible tokenizer for strategy {} (confidence={})",
row.strategy_code, row.tokenizer_confidence
))
} else {
Some(format!("strategy {} is not directly executable in local mode", row.strategy_code))
};
(
0,
None,
None,
None,
None,
err,
"unresolved".to_string(),
)
}
};
let res = CountResult {
model: row.model_key.clone(),
strategy: row.strategy_code.clone(),
provider: row.provider_code.clone(),
estimate_tokens: tokens,
estimate_source: source,
billed_input_tokens: billed_in,
billed_output_tokens: billed_out,
extra: details,
error: err,
};
if log_audit {
if let Err(e) = append_count_log_local(log_path, &row, input, &res) {
eprintln!("local audit log failed: {e}");
}
let _ = api_url;
}
Ok(res)
}
fn list_models_local(store_path: &Path, strategy: Option<String>, limit: u64) -> Result<()> {
let mut rows = load_local_store(store_path)?;
if let Some(target) = strategy {
rows.retain(|r| r.strategy_code == target);
}
rows.sort_by(|a, b| a.model_key.cmp(&b.model_key));
if rows.is_empty() {
println!("no rows");
return Ok(());
}
for row in rows.into_iter().take(limit as usize) {
println!(
"{} | {} | {} | {} | {}",
row.model_key,
row.model_name,
row.provider_code,
row.strategy_code,
row.tokenizer_confidence
);
}
Ok(())
}
fn validate_mapping_local(store_path: &Path, min_confidence: u8, only_text: bool) -> Result<()> {
let rows = load_local_store(store_path)?;
let mut rows = rows
.into_iter()
.filter(|r| if only_text { r.is_text_model } else { true })
.collect::<Vec<_>>();
rows.sort_by(|a, b| a.tokenizer_confidence.cmp(&b.tokenizer_confidence).then(a.model_key.cmp(&b.model_key)));
let mut cnt = 0usize;
for r in rows {
if r.tokenizer_confidence >= min_confidence {
let st = local_strategy_type(&r.strategy_code);
if st != "private_api" && st != "unknown" && st != "usage_only" {
continue;
}
}
println!(
"{} | prov={} | strat={} ({}) | conf={} | hint={:?}",
r.model_key,
r.provider_code,
r.strategy_code,
local_strategy_type(&r.strategy_code),
r.tokenizer_confidence,
r.tokenizer_family_hint
);
cnt += 1;
}
println!("validate rows: {cnt}");
Ok(())
}
#[derive(Debug, Serialize)]
struct VerifyReport {
generated_at_unix: u64,
total: usize,
passed: usize,
failed: usize,
skipped: usize,
items: Vec<VerifyRecord>,
}
async fn verify_models_local(
store_path: &Path,
only_text: bool,
all: bool,
min_confidence: u8,
run_count: bool,
sample_text: &str,
report_json_path: &Path,
report_csv_path: &Path,
max_models: Option<usize>,
offset: Option<usize>,
) -> Result<()> {
let rows = load_local_store(store_path)?;
let mut rows = rows.into_iter().filter(|r| if only_text { r.is_text_model } else { true }).collect::<Vec<_>>();
rows.sort_by(|a, b| a.model_key.cmp(&b.model_key));
let mut candidates = Vec::new();
for r in rows {
let st = local_strategy_type(&r.strategy_code);
if !all && r.tokenizer_confidence >= min_confidence && st != "private_api" && st != "unknown" && st != "usage_only" {
continue;
}
candidates.push(r);
}
let offset = offset.unwrap_or(0);
if offset >= candidates.len() {
println!("verify finished: no models in requested range");
return Ok(());
}
let mut candidates = candidates.into_iter().skip(offset).collect::<Vec<_>>();
if let Some(max_models) = max_models {
candidates.truncate(max_models);
}
let mut items = Vec::with_capacity(candidates.len());
let mut passed = 0usize;
let mut failed = 0usize;
let mut skipped = 0usize;
for row in &candidates {
let rec = verify_single_model(row, run_count, sample_text).await;
if rec.verified {
passed += 1;
} else {
failed += 1;
}
if rec.verify_source == "provider_api_config_only" {
skipped += 1;
}
items.push(rec);
}
let report = VerifyReport {
generated_at_unix: SystemTime::now()
.duration_since(UNIX_EPOCH)
.context("system time")?
.as_secs(),
total: items.len(),
passed,
failed,
skipped,
items: items.clone(),
};
let json = serde_json::to_string_pretty(&report)?;
std::fs::write(report_json_path, json)?;
write_verify_report_csv(report_csv_path, &items)?;
println!("verify done. total: {}, passed: {}, failed: {}, skipped: {}", items.len(), passed, failed, skipped);
for item in items.iter().filter(|i| !i.verified) {
println!(
"{} | {} | {} | {} | {}",
item.model_key, item.provider_code, item.strategy_code, item.verify_source, item.note
);
}
Ok(())
}
fn write_verify_report_csv(path: &Path, items: &[VerifyRecord]) -> Result<()> {
let mut out = String::from("model_key,model_name,provider_code,strategy_code,tokenizer_confidence,strategy_type,verified,verify_source,estimate_tokens,error,note\n");
for item in items {
let row = format!(
"{},{},{},{},{},{},{},{},{},{},{}\n",
csv_escape(&item.model_key),
csv_escape(&item.model_name),
csv_escape(&item.provider_code),
csv_escape(&item.strategy_code),
item.tokenizer_confidence,
csv_escape(&item.strategy_type),
if item.verified { "true" } else { "false" },
csv_escape(&item.verify_source),
item.estimate_tokens.map_or_else(|| "".to_string(), |v| v.to_string()),
csv_escape(&item.error.clone().unwrap_or_default()),
csv_escape(&item.note),
);
out.push_str(&row);
}
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent).context("create csv report dir")?;
}
}
std::fs::write(path, out)?;
Ok(())
}
fn csv_escape(v: &str) -> String {
if v.contains([',', '"', '\n']) {
format!("\"{}\"", v.replace('"', "\"\""))
} else {
v.to_string()
}
}
async fn verify_single_model(row: &LocalModelRow, run_count: bool, sample_text: &str) -> VerifyRecord {
let strategy_type = local_strategy_type(&row.strategy_code).to_string();
let dev_id = row
.developer_id
.or_else(|| row.raw_payload.get("developer_id").and_then(Value::as_i64).and_then(|v| i64::try_from(v).ok()).and_then(|v| i32::try_from(v).ok()));
let mut rec = VerifyRecord {
model_key: row.model_key.clone(),
model_name: row.model_name.clone(),
provider_code: row.provider_code.clone(),
strategy_code: row.strategy_code.clone(),
strategy_type,
tokenizer_confidence: row.tokenizer_confidence,
verified: false,
verify_source: "unverified".to_string(),
estimate_tokens: None,
error: None,
note: String::new(),
};
let maybe_hf = |row: &LocalModelRow, dev_id: Option<i32>| {
let mut out = Vec::<String>::new();
if let Some(hint) = row.tokenizer_family_hint.as_deref() {
if let Some(v) = hint.strip_prefix("hf:") {
out.push(v.to_string());
}
if let Some(v) = hint.strip_prefix("hf_candidates:") {
for p in v.split('|').map(|x| x.trim()).filter(|x| !x.is_empty()) {
out.push(p.to_string());
}
}
}
if is_hf_repo_like(&row.model_key) {
out.push(row.model_key.clone());
}
out.extend(open_source_repo_candidates(&row.model_key, dev_id));
out.sort();
out.dedup();
out
};
match row.strategy_code.as_str() {
"local_tiktoken_openai" => {
if run_count {
match count_with_python("tiktoken", &row.model_key, sample_text).await {
Ok(n) => {
rec.verified = true;
rec.verify_source = "local_count".to_string();
rec.estimate_tokens = Some(n);
rec.note = "local tiktoken count succeeded".to_string();
}
Err(e) => {
rec.error = Some(e.to_string());
rec.note = "local tiktoken count failed".to_string();
}
}
} else {
rec.verified = true;
rec.verify_source = "local_route".to_string();
rec.note = "openai family strategy is configured for local count".to_string();
}
}
"local_hf_transformers" => {
let repos = maybe_hf(row, dev_id);
if run_count {
let mut selected_repo: Option<String> = row
.tokenizer_family_hint
.as_deref()
.and_then(|x| x.strip_prefix("hf:"))
.map(|x| x.to_string());
if selected_repo.is_none() {
selected_repo = repos.into_iter().find(|p| is_hf_repo_like(p));
}
if let Some(repo) = selected_repo.as_deref() {
match count_with_python("hf", repo, sample_text).await {
Ok(n) => {
rec.verified = true;
rec.verify_source = "local_count".to_string();
rec.estimate_tokens = Some(n);
rec.note = format!("local hf count succeeded for {repo}");
}
Err(e) => {
rec.error = Some(e.to_string());
rec.note = format!("local hf count failed for {:?}", repo);
}
}
} else {
rec.verify_source = "route_unresolvable".to_string();
rec.note = "no hf repo hint available".to_string();
}
} else {
let mut found = false;
for repo in repos.iter() {
if has_hf_tokenizer_artifacts(repo).await {
rec.verified = true;
rec.verify_source = "hf_artifacts_ok".to_string();
rec.note = format!("tokenizer artifacts verified at {repo}");
found = true;
break;
}
}
if !found {
rec.verify_source = "hf_artifacts_missing".to_string();
rec.note = "no tokenizer artifacts found on candidate repos".to_string();
}
}
}
"provider_anthropic_count_tokens" => {
if run_count {
match count_with_anthropic(&row.model_key, sample_text, "text").await {
Ok((n, _b_in, _b_out, _d, endpoint)) => {
rec.verified = true;
rec.estimate_tokens = Some(n);
rec.verify_source = endpoint;
rec.note = "anthropic count_tokens request succeeded".to_string();
}
Err(e) => {
rec.error = Some(e.to_string());
rec.note = "anthropic count_tokens request failed".to_string();
}
}
} else {
rec.verify_source = "provider_api_config_only".to_string();
rec.verified = first_env(&["ANTHROPIC_API_KEY", "MODEL_COUNT_API_KEY", "AIHUBMIX_API_KEY"]).is_some();
rec.note = "run_count=false, not executed request".to_string();
}
}
"provider_gemini_count_tokens" => {
if run_count {
match count_with_gemini(&row.model_key, sample_text).await {
Ok((n, _b_in, _b_out, _d)) => {
rec.verified = true;
rec.estimate_tokens = Some(n);
rec.verify_source = "gemini_count_tokens".to_string();
rec.note = "gemini countTokens request succeeded".to_string();
}
Err(e) => {
rec.error = Some(e.to_string());
rec.verify_source = "gemini_count_tokens_failed".to_string();
rec.note = "gemini countTokens request failed".to_string();
}
}
} else {
rec.verify_source = "provider_api_config_only".to_string();
rec.verified = std::env::var("GEMINI_API_KEY").is_ok();
rec.note = "run_count=false, not executed request".to_string();
}
}
"provider_cohere_tokenize" => {
if run_count {
match count_with_cohere(&row.model_key, sample_text).await {
Ok((n, _b_in, _b_out, _d)) => {
rec.verified = true;
rec.estimate_tokens = Some(n);
rec.verify_source = "cohere_tokenize".to_string();
rec.note = "cohere tokenize request succeeded".to_string();
}
Err(e) => {
rec.error = Some(e.to_string());
rec.verify_source = "cohere_tokenize_failed".to_string();
rec.note = "cohere tokenize request failed".to_string();
}
}
} else {
rec.verify_source = "provider_api_config_only".to_string();
rec.verified = std::env::var("COHERE_API_KEY").is_ok();
rec.note = "run_count=false, not executed request".to_string();
}
}
"provider_xai_tokenize_text" => {
if run_count {
match count_with_xai(&row.model_key, sample_text).await {
Ok((n, _, _, _, endpoint)) => {
rec.verified = true;
rec.estimate_tokens = Some(n);
rec.verify_source = endpoint;
rec.note = "xai tokenize request succeeded".to_string();
}
Err(e) => {
rec.error = Some(e.to_string());
rec.verify_source = "xai_tokenize_failed".to_string();
rec.note = "xai tokenize request failed".to_string();
}
}
} else {
rec.verify_source = "provider_api_config_only".to_string();
rec.verified = std::env::var("XAI_API_KEY").is_ok();
rec.note = "run_count=false, not executed request".to_string();
}
}
_ => {
rec.verify_source = "unresolved".to_string();
rec.note = format!("strategy {} requires manual mapping", row.strategy_code);
}
}
if rec.verified && rec.error.is_none() {
if rec.note.is_empty() {
rec.note = "verified".to_string();
}
}
rec
}
fn is_hf_repo_like(s: &str) -> bool {
let p = s.trim();
if p.is_empty() || p.contains(' ') {
return false;
}
let parts: Vec<&str> = p.split('/').collect();
if parts.len() != 2 {
return false;
}
parts[0].chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
&& parts[1].chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
}
fn should_probe_source_route(m: &ApiModel, c: &ClassifyResult, force_reclassify: bool) -> bool {
if force_reclassify {
return true;
}
if c.strategy_code == "local_hf_transformers" {
return is_hf_repo_like(&m.model_key)
|| m
.model_name
.as_deref()
.is_some_and(|name| is_hf_repo_like(name))
|| m
.developer_id
.is_some_and(|d| matches!(d, 25 | 31 | 22));
}
if c.provider_code == "minimax"
|| c.provider_code == "qwen"
|| c.provider_code == "k2"
{
return true;
}
if is_hf_repo_like(&m.model_key) {
return true;
}
if m
.model_name
.as_deref()
.is_some_and(|name| is_hf_repo_like(name))
{
return true;
}
m.developer_id
.is_some_and(|d| matches!(d, 4 | 10 | 13 | 15 | 17 | 18 | 22 | 25 | 31))
}
fn normalize_repo_token_segment(token: &str) -> String {
let lower = token.to_ascii_lowercase();
if lower == "vl" {
return "VL".to_string();
}
if lower.chars().all(|c| c.is_ascii_digit() || c == '.') {
return token.to_string();
}
if lower.chars().all(|c| c.is_ascii_alphabetic()) {
if lower.len() <= 3 {
return lower.to_ascii_uppercase();
}
let mut chars = lower.chars();
let first = chars.next().unwrap_or_default().to_ascii_uppercase();
let rest = chars.as_str();
return format!("{}{}", first, rest);
}
if lower.chars().any(|c| c.is_ascii_digit()) {
let mut out = String::new();
for ch in lower.chars() {
if ch.is_ascii_alphabetic() {
out.push(ch.to_ascii_uppercase());
} else {
out.push(ch);
}
}
return out;
}
lower.to_ascii_lowercase()
}
fn qwen_repo_candidates(model_key: &str) -> Vec<String> {
let mut key = model_key.trim();
if let Some(stripped) = key.strip_prefix("bai-") {
key = stripped;
}
let lower = key.to_ascii_lowercase();
if !lower.starts_with("qwen") {
return Vec::new();
}
let mut out = Vec::new();
let mut parts: Vec<&str> = key.split('-').filter(|p| !p.is_empty()).collect();
if parts.is_empty() {
return Vec::new();
}
let mut rebuilt = String::new();
let first = parts.remove(0);
if first.len() == 4 {
rebuilt.push_str("Qwen");
} else {
rebuilt.push_str("Qwen");
rebuilt.push_str(&first[4..]);
}
for p in parts {
rebuilt.push('-');
rebuilt.push_str(&normalize_repo_token_segment(p));
}
let candidate_1 = format!("Qwen/{}", rebuilt);
let candidate_2 = candidate_1.replace("-Vl-", "-VL-");
if candidate_2 == candidate_1 {
out.push(candidate_1);
} else {
out.push(candidate_1);
out.push(candidate_2);
}
out
}
fn open_source_repo_candidates(model_key: &str, developer_id: Option<i32>) -> Vec<String> {
let mut out = Vec::new();
let key = model_key.trim().to_ascii_lowercase();
if is_hf_repo_like(model_key) {
out.push(model_key.to_string());
}
out.extend(qwen_repo_candidates(model_key));
if key.starts_with("k2") || key.starts_with("kimi") || key.contains("cc-kimi") {
out.push("MoonshotAI/Kimi-K2-Instruct".to_string());
if key.starts_with("kimi-k2-0905") || key == "k2.6-0905" {
out.push("moonshotai/Kimi-K2-Instruct-0905".to_string());
}
}
if key.contains("cc-kimi") {
out.push("MoonshotAI/Kimi-K2-Instruct".to_string());
if key.contains("0905") {
out.push("moonshotai/Kimi-K2-Instruct-0905".to_string());
}
}
if key.contains("minimax") || key.starts_with("cc-minimax") || key.starts_with("coding-minimax") || key.starts_with("mm-minimax") {
out.push("MiniMaxAI/MiniMax-M2".to_string());
out.push("MiniMaxAI/MiniMax-Text-01".to_string());
}
if key.starts_with("doubao") || key.contains("seed-") {
out.push("ByteDance-Seed/Seed-OSS-36B-Instruct".to_string());
}
if key.starts_with("jina-") {
out.push("jinaai/jina-embeddings-v2-base-zh".to_string());
}
if key.starts_with("mistral-large-3") || developer_id == Some(10) {
out.push("mistralai/Mistral-7B-Instruct-v0.2".to_string());
}
if key.starts_with("nvidia-nemotron") || developer_id == Some(17) {
out.push("nvidia/Llama-3.1-Nemotron-70B-Instruct".to_string());
}
if key.starts_with("kat-dev") {
out.push("Qwen/KAT".to_string());
}
if key.starts_with("mimo-v2-") {
out.push("XiaomiMiMo/MiMo-V2-Flash".to_string());
out.push("XiaomiMiMo/MiMo-V2".to_string());
}
if key.starts_with("deepseek") || key.contains("deepseek") {
out.push("deepseek-ai/DeepSeek-V3".to_string());
out.push("deepseek-ai/DeepSeek-R1".to_string());
out.push("deepseek-ai/DeepSeek-V3.2".to_string());
out.push("deepseek-ai/DeepSeek-V3.2-Speciale".to_string());
out.push("deepseek-ai/DeepSeek-V3.2-Exp".to_string());
out.push("deepseek-ai/DeepSeek-V3.2-Think".to_string());
out.push("deepseek-ai/DeepSeek-R1-Distill-Llama-70B".to_string());
out.push("deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B".to_string());
out.push("deepseek-ai/DeepSeek-OCR".to_string());
}
if key.starts_with("glm")
|| key.starts_with("crush-glm")
|| key.starts_with("sophnet-glm")
|| key.starts_with("zai-glm")
|| key.starts_with("aihubmix-p")
{
out.push("THUDM/glm-4-9b".to_string());
out.push("THUDM/glm-4v-9b".to_string());
out.push("zai-org/GLM-5".to_string());
out.push("zai-org/GLM-5.1".to_string());
out.push("zai-org/GLM-4.5".to_string());
out.push("zai-org/GLM-4.6".to_string());
out.push("THUDM/GLM-4-32B".to_string());
out.push("THUDM/GLM-4-9B".to_string());
out.push("THUDM/glm-4v".to_string());
out.push("ZhipuAI/glm-4-9b".to_string());
out.push("ZhipuAI/glm-4.5".to_string());
out.push("ZhipuAI/GLM-4.5".to_string());
}
if key.starts_with("phi")
|| key.starts_with("aihub-p")
|| key.starts_with("ahm-p")
{
out.push("microsoft/Phi-4".to_string());
out.push("microsoft/Phi-4-mini-instruct".to_string());
out.push("microsoft/Phi-4-multimodal-instruct".to_string());
out.push("microsoft/Phi-4-mini-reasoning".to_string());
out.push("microsoft/Phi-3.5-vision-instruct".to_string());
}
if key.starts_with("llama-4") || key.starts_with("llama2") {
out.push("meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8".to_string());
out.push("meta-llama/Llama-4-Scout-17B-16E-Instruct".to_string());
out.push("meta-llama/Llama-2-70b-chat-hf".to_string());
out.push("meta-llama/Llama-2-70b-hf".to_string());
}
if key.starts_with("aihubmix-p") {
out.push("microsoft/Phi-4".to_string());
out.push("microsoft/Phi-4-mini-instruct".to_string());
out.push("microsoft/Phi-4-mini-reasoning".to_string());
out.push("microsoft/Phi-4-multimodal-instruct".to_string());
}
if key.starts_with("qwen3-coder") {
out.push("Qwen/Qwen3-Coder-Next".to_string());
}
if key.starts_with("qwen3.5") {
out.push("Qwen/Qwen3.5-9B".to_string());
out.push("Qwen/Qwen3.5-27B".to_string());
}
if key.starts_with("qwen3-vl") {
out.push("Qwen/Qwen3-VL-235B-A22B".to_string());
}
if key.starts_with("qwen3-max") {
out.push("Qwen/Qwen3-235B-A22B".to_string());
}
if key.starts_with("qwen3.6") {
out.push("Qwen/Qwen3.6-30B-A3B".to_string());
}
if key.starts_with("kat-dev") {
out.push("Qwen/QwQ-32B".to_string());
}
if key.starts_with("k2.6") || key.starts_with("kimi-k2") || key.contains("cc-kimi") {
out.push("moonshotai/Kimi-K2.5".to_string());
out.push("moonshotai/Kimi-K2-Instruct".to_string());
out.push("moonshotai/Kimi-K2-Thinking".to_string());
}
if key.starts_with("minimax") || key.starts_with("cc-minimax") || key.starts_with("coding-minimax") || key.starts_with("mm-minimax") {
out.push("MiniMaxAI/MiniMax-M2.5".to_string());
out.push("MiniMaxAI/MiniMax-M2.7".to_string());
out.push("MiniMaxAI/MiniMax-M2".to_string());
}
if key.starts_with("step-3.5") {
out.push("stepfun-ai/Step-3.5-Flash".to_string());
}
if developer_id == Some(25) {
if key.contains("cc-ernie-4.5-300b-a47b") || (key.contains("ernie") && key.contains("300b")) {
out.push("baidu/ERNIE-4.5-300B-A47B-PT".to_string());
out.push("baidu/ERNIE-4.5-300B-A47B-Base-PT".to_string());
}
if key.contains("4.5-0.3b") {
out.push("baidu/ERNIE-4.5-0.3B-Paddle".to_string());
out.push("baidu/ERNIE-4.5-0.3B-PT".to_string());
}
if key.contains("4.5-vl") || key.contains("turbo-vl") {
out.push("baidu/ERNIE-4.5-VL-28B-A3B-PT".to_string());
out.push("baidu/ERNIE-4.5-VL-424B-A47B-PT".to_string());
}
if key.contains("4.5") && !key.contains("0.3b") && !key.contains("vl") && !key.contains("turbo") {
out.push("baidu/ERNIE-4.5-21B-A3B-PT".to_string());
out.push("baidu/ERNIE-4.5-21B-A3B-Base-PT".to_string());
}
if key.contains("turbo") && !key.contains("vl") {
out.push("baidu/ERNIE-4.5-300B-A47B-PT".to_string());
out.push("baidu/ERNIE-4.5-300B-A47B-Base-PT".to_string());
}
}
out.sort();
out.dedup();
out
}
async fn has_hf_tokenizer_artifacts(repo: &str) -> bool {
let repo = repo.replace(':', "/");
let client = match reqwest::Client::builder().timeout(Duration::from_secs(10)).build() {
Ok(c) => c,
Err(_) => reqwest::Client::new(),
};
let model_info_url = format!("https://huggingface.co/api/models/{}", repo);
match client.get(model_info_url).send().await {
Ok(resp) => {
if resp.status().is_success() {
return true;
}
}
Err(_) => {}
}
let file_candidates = [
"tokenizer.json",
"tokenizer_config.json",
"tokenizer.model",
"spiece.model",
"sentencepiece.bpe.model",
"vocab.txt",
];
for f in file_candidates {
let url = format!("https://huggingface.co/{}/resolve/main/{}", repo, f);
if let Ok(resp) = client.get(url).send().await {
if resp.status().is_success() {
return true;
}
}
}
false
}
async fn probe_open_source_hint(
model_key: &str,
developer_id: Option<i32>,
) -> Option<(String, String, u8, String, Option<String>)> {
let candidates = open_source_repo_candidates(model_key, developer_id);
for repo in candidates.iter() {
if has_hf_tokenizer_artifacts(repo).await {
return Some((
"open_source".to_string(),
"local_hf_transformers".to_string(),
92,
format!("Verified HF tokenizer materials for {repo}."),
Some(format!("hf:{repo}")),
));
}
}
if is_hf_repo_like(model_key) {
return Some((
"open_source".to_string(),
"provider_private_api".to_string(),
56,
format!("HF-like model id {model_key} but tokenizer artifacts not resolvable."),
Some(format!("hf:{model_key}")),
));
}
if !candidates.is_empty() {
return Some((
"open_source".to_string(),
"provider_private_api".to_string(),
50,
format!("Likely open model family but tokenizer repo not confirmed in candidates: {:?}", candidates),
Some(format!("hf_candidates:{}", candidates.join("|"))),
));
}
if matches!(developer_id, Some(4 | 10 | 13 | 15 | 16 | 17 | 18 | 22 | 25 | 31)) {
return Some((
"open_source".to_string(),
"provider_private_api".to_string(),
50,
"Likely open model family by developer marker; requires local model resolution for exact tokenizer.".to_string(),
Some("hf_auto".to_string()),
));
}
None
}
fn parse_models_payload(raw: &str) -> Result<Vec<ApiModel>> {
let v: Value = serde_json::from_str(raw)?;
let arr = if let Some(a) = v.get("data").and_then(Value::as_array) {
a
} else if let Some(a) = v.as_array() {
a
} else if let Some(a) = v.get("result").and_then(Value::as_array) {
a
} else {
return Err(anyhow!("invalid model payload: missing array"));
};
let mut out = Vec::new();
for item in arr {
out.push(serde_json::from_value(item.clone())?);
}
Ok(out)
}
fn is_text_model(m: &ApiModel) -> bool {
let types = m.types.clone().unwrap_or_default().to_ascii_lowercase();
let modalities = normalize_csv(&m.input_modalities.clone().unwrap_or_default());
let contains_non_text = types.contains("image") || types.contains("audio") || types.contains("video");
if !(types.contains("llm") || types.contains("chat") || types.contains("text") || modalities.contains("text")) {
return false;
}
if contains_non_text && !modalities.contains("text") && !types.contains("multimodal") {
return false;
}
if modalities.is_empty() {
return !types.contains("embedding");
}
modalities.split(',').any(|x| x == "text" || x == "text_only" || x == "multi_modal" || x == "multimodal")
}
fn normalize_csv(v: &str) -> String {
v.replace(',', ",")
.split(',')
.map(|s| s.trim().to_ascii_lowercase())
.filter(|s| !s.is_empty())
.collect::<Vec<_>>()
.join(",")
}
fn classify_model(m: &ApiModel) -> ClassifyResult {
let mk = m.model_key.to_ascii_lowercase();
let name = m.model_name.clone().unwrap_or_default().to_ascii_lowercase();
let desc = m.desc.clone().unwrap_or_default().to_ascii_lowercase();
let dev = m.developer_id.unwrap_or(0);
let has_open_desc = desc.contains("open-source")
|| desc.contains("open version")
|| desc.contains("publicly available")
|| desc.contains("free and open");
let has = |text: &str, pats: &[&str]| -> bool {
let t = text.to_ascii_lowercase();
pats.iter().any(|p| t.contains(p))
};
if has(&mk, &["gpt-5", "gpt-4o", "gpt-4", "gpt-3.5", "o1", "o3", "gpt-oss", "text-ada", "text-davinci"]) {
return ClassifyResult {
provider_code: "openai".to_string(),
strategy_code: "local_tiktoken_openai".to_string(),
fallback_strategy_code: "usage_only_private".to_string(),
confidence: if has(&mk, &["gpt-5", "gpt-4o", "o1", "o3"]) {
95
} else {
88
},
family_hint: Some(if has(&mk, &["o1", "gpt-5", "o3", "gpt-4o"]) {
"o200k_base".to_string()
} else {
"cl100k_base".to_string()
}),
notes: "openai-family".to_string(),
};
}
if has(&mk, &["o4-", "codex-mini"]) {
return ClassifyResult {
provider_code: "openai".to_string(),
strategy_code: "local_tiktoken_openai".to_string(),
fallback_strategy_code: "usage_only_private".to_string(),
confidence: 92,
family_hint: Some("o200k_base".to_string()),
notes: "openai-family derivative".to_string(),
};
}
if has(&mk, &["claude", "opus", "sonnet", "haiku"]) || dev == 2 {
return ClassifyResult {
provider_code: "anthropic".to_string(),
strategy_code: "provider_anthropic_count_tokens".to_string(),
fallback_strategy_code: "usage_only_private".to_string(),
confidence: 95,
family_hint: Some("anthropic-messages-format".to_string()),
notes: "anthropic family".to_string(),
};
}
if dev == 10 || has(&mk, &["mistral-large-3"]) {
return ClassifyResult {
provider_code: "open_source".to_string(),
strategy_code: "local_hf_transformers".to_string(),
fallback_strategy_code: "provider_private_api".to_string(),
confidence: if has(&name, &["permissively"]) { 90 } else { 84 },
family_hint: Some("hf_auto".to_string()),
notes: "Mistral Large 3 explicitly open-weight; HF tokenizer first-pass".to_string(),
};
}
if dev == 17 || has(&mk, &["nemotron"]) {
return ClassifyResult {
provider_code: "open_source".to_string(),
strategy_code: "local_hf_transformers".to_string(),
fallback_strategy_code: "provider_private_api".to_string(),
confidence: 90,
family_hint: Some("hf_auto".to_string()),
notes: "Nemotron open-source model, local HF tokenizer first-pass".to_string(),
};
}
if dev == 13 || has(&mk, &["kat-dev"]) {
return ClassifyResult {
provider_code: "qwen".to_string(),
strategy_code: "local_hf_transformers".to_string(),
fallback_strategy_code: "provider_private_api".to_string(),
confidence: 92,
family_hint: Some("hf_auto".to_string()),
notes: "kat-dev is explicitly open-source in API description".to_string(),
};
}
if mk.starts_with("jina-") || dev == 22 {
return ClassifyResult {
provider_code: "jina".to_string(),
strategy_code: "provider_private_api".to_string(),
fallback_strategy_code: "usage_only_private".to_string(),
confidence: 40,
family_hint: None,
notes: "Jina Deepsearch stack not confirmed open tokenizer".to_string(),
};
}
if dev == 15 || mk.starts_with("k2.6-") || mk.starts_with("cc-kimi-k2") || mk.contains("cc-kimi") {
return ClassifyResult {
provider_code: "k2".to_string(),
strategy_code: if has_open_desc && has(&mk, &["free", "open"]) {
"local_hf_transformers".to_string()
} else {
"provider_private_api".to_string()
},
fallback_strategy_code: "provider_private_api".to_string(),
confidence: if has_open_desc { 84 } else { 58 },
family_hint: Some("kimi-k2".to_string()),
notes: if has_open_desc {
"Kimi free/open version hinted; try local tokenizer if repo known."
} else {
"Kimi route. cc-kimi variants are mapped to moonshot repos in probe."
}
.to_string(),
};
}
if dev == 16 || mk.starts_with("step-3.5-") {
return ClassifyResult {
provider_code: "stepfun".to_string(),
strategy_code: "local_hf_transformers".to_string(),
fallback_strategy_code: "usage_only_private".to_string(),
confidence: 86,
family_hint: Some("hf:stepfun-ai/Step-3.5-Flash".to_string()),
notes: "Stepfun 3.5 variants map to HF tokenizer checkpoint".to_string(),
};
}
if dev == 18 || mk.contains("minimax") {
let is_open = has_open_desc || mk.contains("-free");
return ClassifyResult {
provider_code: "minimax".to_string(),
strategy_code: if is_open {
"local_hf_transformers".to_string()
} else {
"provider_private_api".to_string()
},
fallback_strategy_code: "provider_private_api".to_string(),
confidence: if is_open { 84 } else { 35 },
family_hint: Some("minimax".to_string()),
notes: if is_open {
"MiniMax/Open-version marker found; local HF attempt".to_string()
} else {
"Default MiniMax route; use provider/private count or usage".to_string()
},
};
}
if dev == 4 || has(&mk, &["doubao", "seed-"]) {
return ClassifyResult {
provider_code: "bytedance".to_string(),
strategy_code: "provider_private_api".to_string(),
fallback_strategy_code: "usage_only_private".to_string(),
confidence: 35,
family_hint: Some("byte-dance-seed".to_string()),
notes: "ByteDance/Doubao family. Private route unless tokenizer repo discovered.".to_string(),
};
}
if dev == 31 || mk.starts_with("mimo-v2-") {
return ClassifyResult {
provider_code: "xiaomi".to_string(),
strategy_code: "local_hf_transformers".to_string(),
fallback_strategy_code: "provider_private_api".to_string(),
confidence: 90,
family_hint: Some("hf:XiaomiMiMo/MiMo-V2-Flash".to_string()),
notes: "MiMo V2 variants are confirmed available via XiaomiMiMo HF tokenizer family.".to_string(),
};
}
if dev == 25 && mk.starts_with("ernie-5.0") {
return ClassifyResult {
provider_code: "open_source".to_string(),
strategy_code: "provider_private_api".to_string(),
fallback_strategy_code: "provider_private_api".to_string(),
confidence: 50,
family_hint: Some("hf:baidu/ERNIE-4.5-300B-A47B-PT".to_string()),
notes: "ERNIE-5.0 is not directly discoverable by model id; use Baidu 4.5/300B as best-effort tokenizer anchor.".to_string(),
};
}
if dev == 25 && mk == "ernie-x1-turbo" {
return ClassifyResult {
provider_code: "open_source".to_string(),
strategy_code: "provider_private_api".to_string(),
fallback_strategy_code: "provider_private_api".to_string(),
confidence: 50,
family_hint: Some("hf:baidu/ERNIE-4.5-21B-A3B-PT".to_string()),
notes: "ERNIE X1 Turbo has open-family characteristics, but tokenizer checkpoint id not published yet.".to_string(),
};
}
if dev == 25 && mk == "ernie-x1.1-preview" {
return ClassifyResult {
provider_code: "open_source".to_string(),
strategy_code: "provider_private_api".to_string(),
fallback_strategy_code: "provider_private_api".to_string(),
confidence: 50,
family_hint: Some("hf:baidu/ERNIE-4.5-21B-A3B-PT".to_string()),
notes: "ERNIE-X1.1 preview has no confirmed HF repo id in API corpus.".to_string(),
};
}
if has(&mk, &["gemini", "gemma"]) || name.contains("gemini") || dev == 8 {
return ClassifyResult {
provider_code: "google".to_string(),
strategy_code: "provider_gemini_count_tokens".to_string(),
fallback_strategy_code: "usage_only_private".to_string(),
confidence: 90,
family_hint: Some("gemini".to_string()),
notes: "gemini/gemma family".to_string(),
};
}
if has(&mk, &["command"]) || dev == 6 {
return ClassifyResult {
provider_code: "cohere".to_string(),
strategy_code: "provider_cohere_tokenize".to_string(),
fallback_strategy_code: "usage_only_private".to_string(),
confidence: 90,
family_hint: Some("cohere".to_string()),
notes: "cohere family".to_string(),
};
}
if has(&mk, &["qwen", "deepseek", "llama", "phi", "kimi", "glm", "ernie", "qwen2", "qwen3"])
|| mk.contains('/')
|| has(&name, &["qwen", "llama", "phi", "deepseek"])
{
return ClassifyResult {
provider_code: "open_source".to_string(),
strategy_code: "local_hf_transformers".to_string(),
fallback_strategy_code: "provider_private_api".to_string(),
confidence: if mk.contains('/') { 85 } else { 72 },
family_hint: Some("hf_auto".to_string()),
notes: "open-source style model; try HF tokenizer".to_string(),
};
}
if dev == 9 {
return ClassifyResult {
provider_code: "xai".to_string(),
strategy_code: "provider_xai_tokenize_text".to_string(),
fallback_strategy_code: "usage_only_private".to_string(),
confidence: 88,
family_hint: None,
notes: "xAI exposes /v1/tokenize-text for token counting.".to_string(),
};
}
ClassifyResult {
provider_code: "unknown".to_string(),
strategy_code: "unknown_no_map".to_string(),
fallback_strategy_code: "usage_only_private".to_string(),
confidence: 0,
family_hint: None,
notes: format!("unmatched model. dev_id={dev}"),
}
}
async fn collect_count_input(
text: Option<String>,
text_file: Option<PathBuf>,
messages_file: Option<PathBuf>,
) -> Result<CountInput> {
if text.is_none() && text_file.is_none() && messages_file.is_none() {
return Err(anyhow!("must provide one of --text, --text-file, or --messages-file"));
}
if messages_file.is_some() && (text.is_some() || text_file.is_some()) {
return Err(anyhow!("--messages-file cannot be mixed with text inputs"));
}
if let Some(path) = messages_file {
let raw = fs::read_to_string(path).await?;
let v: Value = serde_json::from_str(&raw).context("invalid messages json")?;
let plain = messages_to_plain_text(&v)?;
let preview = plain.chars().take(220).collect();
return Ok(CountInput {
kind: "messages_json".to_string(),
text: plain,
preview,
});
}
let content = if let Some(t) = text {
t
} else if let Some(path) = text_file {
fs::read_to_string(path).await?
} else {
String::new()
};
let preview = content.chars().take(220).collect();
Ok(CountInput {
kind: "text".to_string(),
text: content,
preview,
})
}
async fn count_with_python(mode: &str, model_or_repo: &str, text: &str) -> Result<i64> {
let python = detect_python();
let script = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("scripts")
.join("count_tokens_py.py");
let out = Command::new(python)
.arg(script)
.arg("--mode")
.arg(mode)
.arg("--model")
.arg(model_or_repo)
.arg("--text")
.arg(text)
.output()
.context("failed to run python counter helper")?;
if !out.status.success() {
let err = if !out.stderr.is_empty() {
String::from_utf8_lossy(&out.stderr).to_string()
} else {
String::from_utf8_lossy(&out.stdout).to_string()
};
return Err(anyhow!("python counter failed: {err}"));
}
let mut combined = String::from_utf8_lossy(&out.stdout).to_string();
if !out.stderr.is_empty() {
combined.push('\n');
combined.push_str(&String::from_utf8_lossy(&out.stderr));
}
let mut parsed = None;
for line in combined.lines().rev() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if trimmed.starts_with('{') && trimmed.ends_with('}') {
if let Ok(v) = serde_json::from_str::<Value>(trimmed) {
parsed = Some(v);
break;
}
}
}
let v = parsed.ok_or_else(|| anyhow!("python output invalid json: {combined}"))?;
v.get("tokens")
.and_then(Value::as_i64)
.ok_or_else(|| anyhow!("python output missing tokens field"))
}
fn detect_python() -> String {
if let Some(p) = first_env(&["TOKENIZER_PYTHON", "AIHUBMIX_PYTHON"]) {
return p;
}
let manifest = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let venv_python = if cfg!(windows) {
manifest.join(".venv\\Scripts\\python.exe")
} else {
manifest.join(".venv/bin/python")
};
if venv_python.exists() {
return venv_python.to_string_lossy().to_string();
}
if cfg!(windows) { "python".to_string() } else { "python3".to_string() }
}
fn first_env(keys: &[&str]) -> Option<String> {
for key in keys {
if let Ok(v) = std::env::var(key) {
let v = v.trim();
if !v.is_empty() {
return Some(v.to_string());
}
}
}
None
}
async fn count_with_anthropic(model: &str, text: &str, kind: &str) -> Result<(i64, i64, i64, Value, String)> {
let key = first_env(&["ANTHROPIC_API_KEY", "MODEL_COUNT_API_KEY", "AIHUBMIX_API_KEY"])
.context("need ANTHROPIC_API_KEY / MODEL_COUNT_API_KEY / AIHUBMIX_API_KEY")?;
let endpoints: Vec<String> = match first_env(&["ANTHROPIC_COUNT_ENDPOINTS", "ANTHROPIC_COUNT_URLS"]) {
Some(v) => {
let parsed = v
.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.map(ToString::to_string)
.collect::<Vec<String>>();
if parsed.is_empty() {
vec!["https://api.anthropic.com/v1/messages/count_tokens".to_string()]
} else {
parsed
}
}
None => vec!["https://api.anthropic.com/v1/messages/count_tokens".to_string()],
};
let messages = if kind == "messages_json" {
serde_json::from_str(text).unwrap_or_else(|_| {
json!([{"role": "user", "content": [{"type": "text", "text": text}]}])
})
} else {
json!([{"role": "user", "content": [{"type": "text", "text": text}]}])
};
let body = json!({
"model": model,
"messages": messages,
"max_tokens": 1
});
let client = reqwest::Client::new();
let mut last_error = String::new();
for base in endpoints {
let resp = client
.post(base.as_str())
.header("x-api-key", &key)
.header("anthropic-version", "2023-06-01")
.json(&body)
.send()
.await;
let parsed = match resp {
Ok(r) => {
let status = r.status();
match r.text().await {
Ok(raw) => {
if !status.is_success() {
last_error = format!("endpoint {base} status {status} body {raw}");
continue;
}
match serde_json::from_str::<Value>(&raw) {
Ok(v) => v,
Err(e) => {
last_error = format!("endpoint {base} response decode failed: {e}");
continue;
}
}
}
Err(e) => {
last_error = format!("endpoint {base} body read failed: {e}");
continue;
}
}
}
Err(e) => {
last_error = format!("endpoint {base} request failed: {e}");
continue;
}
};
let input_tokens = extract_first_i64(
&parsed,
&[
"input_tokens",
"usage.input_tokens",
"usage.inputTokens",
"input_tokens_count",
],
)
.or_else(|| {
parsed
.get("message")
.and_then(|m| m.get("usage"))
.and_then(|u| u.get("input_tokens"))
.and_then(Value::as_i64)
})
.unwrap_or(0);
return Ok((input_tokens, input_tokens, 0, parsed, base));
}
Err(anyhow!(
"all anthropic count endpoints failed: {last_error}"
))
}
async fn count_with_xai(model: &str, text: &str) -> Result<(i64, i64, i64, Value, String)> {
let key = std::env::var("XAI_API_KEY").context("need XAI_API_KEY")?;
let endpoints: Vec<String> = match std::env::var("XAI_TOKENIZE_URL") {
Ok(v) => vec![
v,
"https://api.x.ai/v1/tokenize-text".to_string(),
],
Err(_) => vec!["https://api.x.ai/v1/tokenize-text".to_string()],
};
let payloads = vec![
json!({ "model": model, "text": text }),
json!({ "model": model, "texts": [text] }),
json!({ "model": model, "input": text }),
];
let client = reqwest::Client::new();
let mut last_error = String::new();
for endpoint in endpoints {
for payload in &payloads {
let resp = client
.post(&endpoint)
.bearer_auth(&key)
.json(payload)
.send()
.await;
let raw = match resp {
Ok(r) => {
let status = r.status();
match r.text().await {
Ok(text) => {
if !status.is_success() {
last_error = format!("endpoint {endpoint} status {status} body {text}");
continue;
}
text
}
Err(e) => {
last_error = format!("endpoint {endpoint} body read failed: {e}");
continue;
}
}
}
Err(e) => {
last_error = format!("endpoint {endpoint} request failed: {e}");
continue;
}
};
let v = serde_json::from_str::<Value>(&raw).context("xai tokenize parse failed")?;
if let Some(n) = extract_first_array_len(&v, &["token_ids", "data.tokens", "tokens"]) {
return Ok((n, 0, 0, v, endpoint));
}
if let Some(cnt) = v.get("count").and_then(Value::as_i64) {
return Ok((cnt, 0, 0, v, endpoint));
}
if let Some(token_count) = v.get("token_count").and_then(Value::as_i64) {
return Ok((token_count, 0, 0, v, endpoint));
}
}
}
Err(anyhow!("xai tokenize failed all payloads: {last_error}"))
}
async fn count_with_gemini(model: &str, text: &str) -> Result<(i64, i64, i64, Value)> {
let key = std::env::var("GEMINI_API_KEY").context("need GEMINI_API_KEY")?;
let endpoint = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{model}:countTokens?key={key}"
);
let body = json!({
"contents": [{
"role": "user",
"parts": [{"text": text}]
}]
});
let resp: Value = reqwest::Client::new()
.post(&endpoint)
.json(&body)
.send()
.await?
.error_for_status()?
.json()
.await?;
let total = resp
.get("totalTokens")
.and_then(Value::as_i64)
.or_else(|| resp.get("total_tokens").and_then(Value::as_i64))
.unwrap_or(0);
Ok((total, 0, 0, resp))
}
async fn count_with_cohere(model: &str, text: &str) -> Result<(i64, i64, i64, Value)> {
let key = std::env::var("COHERE_API_KEY").context("need COHERE_API_KEY")?;
let endpoint = std::env::var("COHERE_TOKENIZE_URL")
.unwrap_or_else(|_| "https://api.cohere.ai/v1/tokenize".to_string());
let body = json!({
"model": model,
"texts": [text]
});
let resp: Value = reqwest::Client::new()
.post(&endpoint)
.bearer_auth(key)
.json(&body)
.send()
.await?
.error_for_status()?
.json()
.await?;
let n = resp
.get("results")
.and_then(Value::as_array)
.and_then(|v| v.first())
.and_then(|r| r.get("tokens"))
.and_then(Value::as_array)
.map(|arr| arr.len() as i64)
.or_else(|| resp.get("tokens").and_then(Value::as_array).map(|arr| arr.len() as i64))
.unwrap_or(0);
Ok((n, 0, 0, resp))
}
fn extract_first_i64(v: &Value, paths: &[&str]) -> Option<i64> {
for p in paths {
let mut cur = v;
let mut ok = true;
for seg in p.split('.') {
if seg.is_empty() {
continue;
}
if let Some(n) = cur.get(seg) {
cur = n;
} else {
ok = false;
break;
}
}
if ok {
if let Some(i) = cur.as_i64() {
return Some(i);
}
if let Some(s) = cur.as_str() {
if let Ok(n) = s.parse::<i64>() {
return Some(n);
}
}
}
}
None
}
fn extract_first_array_len(v: &Value, paths: &[&str]) -> Option<i64> {
for p in paths {
let mut cur = v;
let mut ok = true;
for seg in p.split('.') {
if seg.is_empty() {
continue;
}
if let Some(n) = cur.get(seg) {
cur = n;
} else {
ok = false;
break;
}
}
if ok {
if let Some(arr) = cur.as_array() {
return Some(arr.len() as i64);
}
}
}
None
}
fn messages_to_plain_text(v: &Value) -> Result<String> {
let mut out = String::new();
if let Some(arr) = v.as_array() {
for item in arr {
let role = item.get("role").and_then(Value::as_str).unwrap_or("user");
out.push_str(role);
out.push(':');
let content = item.get("content");
if let Some(s) = content.and_then(Value::as_str) {
out.push_str(s);
} else if let Some(parts) = content.and_then(Value::as_array) {
for p in parts {
if let Some(t) = p.get("text").and_then(Value::as_str) {
out.push_str(t);
}
if let Some(t) = p.get("content").and_then(Value::as_str) {
out.push_str(t);
}
}
}
out.push('\n');
}
return Ok(out);
}
if let Some(obj) = v.as_object() {
for (_k, vv) in obj {
if let Some(s) = vv.as_str() {
out.push_str(s);
out.push('\n');
}
}
return Ok(out);
}
Ok(serde_json::to_string(v)?)
}