hirn_engine/operators/
narrative.rs1use std::sync::Arc;
7
8use arrow_array::Array;
9use arrow_array::RecordBatch;
10use arrow_array::builder::StringBuilder;
11use arrow_array::cast::AsArray;
12use arrow_schema::{DataType, Field, Schema};
13use async_trait::async_trait;
14
15use hirn_core::embed::TokenCounter;
16use hirn_core::error::HirnResult;
17
18use super::{OpContext, Operator};
19
20pub struct NarrativeAssemble {
26 pub max_tokens: usize,
28 pub token_counter: Arc<dyn TokenCounter>,
30}
31
32#[async_trait]
33impl Operator for NarrativeAssemble {
34 async fn execute(
35 &self,
36 input: Vec<RecordBatch>,
37 _ctx: &OpContext,
38 ) -> HirnResult<Vec<RecordBatch>> {
39 let mut narrative = String::new();
40 let mut used_tokens: usize = 0;
41
42 'outer: for batch in &input {
43 let content_col = match batch.column_by_name("content") {
44 Some(c) => c,
45 None => continue,
46 };
47 let str_arr = content_col.as_string::<i32>();
48 for i in 0..str_arr.len() {
49 if str_arr.is_null(i) {
50 continue;
51 }
52 let text = str_arr.value(i);
53 let tokens = self.token_counter.count_tokens(text);
54 if used_tokens + tokens > self.max_tokens {
55 let remaining = self.max_tokens.saturating_sub(used_tokens);
57 if remaining > 0 {
58 let truncated = truncate_to_tokens(text, remaining, &*self.token_counter);
59 if !truncated.is_empty() {
60 if !narrative.is_empty() {
61 narrative.push_str("\n\n");
62 }
63 narrative.push_str(&truncated);
64 }
65 }
66 break 'outer;
67 }
68 if !narrative.is_empty() {
69 narrative.push_str("\n\n");
70 let sep_tokens = self.token_counter.count_tokens("\n\n");
72 used_tokens += sep_tokens;
73 }
74 narrative.push_str(text);
75 used_tokens += tokens;
76 }
77 }
78
79 let schema = Arc::new(Schema::new(vec![Field::new(
80 "narrative",
81 DataType::Utf8,
82 false,
83 )]));
84
85 if narrative.is_empty() {
86 return Ok(vec![RecordBatch::new_empty(schema)]);
87 }
88
89 let mut builder = StringBuilder::new();
90 builder.append_value(&narrative);
91 let batch = RecordBatch::try_new(schema, vec![Arc::new(builder.finish())])
92 .map_err(|e| hirn_core::error::HirnError::storage(e))?;
93 Ok(vec![batch])
94 }
95}
96
97fn truncate_to_tokens(text: &str, max_tokens: usize, counter: &dyn TokenCounter) -> String {
100 if counter.count_tokens(text) <= max_tokens {
101 return text.to_string();
102 }
103 let chars: Vec<char> = text.chars().collect();
105 let mut lo = 0usize;
106 let mut hi = chars.len();
107 while lo < hi {
108 let mid = lo + (hi - lo + 1) / 2;
109 let prefix: String = chars[..mid].iter().collect();
110 if counter.count_tokens(&prefix) <= max_tokens {
111 lo = mid;
112 } else {
113 hi = mid - 1;
114 }
115 }
116 if lo == 0 {
117 return String::new();
118 }
119 chars[..lo].iter().collect()
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use hirn_core::embed::CharEstimateCounter;
126
127 #[test]
128 fn truncate_respects_budget() {
129 let counter = CharEstimateCounter;
130 let text = "a]".repeat(40); let result = truncate_to_tokens(&text, 10, &counter);
133 let tokens = counter.count_tokens(&result);
134 assert!(tokens <= 10, "tokens={tokens}");
135 }
136
137 #[test]
138 fn truncate_empty_budget() {
139 let counter = CharEstimateCounter;
140 let result = truncate_to_tokens("hello world", 0, &counter);
141 assert!(result.is_empty());
142 }
143}