mod ast;
mod embedding;
mod mpnn;
mod path;
pub use ast::{AstNode, AstNodeType, Token, TokenType};
pub use embedding::{Code2VecEncoder, CodeEmbedding};
pub use mpnn::{
pooling, CodeEdgeType, CodeGraph, CodeGraphEdge, CodeGraphNode, CodeMPNN, CodeMPNNLayer,
};
pub use path::{AstPath, PathContext, PathExtractor};
pub const MAX_PATH_LENGTH: usize = 8;
pub const MAX_PATHS_PER_METHOD: usize = 200;
pub const DEFAULT_EMBEDDING_DIM: usize = 128;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_module_constants() {
assert!(MAX_PATH_LENGTH > 0);
assert!(MAX_PATHS_PER_METHOD > 0);
assert!(DEFAULT_EMBEDDING_DIM > 0);
}
#[test]
fn test_ast_node_creation() {
let node = AstNode::new(AstNodeType::Function, "calculate_sum");
assert_eq!(node.node_type(), AstNodeType::Function);
assert_eq!(node.value(), "calculate_sum");
assert!(node.children().is_empty());
}
#[test]
fn test_token_creation() {
let token = Token::new(TokenType::Identifier, "foo");
assert_eq!(token.token_type(), TokenType::Identifier);
assert_eq!(token.value(), "foo");
}
#[test]
fn test_path_extractor_simple() {
let mut func = AstNode::new(AstNodeType::Function, "add");
func.add_child(AstNode::new(AstNodeType::Parameter, "x"));
func.add_child(AstNode::new(AstNodeType::Parameter, "y"));
func.add_child(AstNode::new(AstNodeType::Return, "sum"));
let extractor = PathExtractor::new(MAX_PATH_LENGTH);
let paths = extractor.extract(&func);
assert!(!paths.is_empty());
}
#[test]
fn test_code2vec_encoder() {
let encoder = Code2VecEncoder::new(DEFAULT_EMBEDDING_DIM);
let path = AstPath::new(
Token::new(TokenType::Identifier, "x"),
vec![
AstNodeType::Parameter,
AstNodeType::Function,
AstNodeType::Return,
],
Token::new(TokenType::Identifier, "result"),
);
let embedding = encoder.encode_path(&path);
assert_eq!(embedding.len(), DEFAULT_EMBEDDING_DIM);
}
#[test]
fn test_code_embedding_aggregation() {
let encoder = Code2VecEncoder::new(DEFAULT_EMBEDDING_DIM);
let paths = vec![
AstPath::new(
Token::new(TokenType::Identifier, "a"),
vec![AstNodeType::Parameter, AstNodeType::Function],
Token::new(TokenType::Identifier, "b"),
),
AstPath::new(
Token::new(TokenType::Identifier, "c"),
vec![AstNodeType::Return, AstNodeType::Function],
Token::new(TokenType::Identifier, "d"),
),
];
let embedding = encoder.aggregate_paths(&paths);
assert_eq!(embedding.dim(), DEFAULT_EMBEDDING_DIM);
}
}