infiniloom_engine/embedding/
type_extraction.rs1use crate::parser::Language;
11
12#[derive(Debug, Clone, Default, PartialEq, Eq)]
14pub struct TypeInfo {
15 pub type_signature: Option<String>,
17 pub parameter_types: Vec<String>,
19 pub return_type: Option<String>,
21 pub error_types: Vec<String>,
23}
24
25pub fn extract_types(content: &str, language: Language) -> Option<TypeInfo> {
33 let ts_lang = language.tree_sitter_language()?;
34
35 let mut parser = tree_sitter::Parser::new();
36 parser.set_language(&ts_lang).ok()?;
37 let tree = parser.parse(content, None)?;
38 let root = tree.root_node();
39
40 match language {
41 Language::Rust => extract_rust_types(root, content),
42 Language::TypeScript => extract_typescript_types(root, content),
43 Language::Python => extract_python_types(root, content),
44 Language::Java => extract_java_types(root, content),
45 Language::Go => extract_go_types(root, content),
46 _ => None,
47 }
48}
49
50fn find_first_node<'a>(
52 node: tree_sitter::Node<'a>,
53 kinds: &[&str],
54) -> Option<tree_sitter::Node<'a>> {
55 if kinds.contains(&node.kind()) {
56 return Some(node);
57 }
58 let mut cursor = node.walk();
59 for child in node.children(&mut cursor) {
60 if let Some(found) = find_first_node(child, kinds) {
61 return Some(found);
62 }
63 }
64 None
65}
66
67fn node_text<'a>(node: tree_sitter::Node<'_>, source: &'a str) -> &'a str {
69 node.utf8_text(source.as_bytes()).unwrap_or("")
70}
71
72fn extract_rust_types(root: tree_sitter::Node<'_>, source: &str) -> Option<TypeInfo> {
77 let func_node = find_first_node(root, &["function_item", "function_signature_item"])?;
78
79 let mut param_types = Vec::new();
80
81 if let Some(params_node) = find_child_by_kind(func_node, "parameters") {
83 let mut cursor = params_node.walk();
84 for child in params_node.children(&mut cursor) {
85 if child.kind() == "parameter" {
86 if let Some(type_node) = find_child_by_kind(child, "type_identifier")
88 .or_else(|| find_child_by_kind(child, "reference_type"))
89 .or_else(|| find_child_by_kind(child, "generic_type"))
90 .or_else(|| find_child_by_kind(child, "scoped_type_identifier"))
91 .or_else(|| find_child_by_kind(child, "primitive_type"))
92 .or_else(|| find_child_by_kind(child, "array_type"))
93 .or_else(|| find_child_by_kind(child, "tuple_type"))
94 .or_else(|| find_child_by_kind(child, "function_type"))
95 .or_else(|| find_child_by_kind(child, "bounded_type"))
96 .or_else(|| find_child_by_kind(child, "dynamic_type"))
97 {
98 param_types.push(node_text(type_node, source).to_owned());
99 }
100 } else if child.kind() == "self_parameter" {
101 param_types.push(node_text(child, source).to_owned());
102 }
103 }
104 }
105
106 let mut return_type: Option<String> = None;
108 let mut cursor = func_node.walk();
109 for child in func_node.children(&mut cursor) {
110 if child.kind() == "->" {
114 if let Some(next) = child.next_sibling() {
116 return_type = Some(node_text(next, source).trim().to_owned());
117 }
118 }
119 }
120
121 let error_types = return_type
123 .as_ref()
124 .map(|rt| extract_rust_error_types(rt))
125 .unwrap_or_default();
126
127 let params_str = param_types
129 .iter()
130 .filter(|p| *p != "&self" && *p != "&mut self" && *p != "self")
131 .cloned()
132 .collect::<Vec<_>>()
133 .join(", ");
134
135 let type_signature = if let Some(ref rt) = return_type {
136 Some(format!("({}) -> {}", params_str, rt))
137 } else if !param_types.is_empty() {
138 Some(format!("({})", params_str))
139 } else {
140 None
141 };
142
143 if type_signature.is_none() && param_types.is_empty() && return_type.is_none() {
144 return None;
145 }
146
147 Some(TypeInfo { type_signature, parameter_types: param_types, return_type, error_types })
148}
149
150fn extract_rust_error_types(return_type: &str) -> Vec<String> {
153 let trimmed = return_type.trim();
154 if !trimmed.starts_with("Result<") && !trimmed.starts_with("Result <") {
155 return Vec::new();
156 }
157
158 if let Some(start) = trimmed.find('<') {
160 let inner = &trimmed[start + 1..];
161 if let Some(end) = find_matching_bracket(inner) {
162 let content = &inner[..end];
163 if let Some(comma_pos) = find_top_level_comma(content) {
165 let error_part = content[comma_pos + 1..].trim();
166 if !error_part.is_empty() {
167 return vec![error_part.to_owned()];
168 }
169 }
170 }
171 }
172 Vec::new()
173}
174
175fn find_matching_bracket(s: &str) -> Option<usize> {
177 let mut depth = 0;
178 for (i, ch) in s.char_indices() {
179 match ch {
180 '<' => depth += 1,
181 '>' => {
182 if depth == 0 {
183 return Some(i);
184 }
185 depth -= 1;
186 },
187 _ => {},
188 }
189 }
190 None
191}
192
193fn find_top_level_comma(s: &str) -> Option<usize> {
195 let mut depth = 0;
196 for (i, ch) in s.char_indices() {
197 match ch {
198 '<' | '(' | '[' => depth += 1,
199 '>' | ')' | ']' if depth > 0 => depth -= 1,
200 '>' | ')' | ']' => {},
201 ',' if depth == 0 => return Some(i),
202 _ => {},
203 }
204 }
205 None
206}
207
208fn extract_typescript_types(root: tree_sitter::Node<'_>, source: &str) -> Option<TypeInfo> {
213 let func_node = find_first_node(
214 root,
215 &["function_declaration", "method_definition", "arrow_function", "function_signature"],
216 )?;
217
218 let mut param_types = Vec::new();
219
220 if let Some(params_node) = find_child_by_kind(func_node, "formal_parameters") {
222 let mut cursor = params_node.walk();
223 for child in params_node.children(&mut cursor) {
224 if child.kind() == "required_parameter" || child.kind() == "optional_parameter" {
225 if let Some(ta) = find_child_by_kind(child, "type_annotation") {
226 let mut ta_cursor = ta.walk();
228 for ta_child in ta.children(&mut ta_cursor) {
229 if ta_child.kind() != ":" {
230 let text = node_text(ta_child, source).trim();
231 if !text.is_empty() {
232 param_types.push(text.to_owned());
233 }
234 }
235 }
236 }
237 }
238 }
239 }
240
241 let return_type = find_child_by_kind(func_node, "type_annotation").and_then(|ta| {
243 let mut cursor = ta.walk();
244 for child in ta.children(&mut cursor) {
245 if child.kind() != ":" {
246 let text = node_text(child, source).trim().to_owned();
247 if !text.is_empty() {
248 return Some(text);
249 }
250 }
251 }
252 None
253 });
254
255 let params_str = param_types.join(", ");
257 let type_signature = if let Some(ref rt) = return_type {
258 Some(format!("({}) => {}", params_str, rt))
259 } else if !param_types.is_empty() {
260 Some(format!("({})", params_str))
261 } else {
262 None
263 };
264
265 if type_signature.is_none() && param_types.is_empty() && return_type.is_none() {
266 return None;
267 }
268
269 Some(TypeInfo {
270 type_signature,
271 parameter_types: param_types,
272 return_type,
273 error_types: Vec::new(),
274 })
275}
276
277fn extract_python_types(root: tree_sitter::Node<'_>, source: &str) -> Option<TypeInfo> {
282 let func_node = find_first_node(root, &["function_definition"])?;
283
284 let mut param_types = Vec::new();
285
286 if let Some(params_node) = find_child_by_kind(func_node, "parameters") {
288 let mut cursor = params_node.walk();
289 for child in params_node.children(&mut cursor) {
290 if child.kind() == "typed_parameter" || child.kind() == "typed_default_parameter" {
292 if let Some(type_node) = find_child_by_kind(child, "type") {
293 param_types.push(node_text(type_node, source).trim().to_owned());
294 }
295 }
296 }
297 }
298
299 let return_type =
301 find_child_by_kind(func_node, "type").map(|n| node_text(n, source).trim().to_owned());
302
303 let params_str = param_types.join(", ");
305 let type_signature = if let Some(ref rt) = return_type {
306 Some(format!("({}) -> {}", params_str, rt))
307 } else if !param_types.is_empty() {
308 Some(format!("({})", params_str))
309 } else {
310 None
311 };
312
313 if type_signature.is_none() && param_types.is_empty() && return_type.is_none() {
314 return None;
315 }
316
317 Some(TypeInfo {
318 type_signature,
319 parameter_types: param_types,
320 return_type,
321 error_types: Vec::new(),
322 })
323}
324
325fn extract_java_types(root: tree_sitter::Node<'_>, source: &str) -> Option<TypeInfo> {
330 let func_node = find_first_node(root, &["method_declaration", "constructor_declaration"])?;
331
332 let mut param_types = Vec::new();
333
334 let mut return_type: Option<String> = None;
338 let mut cursor = func_node.walk();
339 for child in func_node.children(&mut cursor) {
340 let kind = child.kind();
341 if kind == "type_identifier"
343 || kind == "generic_type"
344 || kind == "array_type"
345 || kind == "void_type"
346 || kind == "integral_type"
347 || kind == "floating_point_type"
348 || kind == "boolean_type"
349 || kind == "scoped_type_identifier"
350 {
351 return_type = Some(node_text(child, source).trim().to_owned());
352 }
353 if kind == "identifier" || kind == "formal_parameters" {
355 break;
356 }
357 }
358
359 if let Some(params_node) = find_child_by_kind(func_node, "formal_parameters") {
361 let mut pcursor = params_node.walk();
362 for child in params_node.children(&mut pcursor) {
363 if child.kind() == "formal_parameter" || child.kind() == "spread_parameter" {
364 let mut param_cursor = child.walk();
366 for pchild in child.children(&mut param_cursor) {
367 let pk = pchild.kind();
368 if pk == "type_identifier"
369 || pk == "generic_type"
370 || pk == "array_type"
371 || pk == "integral_type"
372 || pk == "floating_point_type"
373 || pk == "boolean_type"
374 || pk == "scoped_type_identifier"
375 {
376 param_types.push(node_text(pchild, source).trim().to_owned());
377 break;
378 }
379 }
380 }
381 }
382 }
383
384 let mut error_types = Vec::new();
386 if let Some(throws_node) = find_child_by_kind(func_node, "throws") {
387 let mut tcursor = throws_node.walk();
388 for child in throws_node.children(&mut tcursor) {
389 if child.kind() == "type_identifier" || child.kind() == "scoped_type_identifier" {
390 error_types.push(node_text(child, source).trim().to_owned());
391 }
392 }
393 }
394
395 let params_str = param_types.join(", ");
397 let mut sig = format!("({}) -> {}", params_str, return_type.as_deref().unwrap_or("void"));
398 if !error_types.is_empty() {
399 sig.push_str(&format!(" throws {}", error_types.join(", ")));
400 }
401 let type_signature = Some(sig);
402
403 Some(TypeInfo { type_signature, parameter_types: param_types, return_type, error_types })
404}
405
406fn extract_go_types(root: tree_sitter::Node<'_>, source: &str) -> Option<TypeInfo> {
411 let func_node = find_first_node(root, &["function_declaration", "method_declaration"])?;
412
413 let mut param_types = Vec::new();
414
415 let param_lists: Vec<tree_sitter::Node<'_>> = {
420 let mut cursor = func_node.walk();
421 func_node
422 .children(&mut cursor)
423 .filter(|c| c.kind() == "parameter_list")
424 .collect()
425 };
426
427 let params_node = if func_node.kind() == "method_declaration" {
430 param_lists.get(1).or(param_lists.first())
431 } else {
432 param_lists.first()
433 };
434
435 if let Some(params) = params_node {
436 let mut cursor = params.walk();
437 for child in params.children(&mut cursor) {
438 if child.kind() == "parameter_declaration" {
439 let mut last_type = None;
443 let mut pcursor = child.walk();
444 for pchild in child.children(&mut pcursor) {
445 let pk = pchild.kind();
446 if pk == "type_identifier"
447 || pk == "pointer_type"
448 || pk == "slice_type"
449 || pk == "array_type"
450 || pk == "map_type"
451 || pk == "channel_type"
452 || pk == "function_type"
453 || pk == "interface_type"
454 || pk == "struct_type"
455 || pk == "qualified_type"
456 {
457 last_type = Some(node_text(pchild, source).trim().to_owned());
458 }
459 }
460 if let Some(t) = last_type {
461 param_types.push(t);
462 }
463 }
464 }
465 }
466
467 let mut return_type: Option<String> = None;
469 let mut error_types = Vec::new();
470
471 let mut cursor = func_node.walk();
472 for child in func_node.children(&mut cursor) {
473 if child.kind() == "parameter_list" {
474 if Some(&child) != params_node {
477 let text = node_text(child, source).trim().to_owned();
478 return_type = Some(text);
479
480 let mut rcursor = child.walk();
482 let return_params: Vec<_> = child
483 .children(&mut rcursor)
484 .filter(|c| c.kind() == "parameter_declaration")
485 .collect();
486 if let Some(last) = return_params.last() {
487 let last_text = node_text(*last, source).trim();
488 if last_text == "error" || last_text.ends_with(" error") {
489 error_types.push("error".to_owned());
490 }
491 }
492 }
493 }
494 if child.kind() == "type_identifier"
495 || child.kind() == "pointer_type"
496 || child.kind() == "slice_type"
497 || child.kind() == "qualified_type"
498 {
499 let prev_sibling_is_params = child
501 .prev_sibling()
502 .is_some_and(|s| s.kind() == "parameter_list");
503 if prev_sibling_is_params || return_type.is_none() {
504 let text = node_text(child, source).trim().to_owned();
505 if text == "error" {
506 error_types.push("error".to_owned());
507 }
508 return_type = Some(text);
509 }
510 }
511 }
512
513 let params_str = param_types.join(", ");
515 let type_signature = if let Some(ref rt) = return_type {
516 Some(format!("({}) -> {}", params_str, rt))
517 } else if !param_types.is_empty() {
518 Some(format!("({})", params_str))
519 } else {
520 None
521 };
522
523 if type_signature.is_none() && param_types.is_empty() && return_type.is_none() {
524 return None;
525 }
526
527 Some(TypeInfo { type_signature, parameter_types: param_types, return_type, error_types })
528}
529
530fn find_child_by_kind<'a>(
536 node: tree_sitter::Node<'a>,
537 kind: &str,
538) -> Option<tree_sitter::Node<'a>> {
539 let count = node.child_count() as u32;
540 for i in 0..count {
541 if let Some(child) = node.child(i) {
542 if child.kind() == kind {
543 return Some(child);
544 }
545 }
546 }
547 None
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553
554 #[test]
555 fn test_rust_typed_function() {
556 let source = r#"fn validate(token: &str, max_age: i32) -> Result<Claims, AuthError> {
557 todo!()
558}"#;
559 let info = extract_types(source, Language::Rust).unwrap();
560 assert_eq!(info.parameter_types, vec!["&str", "i32"]);
561 assert_eq!(info.return_type.as_deref(), Some("Result<Claims, AuthError>"));
562 assert_eq!(info.error_types, vec!["AuthError"]);
563 assert!(info
564 .type_signature
565 .as_ref()
566 .unwrap()
567 .contains("-> Result<Claims, AuthError>"));
568 }
569
570 #[test]
571 fn test_rust_self_method() {
572 let source = r#"fn process(&self, data: Vec<u8>) -> bool {
573 true
574}"#;
575 let info = extract_types(source, Language::Rust).unwrap();
576 assert!(info.parameter_types.contains(&"&self".to_owned()));
577 assert!(info.parameter_types.contains(&"Vec<u8>".to_owned()));
578 assert_eq!(info.return_type.as_deref(), Some("bool"));
579 }
580
581 #[test]
582 fn test_rust_no_return_type() {
583 let source = r#"fn setup(config: Config) {
584 // ...
585}"#;
586 let info = extract_types(source, Language::Rust).unwrap();
587 assert_eq!(info.parameter_types, vec!["Config"]);
588 assert!(info.return_type.is_none());
589 }
590
591 #[test]
592 fn test_typescript_function() {
593 let source = r#"function greet(name: string, age: number): Promise<void> {
594 console.log(name);
595}"#;
596 let info = extract_types(source, Language::TypeScript).unwrap();
597 assert_eq!(info.parameter_types, vec!["string", "number"]);
598 assert_eq!(info.return_type.as_deref(), Some("Promise<void>"));
599 assert!(info
600 .type_signature
601 .as_ref()
602 .unwrap()
603 .contains("=> Promise<void>"));
604 }
605
606 #[test]
607 fn test_python_function() {
608 let source = r#"def process(data: list, count: int) -> dict:
609 pass"#;
610 let info = extract_types(source, Language::Python).unwrap();
611 assert_eq!(info.parameter_types, vec!["list", "int"]);
612 assert_eq!(info.return_type.as_deref(), Some("dict"));
613 assert!(info.type_signature.as_ref().unwrap().contains("-> dict"));
614 }
615
616 #[test]
617 fn test_no_types_returns_none() {
618 let source = r#"def hello(name):
620 print(name)"#;
621 let result = extract_types(source, Language::Python);
622 assert!(result.is_none());
623 }
624
625 #[test]
626 fn test_rust_error_type_extraction() {
627 assert_eq!(extract_rust_error_types("Result<Claims, AuthError>"), vec!["AuthError"]);
628 assert_eq!(extract_rust_error_types("Result<(), std::io::Error>"), vec!["std::io::Error"]);
629 assert!(extract_rust_error_types("bool").is_empty());
630 assert!(extract_rust_error_types("Option<String>").is_empty());
631 }
632
633 #[test]
634 fn test_unsupported_language_returns_none() {
635 let source = "def foo; end";
636 let result = extract_types(source, Language::Ruby);
637 assert!(result.is_none());
638 }
639
640 #[test]
641 fn test_java_method() {
642 let source = r#"class Foo {
643 public String process(int count, List<String> items) throws IOException {
644 return "";
645 }
646}"#;
647 let info = extract_types(source, Language::Java);
648 if let Some(info) = info {
650 assert!(
651 info.parameter_types.contains(&"int".to_owned())
652 || !info.parameter_types.is_empty()
653 );
654 assert!(!info.error_types.is_empty() || info.return_type.is_some());
655 }
656 }
657
658 #[test]
659 fn test_go_function() {
660 let source = r#"package main
661
662func Process(data []byte, count int) (string, error) {
663 return "", nil
664}"#;
665 let info = extract_types(source, Language::Go);
666 if let Some(info) = info {
667 assert!(!info.parameter_types.is_empty() || info.return_type.is_some());
668 }
669 }
670}