Skip to main content

hirn_engine/operators/
narrative.rs

1//! Narrative assembly operator.
2//!
3//! Packs memory content into a single narrative `RecordBatch` that fits
4//! within a given token budget. Used to construct LLM-ready context windows.
5
6use std::sync::Arc;
7
8use arrow_array::Array;
9use arrow_array::RecordBatch;
10use arrow_array::builder::StringBuilder;
11use arrow_array::cast::AsArray;
12use arrow_schema::{DataType, Field, Schema};
13use async_trait::async_trait;
14
15use hirn_core::embed::TokenCounter;
16use hirn_core::error::HirnResult;
17
18use super::{OpContext, Operator};
19
20/// Operator that assembles input content into a token-budgeted narrative.
21///
22/// Reads the `content` column from input batches in order, concatenating
23/// until the token budget is exhausted. Produces a single-row `RecordBatch`
24/// with a `narrative` column containing the assembled text.
25pub struct NarrativeAssemble {
26    /// Maximum number of tokens in the assembled narrative.
27    pub max_tokens: usize,
28    /// Token counter implementation.
29    pub token_counter: Arc<dyn TokenCounter>,
30}
31
32#[async_trait]
33impl Operator for NarrativeAssemble {
34    async fn execute(
35        &self,
36        input: Vec<RecordBatch>,
37        _ctx: &OpContext,
38    ) -> HirnResult<Vec<RecordBatch>> {
39        let mut narrative = String::new();
40        let mut used_tokens: usize = 0;
41
42        'outer: for batch in &input {
43            let content_col = match batch.column_by_name("content") {
44                Some(c) => c,
45                None => continue,
46            };
47            let str_arr = content_col.as_string::<i32>();
48            for i in 0..str_arr.len() {
49                if str_arr.is_null(i) {
50                    continue;
51                }
52                let text = str_arr.value(i);
53                let tokens = self.token_counter.count_tokens(text);
54                if used_tokens + tokens > self.max_tokens {
55                    // Try to fit a truncated version.
56                    let remaining = self.max_tokens.saturating_sub(used_tokens);
57                    if remaining > 0 {
58                        let truncated = truncate_to_tokens(text, remaining, &*self.token_counter);
59                        if !truncated.is_empty() {
60                            if !narrative.is_empty() {
61                                narrative.push_str("\n\n");
62                            }
63                            narrative.push_str(&truncated);
64                        }
65                    }
66                    break 'outer;
67                }
68                if !narrative.is_empty() {
69                    narrative.push_str("\n\n");
70                    // Account for separator tokens.
71                    let sep_tokens = self.token_counter.count_tokens("\n\n");
72                    used_tokens += sep_tokens;
73                }
74                narrative.push_str(text);
75                used_tokens += tokens;
76            }
77        }
78
79        let schema = Arc::new(Schema::new(vec![Field::new(
80            "narrative",
81            DataType::Utf8,
82            false,
83        )]));
84
85        if narrative.is_empty() {
86            return Ok(vec![RecordBatch::new_empty(schema)]);
87        }
88
89        let mut builder = StringBuilder::new();
90        builder.append_value(&narrative);
91        let batch = RecordBatch::try_new(schema, vec![Arc::new(builder.finish())])
92            .map_err(|e| hirn_core::error::HirnError::storage(e))?;
93        Ok(vec![batch])
94    }
95}
96
97/// Truncate text to fit approximately `max_tokens` tokens by binary search
98/// on character boundaries.
99fn truncate_to_tokens(text: &str, max_tokens: usize, counter: &dyn TokenCounter) -> String {
100    if counter.count_tokens(text) <= max_tokens {
101        return text.to_string();
102    }
103    // Binary search for the longest prefix that fits.
104    let chars: Vec<char> = text.chars().collect();
105    let mut lo = 0usize;
106    let mut hi = chars.len();
107    while lo < hi {
108        let mid = lo + (hi - lo + 1) / 2;
109        let prefix: String = chars[..mid].iter().collect();
110        if counter.count_tokens(&prefix) <= max_tokens {
111            lo = mid;
112        } else {
113            hi = mid - 1;
114        }
115    }
116    if lo == 0 {
117        return String::new();
118    }
119    chars[..lo].iter().collect()
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use hirn_core::embed::CharEstimateCounter;
126
127    #[test]
128    fn truncate_respects_budget() {
129        let counter = CharEstimateCounter;
130        // CharEstimateCounter: ceil(len / 4)
131        let text = "a]".repeat(40); // 80 chars → 20 tokens
132        let result = truncate_to_tokens(&text, 10, &counter);
133        let tokens = counter.count_tokens(&result);
134        assert!(tokens <= 10, "tokens={tokens}");
135    }
136
137    #[test]
138    fn truncate_empty_budget() {
139        let counter = CharEstimateCounter;
140        let result = truncate_to_tokens("hello world", 0, &counter);
141        assert!(result.is_empty());
142    }
143}