use std::collections::HashMap;
use cairo_lang_defs::db::DefsGroup;
use cairo_lang_filesystem::span::{TextOffset, TextSpan};
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::{SyntaxNode, TypedSyntaxNode};
use cairo_lang_utils::extract_matches;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum RewriteNode {
Trimmed(SyntaxNode),
Copied(SyntaxNode),
Modified(ModifiedNode),
Text(String),
}
impl RewriteNode {
pub fn from_ast<T: TypedSyntaxNode>(node: &T) -> Self {
RewriteNode::Copied(node.as_syntax_node())
}
pub fn modify(&mut self, db: &dyn SyntaxGroup) -> &mut ModifiedNode {
match self {
RewriteNode::Copied(syntax_node) => {
*self = RewriteNode::Modified(ModifiedNode {
children: syntax_node.children(db).map(RewriteNode::Copied).collect(),
});
extract_matches!(self, RewriteNode::Modified)
}
RewriteNode::Trimmed(_) => {
panic!("Not supported.")
}
RewriteNode::Modified(modified) => modified,
RewriteNode::Text(_) => {
*self = RewriteNode::Modified(ModifiedNode { children: vec![] });
extract_matches!(self, RewriteNode::Modified)
}
}
}
pub fn modify_child(&mut self, db: &dyn SyntaxGroup, index: usize) -> &mut RewriteNode {
&mut self.modify(db).children[index]
}
pub fn set_str(&mut self, s: String) {
*self = RewriteNode::Text(s)
}
pub fn interpolate_patched(code: &str, patches: HashMap<String, RewriteNode>) -> RewriteNode {
let mut chars = code.chars().peekable();
let mut pending_text = String::new();
let mut children = Vec::new();
while let Some(c) = chars.next() {
if c != '$' {
pending_text.push(c);
continue;
}
let mut name = String::new();
for c in chars.by_ref() {
if c == '$' {
break;
}
name.push(c);
}
if name.is_empty() {
pending_text.push('$');
continue;
}
if !pending_text.is_empty() {
children.push(RewriteNode::Text(pending_text.clone()));
pending_text.clear();
}
children.push(patches[&name].clone());
}
if !pending_text.is_empty() {
children.push(RewriteNode::Text(pending_text.clone()));
}
RewriteNode::Modified(ModifiedNode { children })
}
}
impl From<SyntaxNode> for RewriteNode {
fn from(node: SyntaxNode) -> Self {
RewriteNode::Copied(node)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ModifiedNode {
pub children: Vec<RewriteNode>,
}
#[derive(Debug, PartialEq, Eq)]
pub struct Patch {
span: TextSpan,
origin_span: TextSpan,
}
#[derive(Debug, Default, PartialEq, Eq)]
pub struct Patches {
patches: Vec<Patch>,
}
impl Patches {
pub fn translate(&self, _db: &dyn DefsGroup, span: TextSpan) -> Option<TextSpan> {
for Patch { span: patch_span, origin_span } in &self.patches {
if patch_span.contains(span) {
let start = origin_span.start.add(span.start - patch_span.start);
return Some(TextSpan { start, end: start.add(span.end - span.start) });
}
}
None
}
}
pub struct PatchBuilder<'a> {
pub db: &'a dyn SyntaxGroup,
pub code: String,
pub patches: Patches,
}
impl<'a> PatchBuilder<'a> {
pub fn new(db: &'a dyn SyntaxGroup) -> Self {
Self { db, code: String::default(), patches: Patches::default() }
}
pub fn add_char(&mut self, c: char) {
self.code.push(c);
}
pub fn add_str(&mut self, s: &str) {
self.code += s;
}
pub fn add_modified(&mut self, node: RewriteNode) {
match node {
RewriteNode::Copied(node) => self.add_node(node),
RewriteNode::Trimmed(node) => self.add_trimmed_node(node),
RewriteNode::Modified(modified) => {
for child in modified.children {
self.add_modified(child)
}
}
RewriteNode::Text(s) => self.add_str(s.as_str()),
}
}
pub fn add_node(&mut self, node: SyntaxNode) {
let orig_span = node.span(self.db);
let start = TextOffset(self.code.len());
self.patches.patches.push(Patch {
span: TextSpan { start, end: start.add(orig_span.end - orig_span.start) },
origin_span: node.span(self.db),
});
self.code += node.get_text(self.db).as_str();
}
pub fn add_trimmed_node(&mut self, node: SyntaxNode) {
let origin_span = node.span_without_trivia(self.db);
let text = node.get_text_of_span(self.db, origin_span);
let start = TextOffset(self.code.len());
self.code += &text;
self.patches
.patches
.push(Patch { span: TextSpan { start, end: start.add(text.len()) }, origin_span });
}
}