code_digest/core/
prioritizer.rs

1//! File prioritization based on token limits
2
3use crate::core::cache::FileCache;
4use crate::core::digest::DigestOptions;
5use crate::core::token::{would_exceed_limit, TokenCounter};
6use crate::core::walker::FileInfo;
7use anyhow::Result;
8use rayon::prelude::*;
9use std::sync::Arc;
10
11/// File with pre-computed token count
12#[derive(Debug, Clone)]
13struct FileWithTokens {
14    file: FileInfo,
15    token_count: usize,
16}
17
18/// Prioritize files based on their importance and token limits
19pub fn prioritize_files(
20    mut files: Vec<FileInfo>,
21    options: &DigestOptions,
22    cache: Arc<FileCache>,
23) -> Result<Vec<FileInfo>> {
24    // If no token limit, return all files sorted by priority
25    let max_tokens = match options.max_tokens {
26        Some(limit) => limit,
27        None => {
28            files.sort_by(|a, b| {
29                b.priority
30                    .partial_cmp(&a.priority)
31                    .unwrap_or(std::cmp::Ordering::Equal)
32                    .then_with(|| a.relative_path.cmp(&b.relative_path))
33            });
34            return Ok(files);
35        }
36    };
37
38    // Create token counter
39    let counter = TokenCounter::new()?;
40
41    // Calculate overhead for markdown structure
42    let structure_overhead = calculate_structure_overhead(options, &files)?;
43
44    // Phase 1: Count tokens for all files in parallel with proper error handling
45    let results: Vec<crate::utils::error::Result<FileWithTokens>> = files
46        .into_par_iter()
47        .map(|file| {
48            // Read file content from cache
49            let content = cache.get_or_load(&file.path).map_err(|e| {
50                crate::utils::error::CodeDigestError::FileProcessingError {
51                    path: file.path.display().to_string(),
52                    error: format!("Could not read file: {e}"),
53                }
54            })?;
55
56            // Count tokens for this file
57            let file_tokens = counter
58                .count_file_tokens(&content, &file.relative_path.to_string_lossy())
59                .map_err(|e| crate::utils::error::CodeDigestError::TokenCountingError {
60                    path: file.path.display().to_string(),
61                    error: e.to_string(),
62                })?;
63
64            Ok(FileWithTokens { file, token_count: file_tokens.total_tokens })
65        })
66        .collect();
67
68    // Use partition_result to separate successes from errors
69    use itertools::Itertools;
70    let (files_with_tokens, errors): (Vec<_>, Vec<_>) = results.into_iter().partition_result();
71
72    // Log errors without failing the entire operation
73    if !errors.is_empty() {
74        eprintln!("Warning: {} files could not be processed for token counting:", errors.len());
75        for error in &errors {
76            eprintln!("  {error}");
77        }
78    }
79
80    // Phase 2: Sort by priority and select files sequentially
81    let mut files_with_tokens = files_with_tokens;
82    files_with_tokens.sort_by(|a, b| {
83        b.file
84            .priority
85            .partial_cmp(&a.file.priority)
86            .unwrap_or(std::cmp::Ordering::Equal)
87            .then_with(|| a.file.relative_path.cmp(&b.file.relative_path))
88    });
89
90    let mut selected_files = Vec::new();
91    let mut total_tokens = structure_overhead;
92
93    // Select files until we hit the token limit
94    for file_with_tokens in files_with_tokens {
95        // Check if adding this file would exceed the limit
96        if would_exceed_limit(total_tokens, file_with_tokens.token_count, max_tokens) {
97            // Try to find smaller files that might fit
98            continue;
99        }
100
101        // Add the file
102        total_tokens += file_with_tokens.token_count;
103        selected_files.push(file_with_tokens.file);
104    }
105
106    // Log statistics
107    if options.include_stats {
108        eprintln!("Token limit: {max_tokens}");
109        eprintln!("Structure overhead: {structure_overhead} tokens");
110        eprintln!(
111            "Selected {} files with approximately {} tokens",
112            selected_files.len(),
113            total_tokens
114        );
115    }
116
117    Ok(selected_files)
118}
119
120/// Calculate token overhead for markdown structure
121fn calculate_structure_overhead(options: &DigestOptions, files: &[FileInfo]) -> Result<usize> {
122    let counter = TokenCounter::new()?;
123    let mut overhead = 0;
124
125    // Document header
126    if !options.doc_header_template.is_empty() {
127        let header = options.doc_header_template.replace("{directory}", ".");
128        overhead += counter.count_tokens(&format!("{header}\n\n"))?;
129    }
130
131    // Statistics section
132    if options.include_stats {
133        // Estimate statistics section size
134        let stats_estimate = format!(
135            "## Statistics\n\n- Total files: {}\n- Total size: X bytes\n\n### Files by type:\n",
136            files.len()
137        );
138        overhead += counter.count_tokens(&stats_estimate)?;
139        overhead += 200; // Buffer for file type list
140    }
141
142    // File tree
143    if options.include_tree {
144        overhead += counter.count_tokens("## File Structure\n\n```\n")?;
145        // Estimate tree size (rough approximation)
146        overhead += files.len() * 20; // ~20 tokens per file in tree
147        overhead += counter.count_tokens("```\n\n")?;
148    }
149
150    // Table of contents
151    if options.include_toc {
152        overhead += counter.count_tokens("## Table of Contents\n\n")?;
153        for file in files {
154            let toc_line = format!("- [{}](#anchor)\n", file.relative_path.display());
155            overhead += counter.count_tokens(&toc_line)?;
156        }
157        overhead += counter.count_tokens("\n")?;
158    }
159
160    Ok(overhead)
161}
162
163/// Group files by directory for better organization
164pub fn group_by_directory(files: Vec<FileInfo>) -> Vec<(String, Vec<FileInfo>)> {
165    use std::collections::HashMap;
166
167    let mut groups: HashMap<String, Vec<FileInfo>> = HashMap::new();
168
169    for file in files {
170        let dir = file
171            .relative_path
172            .parent()
173            .map(|p| p.to_string_lossy().to_string())
174            .unwrap_or_else(|| ".".to_string());
175
176        groups.entry(dir).or_default().push(file);
177    }
178
179    let mut result: Vec<_> = groups.into_iter().collect();
180    result.sort_by(|a, b| a.0.cmp(&b.0));
181
182    // Sort files within each group by priority
183    for (_, files) in &mut result {
184        files.sort_by(|a, b| {
185            b.priority
186                .partial_cmp(&a.priority)
187                .unwrap_or(std::cmp::Ordering::Equal)
188                .then_with(|| a.relative_path.cmp(&b.relative_path))
189        });
190    }
191
192    result
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use crate::utils::file_ext::FileType;
199    use std::fs;
200    use std::path::PathBuf;
201    use tempfile::TempDir;
202
203    fn create_test_cache() -> Arc<FileCache> {
204        Arc::new(FileCache::new())
205    }
206
207    fn create_test_files(_temp_dir: &TempDir, files: &[FileInfo]) {
208        for file in files {
209            if let Some(parent) = file.path.parent() {
210                fs::create_dir_all(parent).ok();
211            }
212            fs::write(&file.path, "test content").ok();
213        }
214    }
215
216    #[test]
217    fn test_prioritize_without_limit() {
218        let temp_dir = TempDir::new().unwrap();
219        let files = vec![
220            FileInfo {
221                path: temp_dir.path().join("low.txt"),
222                relative_path: PathBuf::from("low.txt"),
223                size: 100,
224                file_type: FileType::Text,
225                priority: 0.3,
226            },
227            FileInfo {
228                path: temp_dir.path().join("high.rs"),
229                relative_path: PathBuf::from("high.rs"),
230                size: 100,
231                file_type: FileType::Rust,
232                priority: 1.0,
233            },
234        ];
235
236        create_test_files(&temp_dir, &files);
237        let cache = create_test_cache();
238        let options = DigestOptions::default();
239        let result = prioritize_files(files, &options, cache).unwrap();
240
241        assert_eq!(result.len(), 2);
242        assert_eq!(result[0].relative_path, PathBuf::from("high.rs"));
243        assert_eq!(result[1].relative_path, PathBuf::from("low.txt"));
244    }
245
246    #[test]
247    fn test_group_by_directory() {
248        let files = vec![
249            FileInfo {
250                path: PathBuf::from("src/main.rs"),
251                relative_path: PathBuf::from("src/main.rs"),
252                size: 100,
253                file_type: FileType::Rust,
254                priority: 1.0,
255            },
256            FileInfo {
257                path: PathBuf::from("src/lib.rs"),
258                relative_path: PathBuf::from("src/lib.rs"),
259                size: 100,
260                file_type: FileType::Rust,
261                priority: 1.0,
262            },
263            FileInfo {
264                path: PathBuf::from("tests/test.rs"),
265                relative_path: PathBuf::from("tests/test.rs"),
266                size: 100,
267                file_type: FileType::Rust,
268                priority: 0.8,
269            },
270        ];
271
272        let groups = group_by_directory(files);
273
274        assert_eq!(groups.len(), 2);
275        assert_eq!(groups[0].0, "src");
276        assert_eq!(groups[0].1.len(), 2);
277        assert_eq!(groups[1].0, "tests");
278        assert_eq!(groups[1].1.len(), 1);
279    }
280
281    #[test]
282    fn test_prioritize_algorithm_ordering() {
283        let temp_dir = TempDir::new().unwrap();
284        let files = vec![
285            FileInfo {
286                path: temp_dir.path().join("test.rs"),
287                relative_path: PathBuf::from("test.rs"),
288                size: 500,
289                file_type: FileType::Rust,
290                priority: 0.8,
291            },
292            FileInfo {
293                path: temp_dir.path().join("main.rs"),
294                relative_path: PathBuf::from("main.rs"),
295                size: 1000,
296                file_type: FileType::Rust,
297                priority: 1.5,
298            },
299            FileInfo {
300                path: temp_dir.path().join("lib.rs"),
301                relative_path: PathBuf::from("lib.rs"),
302                size: 800,
303                file_type: FileType::Rust,
304                priority: 1.2,
305            },
306        ];
307
308        create_test_files(&temp_dir, &files);
309        let cache = create_test_cache();
310        let options = DigestOptions::default();
311        let result = prioritize_files(files, &options, cache).unwrap();
312
313        // Should return all files when no limit
314        assert_eq!(result.len(), 3);
315
316        // Files should be sorted by priority (highest first)
317        assert_eq!(result[0].relative_path, PathBuf::from("main.rs"));
318        assert_eq!(result[1].relative_path, PathBuf::from("lib.rs"));
319        assert_eq!(result[2].relative_path, PathBuf::from("test.rs"));
320    }
321
322    #[test]
323    fn test_calculate_structure_overhead() {
324        let files = vec![FileInfo {
325            path: PathBuf::from("main.rs"),
326            relative_path: PathBuf::from("main.rs"),
327            size: 1000,
328            file_type: FileType::Rust,
329            priority: 1.5,
330        }];
331
332        let options = DigestOptions {
333            max_tokens: None,
334            include_tree: true,
335            include_stats: true,
336            group_by_type: true,
337            sort_by_priority: true,
338            file_header_template: "## {path}".to_string(),
339            doc_header_template: "# Code Digest".to_string(),
340            include_toc: true,
341            enhanced_context: false,
342        };
343
344        let overhead = calculate_structure_overhead(&options, &files).unwrap();
345
346        // Should account for headers, tree, stats, TOC
347        assert!(overhead > 0);
348        assert!(overhead < 10000); // Reasonable upper bound
349    }
350
351    #[test]
352    fn test_priority_ordering() {
353        let mut files = [
354            FileInfo {
355                path: PathBuf::from("test.rs"),
356                relative_path: PathBuf::from("test.rs"),
357                size: 500,
358                file_type: FileType::Rust,
359                priority: 0.8,
360            },
361            FileInfo {
362                path: PathBuf::from("main.rs"),
363                relative_path: PathBuf::from("main.rs"),
364                size: 1000,
365                file_type: FileType::Rust,
366                priority: 1.5,
367            },
368            FileInfo {
369                path: PathBuf::from("lib.rs"),
370                relative_path: PathBuf::from("lib.rs"),
371                size: 800,
372                file_type: FileType::Rust,
373                priority: 1.2,
374            },
375        ];
376
377        // Sort by priority (highest first)
378        files.sort_by(|a, b| b.priority.partial_cmp(&a.priority).unwrap());
379
380        assert_eq!(files[0].relative_path, PathBuf::from("main.rs"));
381        assert_eq!(files[1].relative_path, PathBuf::from("lib.rs"));
382        assert_eq!(files[2].relative_path, PathBuf::from("test.rs"));
383    }
384
385    #[test]
386    fn test_group_by_directory_complex() {
387        let files = vec![
388            FileInfo {
389                path: PathBuf::from("src/core/mod.rs"),
390                relative_path: PathBuf::from("src/core/mod.rs"),
391                size: 500,
392                file_type: FileType::Rust,
393                priority: 1.0,
394            },
395            FileInfo {
396                path: PathBuf::from("src/utils/helpers.rs"),
397                relative_path: PathBuf::from("src/utils/helpers.rs"),
398                size: 300,
399                file_type: FileType::Rust,
400                priority: 0.9,
401            },
402            FileInfo {
403                path: PathBuf::from("tests/integration.rs"),
404                relative_path: PathBuf::from("tests/integration.rs"),
405                size: 200,
406                file_type: FileType::Rust,
407                priority: 0.8,
408            },
409            FileInfo {
410                path: PathBuf::from("main.rs"),
411                relative_path: PathBuf::from("main.rs"),
412                size: 1000,
413                file_type: FileType::Rust,
414                priority: 1.5,
415            },
416        ];
417
418        let grouped = group_by_directory(files);
419
420        // Should have at least 3 groups
421        assert!(grouped.len() >= 3);
422
423        // Check that files are correctly grouped by directory
424        let has_root_or_main = grouped.iter().any(|(dir, files)| {
425            (dir == "." || dir.is_empty())
426                && files.iter().any(|f| f.relative_path == PathBuf::from("main.rs"))
427        });
428        assert!(has_root_or_main);
429
430        let has_src_core = grouped.iter().any(|(dir, files)| {
431            dir == "src/core"
432                && files.iter().any(|f| f.relative_path == PathBuf::from("src/core/mod.rs"))
433        });
434        assert!(has_src_core);
435    }
436}