use tree_sitter::Node;
pub fn get_node_text<'a>(node: &Node, source: &'a str) -> &'a str {
&source[node.start_byte()..node.end_byte()]
}
pub fn get_node_text_owned(node: &Node, source: &str) -> String {
source[node.start_byte()..node.end_byte()].to_string()
}
pub fn find_containing_function<'a>(node: &Node<'a>) -> Option<Node<'a>> {
let mut current = Some(*node);
while let Some(n) = current {
if n.kind() == "function_definition" {
return Some(n);
}
current = n.parent();
}
None
}
pub fn is_inside_loop(node: &Node) -> bool {
let mut current = node.parent();
while let Some(parent) = current {
match parent.kind() {
"for_statement" | "while_statement" | "do_statement" => return true,
"function_definition" => return false, _ => current = parent.parent(),
}
}
false
}
#[allow(dead_code)]
pub fn is_inside_conditional(node: &Node) -> bool {
let mut current = node.parent();
while let Some(parent) = current {
match parent.kind() {
"if_statement" | "switch_statement" => return true,
"function_definition" => return false, _ => current = parent.parent(),
}
}
false
}
pub fn get_identifier_from_declarator(declarator: &Node, source: &str) -> String {
match declarator.kind() {
"identifier" => get_node_text_owned(declarator, source),
"pointer_declarator"
| "array_declarator"
| "function_declarator"
| "parenthesized_declarator" => {
for i in 0..declarator.child_count() {
if let Some(child) = declarator.child(i) {
if child.kind() == "identifier" {
return get_node_text_owned(&child, source);
}
let nested = get_identifier_from_declarator(&child, source);
if !nested.is_empty() {
return nested;
}
}
}
String::new() }
_ => String::new(), }
}
pub fn find_identifier_in_declarator(declarator: &Node, source: &str) -> Option<String> {
for i in 0..declarator.child_count() {
if let Some(child) = declarator.child(i) {
if child.kind() == "identifier" {
return Some(get_node_text_owned(&child, source));
} else if matches!(
child.kind(),
"array_declarator"
| "pointer_declarator"
| "function_declarator"
| "parenthesized_declarator"
) {
if let Some(id) = find_identifier_in_declarator(&child, source) {
return Some(id);
}
}
}
}
None
}
pub fn get_function_parameters(
function_node: &Node,
source: &str,
) -> Option<Vec<(String, String)>> {
for i in 0..function_node.child_count() {
if let Some(child) = function_node.child(i) {
if child.kind() == "function_declarator" {
return extract_parameters(&child, source);
}
}
}
None
}
fn extract_parameters(declarator_node: &Node, source: &str) -> Option<Vec<(String, String)>> {
let mut parameters = Vec::new();
for i in 0..declarator_node.child_count() {
if let Some(child) = declarator_node.child(i) {
if child.kind() == "parameter_list" {
for j in 0..child.child_count() {
if let Some(param) = child.child(j) {
if param.kind() == "parameter_declaration" {
if let Some((name, param_type)) = extract_parameter_info(¶m, source)
{
parameters.push((name, param_type));
}
}
}
}
}
}
}
if parameters.is_empty() {
None
} else {
Some(parameters)
}
}
fn extract_parameter_info(param_node: &Node, source: &str) -> Option<(String, String)> {
let param_text = get_node_text(param_node, source);
for i in 0..param_node.child_count() {
if let Some(child) = param_node.child(i) {
if matches!(
child.kind(),
"array_declarator" | "pointer_declarator" | "function_declarator"
) {
if let Some(identifier) = find_identifier_in_declarator(&child, source) {
return Some((identifier, param_text.to_string()));
}
} else if child.kind() == "identifier" {
let name = get_node_text(&child, source);
return Some((name.to_string(), param_text.to_string()));
}
}
}
None
}
pub fn is_function_parameter(function_node: &Node, var_name: &str, source: &str) -> bool {
for i in 0..function_node.child_count() {
if let Some(child) = function_node.child(i) {
if child.kind() == "function_declarator" {
for j in 0..child.child_count() {
if let Some(param_list) = child.child(j) {
if param_list.kind() == "parameter_list" {
let param_text = get_node_text(¶m_list, source);
let words: Vec<&str> = param_text
.split(|c: char| !c.is_alphanumeric() && c != '_')
.collect();
if words.contains(&var_name) {
return true;
}
}
}
}
}
}
}
false
}
pub fn is_array_parameter_type(param_type: &str) -> bool {
param_type.contains('[') || (param_type.contains('*') && !param_type.contains("const char *"))
}
pub fn is_pointer_type(type_str: &str) -> bool {
type_str.contains('*')
}
#[allow(dead_code)]
pub fn is_signed_type(type_str: &str) -> bool {
matches!(
type_str.trim(),
"int"
| "short"
| "long"
| "char"
| "signed"
| "signed int"
| "signed short"
| "signed long"
| "signed char"
| "int8_t"
| "int16_t"
| "int32_t"
| "int64_t"
| "ptrdiff_t"
| "ssize_t"
)
}
#[allow(dead_code)]
pub fn is_unsigned_type(type_str: &str) -> bool {
type_str.contains("unsigned")
|| matches!(
type_str.trim(),
"size_t" | "uint8_t" | "uint16_t" | "uint32_t" | "uint64_t" | "uintptr_t" | "uintmax_t"
)
}
pub fn get_binary_operator<'a>(node: &Node, source: &'a str) -> Option<&'a str> {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
let kind = child.kind();
if matches!(
kind,
"+" | "-"
| "*"
| "/"
| "%"
| "=="
| "!="
| "<"
| ">"
| "<="
| ">="
| "&&"
| "||"
| "&"
| "|"
| "^"
| "<<"
| ">>"
| "="
| "+="
| "-="
| "*="
| "/="
| "%="
| "&="
| "|="
| "^="
| "<<="
| ">>="
) {
return Some(get_node_text(&child, source));
}
}
}
None
}
#[allow(dead_code)]
pub fn find_array_size(array_name: &str, preceding_text: &str) -> Option<usize> {
let pattern = format!("{}[", array_name);
if let Some(pos) = preceding_text.rfind(&pattern) {
let after_bracket = &preceding_text[pos + pattern.len()..];
if let Some(close_bracket) = after_bracket.find(']') {
let size_str = after_bracket[..close_bracket].trim();
if let Ok(size) = size_str.parse::<usize>() {
return Some(size);
}
if size_str.contains('*') {
let parts: Vec<&str> = size_str.split('*').collect();
if parts.len() == 2 {
if let (Ok(a), Ok(b)) = (
parts[0].trim().parse::<usize>(),
parts[1].trim().parse::<usize>(),
) {
return Some(a * b);
}
}
}
}
}
None
}
#[allow(dead_code)]
pub fn get_type_size(type_name: &str) -> usize {
match type_name.trim() {
"char" | "signed char" | "unsigned char" | "int8_t" | "uint8_t" => 1,
"short" | "signed short" | "unsigned short" | "int16_t" | "uint16_t" => 2,
"int" | "signed int" | "unsigned int" | "int32_t" | "uint32_t" | "float" => 4,
"long" | "signed long" | "unsigned long" | "long long" | "signed long long"
| "unsigned long long" | "int64_t" | "uint64_t" | "double" | "size_t" | "ptrdiff_t" => 8,
"long double" => 16,
t if t.ends_with('*') => 8, _ => 4, }
}
pub fn is_write_context(node: &Node) -> bool {
let mut current = *node;
loop {
if let Some(parent) = current.parent() {
if parent.kind() == "assignment_expression" {
if let Some(left) = parent.child_by_field_name("left") {
return left.id() == current.id();
}
return false;
} else if parent.kind() == "subscript_expression" {
current = parent;
} else {
return false;
}
} else {
return false;
}
}
}
#[allow(dead_code)]
pub fn is_in_sizeof(node: &Node) -> bool {
let mut current = node.parent();
while let Some(parent) = current {
if parent.kind() == "sizeof_expression" {
return true;
}
if parent.kind() == "function_definition" {
return false;
}
current = parent.parent();
}
false
}
pub fn find_containing_for_loop<'a>(node: &Node<'a>) -> Option<Node<'a>> {
let mut current = node.parent();
while let Some(n) = current {
if n.kind() == "for_statement" {
return Some(n);
}
current = n.parent();
}
None
}
pub fn find_containing_if_statement<'a>(node: &Node<'a>) -> Option<Node<'a>> {
let mut current = node.parent();
while let Some(n) = current {
if n.kind() == "if_statement" {
return Some(n);
}
current = n.parent();
}
None
}
pub fn extract_struct_name_from_type(type_str: &str) -> Option<&str> {
let trimmed = type_str.trim();
let mut base = trimmed
.trim_end_matches('*')
.trim_end()
.trim_end_matches("const")
.trim_end_matches("volatile")
.trim();
loop {
let next = base
.strip_prefix("const ")
.or_else(|| base.strip_prefix("volatile "))
.unwrap_or(base)
.trim();
if next == base {
break;
}
base = next;
}
if matches!(
base,
"int"
| "unsigned int"
| "signed int"
| "short"
| "unsigned short"
| "long"
| "unsigned long"
| "long long"
| "unsigned long long"
| "char"
| "unsigned char"
| "signed char"
| "float"
| "double"
| "void"
| "_Bool"
) {
return None;
}
if base.ends_with("_t")
&& (base.starts_with("int") || base.starts_with("uint") || base.starts_with("size"))
{
return None;
}
if let Some(name) = base.strip_prefix("struct ") {
let name = name.trim();
if !name.is_empty() && name.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Some(name);
}
return None;
}
if !base.is_empty()
&& base
.chars()
.next()
.is_some_and(|c| c.is_alphabetic() || c == '_')
&& base.chars().all(|c| c.is_alphanumeric() || c == '_')
{
return Some(base);
}
None
}
pub fn resolve_field_expression_type(
node: &Node,
source: &str,
type_map: &std::collections::HashMap<String, String>,
struct_field_types: &std::collections::HashMap<
String,
std::collections::HashMap<String, String>,
>,
) -> Option<String> {
let field_node = node.child_by_field_name("field")?;
let field_name = field_node.utf8_text(source.as_bytes()).ok()?;
let argument = node.child_by_field_name("argument")?;
let base_type = match argument.kind() {
"identifier" => {
let base_name = argument.utf8_text(source.as_bytes()).ok()?;
type_map.get(base_name)?.clone()
}
"field_expression" => {
resolve_field_expression_type(&argument, source, type_map, struct_field_types)?
}
"pointer_expression" => {
let inner = argument.child_by_field_name("argument")?;
let inner_name = inner.utf8_text(source.as_bytes()).ok()?;
let t = type_map.get(inner_name)?;
t.strip_suffix(" *")
.or_else(|| t.strip_suffix('*'))
.map(|s| s.trim().to_string())?
}
_ => return None,
};
let struct_name = extract_struct_name_from_type(&base_type)?;
struct_field_types
.get(struct_name)
.and_then(|fields| fields.get(field_name))
.cloned()
}
#[cfg(test)]
mod tests {
use super::*;
use tree_sitter::Parser;
fn parse_c_code(code: &str) -> (tree_sitter::Tree, String) {
let mut parser = Parser::new();
let language = tree_sitter_c::language();
parser.set_language(&language).unwrap();
let tree = parser.parse(code, None).unwrap();
(tree, code.to_string())
}
#[test]
fn test_get_node_text() {
let (tree, source) = parse_c_code("int x = 5;");
let root = tree.root_node();
let text = get_node_text(&root, &source);
assert_eq!(text, "int x = 5;");
}
#[test]
fn test_find_containing_function() {
let (tree, _source) = parse_c_code("void foo() { int x = 5; }");
let root = tree.root_node();
let func_def = root.child(0).unwrap();
assert_eq!(func_def.kind(), "function_definition");
let compound_stmt = func_def.child_by_field_name("body").unwrap();
let decl = compound_stmt.child(1).unwrap();
let containing_func = find_containing_function(&decl);
assert!(containing_func.is_some());
assert_eq!(containing_func.unwrap().kind(), "function_definition");
}
#[test]
fn test_find_array_size() {
let text = "int main() { int arr[10]; }";
let size = find_array_size("arr", text);
assert_eq!(size, Some(10));
}
#[test]
fn test_is_signed_type() {
assert!(is_signed_type("int"));
assert!(is_signed_type("signed int"));
assert!(is_signed_type("int32_t"));
assert!(!is_signed_type("unsigned int"));
assert!(!is_signed_type("size_t"));
}
#[test]
fn test_is_unsigned_type() {
assert!(is_unsigned_type("unsigned int"));
assert!(is_unsigned_type("size_t"));
assert!(is_unsigned_type("uint32_t"));
assert!(!is_unsigned_type("int"));
assert!(!is_unsigned_type("signed int"));
}
#[test]
fn test_get_type_size() {
assert_eq!(get_type_size("char"), 1);
assert_eq!(get_type_size("short"), 2);
assert_eq!(get_type_size("int"), 4);
assert_eq!(get_type_size("long"), 8);
assert_eq!(get_type_size("int *"), 8);
}
}