1use std::collections::HashMap;
16use std::collections::HashSet;
17
18use react_compiler_diagnostics::CompilerDiagnostic;
19use react_compiler_diagnostics::CompilerDiagnosticDetail;
20use react_compiler_diagnostics::ErrorCategory;
21use react_compiler_hir::ArrayElement;
22use react_compiler_hir::DependencyPathEntry;
23use react_compiler_hir::Effect;
24use react_compiler_hir::EvaluationOrder;
25use react_compiler_hir::HirFunction;
26use react_compiler_hir::IdentifierId;
27use react_compiler_hir::IdentifierName;
28use react_compiler_hir::Instruction;
29use react_compiler_hir::InstructionId;
30use react_compiler_hir::InstructionValue;
31use react_compiler_hir::ManualMemoDependency;
32use react_compiler_hir::ManualMemoDependencyRoot;
33use react_compiler_hir::NonLocalBinding;
34use react_compiler_hir::Place;
35use react_compiler_hir::PlaceOrSpread;
36use react_compiler_hir::PropertyLiteral;
37use react_compiler_hir::SourceLocation;
38use react_compiler_hir::environment::Environment;
39use react_compiler_lowering::create_temporary_place;
40use react_compiler_lowering::mark_instruction_ids;
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47enum ManualMemoKind {
48 UseMemo,
49 UseCallback,
50}
51
52#[derive(Debug, Clone)]
53struct ManualMemoCallee {
54 kind: ManualMemoKind,
55 load_instr_id: InstructionId,
57}
58
59struct IdentifierSidemap {
60 functions: HashSet<IdentifierId>,
62 manual_memos: HashMap<IdentifierId, ManualMemoCallee>,
64 react: HashSet<IdentifierId>,
66 maybe_deps_lists: HashMap<IdentifierId, MaybeDepsListInfo>,
68 maybe_deps: HashMap<IdentifierId, ManualMemoDependency>,
70 optionals: HashSet<IdentifierId>,
72}
73
74#[derive(Debug, Clone)]
75struct MaybeDepsListInfo {
76 loc: Option<SourceLocation>,
77 deps: Vec<Place>,
78}
79
80struct ExtractedMemoArgs {
81 fn_place: Place,
82 deps_list: Option<Vec<ManualMemoDependency>>,
83 deps_loc: Option<SourceLocation>,
84}
85
86pub fn drop_manual_memoization(
93 func: &mut HirFunction,
94 env: &mut Environment,
95) -> Result<(), CompilerDiagnostic> {
96 let is_validation_enabled = env.validate_preserve_existing_memoization_guarantees
97 || env.validate_no_set_state_in_render
98 || env.enable_preserve_existing_memoization_guarantees;
99
100 let optionals = find_optional_places(func)?;
101 let mut sidemap = IdentifierSidemap {
102 functions: HashSet::new(),
103 manual_memos: HashMap::new(),
104 react: HashSet::new(),
105 maybe_deps: HashMap::new(),
106 maybe_deps_lists: HashMap::new(),
107 optionals,
108 };
109 let mut next_manual_memo_id: u32 = 0;
110
111 let mut queued_inserts: HashMap<InstructionId, Instruction> = HashMap::new();
117
118 let all_block_instructions: Vec<Vec<InstructionId>> = func
121 .body
122 .blocks
123 .values()
124 .map(|block| block.instructions.clone())
125 .collect();
126
127 for block_instructions in &all_block_instructions {
128 for &instr_id in block_instructions {
129 let instr = &func.instructions[instr_id.0 as usize];
130
131 let lookup_id = match &instr.value {
133 InstructionValue::CallExpression { callee, .. } => Some(callee.identifier),
134 InstructionValue::MethodCall { property, .. } => Some(property.identifier),
135 _ => None,
136 };
137
138 let manual_memo = lookup_id.and_then(|id| sidemap.manual_memos.get(&id).cloned());
139
140 if let Some(manual_memo) = manual_memo {
141 process_manual_memo_call(
142 func,
143 env,
144 instr_id,
145 &manual_memo,
146 &mut sidemap,
147 is_validation_enabled,
148 &mut next_manual_memo_id,
149 &mut queued_inserts,
150 );
151 } else {
152 collect_temporaries(func, env, instr_id, &mut sidemap);
153 }
154 }
155 }
156
157 if !queued_inserts.is_empty() {
159 let mut has_changes = false;
160 for block in func.body.blocks.values_mut() {
161 let mut next_instructions: Option<Vec<InstructionId>> = None;
162 for i in 0..block.instructions.len() {
163 let instr_id = block.instructions[i];
164 if let Some(insert_instr) = queued_inserts.remove(&instr_id) {
165 if next_instructions.is_none() {
166 next_instructions = Some(block.instructions[..i].to_vec());
167 }
168 let ni = next_instructions.as_mut().unwrap();
169 ni.push(instr_id);
170 let new_instr_id = InstructionId(func.instructions.len() as u32);
172 func.instructions.push(insert_instr);
173 ni.push(new_instr_id);
174 } else if let Some(ni) = next_instructions.as_mut() {
175 ni.push(instr_id);
176 }
177 }
178 if let Some(ni) = next_instructions {
179 block.instructions = ni;
180 has_changes = true;
181 }
182 }
183
184 if has_changes {
185 mark_instruction_ids(&mut func.body, &mut func.instructions);
186 }
187 }
188
189 Ok(())
190}
191
192#[allow(clippy::too_many_arguments)]
197fn process_manual_memo_call(
198 func: &mut HirFunction,
199 env: &mut Environment,
200 instr_id: InstructionId,
201 manual_memo: &ManualMemoCallee,
202 sidemap: &mut IdentifierSidemap,
203 is_validation_enabled: bool,
204 next_manual_memo_id: &mut u32,
205 queued_inserts: &mut HashMap<InstructionId, Instruction>,
206) {
207 let instr = &func.instructions[instr_id.0 as usize];
208
209 let memo_details = extract_manual_memoization_args(instr, manual_memo.kind, sidemap, env);
210
211 let Some(memo_details) = memo_details else {
212 return;
213 };
214
215 let ExtractedMemoArgs {
216 fn_place,
217 deps_list,
218 deps_loc,
219 } = memo_details;
220
221 let loc = func.instructions[instr_id.0 as usize].value.loc().cloned();
222
223 let replacement = get_manual_memoization_replacement(&fn_place, loc.clone(), manual_memo.kind);
225 func.instructions[instr_id.0 as usize].value = replacement;
226
227 if is_validation_enabled {
228 if !sidemap.functions.contains(&fn_place.identifier) {
230 let mut diag = CompilerDiagnostic::new(
231 ErrorCategory::UseMemo,
232 "Expected the first argument to be an inline function expression",
233 Some("Expected the first argument to be an inline function expression".to_string()),
234 )
235 .with_detail(CompilerDiagnosticDetail::Error {
236 loc: fn_place.loc.clone(),
237 message: Some(
238 "Expected the first argument to be an inline function expression".to_string(),
239 ),
240 identifier_name: None,
241 });
242 diag.suggestions = Some(vec![]);
244 env.record_diagnostic(diag);
245 return;
246 }
247
248 let memo_decl: Place = if manual_memo.kind == ManualMemoKind::UseMemo {
249 func.instructions[instr_id.0 as usize].lvalue.clone()
250 } else {
251 Place {
252 identifier: fn_place.identifier,
253 effect: Effect::Unknown,
254 reactive: false,
255 loc: fn_place.loc.clone(),
256 }
257 };
258
259 let manual_memo_id = *next_manual_memo_id;
260 *next_manual_memo_id += 1;
261
262 let (start_marker, finish_marker) = make_manual_memoization_markers(
263 &fn_place,
264 env,
265 deps_list,
266 deps_loc,
267 &memo_decl,
268 manual_memo_id,
269 );
270
271 queued_inserts.insert(manual_memo.load_instr_id, start_marker);
272 queued_inserts.insert(instr_id, finish_marker);
273 }
274}
275
276fn collect_temporaries(
277 func: &HirFunction,
278 env: &Environment,
279 instr_id: InstructionId,
280 sidemap: &mut IdentifierSidemap,
281) {
282 let instr = &func.instructions[instr_id.0 as usize];
283 let lvalue_id = instr.lvalue.identifier;
284
285 match &instr.value {
286 InstructionValue::FunctionExpression { .. } => {
287 sidemap.functions.insert(lvalue_id);
288 }
289 InstructionValue::LoadGlobal { binding, .. } => {
290 let hook_name = get_hook_detection_name(binding);
291 let mut detected = false;
292 if let Some(name) = hook_name {
293 if name == "useMemo" {
294 sidemap.manual_memos.insert(
295 lvalue_id,
296 ManualMemoCallee {
297 kind: ManualMemoKind::UseMemo,
298 load_instr_id: instr_id,
299 },
300 );
301 detected = true;
302 } else if name == "useCallback" {
303 sidemap.manual_memos.insert(
304 lvalue_id,
305 ManualMemoCallee {
306 kind: ManualMemoKind::UseCallback,
307 load_instr_id: instr_id,
308 },
309 );
310 detected = true;
311 }
312 }
313 if !detected && binding.name() == "React" {
314 sidemap.react.insert(lvalue_id);
315 }
316 }
317 InstructionValue::PropertyLoad {
318 object, property, ..
319 } => {
320 if sidemap.react.contains(&object.identifier) {
321 if let PropertyLiteral::String(prop_name) = property {
322 if prop_name == "useMemo" {
323 sidemap.manual_memos.insert(
324 lvalue_id,
325 ManualMemoCallee {
326 kind: ManualMemoKind::UseMemo,
327 load_instr_id: instr_id,
328 },
329 );
330 } else if prop_name == "useCallback" {
331 sidemap.manual_memos.insert(
332 lvalue_id,
333 ManualMemoCallee {
334 kind: ManualMemoKind::UseCallback,
335 load_instr_id: instr_id,
336 },
337 );
338 }
339 }
340 }
341 }
342 InstructionValue::ArrayExpression { elements, .. } => {
343 let all_places: Option<Vec<Place>> = elements
345 .iter()
346 .map(|e| match e {
347 ArrayElement::Place(p) => Some(p.clone()),
348 _ => None,
349 })
350 .collect();
351
352 if let Some(deps) = all_places {
353 sidemap.maybe_deps_lists.insert(
354 lvalue_id,
355 MaybeDepsListInfo {
356 loc: instr.value.loc().cloned(),
357 deps,
358 },
359 );
360 }
361 }
362 _ => {}
363 }
364
365 let is_optional = sidemap.optionals.contains(&lvalue_id);
366 let maybe_dep =
367 collect_maybe_memo_dependencies(&instr.value, &sidemap.maybe_deps, is_optional, env);
368 if let Some(dep) = maybe_dep {
369 if let InstructionValue::StoreLocal { lvalue, .. } = &instr.value {
373 sidemap
374 .maybe_deps
375 .insert(lvalue.place.identifier, dep.clone());
376 }
377 sidemap.maybe_deps.insert(lvalue_id, dep);
378 }
379}
380
381pub fn collect_maybe_memo_dependencies(
388 value: &InstructionValue,
389 maybe_deps: &HashMap<IdentifierId, ManualMemoDependency>,
390 optional: bool,
391 env: &Environment,
392) -> Option<ManualMemoDependency> {
393 match value {
394 InstructionValue::LoadGlobal { binding, loc, .. } => Some(ManualMemoDependency {
395 root: ManualMemoDependencyRoot::Global {
396 identifier_name: binding.name().to_string(),
397 },
398 path: vec![],
399 loc: loc.clone(),
400 }),
401 InstructionValue::PropertyLoad {
402 object,
403 property,
404 loc,
405 ..
406 } => {
407 if let Some(object_dep) = maybe_deps.get(&object.identifier) {
408 Some(ManualMemoDependency {
409 root: object_dep.root.clone(),
410 path: {
411 let mut path = object_dep.path.clone();
412 path.push(DependencyPathEntry {
413 property: property.clone(),
414 optional,
415 loc: loc.clone(),
416 });
417 path
418 },
419 loc: loc.clone(),
420 })
421 } else {
422 None
423 }
424 }
425 InstructionValue::LoadLocal { place, .. } | InstructionValue::LoadContext { place, .. } => {
426 if let Some(source) = maybe_deps.get(&place.identifier) {
427 Some(source.clone())
428 } else if matches!(
429 &env.identifiers[place.identifier.0 as usize].name,
430 Some(IdentifierName::Named(_))
431 ) {
432 Some(ManualMemoDependency {
433 root: ManualMemoDependencyRoot::NamedLocal {
434 value: place.clone(),
435 constant: false,
436 },
437 path: vec![],
438 loc: place.loc.clone(),
439 })
440 } else {
441 None
442 }
443 }
444 InstructionValue::StoreLocal {
445 lvalue, value: val, ..
446 } => {
447 let lvalue_id = lvalue.place.identifier;
451 let rvalue_id = val.identifier;
452 if let Some(aliased) = maybe_deps.get(&rvalue_id) {
453 let lvalue_name = &env.identifiers[lvalue_id.0 as usize].name;
454 if !matches!(lvalue_name, Some(IdentifierName::Named(_))) {
455 return Some(aliased.clone());
458 }
459 }
460 None
461 }
462 _ => None,
463 }
464}
465
466fn get_manual_memoization_replacement(
471 fn_place: &Place,
472 loc: Option<SourceLocation>,
473 kind: ManualMemoKind,
474) -> InstructionValue {
475 if kind == ManualMemoKind::UseMemo {
476 InstructionValue::CallExpression {
478 callee: fn_place.clone(),
479 args: vec![],
480 loc,
481 }
482 } else {
483 InstructionValue::LoadLocal {
485 place: Place {
486 identifier: fn_place.identifier,
487 effect: Effect::Unknown,
488 reactive: false,
489 loc: loc.clone(),
490 },
491 loc,
492 }
493 }
494}
495
496fn make_manual_memoization_markers(
497 fn_expr: &Place,
498 env: &mut Environment,
499 deps_list: Option<Vec<ManualMemoDependency>>,
500 deps_loc: Option<SourceLocation>,
501 memo_decl: &Place,
502 manual_memo_id: u32,
503) -> (Instruction, Instruction) {
504 let start = Instruction {
505 id: EvaluationOrder(0),
506 lvalue: create_temporary_place(env, fn_expr.loc.clone()),
507 value: InstructionValue::StartMemoize {
508 manual_memo_id,
509 deps: deps_list,
510 deps_loc: Some(deps_loc),
511 has_invalid_deps: false,
512 loc: fn_expr.loc.clone(),
513 },
514 loc: fn_expr.loc.clone(),
515 effects: None,
516 };
517 let finish = Instruction {
518 id: EvaluationOrder(0),
519 lvalue: create_temporary_place(env, fn_expr.loc.clone()),
520 value: InstructionValue::FinishMemoize {
521 manual_memo_id,
522 decl: memo_decl.clone(),
523 pruned: false,
524 loc: fn_expr.loc.clone(),
525 },
526 loc: fn_expr.loc.clone(),
527 effects: None,
528 };
529 (start, finish)
530}
531
532fn extract_manual_memoization_args(
533 instr: &Instruction,
534 kind: ManualMemoKind,
535 sidemap: &IdentifierSidemap,
536 env: &mut Environment,
537) -> Option<ExtractedMemoArgs> {
538 let args: &[PlaceOrSpread] = match &instr.value {
539 InstructionValue::CallExpression { args, .. } => args,
540 InstructionValue::MethodCall { args, .. } => args,
541 _ => return None,
542 };
543
544 let kind_name = match kind {
545 ManualMemoKind::UseMemo => "useMemo",
546 ManualMemoKind::UseCallback => "useCallback",
547 };
548
549 let fn_place = match args.first() {
551 Some(PlaceOrSpread::Place(p)) => p.clone(),
552 _ => {
553 let loc = instr.value.loc().cloned();
554 env.record_diagnostic(
555 CompilerDiagnostic::new(
556 ErrorCategory::UseMemo,
557 format!("Expected a callback function to be passed to {kind_name}"),
558 Some(if kind == ManualMemoKind::UseCallback {
559 "The first argument to useCallback() must be a function to cache".to_string()
560 } else {
561 "The first argument to useMemo() must be a function that calculates a result to cache".to_string()
562 }),
563 )
564 .with_detail(CompilerDiagnosticDetail::Error {
565 loc,
566 message: Some(if kind == ManualMemoKind::UseCallback {
567 "Expected a callback function".to_string()
568 } else {
569 "Expected a memoization function".to_string()
570 }),
571 identifier_name: None,
572 }),
573 );
574 return None;
575 }
576 };
577
578 let deps_list_place = args.get(1);
580 if deps_list_place.is_none() {
581 return Some(ExtractedMemoArgs {
582 fn_place,
583 deps_list: None,
584 deps_loc: None,
585 });
586 }
587
588 let deps_list_id = match deps_list_place {
589 Some(PlaceOrSpread::Place(p)) => Some(p.identifier),
590 _ => None,
591 };
592
593 let maybe_deps_list = deps_list_id.and_then(|id| sidemap.maybe_deps_lists.get(&id));
594
595 if maybe_deps_list.is_none() {
596 let loc = match deps_list_place {
597 Some(PlaceOrSpread::Place(p)) => p.loc.clone(),
598 _ => instr.loc.clone(),
599 };
600 env.record_diagnostic(
601 CompilerDiagnostic::new(
602 ErrorCategory::UseMemo,
603 format!("Expected the dependency list for {kind_name} to be an array literal"),
604 Some(format!(
605 "Expected the dependency list for {kind_name} to be an array literal"
606 )),
607 )
608 .with_detail(CompilerDiagnosticDetail::Error {
609 loc,
610 message: Some(format!(
611 "Expected the dependency list for {kind_name} to be an array literal"
612 )),
613 identifier_name: None,
614 }),
615 );
616 return None;
617 }
618
619 let deps_info = maybe_deps_list.unwrap();
620 let mut deps_list: Vec<ManualMemoDependency> = Vec::new();
621 for dep in &deps_info.deps {
622 let maybe_dep = sidemap.maybe_deps.get(&dep.identifier);
623 if let Some(d) = maybe_dep {
624 deps_list.push(d.clone());
625 } else {
626 env.record_diagnostic(
627 CompilerDiagnostic::new(
628 ErrorCategory::UseMemo,
629 "Expected the dependency list to be an array of simple expressions (e.g. `x`, `x.y.z`, `x?.y?.z`)",
630 Some("Expected the dependency list to be an array of simple expressions (e.g. `x`, `x.y.z`, `x?.y?.z`)".to_string()),
631 )
632 .with_detail(CompilerDiagnosticDetail::Error {
633 loc: dep.loc.clone(),
634 message: Some("Expected the dependency list to be an array of simple expressions (e.g. `x`, `x.y.z`, `x?.y?.z`)".to_string()),
635 identifier_name: None,
636 }),
637 );
638 }
639 }
640
641 Some(ExtractedMemoArgs {
642 fn_place,
643 deps_list: Some(deps_list),
644 deps_loc: deps_info.loc.clone(),
645 })
646}
647
648fn find_optional_places(func: &HirFunction) -> Result<HashSet<IdentifierId>, CompilerDiagnostic> {
653 use react_compiler_hir::Terminal;
654
655 let mut optionals = HashSet::new();
656 for block in func.body.blocks.values() {
657 if let Terminal::Optional {
658 optional: true,
659 test,
660 fallthrough,
661 ..
662 } = &block.terminal
663 {
664 let optional_fallthrough = *fallthrough;
665 let mut test_block_id = *test;
666 loop {
667 let test_block = &func.body.blocks[&test_block_id];
668 match &test_block.terminal {
669 Terminal::Branch {
670 consequent,
671 fallthrough,
672 ..
673 } => {
674 if *fallthrough == optional_fallthrough {
675 let consequent_block = &func.body.blocks[consequent];
677 if let Some(&last_instr_id) = consequent_block.instructions.last() {
678 let last_instr = &func.instructions[last_instr_id.0 as usize];
679 if let InstructionValue::StoreLocal { value, .. } =
680 &last_instr.value
681 {
682 optionals.insert(value.identifier);
683 }
684 }
685 break;
686 } else {
687 test_block_id = *fallthrough;
688 }
689 }
690 Terminal::Optional { fallthrough, .. }
691 | Terminal::Logical { fallthrough, .. }
692 | Terminal::Sequence { fallthrough, .. }
693 | Terminal::Ternary { fallthrough, .. } => {
694 test_block_id = *fallthrough;
695 }
696 Terminal::MaybeThrow { continuation, .. } => {
697 test_block_id = *continuation;
698 }
699 other => {
700 return Err(CompilerDiagnostic::new(
703 ErrorCategory::Invariant,
704 format!(
705 "Unexpected terminal kind in optional: {:?}",
706 std::mem::discriminant(other)
707 ),
708 None,
709 ));
710 }
711 }
712 }
713 }
714 }
715 Ok(optionals)
716}
717
718fn is_known_react_module(module: &str) -> bool {
719 let lower = module.to_lowercase();
720 lower == "react" || lower == "react-dom"
721}
722
723fn get_hook_detection_name(binding: &NonLocalBinding) -> Option<&str> {
734 match binding {
735 NonLocalBinding::Global { name } => Some(name.as_str()),
736 NonLocalBinding::ImportSpecifier {
737 imported, module, ..
738 } => {
739 if is_known_react_module(module) {
740 Some(imported.as_str())
741 } else {
742 None
743 }
744 }
745 NonLocalBinding::ImportDefault { name, module }
746 | NonLocalBinding::ImportNamespace { name, module } => {
747 if is_known_react_module(module) {
748 Some(name.as_str())
749 } else {
750 None
751 }
752 }
753 NonLocalBinding::ModuleLocal { .. } => None,
754 }
755}