Skip to main content

infiniloom_engine/embedding/
type_extraction.rs

1//! Type signature extraction via Tree-sitter queries
2//!
3//! Given a chunk's source code and language, this module parses the AST
4//! and extracts type information (parameter types, return types, error types)
5//! for function/method declarations.
6//!
7//! Supported languages: Rust, TypeScript, Python, Java, Go.
8//! For unsupported languages, returns `None`.
9
10use crate::parser::Language;
11
12/// Extracted type information for a function or method
13#[derive(Debug, Clone, Default, PartialEq, Eq)]
14pub struct TypeInfo {
15    /// Clean type signature, e.g. "(i32, &str) -> Result<Claims, AuthError>"
16    pub type_signature: Option<String>,
17    /// Individual parameter types, e.g. ["i32", "&str"]
18    pub parameter_types: Vec<String>,
19    /// Return type, e.g. "Result<Claims, AuthError>"
20    pub return_type: Option<String>,
21    /// Error/exception types, e.g. ["AuthError"]
22    pub error_types: Vec<String>,
23}
24
25/// Extract type information from a code chunk's content.
26///
27/// Parses the content with Tree-sitter for the given language and extracts
28/// parameter types, return type, error types, and a clean type signature
29/// from the first function/method declaration found.
30///
31/// Returns `None` if the language is unsupported or no function is found.
32pub fn extract_types(content: &str, language: Language) -> Option<TypeInfo> {
33    let ts_lang = language.tree_sitter_language()?;
34
35    let mut parser = tree_sitter::Parser::new();
36    parser.set_language(&ts_lang).ok()?;
37    let tree = parser.parse(content, None)?;
38    let root = tree.root_node();
39
40    match language {
41        Language::Rust => extract_rust_types(root, content),
42        Language::TypeScript => extract_typescript_types(root, content),
43        Language::Python => extract_python_types(root, content),
44        Language::Java => extract_java_types(root, content),
45        Language::Go => extract_go_types(root, content),
46        _ => None,
47    }
48}
49
50/// Recursively find the first node with one of the given kinds.
51fn find_first_node<'a>(
52    node: tree_sitter::Node<'a>,
53    kinds: &[&str],
54) -> Option<tree_sitter::Node<'a>> {
55    if kinds.contains(&node.kind()) {
56        return Some(node);
57    }
58    let mut cursor = node.walk();
59    for child in node.children(&mut cursor) {
60        if let Some(found) = find_first_node(child, kinds) {
61            return Some(found);
62        }
63    }
64    None
65}
66
67/// Get the text of a node from source content.
68fn node_text<'a>(node: tree_sitter::Node<'_>, source: &'a str) -> &'a str {
69    node.utf8_text(source.as_bytes()).unwrap_or("")
70}
71
72// ---------------------------------------------------------------------------
73// Rust
74// ---------------------------------------------------------------------------
75
76fn extract_rust_types(root: tree_sitter::Node<'_>, source: &str) -> Option<TypeInfo> {
77    let func_node = find_first_node(root, &["function_item", "function_signature_item"])?;
78
79    let mut param_types = Vec::new();
80
81    // Find parameters node
82    if let Some(params_node) = find_child_by_kind(func_node, "parameters") {
83        let mut cursor = params_node.walk();
84        for child in params_node.children(&mut cursor) {
85            if child.kind() == "parameter" {
86                // Look for the type child (skip the pattern/name)
87                if let Some(type_node) = find_child_by_kind(child, "type_identifier")
88                    .or_else(|| find_child_by_kind(child, "reference_type"))
89                    .or_else(|| find_child_by_kind(child, "generic_type"))
90                    .or_else(|| find_child_by_kind(child, "scoped_type_identifier"))
91                    .or_else(|| find_child_by_kind(child, "primitive_type"))
92                    .or_else(|| find_child_by_kind(child, "array_type"))
93                    .or_else(|| find_child_by_kind(child, "tuple_type"))
94                    .or_else(|| find_child_by_kind(child, "function_type"))
95                    .or_else(|| find_child_by_kind(child, "bounded_type"))
96                    .or_else(|| find_child_by_kind(child, "dynamic_type"))
97                {
98                    param_types.push(node_text(type_node, source).to_owned());
99                }
100            } else if child.kind() == "self_parameter" {
101                param_types.push(node_text(child, source).to_owned());
102            }
103        }
104    }
105
106    // Find return type
107    let mut return_type: Option<String> = None;
108    let mut cursor = func_node.walk();
109    for child in func_node.children(&mut cursor) {
110        // In tree-sitter-rust, the return type appears as a child with
111        // kind "type_identifier", "generic_type", etc. after the "->" token.
112        // Look for a node whose previous sibling is "->".
113        if child.kind() == "->" {
114            // The next sibling is the return type
115            if let Some(next) = child.next_sibling() {
116                return_type = Some(node_text(next, source).trim().to_owned());
117            }
118        }
119    }
120
121    // Extract error types from Result<_, E>
122    let error_types = return_type
123        .as_ref()
124        .map(|rt| extract_rust_error_types(rt))
125        .unwrap_or_default();
126
127    // Build type signature
128    let params_str = param_types
129        .iter()
130        .filter(|p| *p != "&self" && *p != "&mut self" && *p != "self")
131        .cloned()
132        .collect::<Vec<_>>()
133        .join(", ");
134
135    let type_signature = if let Some(ref rt) = return_type {
136        Some(format!("({}) -> {}", params_str, rt))
137    } else if !param_types.is_empty() {
138        Some(format!("({})", params_str))
139    } else {
140        None
141    };
142
143    if type_signature.is_none() && param_types.is_empty() && return_type.is_none() {
144        return None;
145    }
146
147    Some(TypeInfo { type_signature, parameter_types: param_types, return_type, error_types })
148}
149
150/// Extract error types from a Rust Result type string.
151/// e.g., "Result<Claims, AuthError>" -> ["AuthError"]
152fn extract_rust_error_types(return_type: &str) -> Vec<String> {
153    let trimmed = return_type.trim();
154    if !trimmed.starts_with("Result<") && !trimmed.starts_with("Result <") {
155        return Vec::new();
156    }
157
158    // Find the content between Result< and the matching >
159    if let Some(start) = trimmed.find('<') {
160        let inner = &trimmed[start + 1..];
161        if let Some(end) = find_matching_bracket(inner) {
162            let content = &inner[..end];
163            // Split on the first top-level comma
164            if let Some(comma_pos) = find_top_level_comma(content) {
165                let error_part = content[comma_pos + 1..].trim();
166                if !error_part.is_empty() {
167                    return vec![error_part.to_owned()];
168                }
169            }
170        }
171    }
172    Vec::new()
173}
174
175/// Find position of matching closing bracket, accounting for nesting.
176fn find_matching_bracket(s: &str) -> Option<usize> {
177    let mut depth = 0;
178    for (i, ch) in s.char_indices() {
179        match ch {
180            '<' => depth += 1,
181            '>' => {
182                if depth == 0 {
183                    return Some(i);
184                }
185                depth -= 1;
186            },
187            _ => {},
188        }
189    }
190    None
191}
192
193/// Find position of the first comma at nesting depth 0.
194fn find_top_level_comma(s: &str) -> Option<usize> {
195    let mut depth = 0;
196    for (i, ch) in s.char_indices() {
197        match ch {
198            '<' | '(' | '[' => depth += 1,
199            '>' | ')' | ']' if depth > 0 => depth -= 1,
200            '>' | ')' | ']' => {},
201            ',' if depth == 0 => return Some(i),
202            _ => {},
203        }
204    }
205    None
206}
207
208// ---------------------------------------------------------------------------
209// TypeScript
210// ---------------------------------------------------------------------------
211
212fn extract_typescript_types(root: tree_sitter::Node<'_>, source: &str) -> Option<TypeInfo> {
213    let func_node = find_first_node(
214        root,
215        &["function_declaration", "method_definition", "arrow_function", "function_signature"],
216    )?;
217
218    let mut param_types = Vec::new();
219
220    // Find formal_parameters
221    if let Some(params_node) = find_child_by_kind(func_node, "formal_parameters") {
222        let mut cursor = params_node.walk();
223        for child in params_node.children(&mut cursor) {
224            if child.kind() == "required_parameter" || child.kind() == "optional_parameter" {
225                if let Some(ta) = find_child_by_kind(child, "type_annotation") {
226                    // The type is the child after the ":"
227                    let mut ta_cursor = ta.walk();
228                    for ta_child in ta.children(&mut ta_cursor) {
229                        if ta_child.kind() != ":" {
230                            let text = node_text(ta_child, source).trim();
231                            if !text.is_empty() {
232                                param_types.push(text.to_owned());
233                            }
234                        }
235                    }
236                }
237            }
238        }
239    }
240
241    // Find return type annotation on the function itself
242    let return_type = find_child_by_kind(func_node, "type_annotation").and_then(|ta| {
243        let mut cursor = ta.walk();
244        for child in ta.children(&mut cursor) {
245            if child.kind() != ":" {
246                let text = node_text(child, source).trim().to_owned();
247                if !text.is_empty() {
248                    return Some(text);
249                }
250            }
251        }
252        None
253    });
254
255    // Build TS-style type signature
256    let params_str = param_types.join(", ");
257    let type_signature = if let Some(ref rt) = return_type {
258        Some(format!("({}) => {}", params_str, rt))
259    } else if !param_types.is_empty() {
260        Some(format!("({})", params_str))
261    } else {
262        None
263    };
264
265    if type_signature.is_none() && param_types.is_empty() && return_type.is_none() {
266        return None;
267    }
268
269    Some(TypeInfo {
270        type_signature,
271        parameter_types: param_types,
272        return_type,
273        error_types: Vec::new(),
274    })
275}
276
277// ---------------------------------------------------------------------------
278// Python
279// ---------------------------------------------------------------------------
280
281fn extract_python_types(root: tree_sitter::Node<'_>, source: &str) -> Option<TypeInfo> {
282    let func_node = find_first_node(root, &["function_definition"])?;
283
284    let mut param_types = Vec::new();
285
286    // Find parameters
287    if let Some(params_node) = find_child_by_kind(func_node, "parameters") {
288        let mut cursor = params_node.walk();
289        for child in params_node.children(&mut cursor) {
290            // typed_parameter has a type child
291            if child.kind() == "typed_parameter" || child.kind() == "typed_default_parameter" {
292                if let Some(type_node) = find_child_by_kind(child, "type") {
293                    param_types.push(node_text(type_node, source).trim().to_owned());
294                }
295            }
296        }
297    }
298
299    // Find return type (-> annotation)
300    let return_type =
301        find_child_by_kind(func_node, "type").map(|n| node_text(n, source).trim().to_owned());
302
303    // Build Python-style type signature
304    let params_str = param_types.join(", ");
305    let type_signature = if let Some(ref rt) = return_type {
306        Some(format!("({}) -> {}", params_str, rt))
307    } else if !param_types.is_empty() {
308        Some(format!("({})", params_str))
309    } else {
310        None
311    };
312
313    if type_signature.is_none() && param_types.is_empty() && return_type.is_none() {
314        return None;
315    }
316
317    Some(TypeInfo {
318        type_signature,
319        parameter_types: param_types,
320        return_type,
321        error_types: Vec::new(),
322    })
323}
324
325// ---------------------------------------------------------------------------
326// Java
327// ---------------------------------------------------------------------------
328
329fn extract_java_types(root: tree_sitter::Node<'_>, source: &str) -> Option<TypeInfo> {
330    let func_node = find_first_node(root, &["method_declaration", "constructor_declaration"])?;
331
332    let mut param_types = Vec::new();
333
334    // Java: return type appears before the method name
335    // In tree-sitter-java, method_declaration has children like:
336    //   modifiers? type_identifier identifier formal_parameters throws? block
337    let mut return_type: Option<String> = None;
338    let mut cursor = func_node.walk();
339    for child in func_node.children(&mut cursor) {
340        let kind = child.kind();
341        // Type nodes that can appear as return type
342        if kind == "type_identifier"
343            || kind == "generic_type"
344            || kind == "array_type"
345            || kind == "void_type"
346            || kind == "integral_type"
347            || kind == "floating_point_type"
348            || kind == "boolean_type"
349            || kind == "scoped_type_identifier"
350        {
351            return_type = Some(node_text(child, source).trim().to_owned());
352        }
353        // Stop before the method name (identifier) and parameters
354        if kind == "identifier" || kind == "formal_parameters" {
355            break;
356        }
357    }
358
359    // Find formal_parameters
360    if let Some(params_node) = find_child_by_kind(func_node, "formal_parameters") {
361        let mut pcursor = params_node.walk();
362        for child in params_node.children(&mut pcursor) {
363            if child.kind() == "formal_parameter" || child.kind() == "spread_parameter" {
364                // The type is the first type-like child
365                let mut param_cursor = child.walk();
366                for pchild in child.children(&mut param_cursor) {
367                    let pk = pchild.kind();
368                    if pk == "type_identifier"
369                        || pk == "generic_type"
370                        || pk == "array_type"
371                        || pk == "integral_type"
372                        || pk == "floating_point_type"
373                        || pk == "boolean_type"
374                        || pk == "scoped_type_identifier"
375                    {
376                        param_types.push(node_text(pchild, source).trim().to_owned());
377                        break;
378                    }
379                }
380            }
381        }
382    }
383
384    // Extract throws clause for error types
385    let mut error_types = Vec::new();
386    if let Some(throws_node) = find_child_by_kind(func_node, "throws") {
387        let mut tcursor = throws_node.walk();
388        for child in throws_node.children(&mut tcursor) {
389            if child.kind() == "type_identifier" || child.kind() == "scoped_type_identifier" {
390                error_types.push(node_text(child, source).trim().to_owned());
391            }
392        }
393    }
394
395    // Build Java-style type signature
396    let params_str = param_types.join(", ");
397    let mut sig = format!("({}) -> {}", params_str, return_type.as_deref().unwrap_or("void"));
398    if !error_types.is_empty() {
399        sig.push_str(&format!(" throws {}", error_types.join(", ")));
400    }
401    let type_signature = Some(sig);
402
403    Some(TypeInfo { type_signature, parameter_types: param_types, return_type, error_types })
404}
405
406// ---------------------------------------------------------------------------
407// Go
408// ---------------------------------------------------------------------------
409
410fn extract_go_types(root: tree_sitter::Node<'_>, source: &str) -> Option<TypeInfo> {
411    let func_node = find_first_node(root, &["function_declaration", "method_declaration"])?;
412
413    let mut param_types = Vec::new();
414
415    // Find parameter_list
416    // In Go, function_declaration has: "func" name parameter_list result? block
417    // method_declaration has: "func" parameter_list name parameter_list result? block
418    // We want the last parameter_list before the result/block
419    let param_lists: Vec<tree_sitter::Node<'_>> = {
420        let mut cursor = func_node.walk();
421        func_node
422            .children(&mut cursor)
423            .filter(|c| c.kind() == "parameter_list")
424            .collect()
425    };
426
427    // For method_declaration, the first parameter_list is the receiver
428    // The actual params are the second parameter_list
429    let params_node = if func_node.kind() == "method_declaration" {
430        param_lists.get(1).or(param_lists.first())
431    } else {
432        param_lists.first()
433    };
434
435    if let Some(params) = params_node {
436        let mut cursor = params.walk();
437        for child in params.children(&mut cursor) {
438            if child.kind() == "parameter_declaration" {
439                // In Go, parameter_declaration has: name type
440                // or just: type (for unnamed params)
441                // The last type-like child is the type
442                let mut last_type = None;
443                let mut pcursor = child.walk();
444                for pchild in child.children(&mut pcursor) {
445                    let pk = pchild.kind();
446                    if pk == "type_identifier"
447                        || pk == "pointer_type"
448                        || pk == "slice_type"
449                        || pk == "array_type"
450                        || pk == "map_type"
451                        || pk == "channel_type"
452                        || pk == "function_type"
453                        || pk == "interface_type"
454                        || pk == "struct_type"
455                        || pk == "qualified_type"
456                    {
457                        last_type = Some(node_text(pchild, source).trim().to_owned());
458                    }
459                }
460                if let Some(t) = last_type {
461                    param_types.push(t);
462                }
463            }
464        }
465    }
466
467    // Find result (return type)
468    let mut return_type: Option<String> = None;
469    let mut error_types = Vec::new();
470
471    let mut cursor = func_node.walk();
472    for child in func_node.children(&mut cursor) {
473        if child.kind() == "parameter_list" {
474            // Could be return type tuple: (Type1, Type2)
475            // Check if this is after the main params by position
476            if Some(&child) != params_node {
477                let text = node_text(child, source).trim().to_owned();
478                return_type = Some(text);
479
480                // Check if last return is "error"
481                let mut rcursor = child.walk();
482                let return_params: Vec<_> = child
483                    .children(&mut rcursor)
484                    .filter(|c| c.kind() == "parameter_declaration")
485                    .collect();
486                if let Some(last) = return_params.last() {
487                    let last_text = node_text(*last, source).trim();
488                    if last_text == "error" || last_text.ends_with(" error") {
489                        error_types.push("error".to_owned());
490                    }
491                }
492            }
493        }
494        if child.kind() == "type_identifier"
495            || child.kind() == "pointer_type"
496            || child.kind() == "slice_type"
497            || child.kind() == "qualified_type"
498        {
499            // Simple single return type
500            let prev_sibling_is_params = child
501                .prev_sibling()
502                .is_some_and(|s| s.kind() == "parameter_list");
503            if prev_sibling_is_params || return_type.is_none() {
504                let text = node_text(child, source).trim().to_owned();
505                if text == "error" {
506                    error_types.push("error".to_owned());
507                }
508                return_type = Some(text);
509            }
510        }
511    }
512
513    // Build Go-style type signature
514    let params_str = param_types.join(", ");
515    let type_signature = if let Some(ref rt) = return_type {
516        Some(format!("({}) -> {}", params_str, rt))
517    } else if !param_types.is_empty() {
518        Some(format!("({})", params_str))
519    } else {
520        None
521    };
522
523    if type_signature.is_none() && param_types.is_empty() && return_type.is_none() {
524        return None;
525    }
526
527    Some(TypeInfo { type_signature, parameter_types: param_types, return_type, error_types })
528}
529
530// ---------------------------------------------------------------------------
531// Helpers
532// ---------------------------------------------------------------------------
533
534/// Find the first direct child of `node` with the given kind.
535fn find_child_by_kind<'a>(
536    node: tree_sitter::Node<'a>,
537    kind: &str,
538) -> Option<tree_sitter::Node<'a>> {
539    let count = node.child_count() as u32;
540    for i in 0..count {
541        if let Some(child) = node.child(i) {
542            if child.kind() == kind {
543                return Some(child);
544            }
545        }
546    }
547    None
548}
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553
554    #[test]
555    fn test_rust_typed_function() {
556        let source = r#"fn validate(token: &str, max_age: i32) -> Result<Claims, AuthError> {
557    todo!()
558}"#;
559        let info = extract_types(source, Language::Rust).unwrap();
560        assert_eq!(info.parameter_types, vec!["&str", "i32"]);
561        assert_eq!(info.return_type.as_deref(), Some("Result<Claims, AuthError>"));
562        assert_eq!(info.error_types, vec!["AuthError"]);
563        assert!(info
564            .type_signature
565            .as_ref()
566            .unwrap()
567            .contains("-> Result<Claims, AuthError>"));
568    }
569
570    #[test]
571    fn test_rust_self_method() {
572        let source = r#"fn process(&self, data: Vec<u8>) -> bool {
573    true
574}"#;
575        let info = extract_types(source, Language::Rust).unwrap();
576        assert!(info.parameter_types.contains(&"&self".to_owned()));
577        assert!(info.parameter_types.contains(&"Vec<u8>".to_owned()));
578        assert_eq!(info.return_type.as_deref(), Some("bool"));
579    }
580
581    #[test]
582    fn test_rust_no_return_type() {
583        let source = r#"fn setup(config: Config) {
584    // ...
585}"#;
586        let info = extract_types(source, Language::Rust).unwrap();
587        assert_eq!(info.parameter_types, vec!["Config"]);
588        assert!(info.return_type.is_none());
589    }
590
591    #[test]
592    fn test_typescript_function() {
593        let source = r#"function greet(name: string, age: number): Promise<void> {
594    console.log(name);
595}"#;
596        let info = extract_types(source, Language::TypeScript).unwrap();
597        assert_eq!(info.parameter_types, vec!["string", "number"]);
598        assert_eq!(info.return_type.as_deref(), Some("Promise<void>"));
599        assert!(info
600            .type_signature
601            .as_ref()
602            .unwrap()
603            .contains("=> Promise<void>"));
604    }
605
606    #[test]
607    fn test_python_function() {
608        let source = r#"def process(data: list, count: int) -> dict:
609    pass"#;
610        let info = extract_types(source, Language::Python).unwrap();
611        assert_eq!(info.parameter_types, vec!["list", "int"]);
612        assert_eq!(info.return_type.as_deref(), Some("dict"));
613        assert!(info.type_signature.as_ref().unwrap().contains("-> dict"));
614    }
615
616    #[test]
617    fn test_no_types_returns_none() {
618        // Python function without type annotations
619        let source = r#"def hello(name):
620    print(name)"#;
621        let result = extract_types(source, Language::Python);
622        assert!(result.is_none());
623    }
624
625    #[test]
626    fn test_rust_error_type_extraction() {
627        assert_eq!(extract_rust_error_types("Result<Claims, AuthError>"), vec!["AuthError"]);
628        assert_eq!(extract_rust_error_types("Result<(), std::io::Error>"), vec!["std::io::Error"]);
629        assert!(extract_rust_error_types("bool").is_empty());
630        assert!(extract_rust_error_types("Option<String>").is_empty());
631    }
632
633    #[test]
634    fn test_unsupported_language_returns_none() {
635        let source = "def foo; end";
636        let result = extract_types(source, Language::Ruby);
637        assert!(result.is_none());
638    }
639
640    #[test]
641    fn test_java_method() {
642        let source = r#"class Foo {
643    public String process(int count, List<String> items) throws IOException {
644        return "";
645    }
646}"#;
647        let info = extract_types(source, Language::Java);
648        // Java should extract from method_declaration
649        if let Some(info) = info {
650            assert!(
651                info.parameter_types.contains(&"int".to_owned())
652                    || !info.parameter_types.is_empty()
653            );
654            assert!(!info.error_types.is_empty() || info.return_type.is_some());
655        }
656    }
657
658    #[test]
659    fn test_go_function() {
660        let source = r#"package main
661
662func Process(data []byte, count int) (string, error) {
663    return "", nil
664}"#;
665        let info = extract_types(source, Language::Go);
666        if let Some(info) = info {
667            assert!(!info.parameter_types.is_empty() || info.return_type.is_some());
668        }
669    }
670}