Skip to main content

hirn_exec/udfs/
token_count.rs

1//! `token_count` UDF — tokenize text and return approximate token count.
2//!
3//! `token_count(text: Utf8) → UInt32`
4
5use std::any::Any;
6use std::sync::Arc;
7
8use arrow_array::Array;
9use arrow_array::cast::AsArray;
10use arrow_array::{ArrayRef, UInt32Array};
11use datafusion_common::Result;
12use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
13
14use arrow_schema::DataType;
15
16#[derive(Debug, PartialEq, Eq, Hash)]
17pub struct TokenCountUdf {
18    signature: Signature,
19}
20
21impl Default for TokenCountUdf {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl TokenCountUdf {
28    pub fn new() -> Self {
29        Self {
30            signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable),
31        }
32    }
33}
34
35/// Estimate token count using the simple heuristic: ~4 chars per token.
36/// This avoids the overhead of loading a full BPE tokenizer for every row.
37fn estimate_tokens(text: &str) -> u32 {
38    // A commonly used heuristic: ~4 characters per token for English text.
39    // This matches OpenAI's guidance and is good enough for budget calculations.
40    let char_count = text.len();
41    ((char_count as f64 / 4.0).ceil()) as u32
42}
43
44impl ScalarUDFImpl for TokenCountUdf {
45    fn as_any(&self) -> &dyn Any {
46        self
47    }
48
49    fn name(&self) -> &str {
50        "token_count"
51    }
52
53    fn signature(&self) -> &Signature {
54        &self.signature
55    }
56
57    fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
58        Ok(DataType::UInt32)
59    }
60
61    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
62        let num_rows = args.number_rows;
63        let arrays: Vec<ArrayRef> = args
64            .args
65            .iter()
66            .map(|a| a.to_array(num_rows))
67            .collect::<Result<Vec<_>>>()?;
68
69        let text_arr = arrays[0].as_string::<i32>();
70        let len = text_arr.len();
71        let mut results = Vec::with_capacity(len);
72
73        for i in 0..len {
74            if text_arr.is_null(i) {
75                results.push(None);
76            } else {
77                results.push(Some(estimate_tokens(text_arr.value(i))));
78            }
79        }
80
81        Ok(ColumnarValue::Array(Arc::new(UInt32Array::from(results))))
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use arrow_array::StringArray;
89    use arrow_schema::Field;
90    use datafusion_common::config::ConfigOptions;
91
92    fn invoke(texts: Vec<Option<&str>>) -> UInt32Array {
93        let udf = TokenCountUdf::new();
94        let args = ScalarFunctionArgs {
95            args: vec![ColumnarValue::Array(Arc::new(StringArray::from(
96                texts.clone(),
97            )))],
98            number_rows: texts.len(),
99            return_field: Arc::new(Field::new("result", DataType::UInt32, true)),
100            arg_fields: vec![],
101            config_options: Arc::new(ConfigOptions::new()),
102        };
103        let result = udf.invoke_with_args(args).unwrap();
104        match result {
105            ColumnarValue::Array(a) => a.as_any().downcast_ref::<UInt32Array>().unwrap().clone(),
106            _ => panic!("expected array"),
107        }
108    }
109
110    #[test]
111    fn known_text() {
112        let vals = invoke(vec![Some("Hello, world! This is a test.")]);
113        // 29 chars / 4 ≈ 8 tokens
114        assert!(vals.value(0) > 0);
115        assert!(vals.value(0) < 20);
116    }
117
118    #[test]
119    fn empty_string() {
120        let vals = invoke(vec![Some("")]);
121        assert_eq!(vals.value(0), 0);
122    }
123
124    #[test]
125    fn null_returns_null() {
126        let vals = invoke(vec![None]);
127        assert!(vals.is_null(0));
128    }
129}