codeprism_core/ast/
mod.rs

1//! Universal Abstract Syntax Tree (U-AST) types
2//!
3//! This module defines language-agnostic AST node and edge types that can
4//! represent code structures from any supported programming language.
5
6use blake3::Hasher;
7use serde::{Deserialize, Serialize};
8use std::fmt;
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11
12/// Unique identifier for AST nodes
13#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14pub struct NodeId([u8; 16]);
15
16impl NodeId {
17    /// Create a new NodeId from components
18    pub fn new(repo_id: &str, file_path: &Path, span: &Span, kind: &NodeKind) -> Self {
19        let mut hasher = Hasher::new();
20        hasher.update(repo_id.as_bytes());
21        hasher.update(file_path.to_string_lossy().as_bytes());
22        hasher.update(&span.start_byte.to_le_bytes());
23        hasher.update(&span.end_byte.to_le_bytes());
24        hasher.update(format!("{kind:?}").as_bytes());
25
26        let hash = hasher.finalize();
27        let mut id = [0u8; 16];
28        id.copy_from_slice(&hash.as_bytes()[..16]);
29        Self(id)
30    }
31
32    /// Get the ID as a hex string
33    pub fn to_hex(&self) -> String {
34        hex::encode(self.0)
35    }
36
37    /// Parse a NodeId from a hex string
38    pub fn from_hex(hex_str: &str) -> Result<Self, hex::FromHexError> {
39        let bytes = hex::decode(hex_str)?;
40        if bytes.len() != 16 {
41            return Err(hex::FromHexError::InvalidStringLength);
42        }
43        let mut id = [0u8; 16];
44        id.copy_from_slice(&bytes);
45        Ok(Self(id))
46    }
47}
48
49impl fmt::Debug for NodeId {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        write!(f, "NodeId({})", &self.to_hex()[..8])
52    }
53}
54
55impl fmt::Display for NodeId {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        write!(f, "{}", &self.to_hex()[..8])
58    }
59}
60
61/// Types of nodes in the Universal AST
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
63#[serde(rename_all = "snake_case")]
64pub enum NodeKind {
65    /// A module or file
66    Module,
67    /// A class definition
68    Class,
69    /// A function definition
70    Function,
71    /// A method definition
72    Method,
73    /// A function/method parameter
74    Parameter,
75    /// A variable declaration
76    Variable,
77    /// A function/method call
78    Call,
79    /// An import statement
80    Import,
81    /// A literal value
82    Literal,
83    /// An HTTP route definition
84    Route,
85    /// A SQL query
86    SqlQuery,
87    /// An event emission
88    Event,
89    /// Unknown node type
90    Unknown,
91}
92
93impl fmt::Display for NodeKind {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        match self {
96            NodeKind::Module => write!(f, "Module"),
97            NodeKind::Class => write!(f, "Class"),
98            NodeKind::Function => write!(f, "Function"),
99            NodeKind::Method => write!(f, "Method"),
100            NodeKind::Parameter => write!(f, "Parameter"),
101            NodeKind::Variable => write!(f, "Variable"),
102            NodeKind::Call => write!(f, "Call"),
103            NodeKind::Import => write!(f, "Import"),
104            NodeKind::Literal => write!(f, "Literal"),
105            NodeKind::Route => write!(f, "Route"),
106            NodeKind::SqlQuery => write!(f, "SqlQuery"),
107            NodeKind::Event => write!(f, "Event"),
108            NodeKind::Unknown => write!(f, "Unknown"),
109        }
110    }
111}
112
113/// Types of edges between nodes
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
115#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
116pub enum EdgeKind {
117    /// Function/method call
118    Calls,
119    /// Variable/field read
120    Reads,
121    /// Variable/field write
122    Writes,
123    /// Module import
124    Imports,
125    /// Event emission
126    Emits,
127    /// HTTP route mapping
128    RoutesTo,
129    /// Exception raising
130    Raises,
131    /// Type inheritance
132    Extends,
133    /// Interface implementation
134    Implements,
135}
136
137impl fmt::Display for EdgeKind {
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139        match self {
140            EdgeKind::Calls => write!(f, "CALLS"),
141            EdgeKind::Reads => write!(f, "READS"),
142            EdgeKind::Writes => write!(f, "WRITES"),
143            EdgeKind::Imports => write!(f, "IMPORTS"),
144            EdgeKind::Emits => write!(f, "EMITS"),
145            EdgeKind::RoutesTo => write!(f, "ROUTES_TO"),
146            EdgeKind::Raises => write!(f, "RAISES"),
147            EdgeKind::Extends => write!(f, "EXTENDS"),
148            EdgeKind::Implements => write!(f, "IMPLEMENTS"),
149        }
150    }
151}
152
153/// Source code location
154#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
155pub struct Span {
156    /// Starting byte offset
157    pub start_byte: usize,
158    /// Ending byte offset (exclusive)
159    pub end_byte: usize,
160    /// Starting line (1-indexed)
161    pub start_line: usize,
162    /// Ending line (1-indexed)
163    pub end_line: usize,
164    /// Starting column (1-indexed)
165    pub start_column: usize,
166    /// Ending column (1-indexed)
167    pub end_column: usize,
168}
169
170impl Span {
171    /// Create a new span
172    pub fn new(
173        start_byte: usize,
174        end_byte: usize,
175        start_line: usize,
176        end_line: usize,
177        start_column: usize,
178        end_column: usize,
179    ) -> Self {
180        Self {
181            start_byte,
182            end_byte,
183            start_line,
184            end_line,
185            start_column,
186            end_column,
187        }
188    }
189
190    /// Get the length in bytes
191    pub fn len(&self) -> usize {
192        self.end_byte - self.start_byte
193    }
194
195    /// Check if the span is empty
196    pub fn is_empty(&self) -> bool {
197        self.start_byte == self.end_byte
198    }
199}
200
201impl fmt::Display for Span {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        write!(
204            f,
205            "{}:{}-{}:{}",
206            self.start_line, self.start_column, self.end_line, self.end_column
207        )
208    }
209}
210
211/// Programming language
212#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
213#[serde(rename_all = "lowercase")]
214pub enum Language {
215    /// JavaScript
216    JavaScript,
217    /// TypeScript
218    TypeScript,
219    /// Python
220    Python,
221    /// Java
222    Java,
223    /// Go
224    Go,
225    /// Rust
226    Rust,
227    /// C
228    C,
229    /// C++
230    Cpp,
231    /// Unknown language
232    Unknown,
233}
234
235impl Language {
236    /// Get language from file extension
237    pub fn from_extension(ext: &str) -> Self {
238        match ext.to_lowercase().as_str() {
239            "js" | "mjs" | "cjs" => Language::JavaScript,
240            "ts" | "tsx" => Language::TypeScript,
241            "py" | "pyw" => Language::Python,
242            "java" => Language::Java,
243            "go" => Language::Go,
244            "rs" => Language::Rust,
245            "c" | "h" => Language::C,
246            "cpp" | "cc" | "cxx" | "hpp" | "hxx" => Language::Cpp,
247            _ => Language::Unknown,
248        }
249    }
250}
251
252impl fmt::Display for Language {
253    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
254        match self {
255            Language::JavaScript => write!(f, "JavaScript"),
256            Language::TypeScript => write!(f, "TypeScript"),
257            Language::Python => write!(f, "Python"),
258            Language::Java => write!(f, "Java"),
259            Language::Go => write!(f, "Go"),
260            Language::Rust => write!(f, "Rust"),
261            Language::C => write!(f, "C"),
262            Language::Cpp => write!(f, "C++"),
263            Language::Unknown => write!(f, "Unknown"),
264        }
265    }
266}
267
268/// A node in the Universal AST
269#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct Node {
271    /// Unique identifier
272    pub id: NodeId,
273    /// Node type
274    pub kind: NodeKind,
275    /// Node name (e.g., function name)
276    pub name: String,
277    /// Programming language
278    pub lang: Language,
279    /// Source file path
280    pub file: PathBuf,
281    /// Source location
282    pub span: Span,
283    /// Optional type signature
284    pub signature: Option<String>,
285    /// Additional metadata
286    pub metadata: serde_json::Value,
287}
288
289impl Node {
290    /// Create a new node
291    pub fn new(
292        repo_id: &str,
293        kind: NodeKind,
294        name: String,
295        lang: Language,
296        file: PathBuf,
297        span: Span,
298    ) -> Self {
299        let id = NodeId::new(repo_id, &file, &span, &kind);
300        Self {
301            id,
302            kind,
303            name,
304            lang,
305            file,
306            span,
307            signature: None,
308            metadata: serde_json::Value::Null,
309        }
310    }
311
312    /// Create a new node with an `Arc<PathBuf>`
313    pub fn with_arc(
314        repo_id: &str,
315        kind: NodeKind,
316        name: String,
317        lang: Language,
318        file: Arc<PathBuf>,
319        span: Span,
320    ) -> Self {
321        let id = NodeId::new(repo_id, &file, &span, &kind);
322        Self {
323            id,
324            kind,
325            name,
326            lang,
327            file: (*file).clone(),
328            span,
329            signature: None,
330            metadata: serde_json::Value::Null,
331        }
332    }
333
334    /// Set the type signature
335    pub fn with_signature(mut self, sig: String) -> Self {
336        self.signature = Some(sig);
337        self
338    }
339
340    /// Set metadata
341    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
342        self.metadata = metadata;
343        self
344    }
345}
346
347impl fmt::Display for Node {
348    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
349        write!(
350            f,
351            "{} {} '{}' at {}:{}",
352            self.lang,
353            self.kind,
354            self.name,
355            self.file.display(),
356            self.span
357        )
358    }
359}
360
361/// Builder for creating nodes
362pub struct NodeBuilder {
363    repo_id: String,
364    kind: NodeKind,
365    name: String,
366    lang: Language,
367    file: PathBuf,
368    span: Span,
369    signature: Option<String>,
370    metadata: serde_json::Value,
371}
372
373impl NodeBuilder {
374    /// Create a new node builder
375    pub fn new(repo_id: impl Into<String>, kind: NodeKind) -> Self {
376        Self {
377            repo_id: repo_id.into(),
378            kind,
379            name: String::new(),
380            lang: Language::Unknown,
381            file: PathBuf::new(),
382            span: Span::new(0, 0, 1, 1, 1, 1),
383            signature: None,
384            metadata: serde_json::Value::Null,
385        }
386    }
387
388    /// Set the node name
389    pub fn name(mut self, name: impl Into<String>) -> Self {
390        self.name = name.into();
391        self
392    }
393
394    /// Set the language
395    pub fn language(mut self, lang: Language) -> Self {
396        self.lang = lang;
397        self
398    }
399
400    /// Set the file path
401    pub fn file(mut self, file: impl Into<PathBuf>) -> Self {
402        self.file = file.into();
403        self
404    }
405
406    /// Set the span
407    pub fn span(mut self, span: Span) -> Self {
408        self.span = span;
409        self
410    }
411
412    /// Set the type signature
413    pub fn signature(mut self, sig: impl Into<String>) -> Self {
414        self.signature = Some(sig.into());
415        self
416    }
417
418    /// Set metadata
419    pub fn metadata(mut self, metadata: serde_json::Value) -> Self {
420        self.metadata = metadata;
421        self
422    }
423
424    /// Build the node
425    pub fn build(self) -> Node {
426        let id = NodeId::new(&self.repo_id, &self.file, &self.span, &self.kind);
427        Node {
428            id,
429            kind: self.kind,
430            name: self.name,
431            lang: self.lang,
432            file: self.file,
433            span: self.span,
434            signature: self.signature,
435            metadata: self.metadata,
436        }
437    }
438}
439
440/// An edge between nodes
441#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
442pub struct Edge {
443    /// Source node ID
444    pub source: NodeId,
445    /// Target node ID
446    pub target: NodeId,
447    /// Edge type
448    pub kind: EdgeKind,
449}
450
451impl Edge {
452    /// Create a new edge
453    pub fn new(source: NodeId, target: NodeId, kind: EdgeKind) -> Self {
454        Self {
455            source,
456            target,
457            kind,
458        }
459    }
460
461    /// Get a unique ID for this edge
462    pub fn id(&self) -> String {
463        format!("{}>{}>:{:?}", self.source, self.target, self.kind)
464    }
465}
466
467impl fmt::Display for Edge {
468    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
469        write!(f, "{} --{}-> {}", self.source, self.kind, self.target)
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476
477    #[test]
478    fn test_node_id_generation() {
479        let span = Span::new(0, 10, 1, 1, 1, 11);
480        let id1 = NodeId::new("repo1", Path::new("file.js"), &span, &NodeKind::Function);
481        let id2 = NodeId::new("repo1", Path::new("file.js"), &span, &NodeKind::Function);
482        assert_eq!(id1, id2);
483
484        let id3 = NodeId::new("repo2", Path::new("file.js"), &span, &NodeKind::Function);
485        assert_ne!(id1, id3);
486    }
487
488    #[test]
489    fn test_node_id_edge_cases() {
490        // Empty path
491        let span = Span::new(0, 10, 1, 1, 1, 11);
492        let id1 = NodeId::new("repo", Path::new(""), &span, &NodeKind::Module);
493        assert!(
494            !id1.to_hex().is_empty(),
495            "Node ID hex representation should not be empty"
496        );
497        assert_eq!(
498            id1.to_hex().len(),
499            32,
500            "Node ID hex should be 32 characters (16 bytes * 2)"
501        );
502
503        // Path with special characters
504        let id2 = NodeId::new(
505            "repo",
506            Path::new("src/@types/index.d.ts"),
507            &span,
508            &NodeKind::Module,
509        );
510        assert!(
511            !id2.to_hex().is_empty(),
512            "Node ID with special chars should have valid hex"
513        );
514        assert_eq!(
515            id2.to_hex().len(),
516            32,
517            "Node ID hex should be 32 characters"
518        );
519
520        // Unicode in path
521        let id3 = NodeId::new("repo", Path::new("src/文件.js"), &span, &NodeKind::Module);
522        assert!(
523            !id3.to_hex().is_empty(),
524            "Node ID with Unicode should have valid hex"
525        );
526        assert_eq!(
527            id3.to_hex().len(),
528            32,
529            "Node ID hex should be 32 characters"
530        );
531    }
532
533    #[test]
534    fn test_language_detection() {
535        assert_eq!(Language::from_extension("js"), Language::JavaScript);
536        assert_eq!(Language::from_extension("ts"), Language::TypeScript);
537        assert_eq!(Language::from_extension("py"), Language::Python);
538        assert_eq!(Language::from_extension("java"), Language::Java);
539        assert_eq!(Language::from_extension("unknown"), Language::Unknown);
540    }
541
542    #[test]
543    fn test_language_detection_edge_cases() {
544        // Case insensitive
545        assert_eq!(Language::from_extension("JS"), Language::JavaScript);
546        assert_eq!(Language::from_extension("Py"), Language::Python);
547
548        // Multiple extensions
549        assert_eq!(Language::from_extension("mjs"), Language::JavaScript);
550        assert_eq!(Language::from_extension("cjs"), Language::JavaScript);
551        assert_eq!(Language::from_extension("tsx"), Language::TypeScript);
552
553        // C++ variations
554        assert_eq!(Language::from_extension("cpp"), Language::Cpp);
555        assert_eq!(Language::from_extension("cc"), Language::Cpp);
556        assert_eq!(Language::from_extension("cxx"), Language::Cpp);
557        assert_eq!(Language::from_extension("hpp"), Language::Cpp);
558
559        // Empty and unknown
560        assert_eq!(Language::from_extension(""), Language::Unknown);
561        assert_eq!(Language::from_extension("xyz"), Language::Unknown);
562    }
563
564    #[test]
565    fn test_span_utilities() {
566        let span = Span::new(10, 20, 2, 3, 5, 15);
567        assert_eq!(span.len(), 10, "Should have 10 items");
568        assert!(
569            !span.is_empty(),
570            "Span with non-zero range should not be empty"
571        );
572        assert_eq!(span.start_byte, 10, "Span should start at byte 10");
573        assert_eq!(span.end_byte, 20, "Span should end at byte 20");
574
575        let empty_span = Span::new(10, 10, 2, 2, 5, 5);
576        assert_eq!(empty_span.len(), 0, "Should have 0 items");
577        assert!(
578            empty_span.is_empty(),
579            "Span with same start and end should be empty"
580        );
581        assert_eq!(
582            empty_span.start_byte, empty_span.end_byte,
583            "Empty span should have equal start and end bytes"
584        );
585    }
586
587    #[test]
588    fn test_node_serialization() {
589        let span = Span::new(0, 10, 1, 1, 1, 11);
590        let node = Node::new(
591            "test_repo",
592            NodeKind::Function,
593            "test_func".to_string(),
594            Language::JavaScript,
595            PathBuf::from("test.js"),
596            span,
597        );
598
599        // Test serialization round-trip
600        let serialized = serde_json::to_string(&node).unwrap();
601        let deserialized: Node = serde_json::from_str(&serialized).unwrap();
602
603        assert_eq!(node.id, deserialized.id);
604        assert_eq!(node.name, deserialized.name);
605        assert_eq!(node.file, deserialized.file);
606    }
607
608    #[test]
609    fn test_node_with_methods() {
610        let span = Span::new(0, 10, 1, 1, 1, 11);
611        let node = Node::new(
612            "test_repo",
613            NodeKind::Function,
614            "test_func".to_string(),
615            Language::JavaScript,
616            PathBuf::from("test.js"),
617            span,
618        )
619        .with_signature("(a: number, b: number) => number".to_string())
620        .with_metadata(serde_json::json!({ "async": true }));
621
622        assert_eq!(
623            node.signature,
624            Some("(a: number, b: number) => number".to_string())
625        );
626        assert_eq!(node.metadata["async"], true);
627    }
628
629    #[test]
630    fn test_node_builder() {
631        let span = Span::new(0, 10, 1, 1, 1, 11);
632        let node = NodeBuilder::new("test_repo", NodeKind::Function)
633            .name("myFunction")
634            .language(Language::TypeScript)
635            .file("src/index.ts")
636            .span(span.clone())
637            .signature("() => void")
638            .metadata(serde_json::json!({ "exported": true }))
639            .build();
640
641        assert_eq!(node.name, "myFunction");
642        assert_eq!(node.lang, Language::TypeScript);
643        assert_eq!(node.file, PathBuf::from("src/index.ts"));
644        assert_eq!(node.span, span);
645        assert_eq!(node.signature, Some("() => void".to_string()));
646        assert_eq!(node.metadata["exported"], true);
647    }
648
649    #[test]
650    fn test_edge_creation_and_serialization() {
651        let span1 = Span::new(0, 10, 1, 1, 1, 11);
652        let span2 = Span::new(20, 30, 2, 1, 2, 11);
653
654        let id1 = NodeId::new("repo", Path::new("file.js"), &span1, &NodeKind::Function);
655        let id2 = NodeId::new("repo", Path::new("file.js"), &span2, &NodeKind::Function);
656
657        let edge = Edge::new(id1, id2, EdgeKind::Calls);
658        assert_eq!(edge.source, id1);
659        assert_eq!(edge.target, id2);
660        assert_eq!(edge.kind, EdgeKind::Calls);
661
662        // Test serialization
663        let serialized = serde_json::to_string(&edge).unwrap();
664        let deserialized: Edge = serde_json::from_str(&serialized).unwrap();
665        assert_eq!(edge, deserialized);
666
667        // Test edge ID
668        let edge_id = edge.id();
669        assert!(edge_id.contains(&id1.to_string()));
670        assert!(edge_id.contains(&id2.to_string()));
671        assert!(edge_id.contains("Calls"));
672    }
673
674    #[test]
675    fn test_display_traits() {
676        // NodeKind display
677        assert_eq!(NodeKind::Function.to_string(), "Function");
678        assert_eq!(NodeKind::Module.to_string(), "Module");
679
680        // EdgeKind display
681        assert_eq!(EdgeKind::Calls.to_string(), "CALLS");
682        assert_eq!(EdgeKind::Imports.to_string(), "IMPORTS");
683
684        // Language display
685        assert_eq!(Language::JavaScript.to_string(), "JavaScript");
686        assert_eq!(Language::Cpp.to_string(), "C++");
687
688        // Span display
689        let span = Span::new(0, 10, 1, 5, 2, 15);
690        assert_eq!(span.to_string(), "1:2-5:15");
691
692        // Node display
693        let node = Node::new(
694            "repo",
695            NodeKind::Function,
696            "myFunc".to_string(),
697            Language::JavaScript,
698            PathBuf::from("test.js"),
699            span.clone(),
700        );
701        let display = node.to_string();
702        assert!(display.contains("JavaScript"));
703        assert!(display.contains("Function"));
704        assert!(display.contains("myFunc"));
705        assert!(display.contains("test.js"));
706
707        // Edge display
708        let id1 = NodeId::new("repo", Path::new("file.js"), &span, &NodeKind::Function);
709        let id2 = NodeId::new("repo", Path::new("file.js"), &span, &NodeKind::Variable);
710        let edge = Edge::new(id1, id2, EdgeKind::Reads);
711        let edge_display = edge.to_string();
712        assert!(edge_display.contains("READS"));
713    }
714}