Skip to main content

rspack_plugin_css_chunking/
lib.rs

1use std::{
2  collections::HashSet,
3  sync::atomic::{AtomicBool, Ordering},
4};
5
6use rspack_collections::{
7  Identifier, IdentifierIndexMap, IdentifierIndexSet, IdentifierMap, IdentifierSet, UkeyMap,
8  UkeySet,
9};
10use rspack_core::{
11  ChunkUkey, Compilation, CompilationOptimizeChunks, CompilationParams, CompilerCompilation,
12  Logger, Module, ModuleIdentifier, Plugin, SourceType,
13};
14use rspack_error::Result;
15use rspack_hook::{plugin, plugin_hook};
16use rspack_plugin_css::CssPlugin;
17use rspack_regex::RspackRegex;
18
19const MIN_CSS_CHUNK_SIZE: f64 = 30_f64 * 1024_f64;
20const MAX_CSS_CHUNK_SIZE: f64 = 100_f64 * 1024_f64;
21
22fn is_global_css(name_for_condition: &Option<Box<str>>) -> bool {
23  name_for_condition.as_ref().is_some_and(|s| {
24    !s.ends_with(".module.css") && !s.ends_with(".module.scss") && !s.ends_with(".module.sass")
25  })
26}
27
28#[derive(Debug)]
29pub struct CssChunkingPluginOptions {
30  pub strict: bool,
31  pub min_size: Option<f64>,
32  pub max_size: Option<f64>,
33  pub exclude: Option<RspackRegex>,
34}
35
36#[plugin]
37#[derive(Debug)]
38pub struct CssChunkingPlugin {
39  once: AtomicBool,
40  strict: bool,
41  min_size: f64,
42  max_size: f64,
43  exclude: Option<RspackRegex>,
44}
45
46impl CssChunkingPlugin {
47  pub fn new(options: CssChunkingPluginOptions) -> Self {
48    Self::new_inner(
49      AtomicBool::new(false),
50      options.strict,
51      options.min_size.unwrap_or(MIN_CSS_CHUNK_SIZE),
52      options.max_size.unwrap_or(MAX_CSS_CHUNK_SIZE),
53      options.exclude,
54    )
55  }
56}
57
58#[plugin_hook(CompilerCompilation for CssChunkingPlugin)]
59async fn compilation(
60  &self,
61  _compilation: &mut Compilation,
62  _params: &mut CompilationParams,
63) -> Result<()> {
64  self.once.store(false, Ordering::Relaxed);
65  Ok(())
66}
67
68#[derive(Debug)]
69struct ChunkState {
70  chunk: ChunkUkey,
71  modules: Vec<ModuleIdentifier>,
72  requests: usize,
73}
74
75#[plugin_hook(CompilationOptimizeChunks for CssChunkingPlugin, stage = 5)]
76async fn optimize_chunks(&self, compilation: &mut Compilation) -> Result<Option<bool>> {
77  let strict = self.strict;
78
79  if self.once.load(Ordering::Relaxed) {
80    return Ok(None);
81  }
82  self.once.store(true, Ordering::Relaxed);
83
84  let logger = compilation.get_logger("rspack.CssChunkingPlugin");
85
86  let start = logger.time("collect all css modules and the execpted order of them");
87  let mut chunk_states: UkeyMap<ChunkUkey, ChunkState> = Default::default();
88  let mut chunk_states_by_module: IdentifierIndexMap<UkeyMap<ChunkUkey, usize>> =
89    Default::default();
90
91  // Collect all css modules in chunks and the execpted order of them
92  let chunk_graph = &compilation.chunk_graph;
93  let chunks = &compilation.chunk_by_ukey;
94  let module_graph = compilation.get_module_graph();
95
96  for (chunk_ukey, chunk) in chunks.iter() {
97    if let Some(name) = chunk.name()
98      && let Some(exclude) = &self.exclude
99      && exclude.test(name)
100    {
101      continue;
102    }
103
104    let modules: Vec<&dyn Module> = chunk_graph
105      .get_chunk_modules(chunk_ukey, module_graph)
106      .into_iter()
107      .filter(|module| {
108        module.source_types(module_graph).iter().any(|t| match t {
109          SourceType::Css => true,
110          SourceType::CssImport => true,
111          SourceType::Custom(str) => str == "css/mini-extract",
112          _ => false,
113        })
114      })
115      .map(|module| module.as_ref())
116      .collect();
117    if modules.is_empty() {
118      continue;
119    }
120    let chunk = compilation.chunk_by_ukey.expect_get(chunk_ukey);
121    let (ordered_modules, _) = CssPlugin::get_modules_in_order(chunk, modules, compilation);
122    let mut module_identifiers: Vec<ModuleIdentifier> = Vec::with_capacity(ordered_modules.len());
123    for (i, module) in ordered_modules.iter().enumerate() {
124      let module_identifier = module.identifier();
125      module_identifiers.push(module_identifier);
126
127      match chunk_states_by_module.entry(module_identifier) {
128        indexmap::map::Entry::Occupied(mut occupied_entry) => {
129          let module_chunk_states = occupied_entry.get_mut();
130          module_chunk_states.insert(*chunk_ukey, i);
131        }
132        indexmap::map::Entry::Vacant(vacant_entry) => {
133          let mut module_chunk_states = UkeyMap::default();
134          module_chunk_states.insert(*chunk_ukey, i);
135          vacant_entry.insert(module_chunk_states);
136        }
137      };
138    }
139    let requests = module_identifiers.len();
140    let chunk_state = ChunkState {
141      chunk: *chunk_ukey,
142      modules: module_identifiers,
143      requests,
144    };
145    chunk_states.insert(*chunk_ukey, chunk_state);
146  }
147
148  let module_infos: IdentifierMap<(f64, Option<Box<str>>)> = {
149    let module_graph = compilation.get_module_graph();
150    let mut result = IdentifierMap::default();
151    for module_identifier in chunk_states_by_module.keys() {
152      #[allow(clippy::unwrap_used)]
153      let module = module_graph
154        .module_by_identifier(module_identifier)
155        .unwrap();
156      let size = module.size(None, None);
157      result.insert(*module_identifier, (size, module.name_for_condition()));
158    }
159    result
160  };
161  logger.time_end(start);
162
163  // Sort modules by their index sum
164  let start = logger.time("sort modules by their index sum");
165  let mut ordered_modules: Vec<(ModuleIdentifier, usize)> = chunk_states_by_module
166    .iter()
167    .map(|(module_identifier, module_states)| {
168      let sum = module_states.values().sum();
169      (*module_identifier, sum)
170    })
171    .collect();
172  ordered_modules.sort_by_key(|&(_module, sum)| sum);
173  let mut remaining_modules: IdentifierIndexSet = ordered_modules
174    .into_iter()
175    .map(|(module_identifier, _)| module_identifier)
176    .collect();
177  logger.time_end(start);
178
179  // In loose mode we guess the dependents of modules from the order
180  // assuming that when a module is a dependency of another module
181  // it will always appear before it in every chunk.
182  let mut all_dependents: IdentifierMap<HashSet<ModuleIdentifier>> = IdentifierMap::default();
183  if !self.strict {
184    let start = logger.time("guess the dependents of modules from the order");
185    for b in &remaining_modules {
186      let mut dependents = HashSet::new();
187      'outer: for a in &remaining_modules {
188        if a == b {
189          continue;
190        }
191        let a_states = &chunk_states_by_module[a];
192        let b_states = &chunk_states_by_module[b];
193        // check if a depends on b
194        for (chunk_ukey, ia) in a_states {
195          match b_states.get(chunk_ukey) {
196            // If a would depend on b, it would be included in that chunk group too
197            None => continue 'outer,
198            // If a would depend on b, b would be before a in order
199            Some(&ib) if ib > *ia => continue 'outer,
200            _ => {}
201          }
202        }
203        dependents.insert(*a);
204      }
205      if !dependents.is_empty() {
206        all_dependents.insert(*b, dependents);
207      }
208    }
209    logger.time_end(start);
210  }
211
212  // Stores the new chunk for every module
213  let mut new_chunks_by_module: IdentifierMap<ChunkUkey> = IdentifierMap::default();
214
215  // Process through all modules
216  let start = logger.time("process through all modules");
217  loop {
218    let Some(start_module_identifier) = remaining_modules.iter().next().cloned() else {
219      break;
220    };
221    remaining_modules.shift_remove(&start_module_identifier);
222
223    #[allow(clippy::unwrap_used)]
224    let mut global_css_mode = is_global_css(&module_infos.get(&start_module_identifier).unwrap().1);
225
226    // The current position of processing in all selected chunks
227    #[allow(clippy::unwrap_used)]
228    let all_chunk_states = chunk_states_by_module
229      .get(&start_module_identifier)
230      .unwrap();
231
232    // The list of modules that goes into the new chunk
233    let mut new_chunk_modules = IdentifierSet::default();
234    new_chunk_modules.insert(start_module_identifier);
235
236    // The current size of the new chunk
237    #[allow(clippy::unwrap_used)]
238    let mut current_size = module_infos.get(&start_module_identifier).unwrap().0;
239
240    // A pool of potential modules where the next module is selected from.
241    // It's filled from the next module of the selected modules in every chunk.
242    // It also keeps some metadata to improve performance [size, chunkStates].
243    let mut potential_next_modules: IdentifierIndexMap<f64> = Default::default();
244    for (chunk_ukey, i) in all_chunk_states {
245      #[allow(clippy::unwrap_used)]
246      let chunk_state = chunk_states.get(chunk_ukey).unwrap();
247      if let Some(next_module_identifier) = chunk_state.modules.get(i + 1)
248        && remaining_modules.contains(next_module_identifier)
249      {
250        #[allow(clippy::unwrap_used)]
251        let next_module_size = module_infos.get(next_module_identifier).unwrap().0;
252        potential_next_modules.insert(*next_module_identifier, next_module_size);
253      }
254    }
255
256    // Try to add modules to the chunk until a break condition is met
257    let mut cont = true;
258    while cont {
259      cont = false;
260
261      // We try to select a module that reduces request count and
262      // has the highest number of requests
263      #[allow(clippy::unwrap_used)]
264      let all_chunk_states = chunk_states_by_module
265        .get(&start_module_identifier)
266        .unwrap();
267      let mut ordered_potential_next_modules: Vec<(Identifier, f64, usize)> =
268        potential_next_modules
269          .iter()
270          .map(|(next_module_identifier, size)| {
271            #[allow(clippy::unwrap_used)]
272            let next_chunk_states = chunk_states_by_module.get(next_module_identifier).unwrap();
273            let mut max_requests = 0;
274            for next_chunk_ukey in next_chunk_states.keys() {
275              // There is always some overlap
276              if all_chunk_states.contains_key(next_chunk_ukey) {
277                #[allow(clippy::unwrap_used)]
278                let chunk_state = chunk_states.get(next_chunk_ukey).unwrap();
279                max_requests = max_requests.max(chunk_state.requests);
280              }
281            }
282            (*next_module_identifier, *size, max_requests)
283          })
284          .collect();
285      ordered_potential_next_modules.sort_by(|a, b| b.2.cmp(&a.2).then_with(|| a.0.cmp(&b.0)));
286
287      // Try every potential module
288      'outer: for (next_module_identifier, size, _) in ordered_potential_next_modules {
289        if current_size + size > self.max_size {
290          // Chunk would be too large
291          continue;
292        }
293        #[allow(clippy::unwrap_used)]
294        let next_chunk_states = chunk_states_by_module
295          .get(&next_module_identifier)
296          .cloned()
297          .unwrap();
298        if !strict {
299          // In loose mode we only check if the dependencies are not violated
300          if let Some(deps) = all_dependents.get(&next_module_identifier) {
301            let new_chunk_modules_ref = &new_chunk_modules;
302            if deps.iter().any(|d| new_chunk_modules_ref.contains(d)) {
303              continue;
304            }
305          }
306        } else {
307          // In strict mode we check that none of the order in any chunk is changed by adding the module
308          for (chunk_ukey, i) in &next_chunk_states {
309            match all_chunk_states.get(chunk_ukey) {
310              None => {
311                // New chunk group, can add it, but should we?
312                // We only add that if below min size
313                if current_size < self.min_size {
314                  continue;
315                } else {
316                  continue 'outer;
317                }
318              }
319              Some(&prev_idx) if prev_idx + 1 == *i => {}
320              _ => continue 'outer,
321            }
322          }
323        }
324
325        // Global CSS must not leak into unrelated chunks
326        #[allow(clippy::unwrap_used)]
327        let is_global = is_global_css(&module_infos.get(&next_module_identifier).unwrap().1);
328        if is_global && global_css_mode && all_chunk_states.len() != next_chunk_states.len() {
329          // Fast check: chunk groups need to be identical
330          continue;
331        }
332        if global_css_mode
333          && next_chunk_states
334            .keys()
335            .any(|cs| !all_chunk_states.contains_key(cs))
336        {
337          continue;
338        }
339        if is_global
340          && all_chunk_states
341            .keys()
342            .any(|cs| !next_chunk_states.contains_key(cs))
343        {
344          continue;
345        }
346        potential_next_modules.shift_remove(&next_module_identifier);
347        current_size += size;
348        if is_global {
349          global_css_mode = true;
350        }
351        #[allow(clippy::unwrap_used)]
352        let all_chunk_states = chunk_states_by_module
353          .get_mut(&start_module_identifier)
354          .unwrap();
355        for (chunk_ukey, i) in next_chunk_states {
356          #[allow(clippy::unwrap_used)]
357          let chunk_state = chunk_states.get_mut(&chunk_ukey).unwrap();
358          if all_chunk_states.contains_key(&chunk_ukey) {
359            // This reduces the request count of the chunk group
360            chunk_state.requests -= 1;
361          }
362          all_chunk_states.insert(chunk_ukey, i);
363          if let Some(next_module_identifier) = chunk_state.modules.get(i + 1)
364            && remaining_modules.contains(next_module_identifier)
365            && !new_chunk_modules.contains(next_module_identifier)
366          {
367            #[allow(clippy::unwrap_used)]
368            let next_module_size = module_infos.get(next_module_identifier).unwrap().0;
369            potential_next_modules.insert(*next_module_identifier, next_module_size);
370          }
371        }
372        new_chunk_modules.insert(next_module_identifier);
373        cont = true;
374        break;
375      }
376    }
377    let new_chunk_ukey = Compilation::add_chunk(&mut compilation.chunk_by_ukey);
378    #[allow(clippy::unwrap_used)]
379    let new_chunk = compilation.chunk_by_ukey.get_mut(&new_chunk_ukey).unwrap();
380    new_chunk.prevent_integration();
381    new_chunk.add_id_name_hints("css".to_string());
382    let chunk_graph = &mut compilation.chunk_graph;
383    for module_identifier in &new_chunk_modules {
384      remaining_modules.shift_remove(module_identifier);
385      chunk_graph.connect_chunk_and_module(new_chunk_ukey, *module_identifier);
386      new_chunks_by_module.insert(*module_identifier, new_chunk_ukey);
387    }
388  }
389  logger.time_end(start);
390
391  let start = logger.time("apply split chunks");
392  let chunk_graph = &mut compilation.chunk_graph;
393  for chunk_state in chunk_states.values() {
394    let mut chunks: UkeySet<ChunkUkey> = UkeySet::default();
395    for module_identifier in &chunk_state.modules {
396      if let Some(new_chunk_ukey) = new_chunks_by_module.get(module_identifier) {
397        chunk_graph.disconnect_chunk_and_module(&chunk_state.chunk, *module_identifier);
398        if chunks.contains(new_chunk_ukey) {
399          continue;
400        }
401        chunks.insert(*new_chunk_ukey);
402        let chunk_by_ukey = &mut compilation.chunk_by_ukey;
403        let [chunk, new_chunk] = chunk_by_ukey.get_many_mut([&chunk_state.chunk, new_chunk_ukey]);
404        #[allow(clippy::unwrap_used)]
405        chunk
406          .unwrap()
407          .split(new_chunk.unwrap(), &mut compilation.chunk_group_by_ukey);
408      }
409    }
410  }
411  logger.time_end(start);
412
413  Ok(None)
414}
415
416impl Plugin for CssChunkingPlugin {
417  fn name(&self) -> &'static str {
418    "rspack.CssChunkingPlugin"
419  }
420
421  fn apply(&self, ctx: &mut rspack_core::ApplyContext<'_>) -> Result<()> {
422    ctx.compiler_hooks.compilation.tap(compilation::new(self));
423
424    ctx
425      .compilation_hooks
426      .optimize_chunks
427      .tap(optimize_chunks::new(self));
428
429    Ok(())
430  }
431}