Skip to main content

collet_treemap/
lib.rs

1use serde::{Deserialize, Serialize};
2use std::path::{Path, PathBuf};
3use thiserror::Error;
4
5/// Errors that can occur during map generation
6#[derive(Error, Debug)]
7pub enum MapError {
8    #[error("IO error: {0}")]
9    Io(#[from] std::io::Error),
10
11    #[error("File not found: {0}")]
12    FileNotFound(PathBuf),
13
14    #[error("Unsupported language: {0}")]
15    UnsupportedLanguage(String),
16
17    #[error("Parse error: {0}")]
18    ParseError(String),
19}
20
21/// Supported programming languages
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Hash)]
23pub enum Language {
24    Rust,
25    Python,
26    JavaScript,
27}
28
29impl Language {
30    pub fn from_extension(ext: &str) -> Option<Self> {
31        match ext {
32            "rs" => Some(Language::Rust),
33            "py" => Some(Language::Python),
34            "js" | "jsx" | "ts" | "tsx" => Some(Language::JavaScript),
35            _ => None,
36        }
37    }
38
39    pub fn as_str(&self) -> &'static str {
40        match self {
41            Language::Rust => "rust",
42            Language::Python => "python",
43            Language::JavaScript => "javascript",
44        }
45    }
46}
47
48/// Configuration for map generation
49#[derive(Debug, Clone, Copy)]
50pub struct Config {
51    pub format: OutputFormat,
52    pub language: Option<Language>,
53    pub max_depth: usize,
54}
55
56impl Default for Config {
57    fn default() -> Self {
58        Self {
59            format: OutputFormat::Json,
60            language: None,
61            max_depth: 10,
62        }
63    }
64}
65
66/// Output format for generated maps
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
68pub enum OutputFormat {
69    Json,
70    Text,
71}
72
73/// A code symbol in the repository
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct Symbol {
76    pub name: String,
77    pub kind: SymbolKind,
78    pub range: FileRange,
79    pub file: PathBuf,
80}
81
82/// Kind of code symbol
83#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
84pub enum SymbolKind {
85    Module,
86    Function,
87    Class,
88    Struct,
89    Interface,
90    Method,
91    Property,
92    Variable,
93    Constant,
94    Enum,
95    Trait,
96    Impl,
97    Other(String),
98}
99
100impl SymbolKind {
101    pub fn as_str(&self) -> &str {
102        match self {
103            SymbolKind::Module => "module",
104            SymbolKind::Function => "function",
105            SymbolKind::Class => "class",
106            SymbolKind::Struct => "struct",
107            SymbolKind::Interface => "interface",
108            SymbolKind::Method => "method",
109            SymbolKind::Property => "property",
110            SymbolKind::Variable => "variable",
111            SymbolKind::Constant => "constant",
112            SymbolKind::Enum => "enum",
113            SymbolKind::Trait => "trait",
114            SymbolKind::Impl => "impl",
115            SymbolKind::Other(s) => s,
116        }
117    }
118}
119
120/// File position information
121#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
122pub struct FileRange {
123    pub start_line: usize,
124    pub start_col: usize,
125    pub end_line: usize,
126    pub end_col: usize,
127}
128
129/// Repository code map
130#[derive(Debug, Clone, Serialize, Deserialize)]
131pub struct RepoMap {
132    pub root: PathBuf,
133    pub symbols: Vec<Symbol>,
134    pub files: Vec<FileInfo>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct FileInfo {
139    pub path: PathBuf,
140    pub language: Option<String>,
141    pub lines: usize,
142}
143
144/// Generate a code map for the given directory
145pub fn generate_map(path: &Path, config: Config) -> Result<RepoMap, MapError> {
146    if !path.exists() {
147        return Err(MapError::FileNotFound(path.to_path_buf()));
148    }
149
150    if !path.is_dir() {
151        return Err(MapError::ParseError("Path must be a directory".to_string()));
152    }
153
154    let mut files = Vec::new();
155
156    // Walk directory and collect symbols
157    for entry in walkdir::WalkDir::new(path)
158        .into_iter()
159        .filter_map(Result::ok)
160        .filter(|e| e.path().is_file())
161    {
162        let file_path = entry.path();
163
164        // Skip common non-source directories
165        if let Some(parent) = file_path.parent() {
166            let parent_str = parent.to_string_lossy();
167            if parent_str.contains("node_modules")
168                || parent_str.contains(".git")
169                || parent_str.contains("target")
170                || parent_str.contains("__pycache__")
171            {
172                continue;
173            }
174        }
175
176        let ext = file_path.extension().and_then(|e| e.to_str()).unwrap_or("");
177
178        let language = Language::from_extension(ext);
179
180        // Skip if language filter is set and doesn't match
181        if let Some(filter_lang) = config.language {
182            if language != Some(filter_lang) {
183                continue;
184            }
185        }
186
187        if let Some(lang) = language {
188            let file_info = FileInfo {
189                path: file_path.to_path_buf(),
190                language: Some(lang.as_str().to_string()),
191                lines: count_lines(file_path).unwrap_or(0),
192            };
193            files.push(file_info);
194        }
195    }
196
197    Ok(RepoMap {
198        root: path.to_path_buf(),
199        symbols: Vec::new(),
200        files,
201    })
202}
203
204fn count_lines(path: &Path) -> Result<usize, MapError> {
205    use std::io::BufRead;
206    let file = std::fs::File::open(path)?;
207    let reader = std::io::BufReader::new(file);
208    Ok(reader.lines().count())
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_language_from_extension() {
217        assert_eq!(Language::from_extension("rs"), Some(Language::Rust));
218        assert_eq!(Language::from_extension("py"), Some(Language::Python));
219        assert_eq!(Language::from_extension("js"), Some(Language::JavaScript));
220        assert_eq!(Language::from_extension("unknown"), None);
221    }
222
223    #[test]
224    fn test_symbol_kind_as_str() {
225        assert_eq!(SymbolKind::Function.as_str(), "function");
226        assert_eq!(SymbolKind::Class.as_str(), "class");
227    }
228
229    #[test]
230    fn test_config_default() {
231        let cfg = Config::default();
232        assert_eq!(cfg.format, OutputFormat::Json);
233        assert_eq!(cfg.max_depth, 10);
234    }
235}