Skip to main content

boundary_rust/
lib.rs

1use std::path::Path;
2
3use anyhow::{Context, Result};
4use tree_sitter::{Language, Parser, Query, QueryCursor, StreamingIterator};
5
6use boundary_core::analyzer::{LanguageAnalyzer, ParsedFile};
7use boundary_core::types::*;
8
9/// Rust language analyzer using tree-sitter.
10pub struct RustAnalyzer {
11    language: Language,
12    trait_query: Query,
13    struct_query: Query,
14    impl_query: Query,
15    use_query: Query,
16}
17
18impl RustAnalyzer {
19    pub fn new() -> Result<Self> {
20        let language: Language = tree_sitter_rust::LANGUAGE.into();
21
22        let trait_query = Query::new(
23            &language,
24            r#"
25            (trait_item
26              name: (type_identifier) @name
27              body: (declaration_list
28                (function_signature_item
29                  name: (identifier) @method)*))
30            "#,
31        )
32        .context("failed to compile trait query")?;
33
34        let struct_query = Query::new(
35            &language,
36            r#"
37            (struct_item
38              name: (type_identifier) @name
39              body: (field_declaration_list
40                (field_declaration
41                  name: (field_identifier) @field
42                  type: (_) @field_type)*)?)
43            "#,
44        )
45        .context("failed to compile struct query")?;
46
47        let impl_query = Query::new(
48            &language,
49            r#"
50            (impl_item
51              trait: (type_identifier)? @trait_name
52              type: (type_identifier) @type_name)
53            "#,
54        )
55        .context("failed to compile impl query")?;
56
57        let use_query = Query::new(
58            &language,
59            r#"
60            (use_declaration
61              argument: (_) @path)
62            "#,
63        )
64        .context("failed to compile use query")?;
65
66        Ok(Self {
67            language,
68            trait_query,
69            struct_query,
70            impl_query,
71            use_query,
72        })
73    }
74}
75
76impl LanguageAnalyzer for RustAnalyzer {
77    fn language(&self) -> &'static str {
78        "rust"
79    }
80
81    fn file_extensions(&self) -> &[&str] {
82        &["rs"]
83    }
84
85    fn parse_file(&self, path: &Path, content: &str) -> Result<ParsedFile> {
86        let mut parser = Parser::new();
87        parser
88            .set_language(&self.language)
89            .context("failed to set Rust language")?;
90        let tree = parser
91            .parse(content, None)
92            .context("failed to parse Rust file")?;
93        Ok(ParsedFile {
94            path: path.to_path_buf(),
95            tree,
96            content: content.to_string(),
97        })
98    }
99
100    fn extract_components(&self, parsed: &ParsedFile) -> Vec<Component> {
101        let mut components = Vec::new();
102        let module_path = derive_module_path(&parsed.path);
103
104        // Extract traits (ports)
105        extract_traits(&self.trait_query, parsed, &module_path, &mut components);
106
107        // Extract structs
108        extract_structs(&self.struct_query, parsed, &module_path, &mut components);
109
110        // Enrich structs with impl info (adapter classification)
111        enrich_with_impls(&self.impl_query, parsed, &module_path, &mut components);
112
113        components
114    }
115
116    fn extract_dependencies(&self, parsed: &ParsedFile) -> Vec<Dependency> {
117        let mut deps = Vec::new();
118        let module_path = derive_module_path(&parsed.path);
119        let from_id = ComponentId::new(&module_path, "<file>");
120
121        let mut cursor = QueryCursor::new();
122        let path_idx = self
123            .use_query
124            .capture_names()
125            .iter()
126            .position(|n| *n == "path")
127            .unwrap_or(0);
128
129        let mut matches = cursor.matches(
130            &self.use_query,
131            parsed.tree.root_node(),
132            parsed.content.as_bytes(),
133        );
134
135        while let Some(m) = matches.next() {
136            for capture in m.captures {
137                if capture.index as usize == path_idx {
138                    let node = capture.node;
139                    let use_path = node_text(node, &parsed.content);
140
141                    // Skip std library imports
142                    if use_path.starts_with("std::") || use_path.starts_with("core::") {
143                        continue;
144                    }
145
146                    let to_id = ComponentId::new(&use_path, "<module>");
147
148                    deps.push(Dependency {
149                        from: from_id.clone(),
150                        to: to_id,
151                        kind: DependencyKind::Import,
152                        location: SourceLocation {
153                            file: parsed.path.clone(),
154                            line: node.start_position().row + 1,
155                            column: node.start_position().column + 1,
156                        },
157                        import_path: Some(use_path),
158                    });
159                }
160            }
161        }
162
163        deps
164    }
165}
166
167fn extract_traits(
168    query: &Query,
169    parsed: &ParsedFile,
170    module_path: &str,
171    components: &mut Vec<Component>,
172) {
173    let mut cursor = QueryCursor::new();
174    let name_idx = query
175        .capture_names()
176        .iter()
177        .position(|n| *n == "name")
178        .unwrap_or(0);
179    let method_idx = query.capture_names().iter().position(|n| *n == "method");
180
181    let mut matches = cursor.matches(query, parsed.tree.root_node(), parsed.content.as_bytes());
182
183    while let Some(m) = matches.next() {
184        let mut name = String::new();
185        let mut methods = Vec::new();
186        let mut start_row = 0;
187        let mut start_col = 0;
188
189        for capture in m.captures {
190            if capture.index as usize == name_idx {
191                name = node_text(capture.node, &parsed.content);
192                start_row = capture.node.start_position().row;
193                start_col = capture.node.start_position().column;
194            } else if Some(capture.index as usize) == method_idx {
195                methods.push(MethodInfo {
196                    name: node_text(capture.node, &parsed.content),
197                    parameters: String::new(),
198                    return_type: String::new(),
199                });
200            }
201        }
202
203        if name.is_empty() {
204            continue;
205        }
206
207        components.push(Component {
208            id: ComponentId::new(module_path, &name),
209            name: name.clone(),
210            kind: ComponentKind::Port(PortInfo { name, methods }),
211            layer: None,
212            location: SourceLocation {
213                file: parsed.path.clone(),
214                line: start_row + 1,
215                column: start_col + 1,
216            },
217            is_cross_cutting: false,
218            architecture_mode: ArchitectureMode::default(),
219        });
220    }
221}
222
223fn extract_structs(
224    query: &Query,
225    parsed: &ParsedFile,
226    module_path: &str,
227    components: &mut Vec<Component>,
228) {
229    let mut cursor = QueryCursor::new();
230    let name_idx = query
231        .capture_names()
232        .iter()
233        .position(|n| *n == "name")
234        .unwrap_or(0);
235    let field_idx = query.capture_names().iter().position(|n| *n == "field");
236    let field_type_idx = query
237        .capture_names()
238        .iter()
239        .position(|n| *n == "field_type");
240
241    let mut matches = cursor.matches(query, parsed.tree.root_node(), parsed.content.as_bytes());
242
243    while let Some(m) = matches.next() {
244        let mut name = String::new();
245        let mut fields = Vec::new();
246        let mut start_row = 0;
247        let mut start_col = 0;
248
249        let mut current_field_name = String::new();
250
251        for capture in m.captures {
252            if capture.index as usize == name_idx {
253                name = node_text(capture.node, &parsed.content);
254                start_row = capture.node.start_position().row;
255                start_col = capture.node.start_position().column;
256            } else if Some(capture.index as usize) == field_idx {
257                current_field_name = node_text(capture.node, &parsed.content);
258            } else if Some(capture.index as usize) == field_type_idx {
259                let type_name = node_text(capture.node, &parsed.content);
260                if !current_field_name.is_empty() {
261                    fields.push(FieldInfo {
262                        name: current_field_name.clone(),
263                        type_name,
264                    });
265                    current_field_name = String::new();
266                }
267            }
268        }
269
270        if name.is_empty() {
271            continue;
272        }
273
274        let kind = classify_struct_kind(&name, &fields);
275
276        components.push(Component {
277            id: ComponentId::new(module_path, &name),
278            name: name.clone(),
279            kind,
280            layer: None,
281            location: SourceLocation {
282                file: parsed.path.clone(),
283                line: start_row + 1,
284                column: start_col + 1,
285            },
286            is_cross_cutting: false,
287            architecture_mode: ArchitectureMode::default(),
288        });
289    }
290}
291
292/// Scan impl blocks and upgrade matching structs to Adapter when they implement a trait.
293fn enrich_with_impls(
294    query: &Query,
295    parsed: &ParsedFile,
296    module_path: &str,
297    components: &mut [Component],
298) {
299    let mut cursor = QueryCursor::new();
300    let trait_name_idx = query
301        .capture_names()
302        .iter()
303        .position(|n| *n == "trait_name");
304    let type_name_idx = query
305        .capture_names()
306        .iter()
307        .position(|n| *n == "type_name")
308        .unwrap_or(0);
309
310    let mut matches = cursor.matches(query, parsed.tree.root_node(), parsed.content.as_bytes());
311
312    while let Some(m) = matches.next() {
313        let mut trait_name: Option<String> = None;
314        let mut type_name = String::new();
315
316        for capture in m.captures {
317            if Some(capture.index as usize) == trait_name_idx {
318                trait_name = Some(node_text(capture.node, &parsed.content));
319            }
320            if capture.index as usize == type_name_idx {
321                type_name = node_text(capture.node, &parsed.content);
322            }
323        }
324
325        if type_name.is_empty() {
326            continue;
327        }
328
329        // If this impl has a trait, mark the struct as an Adapter
330        if let Some(ref trait_name) = trait_name {
331            let id = ComponentId::new(module_path, &type_name);
332            if let Some(comp) = components.iter_mut().find(|c| c.id == id) {
333                match &mut comp.kind {
334                    ComponentKind::Adapter(info) => {
335                        if !info.implements.contains(trait_name) {
336                            info.implements.push(trait_name.clone());
337                        }
338                    }
339                    _ => {
340                        comp.kind = ComponentKind::Adapter(AdapterInfo {
341                            name: type_name.clone(),
342                            implements: vec![trait_name.clone()],
343                        });
344                    }
345                }
346            }
347        }
348    }
349}
350
351/// Classify a struct by its name suffix heuristic.
352fn classify_struct_kind(name: &str, fields: &[FieldInfo]) -> ComponentKind {
353    let lower = name.to_lowercase();
354    if lower.ends_with("repository") || lower.ends_with("repo") {
355        ComponentKind::Repository
356    } else if lower.ends_with("service") || lower.ends_with("svc") {
357        ComponentKind::Service
358    } else if lower.ends_with("handler") || lower.ends_with("controller") {
359        ComponentKind::Adapter(AdapterInfo {
360            name: name.to_string(),
361            implements: Vec::new(),
362        })
363    } else if lower.ends_with("usecase") || lower.ends_with("interactor") {
364        ComponentKind::UseCase
365    } else if lower.ends_with("event") {
366        ComponentKind::DomainEvent(EventInfo {
367            name: name.to_string(),
368            fields: fields.to_vec(),
369        })
370    } else if !fields.is_empty()
371        && !fields.iter().any(|f| {
372            let fl = f.name.to_lowercase();
373            fl == "id" || fl == "uuid"
374        })
375    {
376        ComponentKind::ValueObject
377    } else {
378        ComponentKind::Entity(EntityInfo {
379            name: name.to_string(),
380            fields: fields.to_vec(),
381            methods: Vec::new(),
382            is_active_record: false,
383        })
384    }
385}
386
387/// Extract text from a tree-sitter node.
388fn node_text(node: tree_sitter::Node, source: &str) -> String {
389    source[node.byte_range()].to_string()
390}
391
392/// Derive a module path from a file path.
393/// e.g., "src/domain/user/mod.rs" -> "src/domain/user"
394fn derive_module_path(path: &Path) -> String {
395    let path_str = path.to_string_lossy().replace('\\', "/");
396    // Remove filename, keeping just the directory
397    if let Some(parent) = path.parent() {
398        parent.to_string_lossy().replace('\\', "/")
399    } else {
400        path_str
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use std::path::PathBuf;
408
409    #[test]
410    fn test_parse_simple_rust_file() {
411        let analyzer = RustAnalyzer::new().unwrap();
412        let content = r#"
413pub trait UserRepository {
414    fn save(&self, user: &User) -> Result<(), Error>;
415    fn find_by_id(&self, id: &str) -> Result<User, Error>;
416}
417
418pub struct User {
419    pub id: String,
420    pub name: String,
421}
422"#;
423        let path = PathBuf::from("src/domain/user/mod.rs");
424        let parsed = analyzer.parse_file(&path, content).unwrap();
425        let components = analyzer.extract_components(&parsed);
426
427        assert!(
428            components.len() >= 2,
429            "expected at least 2 components, got {}",
430            components.len()
431        );
432
433        let trait_comp = components.iter().find(|c| c.name == "UserRepository");
434        assert!(trait_comp.is_some(), "should find UserRepository trait");
435        assert!(matches!(trait_comp.unwrap().kind, ComponentKind::Port(_)));
436
437        if let ComponentKind::Port(ref info) = trait_comp.unwrap().kind {
438            assert!(info.methods.iter().any(|m| m.name == "save"));
439            assert!(info.methods.iter().any(|m| m.name == "find_by_id"));
440        }
441
442        let entity = components.iter().find(|c| c.name == "User");
443        assert!(entity.is_some(), "should find User struct");
444    }
445
446    #[test]
447    fn test_extract_use_statements() {
448        let analyzer = RustAnalyzer::new().unwrap();
449        let content = r#"
450use std::collections::HashMap;
451use crate::domain::user::User;
452use crate::infrastructure::postgres::PostgresRepo;
453"#;
454        let path = PathBuf::from("src/application/user_service.rs");
455        let parsed = analyzer.parse_file(&path, content).unwrap();
456        let deps = analyzer.extract_dependencies(&parsed);
457
458        // Should skip std imports
459        let paths: Vec<&str> = deps
460            .iter()
461            .filter_map(|d| d.import_path.as_deref())
462            .collect();
463        assert!(!paths.iter().any(|p| p.starts_with("std::")));
464        assert!(paths.iter().any(|p| p.contains("domain::user::User")));
465        assert!(paths
466            .iter()
467            .any(|p| p.contains("infrastructure::postgres::PostgresRepo")));
468    }
469
470    #[test]
471    fn test_struct_classification() {
472        let analyzer = RustAnalyzer::new().unwrap();
473        let content = r#"
474pub struct PostgresUserRepository {
475    pool: Pool,
476}
477
478pub struct UserService {
479    repo: Box<dyn UserRepository>,
480}
481
482pub struct HttpHandler {
483    service: UserService,
484}
485
486pub struct CreateUserUseCase {
487    repo: Box<dyn UserRepository>,
488}
489"#;
490        let path = PathBuf::from("src/lib.rs");
491        let parsed = analyzer.parse_file(&path, content).unwrap();
492        let components = analyzer.extract_components(&parsed);
493
494        let repo = components
495            .iter()
496            .find(|c| c.name == "PostgresUserRepository");
497        assert!(matches!(repo.unwrap().kind, ComponentKind::Repository));
498
499        let svc = components.iter().find(|c| c.name == "UserService");
500        assert!(matches!(svc.unwrap().kind, ComponentKind::Service));
501
502        let handler = components.iter().find(|c| c.name == "HttpHandler");
503        assert!(matches!(handler.unwrap().kind, ComponentKind::Adapter(_)));
504
505        let uc = components.iter().find(|c| c.name == "CreateUserUseCase");
506        assert!(matches!(uc.unwrap().kind, ComponentKind::UseCase));
507    }
508
509    #[test]
510    fn test_impl_trait_enrichment() {
511        let analyzer = RustAnalyzer::new().unwrap();
512        let content = r#"
513pub trait UserRepository {
514    fn save(&self, user: &User);
515}
516
517pub struct PostgresRepo {
518    pool: Pool,
519}
520
521impl UserRepository for PostgresRepo {
522    fn save(&self, user: &User) {}
523}
524"#;
525        let path = PathBuf::from("src/infrastructure/postgres.rs");
526        let parsed = analyzer.parse_file(&path, content).unwrap();
527        let components = analyzer.extract_components(&parsed);
528
529        let repo = components.iter().find(|c| c.name == "PostgresRepo");
530        assert!(repo.is_some(), "should find PostgresRepo");
531        match &repo.unwrap().kind {
532            ComponentKind::Adapter(info) => {
533                assert!(
534                    info.implements.contains(&"UserRepository".to_string()),
535                    "should track implemented trait"
536                );
537            }
538            other => panic!("expected Adapter, got {:?}", other),
539        }
540    }
541}