Skip to main content

oak_pretty_print/formatter/
mod.rs

1use crate::{CommentProcessor, Document, FormatConfig, FormatResult, RuleSet, create_builtin_rules};
2use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec};
3use oak_core::{
4    language::Language,
5    tree::{RedLeaf, RedNode, RedTree},
6};
7
8/// 格式化输出
9#[derive(Debug, Clone, serde::Serialize)]
10#[serde(rename_all = "camelCase")]
11pub struct FormatOutput {
12    /// 格式化后的代码
13    pub content: String,
14    /// 是否有变化
15    pub changed: bool,
16}
17
18impl FormatOutput {
19    /// 创建新的格式化输出
20    pub fn new(content: String, changed: bool) -> Self {
21        Self { content, changed }
22    }
23}
24
25/// 路径节点,用于高效记录格式化路径
26#[derive(Debug)]
27pub struct PathNode<L: Language> {
28    pub kind: L::ElementType,
29    pub parent: Option<Arc<PathNode<L>>>,
30}
31
32/// 格式化上下文,管理格式化过程中的状态
33#[derive(Debug, Clone)]
34pub struct FormatContext<L: Language> {
35    /// 格式化配置
36    pub config: Arc<FormatConfig>,
37    /// 注释处理器
38    pub comment_processor: Arc<CommentProcessor>,
39    /// 源码内容
40    pub source: Option<Arc<str>>,
41    /// 当前嵌套深度
42    pub depth: usize,
43    /// 父节点类型路径
44    pub path: Option<Arc<PathNode<L>>>,
45}
46
47impl<L: Language> FormatContext<L> {
48    /// 创建新的格式化上下文
49    pub fn new(config: FormatConfig) -> Self {
50        let config = Arc::new(config);
51        Self { config: config.clone(), comment_processor: Arc::new(CommentProcessor::new().with_preserve_comments(config.format_comments).with_format_comments(config.format_comments)), source: None, depth: 0, path: None }
52    }
53
54    /// 进入子节点,增加深度并记录路径
55    pub fn enter(&self, kind: L::ElementType) -> Self {
56        let path = Some(Arc::new(PathNode { kind, parent: self.path.clone() }));
57        Self { config: self.config.clone(), comment_processor: self.comment_processor.clone(), source: self.source.clone(), depth: self.depth + 1, path }
58    }
59
60    /// 检查是否处于特定类型的节点内部
61    pub fn is_inside(&self, kind: L::ElementType) -> bool {
62        let mut current = self.path.as_ref();
63        while let Some(node) = current {
64            if node.kind == kind {
65                return true;
66            }
67            current = node.parent.as_ref();
68        }
69        false
70    }
71
72    /// 获取父节点类型
73    pub fn parent_kind(&self) -> Option<L::ElementType> {
74        self.path.as_ref().map(|n| n.kind.clone())
75    }
76}
77
78/// 通用格式化器
79pub struct Formatter<L: Language + 'static> {
80    /// 格式化规则集合
81    rules: RuleSet<L>,
82    /// 初始格式化上下文
83    pub context: FormatContext<L>,
84}
85
86impl<L: Language + 'static> Formatter<L>
87where
88    L::ElementType: oak_core::language::TokenType,
89{
90    /// 创建新的格式化器
91    pub fn new(config: FormatConfig) -> Self {
92        let mut formatter = Self { rules: RuleSet::new(), context: FormatContext::new(config) };
93
94        // 添加内置规则
95        for rule in create_builtin_rules::<L>() {
96            let _ = formatter.rules.add_rule(rule);
97        }
98
99        formatter
100    }
101
102    /// 添加格式化规则
103    pub fn add_rule(&mut self, rule: Box<dyn crate::FormatRule<L>>) -> FormatResult<()> {
104        self.rules.add_rule(rule)
105    }
106
107    /// 格式化 AST 节点
108    pub fn format<'a>(&mut self, root: &RedNode<L>, source: &'a str) -> FormatResult<FormatOutput> {
109        self.context.source = Some(Arc::from(source));
110        let doc = self.format_node(root, &self.context, source)?;
111        let content = doc.render((*self.context.config).clone());
112        let changed = content != source;
113        Ok(FormatOutput::new(content, changed))
114    }
115
116    /// 递归格式化节点并生成 Document
117    fn format_node<'a>(&self, node: &RedNode<L>, context: &FormatContext<L>, source: &'a str) -> FormatResult<Document<'a>> {
118        // 创建一个新的上下文,记录当前路径和深度
119        let new_context = context.enter(node.green.kind.clone());
120
121        // 创建一个用于格式化子节点的闭包
122        let format_children = |n: &RedNode<L>| {
123            let mut children_docs = Vec::new();
124            for child in n.children() {
125                match child {
126                    RedTree::Node(child_node) => children_docs.push(self.format_node(&child_node, &new_context, source)?),
127                    RedTree::Leaf(child_token) => children_docs.push(self.format_token(&child_token, &new_context, source)?),
128                }
129            }
130            Ok(Document::Concat(children_docs))
131        };
132
133        // 应用节点规则
134        if let Some(doc) = self.rules.apply_node_rules(node, &new_context, source, &format_children)? {
135            return Ok(doc);
136        }
137
138        // 默认逻辑:格式化所有子节点并连接
139        format_children(node)
140    }
141
142    /// 递归格式化 Token 并生成 Document
143    fn format_token<'a>(&self, token: &RedLeaf<L>, context: &FormatContext<L>, source: &'a str) -> FormatResult<Document<'a>> {
144        // 应用 Token 规则
145        if let Some(doc) = self.rules.apply_token_rules(token, context, source)? {
146            return Ok(doc);
147        }
148
149        // 默认逻辑:原样输出
150        let text = &source[token.span.start..token.span.end];
151        Ok(Document::Text(text.into()))
152    }
153}