use async_trait::async_trait;
use clap::{CommandFactory, Parser};
use tiktoken_rs::{cl100k_base, o200k_base, p50k_base};
use crate::interpreter::{ExecResult, OutputData, OutputNode};
use crate::tools::{schema_from_clap, ExecContext, ToolCtx, GlobalFlags, Tool, ToolArgs, ToolSchema};
pub struct Tokens;
#[derive(Parser, Debug)]
#[command(name = "tokens", about = "Count BPE tokens using tiktoken tokenization")]
struct TokensArgs {
#[arg(short = 'e', long = "encoding")]
encoding: Option<String>,
#[arg(short = 'v', long = "verbose")]
verbose: bool,
#[command(flatten)]
global: GlobalFlags,
text: Vec<String>,
}
#[async_trait]
impl Tool for Tokens {
fn name(&self) -> &str {
"tokens"
}
fn schema(&self) -> ToolSchema {
schema_from_clap(
&TokensArgs::command(),
"tokens",
"Count BPE tokens using tiktoken tokenization",
[
("Count tokens (default: cl100k)", "tokens \"Hello, Kaijutsu! 会術\""),
("From stdin", "cat file.txt | tokens"),
("Verbose (show token IDs)", "tokens -v \"Hello\""),
("Different encoding", "tokens -e o200k \"Hello\""),
("JSON output", "tokens --json \"Hello world\""),
],
)
}
async fn execute(&self, args: ToolArgs, ctx: &mut dyn ToolCtx) -> ExecResult {
let Some(ctx) = ctx.as_any_mut().downcast_mut::<ExecContext>() else {
return ExecResult::failure(1, "internal error: kernel builtin requires ExecContext");
};
let parsed = match TokensArgs::try_parse_from(
std::iter::once("tokens".to_string()).chain(args.to_argv()),
) {
Ok(p) => p,
Err(e) => return ExecResult::failure(2, format!("tokens: {e}")),
};
parsed.global.apply(ctx);
let text = match args.get_string("text", 0) {
Some(t) => t,
None => match ctx.read_stdin_to_string().await {
Some(s) => s,
None => return ExecResult::failure(1, "tokens: no input provided"),
},
};
let encoding = parsed.encoding.clone()
.or_else(|| args.get_string("encoding", usize::MAX))
.unwrap_or_else(|| "cl100k".into());
let bpe = match encoding.as_str() {
"cl100k" | "cl100k_base" => match cl100k_base() {
Ok(b) => b,
Err(e) => return ExecResult::failure(1, format!("tokens: failed to load cl100k encoding: {}", e)),
},
"o200k" | "o200k_base" => match o200k_base() {
Ok(b) => b,
Err(e) => return ExecResult::failure(1, format!("tokens: failed to load o200k encoding: {}", e)),
},
"p50k" | "p50k_base" => match p50k_base() {
Ok(b) => b,
Err(e) => return ExecResult::failure(1, format!("tokens: failed to load p50k encoding: {}", e)),
},
_ => {
return ExecResult::failure(
1,
format!("tokens: unknown encoding '{}' (use cl100k, o200k, or p50k)", encoding),
)
}
};
let token_ids = bpe.encode_with_special_tokens(&text);
let count = token_ids.len();
let verbose = parsed.verbose;
if verbose {
let mut lines = format!("count: {}\nids: [", count);
for (i, id) in token_ids.iter().enumerate() {
if i > 0 { lines.push_str(", "); }
lines.push_str(&id.to_string());
}
lines.push(']');
return ExecResult::with_output(OutputData::text(lines));
}
let nodes: Vec<OutputNode> = token_ids
.iter()
.enumerate()
.map(|(i, id)| OutputNode::new(i.to_string()).with_cells(vec![id.to_string()]))
.collect();
let table = OutputData::table(
vec!["INDEX".to_string(), "ID".to_string()],
nodes,
);
ExecResult::with_output_and_text(table, count.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::Value;
use crate::vfs::{MemoryFs, VfsRouter};
use std::sync::Arc;
async fn make_ctx() -> ExecContext {
let mut vfs = VfsRouter::new();
vfs.mount("/", MemoryFs::new());
ExecContext::new(Arc::new(vfs))
}
#[tokio::test]
async fn test_tokens_basic() {
let mut ctx = make_ctx().await;
let mut args = ToolArgs::new();
args.positional.push(Value::String("Hello world".into()));
let result = Tokens.execute(args, &mut ctx).await;
assert!(result.ok());
assert_eq!(result.text_out().trim(), "2");
}
#[tokio::test]
async fn test_tokens_stdin() {
let mut ctx = make_ctx().await;
ctx.set_stdin("Hello world".to_string());
let args = ToolArgs::new();
let result = Tokens.execute(args, &mut ctx).await;
assert!(result.ok());
assert_eq!(result.text_out().trim(), "2");
}
#[tokio::test]
async fn test_tokens_no_input() {
let mut ctx = make_ctx().await;
let args = ToolArgs::new();
let result = Tokens.execute(args, &mut ctx).await;
assert!(!result.ok());
assert!(result.err.contains("no input"));
}
#[tokio::test]
async fn test_tokens_verbose() {
let mut ctx = make_ctx().await;
let mut args = ToolArgs::new();
args.positional.push(Value::String("Hello".into()));
args.flags.insert("v".to_string());
let result = Tokens.execute(args, &mut ctx).await;
assert!(result.ok());
assert!(result.text_out().contains("count:"));
assert!(result.text_out().contains("ids:"));
}
#[tokio::test]
async fn test_tokens_json_via_global_flag() {
use crate::interpreter::{apply_output_format, OutputFormat};
let mut ctx = make_ctx().await;
let mut args = ToolArgs::new();
args.positional.push(Value::String("Hello".into()));
let result = Tokens.execute(args, &mut ctx).await;
assert!(result.ok());
assert_eq!(result.text_out().trim(), "1");
assert!(result.has_output());
let result = apply_output_format(result, OutputFormat::Json);
let json: serde_json::Value = serde_json::from_str(&result.text_out()).expect("valid JSON");
let arr = json.as_array().expect("should be array");
assert_eq!(arr.len(), 1); assert!(arr[0].get("INDEX").is_some());
assert!(arr[0].get("ID").is_some());
}
#[tokio::test]
async fn test_tokens_unicode() {
let mut ctx = make_ctx().await;
let mut args = ToolArgs::new();
args.positional.push(Value::String("会術".into()));
let result = Tokens.execute(args, &mut ctx).await;
assert!(result.ok());
let count: usize = result.text_out().trim().parse().expect("valid number");
assert!(count >= 1);
}
#[tokio::test]
async fn test_tokens_empty_string() {
let mut ctx = make_ctx().await;
let mut args = ToolArgs::new();
args.positional.push(Value::String("".into()));
let result = Tokens.execute(args, &mut ctx).await;
assert!(result.ok());
assert_eq!(result.text_out().trim(), "0");
}
#[tokio::test]
async fn test_tokens_unknown_encoding() {
let mut ctx = make_ctx().await;
let mut args = ToolArgs::new();
args.positional.push(Value::String("Hello".into()));
args.named
.insert("encoding".to_string(), Value::String("invalid".into()));
let result = Tokens.execute(args, &mut ctx).await;
assert!(!result.ok());
assert!(result.err.contains("unknown encoding"));
}
#[tokio::test]
async fn test_tokens_o200k_encoding() {
let mut ctx = make_ctx().await;
let mut args = ToolArgs::new();
args.positional.push(Value::String("Hello world".into()));
args.named
.insert("encoding".to_string(), Value::String("o200k".into()));
let result = Tokens.execute(args, &mut ctx).await;
assert!(result.ok());
let count: usize = result.text_out().trim().parse().expect("valid number");
assert!(count >= 1);
}
#[tokio::test]
async fn test_tokens_p50k_encoding() {
let mut ctx = make_ctx().await;
let mut args = ToolArgs::new();
args.positional.push(Value::String("Hello world".into()));
args.named
.insert("encoding".to_string(), Value::String("p50k".into()));
let result = Tokens.execute(args, &mut ctx).await;
assert!(result.ok());
let count: usize = result.text_out().trim().parse().expect("valid number");
assert!(count >= 1);
}
#[tokio::test]
async fn test_tokens_long_text() {
let mut ctx = make_ctx().await;
let mut args = ToolArgs::new();
let text = "The quick brown fox jumps over the lazy dog. ".repeat(10);
args.positional.push(Value::String(text.into()));
let result = Tokens.execute(args, &mut ctx).await;
assert!(result.ok());
let count: usize = result.text_out().trim().parse().expect("valid number");
assert!(count > 10);
}
#[tokio::test]
async fn test_tokens_special_characters() {
let mut ctx = make_ctx().await;
let mut args = ToolArgs::new();
args.positional
.push(Value::String("Hello\nWorld\t!@#$%".into()));
let result = Tokens.execute(args, &mut ctx).await;
assert!(result.ok());
let count: usize = result.text_out().trim().parse().expect("valid number");
assert!(count >= 1);
}
}