use crate::error::Result;
use crate::ingest::imports::{ImportExtractor, ImportKind};
use crate::symbol::Language;
use std::path::Path;
pub struct CppExtractor;
impl ImportExtractor for CppExtractor {
fn language() -> tree_sitter::Language {
tree_sitter_cpp::language()
}
fn language_enum() -> Language {
Language::Cpp
}
fn extract_from_node(
node: tree_sitter::Node,
source: &[u8],
imports: &mut Vec<super::ImportFact>,
) {
extract_include_statements(node, source, imports);
}
}
fn strip_quotes_or_angle_brackets(text: &str) -> String {
let chars: Vec<char> = text.chars().collect();
if chars.len() >= 2 {
match (chars.first(), chars.last()) {
(Some('<'), Some('>')) | (Some('"'), Some('"')) | (Some('\''), Some('\'')) => {
chars[1..chars.len() - 1].iter().collect()
}
_ => text.to_string(),
}
} else {
text.to_string()
}
}
pub fn extract_cpp_imports(path: &Path, source: &[u8]) -> Result<Vec<super::ImportFact>> {
CppExtractor::extract(path, source)
}
fn extract_include_statements(
node: tree_sitter::Node,
source: &[u8],
imports: &mut Vec<super::ImportFact>,
) {
if node.kind() == "preproc_include" {
if let Some(include) = extract_preproc_include(node, source) {
imports.push(include);
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
extract_include_statements(child, source, imports);
}
}
fn extract_preproc_include(node: tree_sitter::Node, source: &[u8]) -> Option<super::ImportFact> {
let byte_start = node.start_byte();
let byte_end = node.end_byte();
let mut cursor = node.walk();
let mut path = String::new();
let mut is_system = false;
for child in node.children(&mut cursor) {
match child.kind() {
"system_lib_string" => {
is_system = true;
if let Ok(text) = child.utf8_text(source) {
path = strip_quotes_or_angle_brackets(text);
}
}
"string_literal" => {
is_system = false;
if let Ok(text) = child.utf8_text(source) {
path = strip_quotes_or_angle_brackets(text);
}
}
_ => {}
}
}
if path.is_empty() {
return None;
}
let import_kind = if is_system {
ImportKind::CppSystemInclude
} else {
ImportKind::CppLocalInclude
};
let path_parts: Vec<String> = path.split('/').map(|s| s.to_string()).collect();
Some(super::ImportFact {
file_path: std::path::PathBuf::new(),
import_kind,
path: path_parts,
imported_names: vec![path.clone()],
is_glob: false,
is_reexport: false,
byte_span: (byte_start, byte_end),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_system_include() -> std::result::Result<(), Box<dyn std::error::Error>> {
let source = b"#include <stdio.h>\n";
let path = Path::new("test.c");
let result = extract_cpp_imports(path, source);
assert!(result.is_ok());
let imports = result?;
assert_eq!(imports.len(), 1);
assert_eq!(imports[0].import_kind, ImportKind::CppSystemInclude);
assert_eq!(imports[0].path, vec!["stdio.h"]);
Ok(())
}
#[test]
fn test_extract_local_include() -> std::result::Result<(), Box<dyn std::error::Error>> {
let source = b"#include \"myheader.h\"\n";
let path = Path::new("test.c");
let result = extract_cpp_imports(path, source);
assert!(result.is_ok());
let imports = result?;
assert_eq!(imports.len(), 1);
assert_eq!(imports[0].import_kind, ImportKind::CppLocalInclude);
assert_eq!(imports[0].path, vec!["myheader.h"]);
Ok(())
}
}