cognis 0.2.1

LLM application framework built on cognis-core
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
use super::TextSplitter;
use cognis_core::utils::tokens::{estimate_token_count, get_model_context_window};

/// A text splitter that respects token limits rather than character limits.
///
/// Uses heuristic token estimation (or model-aware estimation) to split text
/// into chunks that fit within a specified token budget. Supports configurable
/// overlap between chunks and hierarchical separator-based splitting.
pub struct TokenAwareTextSplitter {
    /// Maximum tokens per chunk.
    pub max_tokens: usize,
    /// Number of overlap tokens between consecutive chunks.
    pub overlap_tokens: usize,
    /// Optional model name for more accurate token estimation.
    pub model_name: Option<String>,
    /// Separators to try in priority order (highest priority first).
    pub separators: Vec<String>,
}

impl Default for TokenAwareTextSplitter {
    fn default() -> Self {
        Self {
            max_tokens: 500,
            overlap_tokens: 50,
            model_name: None,
            separators: vec!["\n\n".into(), "\n".into(), ". ".into(), " ".into()],
        }
    }
}

impl TokenAwareTextSplitter {
    /// Create a new `TokenAwareTextSplitter` with default settings.
    pub fn new() -> Self {
        Self::default()
    }

    /// Set the maximum tokens per chunk.
    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
        self.max_tokens = max_tokens;
        self
    }

    /// Set the overlap tokens between chunks.
    pub fn with_overlap_tokens(mut self, overlap_tokens: usize) -> Self {
        self.overlap_tokens = overlap_tokens;
        self
    }

    /// Set the model name for token estimation context.
    pub fn with_model(mut self, model_name: impl Into<String>) -> Self {
        self.model_name = Some(model_name.into());
        self
    }

    /// Set custom separators (highest priority first).
    pub fn with_separators(mut self, separators: Vec<String>) -> Self {
        self.separators = separators;
        self
    }

    /// Create a splitter sized for a specific model's context window.
    ///
    /// Divides the model's context window by `chunks_per_context` to determine
    /// `max_tokens`. Falls back to a default of 2000 tokens if the model is not
    /// recognized.
    pub fn from_model_context(model_name: &str, chunks_per_context: usize) -> Self {
        let context_window = get_model_context_window(model_name).unwrap_or(2000);
        let max_tokens = context_window
            .checked_div(chunks_per_context)
            .unwrap_or(context_window);
        Self {
            max_tokens,
            overlap_tokens: 50,
            model_name: Some(model_name.to_string()),
            separators: vec!["\n\n".into(), "\n".into(), ". ".into(), " ".into()],
        }
    }

    /// Estimate the token count for a piece of text.
    fn estimate_tokens(text: &str, _model: Option<&str>) -> usize {
        estimate_token_count(text)
    }

    /// Split text at the highest-priority separator that produces sub-chunks,
    /// then merge small pieces and add overlap.
    fn split_with_separators(&self, text: &str) -> Vec<String> {
        if text.is_empty() {
            return Vec::new();
        }

        let text_tokens = Self::estimate_tokens(text, self.model_name.as_deref());
        if text_tokens <= self.max_tokens {
            return vec![text.to_string()];
        }

        // Find the highest-priority separator present in the text.
        let separator = self
            .separators
            .iter()
            .find(|sep| text.contains(sep.as_str()))
            .cloned();

        let pieces: Vec<&str> = match &separator {
            Some(sep) => text.split(sep.as_str()).collect(),
            None => {
                // No separator found; fall back to word-level splitting on whitespace.
                text.split_whitespace().collect()
            }
        };

        // Filter out empty pieces.
        let pieces: Vec<&str> = pieces.iter().copied().filter(|p| !p.is_empty()).collect();

        // Merge small pieces into chunks that fit within max_tokens.
        let mut chunks: Vec<String> = Vec::new();
        let mut current = String::new();
        let mut current_tokens: usize = 0;

        for piece in &pieces {
            let piece_tokens = Self::estimate_tokens(piece, self.model_name.as_deref());

            // If a single piece exceeds max_tokens, try to split it further
            // with lower-priority separators.
            if piece_tokens > self.max_tokens {
                // Flush current buffer first.
                if !current.is_empty() {
                    chunks.push(current.trim().to_string());
                    current = String::new();
                    current_tokens = 0;
                }
                let sub_chunks = self.split_subsection(piece);
                chunks.extend(sub_chunks);
                continue;
            }

            let sep_str = separator.as_deref().unwrap_or(" ");
            let would_be = if current.is_empty() {
                piece_tokens
            } else {
                current_tokens
                    + Self::estimate_tokens(sep_str, self.model_name.as_deref())
                    + piece_tokens
            };

            if would_be > self.max_tokens && !current.is_empty() {
                chunks.push(current.trim().to_string());
                current = String::new();
            }

            if current.is_empty() {
                current = piece.to_string();
                current_tokens = piece_tokens;
            } else {
                current.push_str(separator.as_deref().unwrap_or(" "));
                current.push_str(piece);
                current_tokens = Self::estimate_tokens(&current, self.model_name.as_deref());
            }
        }

        if !current.is_empty() {
            let trimmed = current.trim().to_string();
            if !trimmed.is_empty() {
                chunks.push(trimmed);
            }
        }

        // Apply overlap between consecutive chunks.
        if self.overlap_tokens > 0 && chunks.len() > 1 {
            chunks = self.apply_overlap(chunks);
        }

        chunks
    }

    /// Try splitting a subsection using lower-priority separators.
    fn split_subsection(&self, text: &str) -> Vec<String> {
        for sep in &self.separators {
            if text.contains(sep.as_str()) {
                let pieces: Vec<&str> =
                    text.split(sep.as_str()).filter(|p| !p.is_empty()).collect();
                if pieces.len() > 1 {
                    let mut sub_chunks = Vec::new();
                    let mut current = String::new();
                    let mut current_tokens: usize = 0;

                    for piece in &pieces {
                        let piece_tokens = Self::estimate_tokens(piece, self.model_name.as_deref());
                        let would_be = if current.is_empty() {
                            piece_tokens
                        } else {
                            current_tokens
                                + Self::estimate_tokens(sep, self.model_name.as_deref())
                                + piece_tokens
                        };

                        if would_be > self.max_tokens && !current.is_empty() {
                            sub_chunks.push(current.trim().to_string());
                            current = String::new();
                        }

                        if current.is_empty() {
                            current = piece.to_string();
                            current_tokens = piece_tokens;
                        } else {
                            current.push_str(sep);
                            current.push_str(piece);
                            current_tokens =
                                Self::estimate_tokens(&current, self.model_name.as_deref());
                        }
                    }

                    if !current.is_empty() {
                        let trimmed = current.trim().to_string();
                        if !trimmed.is_empty() {
                            sub_chunks.push(trimmed);
                        }
                    }
                    return sub_chunks;
                }
            }
        }
        // Cannot split further; return as-is.
        vec![text.to_string()]
    }

    /// Add overlap from the end of the previous chunk to the start of the next.
    fn apply_overlap(&self, chunks: Vec<String>) -> Vec<String> {
        if chunks.len() <= 1 {
            return chunks;
        }

        let mut result = Vec::with_capacity(chunks.len());
        result.push(chunks[0].clone());

        for i in 1..chunks.len() {
            let prev = &chunks[i - 1];
            let overlap_text = self.get_overlap_suffix(prev);
            if overlap_text.is_empty() {
                result.push(chunks[i].clone());
            } else {
                let merged = format!("{} {}", overlap_text.trim(), chunks[i].trim());
                // Only use overlap if the merged chunk still fits.
                let merged_tokens = Self::estimate_tokens(&merged, self.model_name.as_deref());
                if merged_tokens <= self.max_tokens {
                    result.push(merged);
                } else {
                    result.push(chunks[i].clone());
                }
            }
        }

        result
    }

    /// Extract the last `overlap_tokens` worth of text from a string.
    fn get_overlap_suffix(&self, text: &str) -> String {
        let words: Vec<&str> = text.split_whitespace().collect();
        let mut suffix_words: Vec<&str> = Vec::new();
        let mut token_count = 0;

        for word in words.iter().rev() {
            let word_tokens = Self::estimate_tokens(word, self.model_name.as_deref());
            if token_count + word_tokens > self.overlap_tokens {
                break;
            }
            token_count += word_tokens;
            suffix_words.push(word);
        }

        suffix_words.reverse();
        suffix_words.join(" ")
    }
}

impl TextSplitter for TokenAwareTextSplitter {
    fn split_text(&self, text: &str) -> Vec<String> {
        self.split_with_separators(text)
    }

    fn chunk_size(&self) -> usize {
        self.max_tokens
    }

    fn chunk_overlap(&self) -> usize {
        self.overlap_tokens
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_short_text_returns_single_chunk() {
        let splitter = TokenAwareTextSplitter::new().with_max_tokens(100);
        let result = splitter.split_text("Hello world.");
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], "Hello world.");
    }

    #[test]
    fn test_long_text_splits_into_multiple_chunks() {
        let splitter = TokenAwareTextSplitter::new()
            .with_max_tokens(10)
            .with_overlap_tokens(0);

        // Build text that is well over 10 tokens (~40 chars = ~10 tokens).
        let text = "The quick brown fox jumps over the lazy dog. \
                     The quick brown fox jumps over the lazy dog. \
                     The quick brown fox jumps over the lazy dog.";

        let chunks = splitter.split_text(text);
        assert!(
            chunks.len() > 1,
            "Expected multiple chunks, got {}",
            chunks.len()
        );

        // Each chunk should be within the token limit (with small tolerance).
        for chunk in &chunks {
            let tokens = estimate_token_count(chunk);
            assert!(
                tokens <= splitter.max_tokens + 2,
                "Chunk has {} tokens, max is {}: {:?}",
                tokens,
                splitter.max_tokens,
                chunk
            );
        }
    }

    #[test]
    fn test_overlap_between_chunks() {
        let splitter = TokenAwareTextSplitter::new()
            .with_max_tokens(15)
            .with_overlap_tokens(5);

        let text = "Alpha beta gamma delta. Epsilon zeta eta theta. \
                     Iota kappa lambda mu. Nu xi omicron pi.";

        let chunks = splitter.split_text(text);
        assert!(
            chunks.len() > 1,
            "Expected multiple chunks for overlap test"
        );

        // With overlap, later chunks should share some text with the previous chunk.
        let mut found_overlap = false;
        for i in 1..chunks.len() {
            let prev_words: Vec<&str> = chunks[i - 1].split_whitespace().collect();
            let curr_words: Vec<&str> = chunks[i].split_whitespace().collect();
            for word in &prev_words {
                if curr_words.contains(word) && word.len() > 3 {
                    found_overlap = true;
                    break;
                }
            }
            if found_overlap {
                break;
            }
        }
        assert!(found_overlap, "Expected overlap between consecutive chunks");
    }

    #[test]
    fn test_custom_separators() {
        let splitter = TokenAwareTextSplitter::new()
            .with_max_tokens(10)
            .with_overlap_tokens(0)
            .with_separators(vec!["||".into()]);

        let text = "chunk one text here||chunk two text here||chunk three text here";
        let chunks = splitter.split_text(text);
        assert!(
            chunks.len() >= 2,
            "Expected at least 2 chunks with custom separator, got {}",
            chunks.len()
        );
    }

    #[test]
    fn test_from_model_context_factory() {
        let splitter = TokenAwareTextSplitter::from_model_context("gpt-4o", 10);
        // gpt-4o has 128_000 context window, divided by 10 = 12_800
        assert_eq!(splitter.max_tokens, 12_800);
        assert_eq!(splitter.model_name.as_deref(), Some("gpt-4o"));

        let splitter_claude = TokenAwareTextSplitter::from_model_context("claude-3-opus", 20);
        // claude-3-opus has 200_000 / 20 = 10_000
        assert_eq!(splitter_claude.max_tokens, 10_000);

        // Unknown model falls back to 2000 / 4 = 500
        let splitter_unknown = TokenAwareTextSplitter::from_model_context("unknown-model", 4);
        assert_eq!(splitter_unknown.max_tokens, 500);
    }

    #[test]
    fn test_empty_text_returns_empty_vec() {
        let splitter = TokenAwareTextSplitter::new();
        let result = splitter.split_text("");
        assert!(result.is_empty());
    }

    #[test]
    fn test_chunk_size_and_overlap_trait_methods() {
        let splitter = TokenAwareTextSplitter::new()
            .with_max_tokens(256)
            .with_overlap_tokens(32);
        assert_eq!(splitter.chunk_size(), 256);
        assert_eq!(splitter.chunk_overlap(), 32);
    }
}