use super::language::{CodeLanguage, TokenContext, TokenType};
use tree_sitter::{Node, Tree};
#[derive(Debug, Clone)]
pub struct CodeToken {
pub text: String,
pub byte_offset: usize,
pub line: usize,
pub column: usize,
pub token_type: TokenType,
pub node_kind: String,
pub context: TokenContext,
}
impl CodeToken {
pub fn new(
text: impl Into<String>,
byte_offset: usize,
line: usize,
column: usize,
token_type: TokenType,
node_kind: impl Into<String>,
) -> Self {
let token_type_copy = token_type;
Self {
text: text.into(),
byte_offset,
line,
column,
token_type,
node_kind: node_kind.into(),
context: TokenContext::new(token_type_copy),
}
}
pub fn is_in_error(&self) -> bool {
self.context.in_error_region
}
pub fn is_correctable(&self) -> bool {
self.token_type.is_correctable()
}
}
pub struct CodeTokenizer<'a, L: CodeLanguage> {
language: &'a L,
include_whitespace: bool,
include_comments: bool,
}
impl<'a, L: CodeLanguage> CodeTokenizer<'a, L> {
pub fn new(language: &'a L) -> Self {
Self {
language,
include_whitespace: false,
include_comments: false,
}
}
pub fn with_whitespace(mut self, include: bool) -> Self {
self.include_whitespace = include;
self
}
pub fn with_comments(mut self, include: bool) -> Self {
self.include_comments = include;
self
}
pub fn tokenize(&self, tree: &Tree, source: &str) -> Vec<CodeToken> {
let mut tokens = Vec::new();
self.traverse_node(tree.root_node(), source, &mut tokens, 0, false);
tokens
}
pub fn tokenize_errors(&self, tree: &Tree, source: &str) -> Vec<CodeToken> {
let mut tokens = Vec::new();
self.collect_error_tokens(tree.root_node(), source, &mut tokens, 0);
tokens
}
fn traverse_node(
&self,
node: Node,
source: &str,
tokens: &mut Vec<CodeToken>,
depth: usize,
in_error: bool,
) {
let in_error = in_error || node.is_error();
if node.child_count() == 0 {
if let Some(token) = self.create_token(node, source, depth, in_error) {
tokens.push(token);
}
} else {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.traverse_node(child, source, tokens, depth + 1, in_error);
}
}
}
fn collect_error_tokens(
&self,
node: Node,
source: &str,
tokens: &mut Vec<CodeToken>,
depth: usize,
) {
if node.is_error() || node.is_missing() {
self.traverse_node(node, source, tokens, depth, true);
} else {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.collect_error_tokens(child, source, tokens, depth + 1);
}
}
}
fn create_token(
&self,
node: Node,
source: &str,
depth: usize,
in_error: bool,
) -> Option<CodeToken> {
let text = node.utf8_text(source.as_bytes()).ok()?;
let node_kind = node.kind();
let token_type = self.language.classify_token(text, node_kind);
match token_type {
TokenType::Whitespace if !self.include_whitespace => return None,
TokenType::Comment if !self.include_comments => return None,
_ => {}
}
let start = node.start_position();
let mut token = CodeToken::new(
text,
node.start_byte(),
start.row,
start.column,
token_type,
node_kind,
);
token.context.depth = depth;
if in_error {
token.context.in_error_region = true;
}
if let Some(parent) = node.parent() {
token.context.parent_node_type = Some(parent.kind().to_string());
let mut cursor = parent.walk();
token.context.sibling_types = parent
.children(&mut cursor)
.filter(|c| c.id() != node.id())
.map(|c| c.kind().to_string())
.collect();
}
Some(token)
}
}
pub struct TokenIterator<L: CodeLanguage> {
tokens: Vec<CodeToken>,
position: usize,
_marker: std::marker::PhantomData<L>,
}
impl<'a, L: CodeLanguage> TokenIterator<L> {
pub fn new(tokenizer: CodeTokenizer<'a, L>, tree: Tree, source: String) -> Self {
let tokens = tokenizer.tokenize(&tree, &source);
Self {
tokens,
position: 0,
_marker: std::marker::PhantomData,
}
}
}
impl<L: CodeLanguage> Iterator for TokenIterator<L> {
type Item = CodeToken;
fn next(&mut self) -> Option<Self::Item> {
if self.position < self.tokens.len() {
let token = self.tokens[self.position].clone();
self.position += 1;
Some(token)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_code_token_creation() {
let token = CodeToken::new("test", 0, 1, 5, TokenType::Identifier, "identifier");
assert_eq!(token.text, "test");
assert_eq!(token.byte_offset, 0);
assert_eq!(token.line, 1);
assert_eq!(token.column, 5);
assert_eq!(token.token_type, TokenType::Identifier);
assert!(token.is_correctable());
assert!(!token.is_in_error());
}
}