hirn_exec/udfs/
token_count.rs1use std::any::Any;
6use std::sync::Arc;
7
8use arrow_array::Array;
9use arrow_array::cast::AsArray;
10use arrow_array::{ArrayRef, UInt32Array};
11use datafusion_common::Result;
12use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
13
14use arrow_schema::DataType;
15
16#[derive(Debug, PartialEq, Eq, Hash)]
17pub struct TokenCountUdf {
18 signature: Signature,
19}
20
21impl Default for TokenCountUdf {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl TokenCountUdf {
28 pub fn new() -> Self {
29 Self {
30 signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable),
31 }
32 }
33}
34
35fn estimate_tokens(text: &str) -> u32 {
38 let char_count = text.len();
41 ((char_count as f64 / 4.0).ceil()) as u32
42}
43
44impl ScalarUDFImpl for TokenCountUdf {
45 fn as_any(&self) -> &dyn Any {
46 self
47 }
48
49 fn name(&self) -> &str {
50 "token_count"
51 }
52
53 fn signature(&self) -> &Signature {
54 &self.signature
55 }
56
57 fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
58 Ok(DataType::UInt32)
59 }
60
61 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
62 let num_rows = args.number_rows;
63 let arrays: Vec<ArrayRef> = args
64 .args
65 .iter()
66 .map(|a| a.to_array(num_rows))
67 .collect::<Result<Vec<_>>>()?;
68
69 let text_arr = arrays[0].as_string::<i32>();
70 let len = text_arr.len();
71 let mut results = Vec::with_capacity(len);
72
73 for i in 0..len {
74 if text_arr.is_null(i) {
75 results.push(None);
76 } else {
77 results.push(Some(estimate_tokens(text_arr.value(i))));
78 }
79 }
80
81 Ok(ColumnarValue::Array(Arc::new(UInt32Array::from(results))))
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use arrow_array::StringArray;
89 use arrow_schema::Field;
90 use datafusion_common::config::ConfigOptions;
91
92 fn invoke(texts: Vec<Option<&str>>) -> UInt32Array {
93 let udf = TokenCountUdf::new();
94 let args = ScalarFunctionArgs {
95 args: vec![ColumnarValue::Array(Arc::new(StringArray::from(
96 texts.clone(),
97 )))],
98 number_rows: texts.len(),
99 return_field: Arc::new(Field::new("result", DataType::UInt32, true)),
100 arg_fields: vec![],
101 config_options: Arc::new(ConfigOptions::new()),
102 };
103 let result = udf.invoke_with_args(args).unwrap();
104 match result {
105 ColumnarValue::Array(a) => a.as_any().downcast_ref::<UInt32Array>().unwrap().clone(),
106 _ => panic!("expected array"),
107 }
108 }
109
110 #[test]
111 fn known_text() {
112 let vals = invoke(vec![Some("Hello, world! This is a test.")]);
113 assert!(vals.value(0) > 0);
115 assert!(vals.value(0) < 20);
116 }
117
118 #[test]
119 fn empty_string() {
120 let vals = invoke(vec![Some("")]);
121 assert_eq!(vals.value(0), 0);
122 }
123
124 #[test]
125 fn null_returns_null() {
126 let vals = invoke(vec![None]);
127 assert!(vals.is_null(0));
128 }
129}