use regex::Regex;
use std::sync::OnceLock;
pub fn rewrite_frb_sealed_variants(source: &str) -> String {
let variant_re = variant_regex();
variant_re
.replace_all(source, |caps: ®ex::Captures<'_>| {
let prefix = &caps["prefix"];
let params = &caps["params"];
let suffix = &caps["suffix"];
let variant_pascal = &caps["variant"];
let rewritten_params = rewrite_param_list(params, variant_pascal);
format!("{prefix}{rewritten_params}{suffix}")
})
.into_owned()
}
fn variant_regex() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(
r"(?s)(?P<prefix>const\s+factory\s+[A-Za-z_][A-Za-z0-9_]*\.[A-Za-z_][A-Za-z0-9_]*\s*\(\s*\{)(?P<params>[^{}]*)(?P<suffix>\}\s*\)\s*=\s*[A-Za-z_][A-Za-z0-9_]*_(?P<variant>[A-Za-z][A-Za-z0-9]*)\s*;)",
)
.expect("variant regex must compile")
})
}
fn rewrite_param_list(params: &str, variant_pascal: &str) -> String {
let param_re = param_regex();
let matches: Vec<regex::Captures<'_>> = param_re.captures_iter(params).collect();
let total_fields = matches
.iter()
.filter(|m| {
let name = m.name("name").map(|m| m.as_str()).unwrap_or("");
is_positional_field(name)
})
.count();
if total_fields == 0 {
return params.to_string();
}
let mut out = String::with_capacity(params.len());
let mut cursor = 0usize;
for caps in &matches {
let whole = caps.get(0).expect("regex match must have group 0");
let name_match = caps.name("name").expect("name capture is required");
let raw_name = name_match.as_str();
out.push_str(¶ms[cursor..name_match.start()]);
if let Some(field_idx) = field_index(raw_name) {
let type_name = caps.name("type").map(|m| m.as_str()).unwrap_or("").trim();
let new_name = payload_param_name(type_name, variant_pascal, field_idx, total_fields);
out.push_str(&new_name);
} else {
out.push_str(raw_name);
}
cursor = name_match.end();
let _ = whole; }
out.push_str(¶ms[cursor..]);
out
}
fn param_regex() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(r"required\s+(?P<type>[^,{}]+?)\s+(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*(?:,|$)")
.expect("param regex must compile")
})
}
fn field_index(name: &str) -> Option<usize> {
let rest = name.strip_prefix("field")?;
rest.parse::<usize>().ok()
}
fn is_positional_field(name: &str) -> bool {
field_index(name).is_some()
}
fn payload_param_name(type_name: &str, variant_pascal: &str, field_idx: usize, total_fields: usize) -> String {
if total_fields > 1 {
return format!("value{field_idx}");
}
let stripped_type = type_name.trim_end_matches('?');
let base_type = stripped_type
.split_once('<')
.map(|(head, _)| head)
.unwrap_or(stripped_type)
.trim();
if let Some(remainder) = base_type.strip_prefix(variant_pascal)
&& !remainder.is_empty()
{
return to_lower_camel(remainder);
}
if is_dart_primitive(base_type) {
return "value".to_string();
}
"value".to_string()
}
fn to_lower_camel(s: &str) -> String {
let mut chars = s.chars();
match chars.next() {
Some(first) => first.to_lowercase().collect::<String>() + chars.as_str(),
None => String::new(),
}
}
fn is_dart_primitive(type_name: &str) -> bool {
matches!(
type_name,
"String"
| "int"
| "double"
| "bool"
| "num"
| "void"
| "dynamic"
| "Object"
| "Uint8List"
| "List"
| "Map"
| "Set"
| "BigInt"
| "DateTime"
| "Duration"
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn named_struct_payload_uses_payload_derived_name() {
let input = r#"sealed class FormatMetadata with _$FormatMetadata {
const FormatMetadata._();
const factory FormatMetadata.pdf({required PdfMetadata field0}) =
FormatMetadata_Pdf;
const factory FormatMetadata.docx({required DocxMetadata field0}) =
FormatMetadata_Docx;
}
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required PdfMetadata metadata"),
"PdfMetadata payload should be named `metadata`, got:\n{out}"
);
assert!(
out.contains("required DocxMetadata metadata"),
"DocxMetadata payload should be named `metadata`, got:\n{out}"
);
assert!(!out.contains("field0"), "no `field0` should remain, got:\n{out}");
}
#[test]
fn primitive_payload_uses_value_name() {
let input = r#" const factory OutputFormat.custom({required String field0}) =
OutputFormat_Custom;
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required String value"),
"String payload should be named `value`, got:\n{out}"
);
assert!(!out.contains("field0"), "no `field0` should remain, got:\n{out}");
}
#[test]
fn multi_field_tuple_uses_value0_value1() {
let input = r#" const factory Point.xy({required int field0, required int field1}) =
Point_Xy;
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required int value0"),
"first tuple field should be `value0`, got:\n{out}"
);
assert!(
out.contains("required int value1"),
"second tuple field should be `value1`, got:\n{out}"
);
assert!(!out.contains("field0"), "no `field0` should remain, got:\n{out}");
assert!(!out.contains("field1"), "no `field1` should remain, got:\n{out}");
}
#[test]
fn named_struct_field_is_preserved() {
let input = r#" const factory Shape.rect({required double width, required double height}) =
Shape_Rect;
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required double width"),
"named field `width` must be preserved, got:\n{out}"
);
assert!(
out.contains("required double height"),
"named field `height` must be preserved, got:\n{out}"
);
}
#[test]
fn non_variant_lines_are_untouched() {
let input = r#"// This file is automatically generated.
import 'package:freezed_annotation/freezed_annotation.dart';
Future<int> extractBytes({required List<int> content}) =>
RustLib.instance.api.crateExtractBytes(content: content);
class Foo {
final int field0;
Foo({required this.field0});
}
"#;
let out = rewrite_frb_sealed_variants(input);
assert_eq!(out, input, "non-variant code must round-trip unchanged");
}
#[test]
fn fallback_when_prefix_does_not_match_uses_value() {
let input = r#" const factory Drawable.image({required Bitmap field0}) =
Drawable_Image;
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required Bitmap value"),
"unrelated payload type should fall back to `value`, got:\n{out}"
);
}
#[test]
fn nullable_payload_strips_question_mark_for_inference() {
let input = r#" const factory Either.left({required LeftValue? field0}) =
Either_Left;
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required LeftValue? value"),
"nullable payload with prefix-matching type should produce `value`, got:\n{out}"
);
}
#[test]
fn realistic_kreuzberg_format_metadata_block() {
let input = r#"sealed class FormatMetadata with _$FormatMetadata {
const FormatMetadata._();
const factory FormatMetadata.pdf({required PdfMetadata field0}) =
FormatMetadata_Pdf;
const factory FormatMetadata.docx({required DocxMetadata field0}) =
FormatMetadata_Docx;
const factory FormatMetadata.excel({required ExcelMetadata field0}) =
FormatMetadata_Excel;
const factory FormatMetadata.code({required String field0}) =
FormatMetadata_Code;
}
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(out.contains("required PdfMetadata metadata"));
assert!(out.contains("required DocxMetadata metadata"));
assert!(out.contains("required ExcelMetadata metadata"));
assert!(out.contains("required String value"));
assert!(
!out.contains("field0"),
"all `field0` occurrences must be rewritten, got:\n{out}"
);
assert!(out.contains("sealed class FormatMetadata"));
assert!(out.contains("FormatMetadata_Pdf"));
}
#[test]
fn idempotent_when_run_twice() {
let input = r#" const factory FormatMetadata.pdf({required PdfMetadata field0}) =
FormatMetadata_Pdf;
"#;
let once = rewrite_frb_sealed_variants(input);
let twice = rewrite_frb_sealed_variants(&once);
assert_eq!(once, twice, "rewriter must be idempotent");
}
#[test]
fn multiple_distinct_sealed_class_variants_all_rewritten() {
let input = r#"sealed class FormatMetadata with _$FormatMetadata {
const FormatMetadata._();
const factory FormatMetadata.pdf({required PdfMetadata field0}) =
FormatMetadata_Pdf;
const factory FormatMetadata.docx({required DocxMetadata field0}) =
FormatMetadata_Docx;
}
sealed class OutputFormat with _$OutputFormat {
const OutputFormat._();
const factory OutputFormat.custom({required String field0}) =
OutputFormat_Custom;
const factory OutputFormat.json({required JsonConfig field0}) =
OutputFormat_Json;
}
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required PdfMetadata metadata"),
"PdfMetadata should become metadata, got:\n{out}"
);
assert!(
out.contains("required DocxMetadata metadata"),
"DocxMetadata should become metadata, got:\n{out}"
);
assert!(
out.contains("required String value"),
"String should become value, got:\n{out}"
);
assert!(
out.contains("required JsonConfig config"),
"JsonConfig payload (Json prefix → Config remainder) should become `config`, got:\n{out}"
);
assert!(!out.contains("field0"), "no `field0` should remain, got:\n{out}");
}
}