1use std::collections::BTreeMap;
2
3use lightningcss::{
4 properties::custom::{TokenList, TokenOrValue},
5 rules::CssRule,
6 selector::{Component, PseudoClass, Selector},
7 stylesheet::{ParserOptions, PrinterOptions},
8 traits::{ParseWithOptions, ToCss as _},
9 values::ident::Ident,
10 visit_types,
11 visitor::{VisitTypes, Visitor},
12};
13use thiserror::Error;
14
15use crate::rcss_at_rule::RcssAtRuleConfig;
16
17pub(crate) struct SelectorVisitor {
18 pub append_class: String,
21 pub class_modify: Box<dyn FnMut(String) -> String>,
23 pub collect_classes: BTreeMap<String, String>,
26 pub extend: Option<syn::Path>,
28 pub declare: Option<syn::ItemStruct>,
30
31 pub state: SelectorState,
33}
34
35#[derive(Default, Clone, Debug)]
36pub struct SelectorState {
37 class_found: bool,
38 global_selector: bool,
39 deep_selector: bool,
40}
41impl SelectorState {
42 fn handle_class(&mut self) {
43 self.class_found = true;
44 }
45 fn handle_combinator(&mut self) {
46 self.class_found = false;
47 }
48}
49
50type GenericParseError = lightningcss::error::Error<lightningcss::error::ParserError<'static>>;
51#[derive(Error, Debug)]
52pub enum Error {
53 #[error("Failed to print token as css")]
54 PrintFailed(#[from] lightningcss::error::PrinterError),
55 #[error("Failed to parse token as css selector: {0}")]
56 ParseSelectorError(String),
57 #[error("Failed to parse tokens in css: {0}")]
58 GenericParser(#[from] GenericParseError),
59 #[error("Not allowed token in selector list: {0}")]
60 NotAllowedToken(String),
61}
62
63impl SelectorVisitor {
64 fn token_list_to_selector<'i>(token_list: TokenList<'i>) -> Result<Selector<'i>, Error> {
65 let mut result = String::new();
66 for token in token_list.0 {
67 match token {
68 TokenOrValue::Angle(ref angle) => {
69 result.push_str(&angle.to_css_string(PrinterOptions::default())?)
70 }
71 TokenOrValue::Token(ref token) => {
72 result.push_str(&token.to_css_string(PrinterOptions::default())?)
73 }
74 TokenOrValue::Color(ref color) => {
75 result.push_str(&color.to_css_string(PrinterOptions::default())?)
76 }
77 TokenOrValue::DashedIdent(ref ident) => {
78 result.push_str(&ident.to_css_string(PrinterOptions::default())?)
79 }
80 TokenOrValue::Length(ref length) => {
81 result.push_str(&length.to_css_string(PrinterOptions::default())?)
82 }
83 TokenOrValue::Resolution(ref resolution) => {
84 result.push_str(&resolution.to_css_string(PrinterOptions::default())?)
85 }
86 TokenOrValue::Time(ref time) => {
87 result.push_str(&time.to_css_string(PrinterOptions::default())?)
88 }
89 TokenOrValue::Url(ref url) => {
90 result.push_str(&url.to_css_string(PrinterOptions::default())?)
91 }
92 _ => return Err(Error::NotAllowedToken(format!("{:?}", token))),
93 }
94 }
95 let selector = Selector::parse_string_with_options(&result, ParserOptions::default())
96 .map_err(|e| Error::ParseSelectorError(format!("{}", e)))?;
97 use lightningcss::traits::IntoOwned;
98
99 Ok(selector.into_owned())
100 }
101 fn try_modify_parts(&mut self, selectors: &mut Selector<'_>) -> Result<(), Error> {
102 let class_name = self.append_class.clone();
103
104 let mut combinators = selectors
107 .iter_raw_match_order()
108 .rev()
109 .filter_map(|x| x.as_combinator());
110 let chunks = selectors
111 .iter_raw_match_order()
112 .as_slice()
113 .split(|x| x.is_combinator())
114 .rev();
115
116 let mut processed_selector = vec![];
118
119 for chunk in chunks {
120 if chunk.is_empty() {
121 continue;
122 }
123 for part in chunk.into_iter().cloned() {
124 let part = match part {
126 Component::Class(mut class) => {
127 self.state.handle_class();
128 if !self.state.global_selector {
130 self.modify_classes(&mut class)?;
131 }
132 Component::Class(class)
134 }
135 Component::NonTSPseudoClass(pseudo_class) => match pseudo_class {
136 PseudoClass::Global { mut selector } => {
138 self.match_global(&mut processed_selector, &mut selector)?;
139 continue;
140 }
141 PseudoClass::CustomFunction { name, arguments } => {
142 if &*name == "deep" {
143 let mut selector =
144 SelectorVisitor::token_list_to_selector(arguments.clone())?;
145
146 self.match_deep(&mut processed_selector, &mut selector)?;
147
148 continue;
149 }
150 if &*name == "global" {
151 let mut selector =
152 SelectorVisitor::token_list_to_selector(arguments.clone())?;
153
154 self.match_global(&mut processed_selector, &mut selector)?;
155
156 continue;
157 }
158 Component::NonTSPseudoClass(PseudoClass::CustomFunction {
159 name,
160 arguments,
161 })
162 }
163 pseudo_class => Component::NonTSPseudoClass(pseudo_class),
164 },
165 rest => rest,
166 };
167 processed_selector.push(part)
168 }
169 if !self.state.class_found {
170 Self::append_class(&self.state, &mut processed_selector, &class_name)?;
171 }
172 if let Some(combinator) = combinators.next() {
173 processed_selector.push(Component::Combinator(combinator));
174 }
175 self.state.handle_combinator();
176 }
177 *selectors = Selector::from(processed_selector);
179 Ok(())
180 }
181 fn append_class(
182 state: &SelectorState,
183 selector_components: &mut Vec<Component>,
184 class_name: &String,
185 ) -> Result<(), Error> {
186 if !state.deep_selector && !state.global_selector {
188 selector_components.push(Component::Class(class_name.clone().into()));
189 }
190 Ok(())
191 }
192 fn match_global<'i>(
193 &mut self,
194 selector_components: &mut Vec<Component<'i>>,
195 selector: &mut Selector<'i>,
196 ) -> Result<(), Error> {
197 let mut child_state = self.state.clone();
198 child_state.global_selector = true;
199 std::mem::swap(&mut self.state, &mut child_state);
200 self.visit_selector(selector)?;
201 std::mem::swap(&mut self.state, &mut child_state);
202
203 selector_components.extend(selector.iter_raw_parse_order_from(0).cloned());
204 self.state.class_found = true;
205 Ok(())
206 }
207 fn match_deep<'i>(
208 &mut self,
209 selector_components: &mut Vec<Component<'i>>,
210 selector: &mut Selector<'i>,
211 ) -> Result<(), Error> {
212 let mut child_state = self.state.clone();
213 child_state.deep_selector = true;
214 std::mem::swap(&mut self.state, &mut child_state);
215 self.visit_selector(selector)?;
216 std::mem::swap(&mut self.state, &mut child_state);
217
218 selector_components.extend(selector.iter_raw_parse_order_from(0).cloned());
219 self.state.class_found = true;
220 Ok(())
221 }
222 fn modify_classes(&mut self, class: &mut Ident<'_>) -> Result<(), Error> {
223 let class_string = class.to_css_string(PrinterOptions::default())?;
224 let modified = (*self.class_modify)(class_string.clone());
225 self.collect_classes.insert(class_string, modified.clone());
226 *class = modified.into();
227 Ok(())
228 }
229 fn save_rcss_rule(&mut self, rcss_rule: RcssAtRuleConfig) {
230 match rcss_rule {
232 RcssAtRuleConfig::Struct(item_struct) => self.declare = Some(item_struct),
233 RcssAtRuleConfig::Extend(path) => self.extend = Some(path),
234 }
235 }
236}
237impl<'i> lightningcss::visitor::Visitor<'i, crate::rcss_at_rule::RcssAtRuleConfig>
238 for SelectorVisitor
239{
240 type Error = Error;
241 fn visit_types(&self) -> VisitTypes {
242 visit_types!(SELECTORS | RULES)
243 }
244
245 fn visit_selector(&mut self, fragment: &mut Selector<'i>) -> Result<(), Self::Error> {
246 self.state.class_found = false;
248 self.try_modify_parts(fragment)?;
249
250 Ok(())
251 }
252 fn visit_rule(
253 &mut self,
254 rule: &mut CssRule<'i, crate::rcss_at_rule::RcssAtRuleConfig>,
255 ) -> Result<(), Self::Error> {
256 match rule {
257 CssRule::Custom(rcss) => {
258 self.save_rcss_rule(rcss.clone());
259 *rule = CssRule::Ignored;
260 }
261 rule => {
262 use lightningcss::visitor::Visit;
263 rule.visit_children(self)?;
264 }
265 }
266 Ok(())
267 }
268}