Skip to main content

boundary_rust/
lib.rs

1use std::path::Path;
2
3use anyhow::{Context, Result};
4use tree_sitter::{Language, Parser, Query, QueryCursor, StreamingIterator};
5
6use boundary_core::analyzer::{LanguageAnalyzer, ParsedFile};
7use boundary_core::types::*;
8
9/// Rust language analyzer using tree-sitter.
10pub struct RustAnalyzer {
11    language: Language,
12    trait_query: Query,
13    struct_query: Query,
14    impl_query: Query,
15    use_query: Query,
16}
17
18impl RustAnalyzer {
19    pub fn new() -> Result<Self> {
20        let language: Language = tree_sitter_rust::LANGUAGE.into();
21
22        let trait_query = Query::new(
23            &language,
24            r#"
25            (trait_item
26              name: (type_identifier) @name
27              body: (declaration_list
28                (function_signature_item
29                  name: (identifier) @method)*))
30            "#,
31        )
32        .context("failed to compile trait query")?;
33
34        let struct_query = Query::new(
35            &language,
36            r#"
37            (struct_item
38              name: (type_identifier) @name
39              body: (field_declaration_list
40                (field_declaration
41                  name: (field_identifier) @field
42                  type: (_) @field_type)*)?)
43            "#,
44        )
45        .context("failed to compile struct query")?;
46
47        let impl_query = Query::new(
48            &language,
49            r#"
50            (impl_item
51              trait: (type_identifier)? @trait_name
52              type: (type_identifier) @type_name)
53            "#,
54        )
55        .context("failed to compile impl query")?;
56
57        let use_query = Query::new(
58            &language,
59            r#"
60            (use_declaration
61              argument: (_) @path)
62            "#,
63        )
64        .context("failed to compile use query")?;
65
66        Ok(Self {
67            language,
68            trait_query,
69            struct_query,
70            impl_query,
71            use_query,
72        })
73    }
74}
75
76impl LanguageAnalyzer for RustAnalyzer {
77    fn language(&self) -> &'static str {
78        "rust"
79    }
80
81    fn file_extensions(&self) -> &[&str] {
82        &["rs"]
83    }
84
85    fn parse_file(&self, path: &Path, content: &str) -> Result<ParsedFile> {
86        let mut parser = Parser::new();
87        parser
88            .set_language(&self.language)
89            .context("failed to set Rust language")?;
90        let tree = parser
91            .parse(content, None)
92            .context("failed to parse Rust file")?;
93        Ok(ParsedFile {
94            path: path.to_path_buf(),
95            tree,
96            content: content.to_string(),
97        })
98    }
99
100    fn extract_components(&self, parsed: &ParsedFile) -> Vec<Component> {
101        let mut components = Vec::new();
102        let module_path = derive_module_path(&parsed.path);
103
104        // Extract traits (ports)
105        extract_traits(&self.trait_query, parsed, &module_path, &mut components);
106
107        // Extract structs
108        extract_structs(&self.struct_query, parsed, &module_path, &mut components);
109
110        // Enrich structs with impl info (adapter classification)
111        enrich_with_impls(&self.impl_query, parsed, &module_path, &mut components);
112
113        components
114    }
115
116    fn extract_dependencies(&self, parsed: &ParsedFile) -> Vec<Dependency> {
117        let mut deps = Vec::new();
118        let module_path = derive_module_path(&parsed.path);
119        let from_id = ComponentId::new(&module_path, "<file>");
120
121        let mut cursor = QueryCursor::new();
122        let path_idx = self
123            .use_query
124            .capture_names()
125            .iter()
126            .position(|n| *n == "path")
127            .unwrap_or(0);
128
129        let mut matches = cursor.matches(
130            &self.use_query,
131            parsed.tree.root_node(),
132            parsed.content.as_bytes(),
133        );
134
135        while let Some(m) = matches.next() {
136            for capture in m.captures {
137                if capture.index as usize == path_idx {
138                    let node = capture.node;
139                    let use_path = node_text(node, &parsed.content);
140
141                    // Skip std library imports
142                    if use_path.starts_with("std::") || use_path.starts_with("core::") {
143                        continue;
144                    }
145
146                    let to_id = ComponentId::new(&use_path, "<module>");
147
148                    deps.push(Dependency {
149                        from: from_id.clone(),
150                        to: to_id,
151                        kind: DependencyKind::Import,
152                        location: SourceLocation {
153                            file: parsed.path.clone(),
154                            line: node.start_position().row + 1,
155                            column: node.start_position().column + 1,
156                        },
157                        import_path: Some(use_path),
158                    });
159                }
160            }
161        }
162
163        deps
164    }
165}
166
167fn extract_traits(
168    query: &Query,
169    parsed: &ParsedFile,
170    module_path: &str,
171    components: &mut Vec<Component>,
172) {
173    let mut cursor = QueryCursor::new();
174    let name_idx = query
175        .capture_names()
176        .iter()
177        .position(|n| *n == "name")
178        .unwrap_or(0);
179    let method_idx = query.capture_names().iter().position(|n| *n == "method");
180
181    let mut matches = cursor.matches(query, parsed.tree.root_node(), parsed.content.as_bytes());
182
183    while let Some(m) = matches.next() {
184        let mut name = String::new();
185        let mut methods = Vec::new();
186        let mut start_row = 0;
187        let mut start_col = 0;
188
189        for capture in m.captures {
190            if capture.index as usize == name_idx {
191                name = node_text(capture.node, &parsed.content);
192                start_row = capture.node.start_position().row;
193                start_col = capture.node.start_position().column;
194            } else if Some(capture.index as usize) == method_idx {
195                methods.push(MethodInfo {
196                    name: node_text(capture.node, &parsed.content),
197                    parameters: String::new(),
198                    return_type: String::new(),
199                });
200            }
201        }
202
203        if name.is_empty() {
204            continue;
205        }
206
207        components.push(Component {
208            id: ComponentId::new(module_path, &name),
209            name: name.clone(),
210            kind: ComponentKind::Port(PortInfo { name, methods }),
211            layer: None,
212            location: SourceLocation {
213                file: parsed.path.clone(),
214                line: start_row + 1,
215                column: start_col + 1,
216            },
217            is_cross_cutting: false,
218            architecture_mode: ArchitectureMode::default(),
219        });
220    }
221}
222
223fn extract_structs(
224    query: &Query,
225    parsed: &ParsedFile,
226    module_path: &str,
227    components: &mut Vec<Component>,
228) {
229    let mut cursor = QueryCursor::new();
230    let name_idx = query
231        .capture_names()
232        .iter()
233        .position(|n| *n == "name")
234        .unwrap_or(0);
235    let field_idx = query.capture_names().iter().position(|n| *n == "field");
236    let field_type_idx = query
237        .capture_names()
238        .iter()
239        .position(|n| *n == "field_type");
240
241    let mut matches = cursor.matches(query, parsed.tree.root_node(), parsed.content.as_bytes());
242
243    while let Some(m) = matches.next() {
244        let mut name = String::new();
245        let mut fields = Vec::new();
246        let mut start_row = 0;
247        let mut start_col = 0;
248
249        let mut current_field_name = String::new();
250
251        for capture in m.captures {
252            if capture.index as usize == name_idx {
253                name = node_text(capture.node, &parsed.content);
254                start_row = capture.node.start_position().row;
255                start_col = capture.node.start_position().column;
256            } else if Some(capture.index as usize) == field_idx {
257                current_field_name = node_text(capture.node, &parsed.content);
258            } else if Some(capture.index as usize) == field_type_idx {
259                let type_name = node_text(capture.node, &parsed.content);
260                if !current_field_name.is_empty() {
261                    fields.push(FieldInfo {
262                        name: current_field_name.clone(),
263                        type_name,
264                    });
265                    current_field_name = String::new();
266                }
267            }
268        }
269
270        if name.is_empty() {
271            continue;
272        }
273
274        let kind = classify_struct_kind(&name, &fields);
275
276        components.push(Component {
277            id: ComponentId::new(module_path, &name),
278            name: name.clone(),
279            kind,
280            layer: None,
281            location: SourceLocation {
282                file: parsed.path.clone(),
283                line: start_row + 1,
284                column: start_col + 1,
285            },
286            is_cross_cutting: false,
287            architecture_mode: ArchitectureMode::default(),
288        });
289    }
290}
291
292/// Scan impl blocks and upgrade matching structs to Adapter when they implement a trait.
293fn enrich_with_impls(
294    query: &Query,
295    parsed: &ParsedFile,
296    module_path: &str,
297    components: &mut [Component],
298) {
299    let mut cursor = QueryCursor::new();
300    let trait_name_idx = query
301        .capture_names()
302        .iter()
303        .position(|n| *n == "trait_name");
304    let type_name_idx = query
305        .capture_names()
306        .iter()
307        .position(|n| *n == "type_name")
308        .unwrap_or(0);
309
310    let mut matches = cursor.matches(query, parsed.tree.root_node(), parsed.content.as_bytes());
311
312    while let Some(m) = matches.next() {
313        let mut trait_name: Option<String> = None;
314        let mut type_name = String::new();
315
316        for capture in m.captures {
317            if Some(capture.index as usize) == trait_name_idx {
318                trait_name = Some(node_text(capture.node, &parsed.content));
319            }
320            if capture.index as usize == type_name_idx {
321                type_name = node_text(capture.node, &parsed.content);
322            }
323        }
324
325        if type_name.is_empty() {
326            continue;
327        }
328
329        // If this impl has a trait, mark the struct as an Adapter
330        if let Some(ref trait_name) = trait_name {
331            let id = ComponentId::new(module_path, &type_name);
332            if let Some(comp) = components.iter_mut().find(|c| c.id == id) {
333                match &mut comp.kind {
334                    ComponentKind::Adapter(info) => {
335                        if !info.implements.contains(trait_name) {
336                            info.implements.push(trait_name.clone());
337                        }
338                    }
339                    _ => {
340                        comp.kind = ComponentKind::Adapter(AdapterInfo {
341                            name: type_name.clone(),
342                            implements: vec![trait_name.clone()],
343                            confidence: AdapterConfidence::default(),
344                            returns_concrete: None,
345                        });
346                    }
347                }
348            }
349        }
350    }
351}
352
353/// Classify a struct by its name suffix heuristic.
354fn classify_struct_kind(name: &str, fields: &[FieldInfo]) -> ComponentKind {
355    let lower = name.to_lowercase();
356    if lower.ends_with("repository") || lower.ends_with("repo") {
357        ComponentKind::Repository
358    } else if lower.ends_with("service") || lower.ends_with("svc") {
359        ComponentKind::Service
360    } else if lower.ends_with("handler") || lower.ends_with("controller") {
361        ComponentKind::Adapter(AdapterInfo {
362            name: name.to_string(),
363            implements: Vec::new(),
364            confidence: AdapterConfidence::default(),
365            returns_concrete: None,
366        })
367    } else if lower.ends_with("usecase") || lower.ends_with("interactor") {
368        ComponentKind::UseCase
369    } else if lower.ends_with("event") {
370        ComponentKind::DomainEvent(EventInfo {
371            name: name.to_string(),
372            fields: fields.to_vec(),
373        })
374    } else if !fields.is_empty()
375        && !fields.iter().any(|f| {
376            let fl = f.name.to_lowercase();
377            fl == "id" || fl == "uuid"
378        })
379    {
380        ComponentKind::ValueObject
381    } else {
382        ComponentKind::Entity(EntityInfo {
383            name: name.to_string(),
384            fields: fields.to_vec(),
385            methods: Vec::new(),
386            is_active_record: false,
387            is_anemic_domain_model: false,
388        })
389    }
390}
391
392/// Extract text from a tree-sitter node.
393fn node_text(node: tree_sitter::Node, source: &str) -> String {
394    source[node.byte_range()].to_string()
395}
396
397/// Derive a module path from a file path.
398/// e.g., "src/domain/user/mod.rs" -> "src/domain/user"
399fn derive_module_path(path: &Path) -> String {
400    let path_str = path.to_string_lossy().replace('\\', "/");
401    // Remove filename, keeping just the directory
402    if let Some(parent) = path.parent() {
403        parent.to_string_lossy().replace('\\', "/")
404    } else {
405        path_str
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use std::path::PathBuf;
413
414    #[test]
415    fn test_parse_simple_rust_file() {
416        let analyzer = RustAnalyzer::new().unwrap();
417        let content = r#"
418pub trait UserRepository {
419    fn save(&self, user: &User) -> Result<(), Error>;
420    fn find_by_id(&self, id: &str) -> Result<User, Error>;
421}
422
423pub struct User {
424    pub id: String,
425    pub name: String,
426}
427"#;
428        let path = PathBuf::from("src/domain/user/mod.rs");
429        let parsed = analyzer.parse_file(&path, content).unwrap();
430        let components = analyzer.extract_components(&parsed);
431
432        assert!(
433            components.len() >= 2,
434            "expected at least 2 components, got {}",
435            components.len()
436        );
437
438        let trait_comp = components.iter().find(|c| c.name == "UserRepository");
439        assert!(trait_comp.is_some(), "should find UserRepository trait");
440        assert!(matches!(trait_comp.unwrap().kind, ComponentKind::Port(_)));
441
442        if let ComponentKind::Port(ref info) = trait_comp.unwrap().kind {
443            assert!(info.methods.iter().any(|m| m.name == "save"));
444            assert!(info.methods.iter().any(|m| m.name == "find_by_id"));
445        }
446
447        let entity = components.iter().find(|c| c.name == "User");
448        assert!(entity.is_some(), "should find User struct");
449    }
450
451    #[test]
452    fn test_extract_use_statements() {
453        let analyzer = RustAnalyzer::new().unwrap();
454        let content = r#"
455use std::collections::HashMap;
456use crate::domain::user::User;
457use crate::infrastructure::postgres::PostgresRepo;
458"#;
459        let path = PathBuf::from("src/application/user_service.rs");
460        let parsed = analyzer.parse_file(&path, content).unwrap();
461        let deps = analyzer.extract_dependencies(&parsed);
462
463        // Should skip std imports
464        let paths: Vec<&str> = deps
465            .iter()
466            .filter_map(|d| d.import_path.as_deref())
467            .collect();
468        assert!(!paths.iter().any(|p| p.starts_with("std::")));
469        assert!(paths.iter().any(|p| p.contains("domain::user::User")));
470        assert!(paths
471            .iter()
472            .any(|p| p.contains("infrastructure::postgres::PostgresRepo")));
473    }
474
475    #[test]
476    fn test_struct_classification() {
477        let analyzer = RustAnalyzer::new().unwrap();
478        let content = r#"
479pub struct PostgresUserRepository {
480    pool: Pool,
481}
482
483pub struct UserService {
484    repo: Box<dyn UserRepository>,
485}
486
487pub struct HttpHandler {
488    service: UserService,
489}
490
491pub struct CreateUserUseCase {
492    repo: Box<dyn UserRepository>,
493}
494"#;
495        let path = PathBuf::from("src/lib.rs");
496        let parsed = analyzer.parse_file(&path, content).unwrap();
497        let components = analyzer.extract_components(&parsed);
498
499        let repo = components
500            .iter()
501            .find(|c| c.name == "PostgresUserRepository");
502        assert!(matches!(repo.unwrap().kind, ComponentKind::Repository));
503
504        let svc = components.iter().find(|c| c.name == "UserService");
505        assert!(matches!(svc.unwrap().kind, ComponentKind::Service));
506
507        let handler = components.iter().find(|c| c.name == "HttpHandler");
508        assert!(matches!(handler.unwrap().kind, ComponentKind::Adapter(_)));
509
510        let uc = components.iter().find(|c| c.name == "CreateUserUseCase");
511        assert!(matches!(uc.unwrap().kind, ComponentKind::UseCase));
512    }
513
514    #[test]
515    fn test_impl_trait_enrichment() {
516        let analyzer = RustAnalyzer::new().unwrap();
517        let content = r#"
518pub trait UserRepository {
519    fn save(&self, user: &User);
520}
521
522pub struct PostgresRepo {
523    pool: Pool,
524}
525
526impl UserRepository for PostgresRepo {
527    fn save(&self, user: &User) {}
528}
529"#;
530        let path = PathBuf::from("src/infrastructure/postgres.rs");
531        let parsed = analyzer.parse_file(&path, content).unwrap();
532        let components = analyzer.extract_components(&parsed);
533
534        let repo = components.iter().find(|c| c.name == "PostgresRepo");
535        assert!(repo.is_some(), "should find PostgresRepo");
536        match &repo.unwrap().kind {
537            ComponentKind::Adapter(info) => {
538                assert!(
539                    info.implements.contains(&"UserRepository".to_string()),
540                    "should track implemented trait"
541                );
542            }
543            other => panic!("expected Adapter, got {:?}", other),
544        }
545    }
546}