use regex::Regex;
use std::sync::OnceLock;
const ALEF_LOADER_MARKER: &str = "_alefResolveExternalLibrary";
pub fn rewrite_frb_external_library_loader(source: &str, package_name: &str, module_name: &str, stem: &str) -> String {
if source.contains(ALEF_LOADER_MARKER) {
return source.to_string();
}
let Some(prologue) = frb_init_prologue(source) else {
return source.to_string();
};
let replacement = frb_init_prologue_replacement(package_name, module_name, stem);
let with_loader = source.replacen(&prologue, &replacement, 1);
ensure_loader_imports(&with_loader)
}
fn frb_init_prologue(source: &str) -> Option<String> {
let re = init_prologue_regex();
re.find(source).map(|m| m.as_str().to_string())
}
fn init_prologue_regex() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(r"(?m)^\s*/// Initialize flutter_rust_bridge\n\s*static Future<void> init\((?s:.)*?\}\) async \{\n")
.expect("init prologue regex must compile")
})
}
fn frb_init_prologue_replacement(package_name: &str, module_name: &str, stem: &str) -> String {
format!(
r#" /// Resolve the prebuilt native library from environment variable,
/// package-relative location, or defer to flutter_rust_bridge's default loader.
/// Returns `null` to defer to flutter_rust_bridge's default loader.
///
/// Checks in order:
/// 1. FRB_DART_LOAD_EXTERNAL_LIBRARY_NATIVE_LIB_DIR environment variable
/// (allows test harnesses to point to development build paths)
/// 2. Package-installed location with RID subdirectory (lib/src/native/<rid>/)
/// (for published pub.dev packages with platform-specific bundled native libraries)
/// 3. Package-installed location (lib/src/{module}_bridge_generated/)
/// (legacy fallback for development or packages without per-platform binaries)
/// 4. Returns null (flutter_rust_bridge falls back to its default loader)
static Future<ExternalLibrary?> {marker}() async {{
try {{
const candidates = <String>[
'lib{stem}.dylib',
'lib{stem}.so',
'{stem}.dll',
];
// Check FRB_DART_LOAD_EXTERNAL_LIBRARY_NATIVE_LIB_DIR env var first.
// This allows test harnesses to override library location for development.
final envDir = Platform.environment['FRB_DART_LOAD_EXTERNAL_LIBRARY_NATIVE_LIB_DIR'];
if (envDir != null && envDir.isNotEmpty) {{
final libDir = Directory(envDir);
if (libDir.existsSync()) {{
for (final candidate in candidates) {{
final libPath = '$envDir/$candidate';
if (File(libPath).existsSync()) {{
return ExternalLibrary.open(libPath);
}}
}}
}}
}}
// Compute RID (runtime identifier) from platform and architecture.
String? computeRid() {{
final os = Platform.operatingSystem;
// Use Dart's Platform.version to detect architecture.
// Format: "Dart <version> (stable) ... on \"<os> <arch>\""
final version = Platform.version;
final archMatch = version.contains('x86_64') ? 'x64'
: version.contains('aarch64') || version.contains('arm64') ? 'arm64'
: version.contains('armv7') ? 'arm'
: null;
if (archMatch == null) return null;
switch (os) {{
case 'linux':
return 'linux-$archMatch';
case 'macos':
return 'macos-$archMatch';
case 'windows':
return 'windows-$archMatch';
default:
return null;
}}
}}
final rid = computeRid();
if (rid != null) {{
final packageRoot =
await Isolate.resolvePackageUri(Uri.parse('package:{package}/{package}.dart'));
if (packageRoot != null) {{
final ridDir = packageRoot.resolve('src/native/$rid/');
for (final candidate in candidates) {{
final libPath = ridDir.resolve(candidate).toFilePath();
if (File(libPath).existsSync()) {{
return ExternalLibrary.open(libPath);
}}
}}
}}
}}
// Check legacy package-installed location as fallback.
final packageRoot =
await Isolate.resolvePackageUri(Uri.parse('package:{package}/{package}.dart'));
if (packageRoot != null) {{
final libDir = packageRoot.resolve('src/{module}_bridge_generated/');
for (final candidate in candidates) {{
final libPath = libDir.resolve(candidate).toFilePath();
if (File(libPath).existsSync()) {{
return ExternalLibrary.open(libPath);
}}
}}
}}
}} catch (_) {{
// Fall through to the default loader on any resolution failure.
}}
return null;
}}
/// Initialize flutter_rust_bridge
static Future<void> init({{
RustLibApi? api,
BaseHandler? handler,
ExternalLibrary? externalLibrary,
bool forceSameCodegenVersion = true,
}}) async {{
externalLibrary ??= await {marker}();
"#,
marker = ALEF_LOADER_MARKER,
package = package_name,
module = module_name,
stem = stem,
)
}
fn ensure_loader_imports(source: &str) -> String {
let mut result = source.to_string();
let needed = [
("import 'dart:io';", "import 'dart:io';\n"),
("import 'dart:isolate';", "import 'dart:isolate';\n"),
];
let anchor = result.find("\nimport ").map(|i| i + 1);
for (probe, line) in needed {
if result.contains(probe) {
continue;
}
match anchor {
Some(pos) => result.insert_str(pos, line),
None => result.insert_str(0, line),
}
}
result
}
fn extract_loader_stem(source: &str) -> Option<String> {
let re = stem_regex();
re.captures(source).map(|c| c["stem"].to_string())
}
fn stem_regex() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| Regex::new(r"stem:\s*'(?P<stem>[A-Za-z0-9_]+)'").expect("stem regex must compile"))
}
fn apply_loader_fix_from_stem(source: &str) -> String {
let Some(stem) = extract_loader_stem(source) else {
return source.to_string();
};
let crate_base = stem.strip_suffix("_dart").unwrap_or(&stem);
let package_name = crate_base;
let module_name = crate_base;
rewrite_frb_external_library_loader(source, package_name, module_name, &stem)
}
pub fn rewrite_frb_sealed_variants(source: &str) -> String {
let source = apply_loader_fix_from_stem(source);
let source = source.as_str();
let variant_re = variant_regex();
variant_re
.replace_all(source, |caps: ®ex::Captures<'_>| {
let prefix = &caps["prefix"];
let params = &caps["params"];
let suffix = &caps["suffix"];
let variant_pascal = &caps["variant"];
let rewritten_params = rewrite_param_list(params, variant_pascal);
format!("{prefix}{rewritten_params}{suffix}")
})
.into_owned()
}
fn variant_regex() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(
r"(?s)(?P<prefix>const\s+factory\s+[A-Za-z_][A-Za-z0-9_]*\.[A-Za-z_][A-Za-z0-9_]*\s*\(\s*\{)(?P<params>[^{}]*)(?P<suffix>\}\s*\)\s*=\s*[A-Za-z_][A-Za-z0-9_]*_(?P<variant>[A-Za-z][A-Za-z0-9]*)\s*;)",
)
.expect("variant regex must compile")
})
}
fn rewrite_param_list(params: &str, variant_pascal: &str) -> String {
let param_re = param_regex();
let matches: Vec<regex::Captures<'_>> = param_re.captures_iter(params).collect();
let total_fields = matches
.iter()
.filter(|m| {
let name = m.name("name").map(|m| m.as_str()).unwrap_or("");
is_positional_field(name)
})
.count();
if total_fields == 0 {
return params.to_string();
}
let mut out = String::with_capacity(params.len());
let mut cursor = 0usize;
for caps in &matches {
let whole = caps.get(0).expect("regex match must have group 0");
let name_match = caps.name("name").expect("name capture is required");
let raw_name = name_match.as_str();
out.push_str(¶ms[cursor..name_match.start()]);
if let Some(field_idx) = field_index(raw_name) {
let type_name = caps.name("type").map(|m| m.as_str()).unwrap_or("").trim();
let new_name = payload_param_name(type_name, variant_pascal, field_idx, total_fields);
out.push_str(&new_name);
} else {
out.push_str(raw_name);
}
cursor = name_match.end();
let _ = whole; }
out.push_str(¶ms[cursor..]);
out
}
fn param_regex() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
Regex::new(r"required\s+(?P<type>[^,{}]+?)\s+(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*(?:,|$)")
.expect("param regex must compile")
})
}
fn field_index(name: &str) -> Option<usize> {
let rest = name.strip_prefix("field")?;
rest.parse::<usize>().ok()
}
fn is_positional_field(name: &str) -> bool {
field_index(name).is_some()
}
fn payload_param_name(type_name: &str, variant_pascal: &str, field_idx: usize, total_fields: usize) -> String {
if total_fields > 1 {
return format!("value{field_idx}");
}
let stripped_type = type_name.trim_end_matches('?');
let base_type = stripped_type
.split_once('<')
.map(|(head, _)| head)
.unwrap_or(stripped_type)
.trim();
if let Some(remainder) = base_type.strip_prefix(variant_pascal)
&& !remainder.is_empty()
{
return to_lower_camel(remainder);
}
if is_dart_primitive(base_type) {
return "value".to_string();
}
"value".to_string()
}
fn to_lower_camel(s: &str) -> String {
let mut chars = s.chars();
match chars.next() {
Some(first) => first.to_lowercase().collect::<String>() + chars.as_str(),
None => String::new(),
}
}
fn is_dart_primitive(type_name: &str) -> bool {
matches!(
type_name,
"String"
| "int"
| "double"
| "bool"
| "num"
| "void"
| "dynamic"
| "Object"
| "Uint8List"
| "List"
| "Map"
| "Set"
| "BigInt"
| "DateTime"
| "Duration"
)
}
pub fn filter_excluded_functions(source: &str, exclude_functions: &std::collections::HashSet<&str>) -> String {
if exclude_functions.is_empty() {
return source.to_string();
}
let lines: Vec<&str> = source.lines().collect();
let mut result = String::with_capacity(source.len());
let mut i = 0;
let mut doc_buffer: Vec<&str> = Vec::new();
while i < lines.len() {
let line = lines[i];
let trimmed = line.trim_start();
if trimmed.starts_with("///")
|| trimmed.starts_with("//")
|| (trimmed.starts_with("*") && !trimmed.starts_with("**/"))
{
doc_buffer.push(line);
i += 1;
continue;
}
let mut should_skip_function = false;
if !trimmed.is_empty() && !trimmed.starts_with("class") && !trimmed.starts_with("enum") {
should_skip_function = exclude_functions.iter().any(|&excluded| {
let camel_excluded = snake_to_camel(excluded);
let pattern = format!(" {}(", camel_excluded);
line.contains(&pattern)
});
}
if should_skip_function {
doc_buffer.clear();
loop {
i += 1;
if i >= lines.len() {
break;
}
let check_line = lines[i];
if check_line.contains(';') {
i += 1;
break;
}
}
} else {
for doc_line in &doc_buffer {
result.push_str(doc_line);
result.push('\n');
}
doc_buffer.clear();
result.push_str(line);
result.push('\n');
i += 1;
}
}
for doc_line in &doc_buffer {
result.push_str(doc_line);
result.push('\n');
}
result
}
fn snake_to_camel(name: &str) -> String {
let mut result = String::new();
let mut capitalize_next = false;
for c in name.chars() {
if c == '_' {
capitalize_next = true;
} else if capitalize_next {
for upper_c in c.to_uppercase() {
result.push(upper_c);
}
capitalize_next = false;
} else if result.is_empty() {
for lower_c in c.to_lowercase() {
result.push(lower_c);
}
} else {
result.push(c);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn named_struct_payload_uses_payload_derived_name() {
let input = r#"sealed class FormatMetadata with _$FormatMetadata {
const FormatMetadata._();
const factory FormatMetadata.pdf({required PdfMetadata field0}) =
FormatMetadata_Pdf;
const factory FormatMetadata.docx({required DocxMetadata field0}) =
FormatMetadata_Docx;
}
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required PdfMetadata metadata"),
"PdfMetadata payload should be named `metadata`, got:\n{out}"
);
assert!(
out.contains("required DocxMetadata metadata"),
"DocxMetadata payload should be named `metadata`, got:\n{out}"
);
assert!(!out.contains("field0"), "no `field0` should remain, got:\n{out}");
}
#[test]
fn primitive_payload_uses_value_name() {
let input = r#" const factory OutputFormat.custom({required String field0}) =
OutputFormat_Custom;
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required String value"),
"String payload should be named `value`, got:\n{out}"
);
assert!(!out.contains("field0"), "no `field0` should remain, got:\n{out}");
}
#[test]
fn multi_field_tuple_uses_value0_value1() {
let input = r#" const factory Point.xy({required int field0, required int field1}) =
Point_Xy;
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required int value0"),
"first tuple field should be `value0`, got:\n{out}"
);
assert!(
out.contains("required int value1"),
"second tuple field should be `value1`, got:\n{out}"
);
assert!(!out.contains("field0"), "no `field0` should remain, got:\n{out}");
assert!(!out.contains("field1"), "no `field1` should remain, got:\n{out}");
}
#[test]
fn named_struct_field_is_preserved() {
let input = r#" const factory Shape.rect({required double width, required double height}) =
Shape_Rect;
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required double width"),
"named field `width` must be preserved, got:\n{out}"
);
assert!(
out.contains("required double height"),
"named field `height` must be preserved, got:\n{out}"
);
}
#[test]
fn non_variant_lines_are_untouched() {
let input = r#"// This file is automatically generated.
import 'package:freezed_annotation/freezed_annotation.dart';
Future<int> extractBytes({required List<int> content}) =>
RustLib.instance.api.crateExtractBytes(content: content);
class Foo {
final int field0;
Foo({required this.field0});
}
"#;
let out = rewrite_frb_sealed_variants(input);
assert_eq!(out, input, "non-variant code must round-trip unchanged");
}
#[test]
fn fallback_when_prefix_does_not_match_uses_value() {
let input = r#" const factory Drawable.image({required Bitmap field0}) =
Drawable_Image;
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required Bitmap value"),
"unrelated payload type should fall back to `value`, got:\n{out}"
);
}
#[test]
fn nullable_payload_strips_question_mark_for_inference() {
let input = r#" const factory Either.left({required LeftValue? field0}) =
Either_Left;
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required LeftValue? value"),
"nullable payload with prefix-matching type should produce `value`, got:\n{out}"
);
}
#[test]
fn realistic_kreuzberg_format_metadata_block() {
let input = r#"sealed class FormatMetadata with _$FormatMetadata {
const FormatMetadata._();
const factory FormatMetadata.pdf({required PdfMetadata field0}) =
FormatMetadata_Pdf;
const factory FormatMetadata.docx({required DocxMetadata field0}) =
FormatMetadata_Docx;
const factory FormatMetadata.excel({required ExcelMetadata field0}) =
FormatMetadata_Excel;
const factory FormatMetadata.code({required String field0}) =
FormatMetadata_Code;
}
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(out.contains("required PdfMetadata metadata"));
assert!(out.contains("required DocxMetadata metadata"));
assert!(out.contains("required ExcelMetadata metadata"));
assert!(out.contains("required String value"));
assert!(
!out.contains("field0"),
"all `field0` occurrences must be rewritten, got:\n{out}"
);
assert!(out.contains("sealed class FormatMetadata"));
assert!(out.contains("FormatMetadata_Pdf"));
}
#[test]
fn idempotent_when_run_twice() {
let input = r#" const factory FormatMetadata.pdf({required PdfMetadata field0}) =
FormatMetadata_Pdf;
"#;
let once = rewrite_frb_sealed_variants(input);
let twice = rewrite_frb_sealed_variants(&once);
assert_eq!(once, twice, "rewriter must be idempotent");
}
#[test]
fn multiple_distinct_sealed_class_variants_all_rewritten() {
let input = r#"sealed class FormatMetadata with _$FormatMetadata {
const FormatMetadata._();
const factory FormatMetadata.pdf({required PdfMetadata field0}) =
FormatMetadata_Pdf;
const factory FormatMetadata.docx({required DocxMetadata field0}) =
FormatMetadata_Docx;
}
sealed class OutputFormat with _$OutputFormat {
const OutputFormat._();
const factory OutputFormat.custom({required String field0}) =
OutputFormat_Custom;
const factory OutputFormat.json({required JsonConfig field0}) =
OutputFormat_Json;
}
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
out.contains("required PdfMetadata metadata"),
"PdfMetadata should become metadata, got:\n{out}"
);
assert!(
out.contains("required DocxMetadata metadata"),
"DocxMetadata should become metadata, got:\n{out}"
);
assert!(
out.contains("required String value"),
"String should become value, got:\n{out}"
);
assert!(
out.contains("required JsonConfig config"),
"JsonConfig payload (Json prefix → Config remainder) should become `config`, got:\n{out}"
);
assert!(!out.contains("field0"), "no `field0` should remain, got:\n{out}");
}
fn frb_generated_fixture() -> &'static str {
r#"// @generated by `flutter_rust_bridge`@ 2.12.0.
import 'dart:async';
import 'dart:convert';
import 'frb_generated.dart';
import 'package:flutter_rust_bridge/flutter_rust_bridge_for_generated.dart';
class RustLib extends BaseEntrypoint<RustLibApi, RustLibApiImpl, RustLibWire> {
RustLib._();
/// Initialize flutter_rust_bridge
static Future<void> init({
RustLibApi? api,
BaseHandler? handler,
ExternalLibrary? externalLibrary,
bool forceSameCodegenVersion = true,
}) async {
await instance.initImpl(
api: api,
handler: handler,
externalLibrary: externalLibrary,
forceSameCodegenVersion: forceSameCodegenVersion,
);
}
static const kDefaultExternalLibraryLoaderConfig =
ExternalLibraryLoaderConfig(
stem: 'spikard_dart',
ioDirectory: 'rust/target/release/',
webPrefix: 'pkg/',
wasmBindgenName: 'wasm_bindgen',
);
}
"#
}
#[test]
fn loader_rewrite_injects_package_relative_resolution() {
let out = rewrite_frb_external_library_loader(frb_generated_fixture(), "spikard", "spikard", "spikard_dart");
assert!(
out.contains("externalLibrary ??= await _alefResolveExternalLibrary();"),
"init must resolve the package-relative library, got:\n{out}"
);
assert!(
out.contains("Isolate.resolvePackageUri(Uri.parse('package:spikard/spikard.dart'))"),
"loader must resolve the package URI, got:\n{out}"
);
assert!(
out.contains("src/spikard_bridge_generated/"),
"loader must target the bridge-generated dir, got:\n{out}"
);
assert!(
out.contains("'libspikard_dart.dylib'"),
"missing macOS candidate, got:\n{out}"
);
assert!(
out.contains("'libspikard_dart.so'"),
"missing linux candidate, got:\n{out}"
);
assert!(
out.contains("'spikard_dart.dll'"),
"missing windows candidate, got:\n{out}"
);
assert!(out.contains("import 'dart:io';"), "must import dart:io, got:\n{out}");
assert!(
out.contains("import 'dart:isolate';"),
"must import dart:isolate, got:\n{out}"
);
}
#[test]
fn loader_rewrite_is_idempotent() {
let once = rewrite_frb_external_library_loader(frb_generated_fixture(), "spikard", "spikard", "spikard_dart");
let twice = rewrite_frb_external_library_loader(&once, "spikard", "spikard", "spikard_dart");
assert_eq!(once, twice, "loader rewrite must be idempotent");
assert_eq!(
twice.matches("import 'dart:io';").count(),
1,
"imports must not duplicate"
);
assert_eq!(
twice.matches("_alefResolveExternalLibrary() async").count(),
1,
"helper must not be injected twice"
);
}
#[test]
fn loader_rewrite_is_noop_without_init_prologue() {
let input = "// just some dart\nFuture<int> foo() async => 1;\n";
assert_eq!(
rewrite_frb_external_library_loader(input, "spikard", "spikard", "spikard_dart"),
input
);
}
#[test]
fn sealed_variant_rewrite_also_applies_loader_fix_via_stem() {
let out = rewrite_frb_sealed_variants(frb_generated_fixture());
assert!(
out.contains("externalLibrary ??= await _alefResolveExternalLibrary();"),
"sealed-variant pass must also inject the loader, got:\n{out}"
);
assert!(
out.contains("Isolate.resolvePackageUri(Uri.parse('package:spikard/spikard.dart'))"),
"package derived from stem must be `spikard`, got:\n{out}"
);
}
#[test]
fn sealed_variant_rewrite_leaves_lib_dart_loader_untouched() {
let input = r#"import 'frb_generated.dart';
Future<int> extractBytes({required List<int> content}) =>
RustLib.instance.api.crateExtractBytes(content: content);
"#;
let out = rewrite_frb_sealed_variants(input);
assert!(
!out.contains("_alefResolveExternalLibrary"),
"lib.dart must not get a loader, got:\n{out}"
);
}
#[test]
fn loader_rewrite_includes_rid_aware_path() {
let out = rewrite_frb_external_library_loader(frb_generated_fixture(), "spikard", "spikard", "spikard_dart");
assert!(
out.contains("src/native/"),
"loader must check RID-aware path (src/native/<rid>/), got:\n{out}"
);
assert!(
out.contains("computeRid()"),
"loader must compute RID from platform and arch, got:\n{out}"
);
assert!(
out.contains("Platform.operatingSystem"),
"loader must detect operating system, got:\n{out}"
);
assert!(
out.contains("'linux-x64'") || out.contains("linux-"),
"loader must support linux RID variants, got:\n{out}"
);
assert!(
out.contains("'macos-arm64'") || out.contains("macos-"),
"loader must support macos RID variants, got:\n{out}"
);
assert!(
out.contains("'windows-x64'") || out.contains("windows-"),
"loader must support windows RID variants, got:\n{out}"
);
let rid_pos = out.find("src/native/").expect("RID path must exist");
let legacy_pos = out.find("src/spikard_bridge_generated/").expect("legacy path must exist");
assert!(
rid_pos < legacy_pos,
"RID-aware check must come before legacy fallback, got:\n{out}"
);
}
}