use super::ast::{AstNode, AstNodeType, Token, TokenType};
#[derive(Debug, Clone)]
pub struct AstPath {
source: Token,
path_nodes: Vec<AstNodeType>,
target: Token,
}
impl AstPath {
#[must_use]
pub fn new(source: Token, path_nodes: Vec<AstNodeType>, target: Token) -> Self {
Self {
source,
path_nodes,
target,
}
}
#[must_use]
pub fn source(&self) -> &Token {
&self.source
}
#[must_use]
pub fn path_nodes(&self) -> &[AstNodeType] {
&self.path_nodes
}
#[must_use]
pub fn target(&self) -> &Token {
&self.target
}
#[must_use]
pub fn len(&self) -> usize {
self.path_nodes.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.path_nodes.is_empty()
}
#[must_use]
pub fn to_path_string(&self) -> String {
let path_str: String = self
.path_nodes
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join("↑↓");
format!(
"{}|{}|{}",
self.source.value(),
path_str,
self.target.value()
)
}
}
#[derive(Debug, Clone)]
pub struct PathContext {
pub path: AstPath,
pub source_index: usize,
pub target_index: usize,
}
impl PathContext {
#[must_use]
pub fn new(path: AstPath, source_index: usize, target_index: usize) -> Self {
Self {
path,
source_index,
target_index,
}
}
}
#[derive(Debug, Clone)]
pub struct PathExtractor {
max_path_length: usize,
max_paths: usize,
}
impl PathExtractor {
#[must_use]
pub fn new(max_path_length: usize) -> Self {
Self {
max_path_length,
max_paths: super::MAX_PATHS_PER_METHOD,
}
}
#[must_use]
pub fn with_max_paths(mut self, max_paths: usize) -> Self {
self.max_paths = max_paths;
self
}
#[must_use]
pub fn extract(&self, root: &AstNode) -> Vec<AstPath> {
let terminals_with_paths = Self::collect_terminals_with_paths(root, Vec::new());
if terminals_with_paths.len() < 2 {
return Vec::new();
}
let mut paths = Vec::new();
for i in 0..terminals_with_paths.len() {
for j in (i + 1)..terminals_with_paths.len() {
let path =
Self::extract_path_between(&terminals_with_paths[i], &terminals_with_paths[j]);
if path.len() <= self.max_path_length {
paths.push(path);
if paths.len() >= self.max_paths {
return paths;
}
}
}
}
paths
}
#[must_use]
pub fn extract_with_context(&self, root: &AstNode) -> Vec<PathContext> {
let terminals_with_paths = Self::collect_terminals_with_paths(root, Vec::new());
if terminals_with_paths.len() < 2 {
return Vec::new();
}
let mut contexts = Vec::new();
for i in 0..terminals_with_paths.len() {
for j in (i + 1)..terminals_with_paths.len() {
let path =
Self::extract_path_between(&terminals_with_paths[i], &terminals_with_paths[j]);
if path.len() <= self.max_path_length {
contexts.push(PathContext::new(path, i, j));
if contexts.len() >= self.max_paths {
return contexts;
}
}
}
}
contexts
}
fn collect_terminals_with_paths(
node: &AstNode,
current_path: Vec<AstNodeType>,
) -> Vec<TerminalWithPath<'_>> {
let mut path = current_path;
path.push(node.node_type());
if node.is_terminal() {
vec![TerminalWithPath {
node,
path_from_root: path,
}]
} else {
node.children()
.iter()
.flat_map(|child| Self::collect_terminals_with_paths(child, path.clone()))
.collect()
}
}
fn extract_path_between(
source: &TerminalWithPath<'_>,
target: &TerminalWithPath<'_>,
) -> AstPath {
let lca_depth = Self::find_lca_depth(&source.path_from_root, &target.path_from_root);
let mut path_nodes = Vec::new();
for node_type in source.path_from_root[lca_depth..].iter().rev() {
path_nodes.push(*node_type);
}
for node_type in &target.path_from_root[(lca_depth + 1)..] {
path_nodes.push(*node_type);
}
let source_token = Self::node_to_token(source.node);
let target_token = Self::node_to_token(target.node);
AstPath::new(source_token, path_nodes, target_token)
}
fn find_lca_depth(path1: &[AstNodeType], path2: &[AstNodeType]) -> usize {
let mut lca_depth = 0;
for (i, (n1, n2)) in path1.iter().zip(path2.iter()).enumerate() {
if n1 == n2 {
lca_depth = i;
} else {
break;
}
}
lca_depth
}
fn node_to_token(node: &AstNode) -> Token {
if let Some(token) = node.token() {
token.clone()
} else {
let token_type = match node.node_type() {
AstNodeType::Literal => TokenType::Number,
AstNodeType::TypeAnnotation | AstNodeType::Generic => TokenType::TypeName,
AstNodeType::BinaryOp | AstNodeType::UnaryOp => TokenType::Operator,
_ => TokenType::Identifier,
};
Token::new(token_type, node.value())
}
}
}
struct TerminalWithPath<'a> {
node: &'a AstNode,
path_from_root: Vec<AstNodeType>,
}
impl Default for PathExtractor {
fn default() -> Self {
Self::new(super::MAX_PATH_LENGTH)
}
}
#[cfg(test)]
#[path = "path_tests.rs"]
mod tests;