use alef_core::config::{BridgeBinding, TraitBridgeConfig};
use alef_core::ir::{FieldDef, FunctionDef, MethodDef, ParamDef, PrimitiveType, TypeDef, TypeRef};
use heck::ToSnakeCase;
use std::collections::HashMap;
pub struct TraitBridgeSpec<'a> {
pub trait_def: &'a TypeDef,
pub bridge_config: &'a TraitBridgeConfig,
pub core_import: &'a str,
pub wrapper_prefix: &'a str,
pub type_paths: HashMap<String, String>,
pub error_type: String,
pub error_constructor: String,
}
impl<'a> TraitBridgeSpec<'a> {
pub fn error_path(&self) -> String {
if self.error_type.contains("::") || self.error_type.contains('<') {
self.error_type.clone()
} else {
format!("{}::{}", self.core_import, self.error_type)
}
}
pub fn make_error(&self, msg_expr: &str) -> String {
self.error_constructor.replace("{msg}", msg_expr)
}
pub fn wrapper_name(&self) -> String {
format!("{}{}Bridge", self.wrapper_prefix, self.trait_def.name)
}
pub fn trait_snake(&self) -> String {
self.trait_def.name.to_snake_case()
}
pub fn trait_path(&self) -> String {
self.trait_def.rust_path.replace('-', "_")
}
pub fn required_methods(&self) -> Vec<&'a MethodDef> {
self.trait_def.methods.iter().filter(|m| !m.has_default_impl).collect()
}
pub fn optional_methods(&self) -> Vec<&'a MethodDef> {
self.trait_def.methods.iter().filter(|m| m.has_default_impl).collect()
}
}
pub trait TraitBridgeGenerator {
fn foreign_object_type(&self) -> &str;
fn bridge_imports(&self) -> Vec<String>;
fn gen_sync_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
fn gen_async_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String;
fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String;
fn gen_unregistration_fn(&self, _spec: &TraitBridgeSpec) -> String {
String::new()
}
fn gen_clear_fn(&self, _spec: &TraitBridgeSpec) -> String {
String::new()
}
fn async_trait_is_send(&self) -> bool {
true
}
}
pub fn gen_bridge_wrapper_struct(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
let wrapper = spec.wrapper_name();
let foreign_type = generator.foreign_object_type();
crate::template_env::render(
"generators/trait_bridge/wrapper_struct.jinja",
minijinja::context! {
wrapper_prefix => spec.wrapper_prefix,
trait_name => &spec.trait_def.name,
wrapper_name => wrapper,
foreign_type => foreign_type,
},
)
}
fn gen_bridge_debug_impl(spec: &TraitBridgeSpec) -> String {
let wrapper = spec.wrapper_name();
format!(
"impl std::fmt::Debug for {wrapper} {{\n fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{\n write!(f, \"{wrapper}\")\n }}\n}}"
)
}
pub fn gen_bridge_plugin_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
let super_trait_name = spec.bridge_config.super_trait.as_deref()?;
let wrapper = spec.wrapper_name();
let core_import = spec.core_import;
let super_trait_path = if super_trait_name.contains("::") {
super_trait_name.to_string()
} else {
format!("{core_import}::{super_trait_name}")
};
let error_path = spec.error_path();
let version_method = MethodDef {
name: "version".to_string(),
params: vec![],
return_type: alef_core::ir::TypeRef::String,
is_async: false,
is_static: false,
error_type: None,
doc: String::new(),
receiver: Some(alef_core::ir::ReceiverKind::Ref),
sanitized: false,
trait_source: None,
returns_ref: false,
returns_cow: false,
return_newtype_wrapper: None,
has_default_impl: false,
};
let version_body = generator.gen_sync_method_body(&version_method, spec);
let init_method = MethodDef {
name: "initialize".to_string(),
params: vec![],
return_type: alef_core::ir::TypeRef::Unit,
is_async: false,
is_static: false,
error_type: Some(error_path.clone()),
doc: String::new(),
receiver: Some(alef_core::ir::ReceiverKind::Ref),
sanitized: false,
trait_source: None,
returns_ref: false,
returns_cow: false,
return_newtype_wrapper: None,
has_default_impl: true,
};
let init_body = generator.gen_sync_method_body(&init_method, spec);
let shutdown_method = MethodDef {
name: "shutdown".to_string(),
params: vec![],
return_type: alef_core::ir::TypeRef::Unit,
is_async: false,
is_static: false,
error_type: Some(error_path.clone()),
doc: String::new(),
receiver: Some(alef_core::ir::ReceiverKind::Ref),
sanitized: false,
trait_source: None,
returns_ref: false,
returns_cow: false,
return_newtype_wrapper: None,
has_default_impl: true,
};
let shutdown_body = generator.gen_sync_method_body(&shutdown_method, spec);
let version_lines: Vec<&str> = version_body.lines().collect();
let init_lines: Vec<&str> = init_body.lines().collect();
let shutdown_lines: Vec<&str> = shutdown_body.lines().collect();
Some(crate::template_env::render(
"generators/trait_bridge/plugin_impl.jinja",
minijinja::context! {
super_trait_path => super_trait_path,
wrapper_name => wrapper,
error_path => error_path,
version_lines => version_lines,
init_lines => init_lines,
shutdown_lines => shutdown_lines,
},
))
}
pub fn gen_bridge_trait_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
let wrapper = spec.wrapper_name();
let trait_path = spec.trait_path();
let has_async_methods = spec
.trait_def
.methods
.iter()
.any(|m| m.is_async && m.trait_source.is_none() && !m.has_default_impl);
let async_trait_is_send = generator.async_trait_is_send();
let own_methods: Vec<_> = spec
.trait_def
.methods
.iter()
.filter(|m| m.trait_source.is_none() && !m.has_default_impl)
.collect();
let mut methods_code = String::with_capacity(1024);
for (i, method) in own_methods.iter().enumerate() {
if i > 0 {
methods_code.push_str("\n\n");
}
let async_kw = if method.is_async { "async " } else { "" };
let receiver = match &method.receiver {
Some(alef_core::ir::ReceiverKind::Ref) => "&self",
Some(alef_core::ir::ReceiverKind::RefMut) => "&mut self",
Some(alef_core::ir::ReceiverKind::Owned) => "self",
None => "",
};
let params: Vec<String> = method
.params
.iter()
.map(|p| format!("{}: {}", p.name, format_param_type(p, &spec.type_paths)))
.collect();
let all_params = if receiver.is_empty() {
params.join(", ")
} else if params.is_empty() {
receiver.to_string()
} else {
format!("{}, {}", receiver, params.join(", "))
};
let error_override = method.error_type.as_ref().map(|_| spec.error_path());
let ret = format_return_type(
&method.return_type,
error_override.as_deref(),
&spec.type_paths,
method.returns_ref,
);
let body = if method.is_async {
generator.gen_async_method_body(method, spec)
} else {
generator.gen_sync_method_body(method, spec)
};
let indented_body = body
.lines()
.map(|line| format!(" {line}"))
.collect::<Vec<_>>()
.join("\n");
methods_code.push_str(&crate::template_env::render(
"generators/trait_bridge/trait_method.jinja",
minijinja::context! {
async_kw => async_kw,
method_name => &method.name,
all_params => all_params,
ret => ret,
indented_body => &indented_body,
},
));
}
crate::template_env::render(
"generators/trait_bridge/trait_impl.jinja",
minijinja::context! {
has_async_methods => has_async_methods,
async_trait_is_send => async_trait_is_send,
trait_path => trait_path,
wrapper_name => wrapper,
methods_code => methods_code,
},
)
}
pub fn gen_bridge_registration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
spec.bridge_config.register_fn.as_deref()?;
Some(generator.gen_registration_fn(spec))
}
pub fn gen_bridge_unregistration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
spec.bridge_config.unregister_fn.as_deref()?;
let body = generator.gen_unregistration_fn(spec);
if body.is_empty() { None } else { Some(body) }
}
pub fn gen_bridge_clear_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
spec.bridge_config.clear_fn.as_deref()?;
let body = generator.gen_clear_fn(spec);
if body.is_empty() { None } else { Some(body) }
}
pub fn host_function_path(spec: &TraitBridgeSpec, fn_name: &str) -> String {
if let Some(getter) = spec.bridge_config.registry_getter.as_deref() {
let last = getter.rsplit("::").next().unwrap_or("");
if let Some(sub) = last.strip_prefix("get_").and_then(|s| s.strip_suffix("_registry")) {
let prefix_end = getter.len() - last.len();
let prefix = &getter[..prefix_end];
let prefix = prefix.trim_end_matches("registry::");
return format!("{prefix}{sub}::{fn_name}");
}
}
format!("{}::plugins::{}", spec.core_import, fn_name)
}
pub struct BridgeOutput {
pub imports: Vec<String>,
pub code: String,
}
pub fn gen_bridge_all(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> BridgeOutput {
let imports = generator.bridge_imports();
let mut out = String::with_capacity(4096);
out.push_str(&gen_bridge_wrapper_struct(spec, generator));
out.push_str("\n\n");
out.push_str(&gen_bridge_debug_impl(spec));
out.push_str("\n\n");
out.push_str(&generator.gen_constructor(spec));
out.push_str("\n\n");
if let Some(plugin_impl) = gen_bridge_plugin_impl(spec, generator) {
out.push_str(&plugin_impl);
out.push_str("\n\n");
}
out.push_str(&gen_bridge_trait_impl(spec, generator));
if let Some(reg_fn_code) = gen_bridge_registration_fn(spec, generator) {
out.push_str("\n\n");
out.push_str(®_fn_code);
}
if let Some(unreg_fn_code) = gen_bridge_unregistration_fn(spec, generator) {
out.push_str("\n\n");
out.push_str(&unreg_fn_code);
}
if let Some(clear_fn_code) = gen_bridge_clear_fn(spec, generator) {
out.push_str("\n\n");
out.push_str(&clear_fn_code);
}
BridgeOutput { imports, code: out }
}
pub fn format_type_ref(ty: &alef_core::ir::TypeRef, type_paths: &HashMap<String, String>) -> String {
use alef_core::ir::{PrimitiveType, TypeRef};
match ty {
TypeRef::Primitive(p) => match p {
PrimitiveType::Bool => "bool",
PrimitiveType::U8 => "u8",
PrimitiveType::U16 => "u16",
PrimitiveType::U32 => "u32",
PrimitiveType::U64 => "u64",
PrimitiveType::I8 => "i8",
PrimitiveType::I16 => "i16",
PrimitiveType::I32 => "i32",
PrimitiveType::I64 => "i64",
PrimitiveType::F32 => "f32",
PrimitiveType::F64 => "f64",
PrimitiveType::Usize => "usize",
PrimitiveType::Isize => "isize",
}
.to_string(),
TypeRef::String => "String".to_string(),
TypeRef::Char => "char".to_string(),
TypeRef::Bytes => "Vec<u8>".to_string(),
TypeRef::Optional(inner) => format!("Option<{}>", format_type_ref(inner, type_paths)),
TypeRef::Vec(inner) => format!("Vec<{}>", format_type_ref(inner, type_paths)),
TypeRef::Map(k, v) => format!(
"std::collections::HashMap<{}, {}>",
format_type_ref(k, type_paths),
format_type_ref(v, type_paths)
),
TypeRef::Named(name) => type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone()),
TypeRef::Path => "std::path::PathBuf".to_string(),
TypeRef::Unit => "()".to_string(),
TypeRef::Json => "serde_json::Value".to_string(),
TypeRef::Duration => "std::time::Duration".to_string(),
}
}
pub fn format_return_type(
ty: &alef_core::ir::TypeRef,
error_type: Option<&str>,
type_paths: &HashMap<String, String>,
returns_ref: bool,
) -> String {
let inner = if returns_ref {
if let alef_core::ir::TypeRef::Vec(elem) = ty {
let elem_str = match elem.as_ref() {
alef_core::ir::TypeRef::String => "&str".to_string(),
alef_core::ir::TypeRef::Bytes => "&[u8]".to_string(),
alef_core::ir::TypeRef::Named(name) => {
let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
format!("&{qualified}")
}
other => format_type_ref(other, type_paths),
};
format!("&[{elem_str}]")
} else {
format_type_ref(ty, type_paths)
}
} else {
format_type_ref(ty, type_paths)
};
match error_type {
Some(err) => format!("std::result::Result<{inner}, {err}>"),
None => inner,
}
}
pub fn format_param_type(param: &ParamDef, type_paths: &HashMap<String, String>) -> String {
use alef_core::ir::TypeRef;
let base = if param.is_ref {
let mutability = if param.is_mut { "mut " } else { "" };
match ¶m.ty {
TypeRef::String => format!("&{mutability}str"),
TypeRef::Bytes => format!("&{mutability}[u8]"),
TypeRef::Path => format!("&{mutability}std::path::Path"),
TypeRef::Vec(inner) => format!("&{mutability}[{}]", format_type_ref(inner, type_paths)),
TypeRef::Named(name) => {
let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
format!("&{mutability}{qualified}")
}
TypeRef::Optional(inner) => {
let inner_type_str = match inner.as_ref() {
TypeRef::String => format!("&{mutability}str"),
TypeRef::Bytes => format!("&{mutability}[u8]"),
TypeRef::Path => format!("&{mutability}std::path::Path"),
TypeRef::Vec(v) => format!("&{mutability}[{}]", format_type_ref(v, type_paths)),
TypeRef::Named(name) => {
let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
format!("&{mutability}{qualified}")
}
other => format_type_ref(other, type_paths),
};
return format!("Option<{inner_type_str}>");
}
other => format_type_ref(other, type_paths),
}
} else {
format_type_ref(¶m.ty, type_paths)
};
if param.optional {
format!("Option<{base}>")
} else {
base
}
}
pub fn prim(p: &PrimitiveType) -> &'static str {
use PrimitiveType::*;
match p {
Bool => "bool",
U8 => "u8",
U16 => "u16",
U32 => "u32",
U64 => "u64",
I8 => "i8",
I16 => "i16",
I32 => "i32",
I64 => "i64",
F32 => "f32",
F64 => "f64",
Usize => "usize",
Isize => "isize",
}
}
pub fn bridge_param_type(ty: &TypeRef, ci: &str, is_ref: bool, tp: &HashMap<String, String>) -> String {
match ty {
TypeRef::Bytes if is_ref => "&[u8]".into(),
TypeRef::Bytes => "Vec<u8>".into(),
TypeRef::String if is_ref => "&str".into(),
TypeRef::String => "String".into(),
TypeRef::Path if is_ref => "&std::path::Path".into(),
TypeRef::Path => "std::path::PathBuf".into(),
TypeRef::Named(n) => {
let qualified = tp.get(n).cloned().unwrap_or_else(|| format!("{ci}::{n}"));
if is_ref { format!("&{qualified}") } else { qualified }
}
TypeRef::Vec(inner) => format!("Vec<{}>", bridge_param_type(inner, ci, false, tp)),
TypeRef::Optional(inner) => format!("Option<{}>", bridge_param_type(inner, ci, false, tp)),
TypeRef::Primitive(p) => prim(p).into(),
TypeRef::Unit => "()".into(),
TypeRef::Char => "char".into(),
TypeRef::Map(k, v) => format!(
"std::collections::HashMap<{}, {}>",
bridge_param_type(k, ci, false, tp),
bridge_param_type(v, ci, false, tp)
),
TypeRef::Json => "serde_json::Value".into(),
TypeRef::Duration => "std::time::Duration".into(),
}
}
pub fn visitor_param_type(ty: &TypeRef, is_ref: bool, optional: bool, tp: &HashMap<String, String>) -> String {
if optional && matches!(ty, TypeRef::String) && is_ref {
return "Option<&str>".to_string();
}
if is_ref {
if let TypeRef::Vec(inner) = ty {
let inner_str = bridge_param_type(inner, "", false, tp);
return format!("&[{inner_str}]");
}
}
bridge_param_type(ty, "", is_ref, tp)
}
pub fn find_bridge_param<'a>(
func: &FunctionDef,
bridges: &'a [TraitBridgeConfig],
) -> Option<(usize, &'a TraitBridgeConfig)> {
for (idx, param) in func.params.iter().enumerate() {
let named = match ¶m.ty {
TypeRef::Named(n) => Some(n.as_str()),
TypeRef::Optional(inner) => {
if let TypeRef::Named(n) = inner.as_ref() {
Some(n.as_str())
} else {
None
}
}
_ => None,
};
for bridge in bridges {
if bridge.bind_via != BridgeBinding::FunctionParam {
continue;
}
if let Some(type_name) = named {
if bridge.type_alias.as_deref() == Some(type_name) {
return Some((idx, bridge));
}
}
if bridge.param_name.as_deref() == Some(param.name.as_str()) {
return Some((idx, bridge));
}
}
}
None
}
#[derive(Debug, Clone)]
pub struct BridgeFieldMatch<'a> {
pub param_index: usize,
pub param_name: String,
pub options_type: String,
pub param_is_optional: bool,
pub field_name: String,
pub field: &'a FieldDef,
pub bridge: &'a TraitBridgeConfig,
}
pub fn find_bridge_field<'a>(
func: &FunctionDef,
types: &'a [TypeDef],
bridges: &'a [TraitBridgeConfig],
) -> Option<BridgeFieldMatch<'a>> {
fn unwrap_named(ty: &TypeRef) -> Option<(&str, bool)> {
match ty {
TypeRef::Named(n) => Some((n.as_str(), false)),
TypeRef::Optional(inner) => {
if let TypeRef::Named(n) = inner.as_ref() {
Some((n.as_str(), true))
} else {
None
}
}
_ => None,
}
}
for (idx, param) in func.params.iter().enumerate() {
let Some((type_name, is_optional)) = unwrap_named(¶m.ty) else {
continue;
};
let Some(type_def) = types.iter().find(|t| t.name == type_name) else {
continue;
};
for bridge in bridges {
if bridge.bind_via != BridgeBinding::OptionsField {
continue;
}
if bridge.options_type.as_deref() != Some(type_name) {
continue;
}
let field_name = bridge.resolved_options_field();
for field in &type_def.fields {
let matches_name = field_name.is_some_and(|n| field.name == n);
let matches_alias = bridge
.type_alias
.as_deref()
.is_some_and(|alias| field_type_matches_alias(&field.ty, alias));
if matches_name || matches_alias {
return Some(BridgeFieldMatch {
param_index: idx,
param_name: param.name.clone(),
options_type: type_name.to_string(),
param_is_optional: is_optional,
field_name: field.name.clone(),
field,
bridge,
});
}
}
}
}
None
}
fn field_type_matches_alias(field_ty: &TypeRef, alias: &str) -> bool {
match field_ty {
TypeRef::Named(n) => n == alias,
TypeRef::Optional(inner) | TypeRef::Vec(inner) => field_type_matches_alias(inner, alias),
_ => false,
}
}
pub fn to_camel_case(s: &str) -> String {
let mut result = String::new();
let mut capitalize_next = false;
for ch in s.chars() {
if ch == '_' {
capitalize_next = true;
} else if capitalize_next {
result.push(ch.to_ascii_uppercase());
capitalize_next = false;
} else {
result.push(ch);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use alef_core::config::TraitBridgeConfig;
use alef_core::ir::{MethodDef, ParamDef, PrimitiveType, ReceiverKind, TypeDef, TypeRef};
fn make_trait_bridge_config(super_trait: Option<&str>, register_fn: Option<&str>) -> TraitBridgeConfig {
TraitBridgeConfig {
trait_name: "OcrBackend".to_string(),
super_trait: super_trait.map(str::to_string),
registry_getter: None,
register_fn: register_fn.map(str::to_string),
unregister_fn: None,
clear_fn: None,
type_alias: None,
param_name: None,
register_extra_args: None,
exclude_languages: Vec::new(),
bind_via: BridgeBinding::FunctionParam,
options_type: None,
options_field: None,
}
}
fn make_type_def(name: &str, rust_path: &str, methods: Vec<MethodDef>) -> TypeDef {
TypeDef {
name: name.to_string(),
rust_path: rust_path.to_string(),
original_rust_path: rust_path.to_string(),
fields: vec![],
methods,
is_opaque: true,
is_clone: false,
is_copy: false,
doc: String::new(),
cfg: None,
is_trait: true,
has_default: false,
has_stripped_cfg_fields: false,
is_return_type: false,
serde_rename_all: None,
has_serde: false,
super_traits: vec![],
}
}
fn make_method(
name: &str,
params: Vec<ParamDef>,
return_type: TypeRef,
is_async: bool,
has_default_impl: bool,
trait_source: Option<&str>,
error_type: Option<&str>,
) -> MethodDef {
MethodDef {
name: name.to_string(),
params,
return_type,
is_async,
is_static: false,
error_type: error_type.map(str::to_string),
doc: String::new(),
receiver: Some(ReceiverKind::Ref),
sanitized: false,
trait_source: trait_source.map(str::to_string),
returns_ref: false,
returns_cow: false,
return_newtype_wrapper: None,
has_default_impl,
}
}
fn make_func(name: &str, params: Vec<ParamDef>) -> FunctionDef {
FunctionDef {
name: name.to_string(),
rust_path: format!("mylib::{name}"),
original_rust_path: String::new(),
params,
return_type: TypeRef::Unit,
is_async: false,
error_type: None,
doc: String::new(),
cfg: None,
sanitized: false,
return_sanitized: false,
returns_ref: false,
returns_cow: false,
return_newtype_wrapper: None,
}
}
fn make_field(name: &str, ty: TypeRef) -> FieldDef {
FieldDef {
name: name.to_string(),
ty,
optional: false,
default: None,
doc: String::new(),
sanitized: false,
is_boxed: false,
type_rust_path: None,
cfg: None,
typed_default: None,
core_wrapper: Default::default(),
vec_inner_core_wrapper: Default::default(),
newtype_wrapper: None,
serde_rename: None,
serde_flatten: false,
}
}
fn make_param(name: &str, ty: TypeRef, is_ref: bool) -> ParamDef {
ParamDef {
name: name.to_string(),
ty,
optional: false,
default: None,
sanitized: false,
typed_default: None,
is_ref,
is_mut: false,
newtype_wrapper: None,
original_type: None,
}
}
fn make_spec<'a>(
trait_def: &'a TypeDef,
bridge_config: &'a TraitBridgeConfig,
wrapper_prefix: &'a str,
type_paths: HashMap<String, String>,
) -> TraitBridgeSpec<'a> {
TraitBridgeSpec {
trait_def,
bridge_config,
core_import: "mylib",
wrapper_prefix,
type_paths,
error_type: "MyError".to_string(),
error_constructor: "MyError::from({msg})".to_string(),
}
}
struct MockBridgeGenerator;
impl TraitBridgeGenerator for MockBridgeGenerator {
fn foreign_object_type(&self) -> &str {
"Py<PyAny>"
}
fn bridge_imports(&self) -> Vec<String> {
vec!["pyo3::prelude::*".to_string(), "pyo3::types::PyString".to_string()]
}
fn gen_sync_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
format!("// sync body for {}", method.name)
}
fn gen_async_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
format!("// async body for {}", method.name)
}
fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String {
format!(
"impl {} {{\n pub fn new(obj: Py<PyAny>) -> Self {{ Self {{ inner: obj, cached_name: String::new() }} }}\n}}",
spec.wrapper_name()
)
}
fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String {
let fn_name = spec.bridge_config.register_fn.as_deref().unwrap_or("register");
format!("pub fn {fn_name}(obj: Py<PyAny>) {{ /* register */ }}")
}
}
#[test]
fn test_wrapper_name() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
assert_eq!(spec.wrapper_name(), "PyOcrBackendBridge");
}
#[test]
fn test_trait_snake() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
assert_eq!(spec.trait_snake(), "ocr_backend");
}
#[test]
fn test_trait_path_replaces_hyphens() {
let trait_def = make_type_def("OcrBackend", "my-lib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
assert_eq!(spec.trait_path(), "my_lib::OcrBackend");
}
#[test]
fn test_required_methods_filters_no_default_impl() {
let methods = vec![
make_method("process", vec![], TypeRef::String, false, false, None, None),
make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
make_method("detect", vec![], TypeRef::String, false, false, None, None),
];
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let required = spec.required_methods();
assert_eq!(required.len(), 2);
assert!(required.iter().any(|m| m.name == "process"));
assert!(required.iter().any(|m| m.name == "detect"));
}
#[test]
fn test_optional_methods_filters_has_default_impl() {
let methods = vec![
make_method("process", vec![], TypeRef::String, false, false, None, None),
make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
make_method("shutdown", vec![], TypeRef::Unit, false, true, None, None),
];
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let optional = spec.optional_methods();
assert_eq!(optional.len(), 2);
assert!(optional.iter().any(|m| m.name == "initialize"));
assert!(optional.iter().any(|m| m.name == "shutdown"));
}
#[test]
fn test_error_path() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
assert_eq!(spec.error_path(), "mylib::MyError");
}
#[test]
fn test_format_type_ref_primitives() {
let paths = HashMap::new();
let cases: Vec<(TypeRef, &str)> = vec![
(TypeRef::Primitive(PrimitiveType::Bool), "bool"),
(TypeRef::Primitive(PrimitiveType::U8), "u8"),
(TypeRef::Primitive(PrimitiveType::U16), "u16"),
(TypeRef::Primitive(PrimitiveType::U32), "u32"),
(TypeRef::Primitive(PrimitiveType::U64), "u64"),
(TypeRef::Primitive(PrimitiveType::I8), "i8"),
(TypeRef::Primitive(PrimitiveType::I16), "i16"),
(TypeRef::Primitive(PrimitiveType::I32), "i32"),
(TypeRef::Primitive(PrimitiveType::I64), "i64"),
(TypeRef::Primitive(PrimitiveType::F32), "f32"),
(TypeRef::Primitive(PrimitiveType::F64), "f64"),
(TypeRef::Primitive(PrimitiveType::Usize), "usize"),
(TypeRef::Primitive(PrimitiveType::Isize), "isize"),
];
for (ty, expected) in cases {
assert_eq!(format_type_ref(&ty, &paths), expected, "mismatch for {expected}");
}
}
#[test]
fn test_format_type_ref_string() {
assert_eq!(format_type_ref(&TypeRef::String, &HashMap::new()), "String");
}
#[test]
fn test_format_type_ref_char() {
assert_eq!(format_type_ref(&TypeRef::Char, &HashMap::new()), "char");
}
#[test]
fn test_format_type_ref_bytes() {
assert_eq!(format_type_ref(&TypeRef::Bytes, &HashMap::new()), "Vec<u8>");
}
#[test]
fn test_format_type_ref_path() {
assert_eq!(format_type_ref(&TypeRef::Path, &HashMap::new()), "std::path::PathBuf");
}
#[test]
fn test_format_type_ref_unit() {
assert_eq!(format_type_ref(&TypeRef::Unit, &HashMap::new()), "()");
}
#[test]
fn test_format_type_ref_json() {
assert_eq!(format_type_ref(&TypeRef::Json, &HashMap::new()), "serde_json::Value");
}
#[test]
fn test_format_type_ref_duration() {
assert_eq!(
format_type_ref(&TypeRef::Duration, &HashMap::new()),
"std::time::Duration"
);
}
#[test]
fn test_format_type_ref_optional() {
let ty = TypeRef::Optional(Box::new(TypeRef::String));
assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<String>");
}
#[test]
fn test_format_type_ref_optional_nested() {
let ty = TypeRef::Optional(Box::new(TypeRef::Optional(Box::new(TypeRef::Primitive(
PrimitiveType::U32,
)))));
assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<Option<u32>>");
}
#[test]
fn test_format_type_ref_vec() {
let ty = TypeRef::Vec(Box::new(TypeRef::Primitive(PrimitiveType::U8)));
assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<u8>");
}
#[test]
fn test_format_type_ref_vec_nested() {
let ty = TypeRef::Vec(Box::new(TypeRef::Vec(Box::new(TypeRef::String))));
assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<Vec<String>>");
}
#[test]
fn test_format_type_ref_map() {
let ty = TypeRef::Map(
Box::new(TypeRef::String),
Box::new(TypeRef::Primitive(PrimitiveType::I64)),
);
assert_eq!(
format_type_ref(&ty, &HashMap::new()),
"std::collections::HashMap<String, i64>"
);
}
#[test]
fn test_format_type_ref_map_nested_value() {
let ty = TypeRef::Map(
Box::new(TypeRef::String),
Box::new(TypeRef::Vec(Box::new(TypeRef::String))),
);
assert_eq!(
format_type_ref(&ty, &HashMap::new()),
"std::collections::HashMap<String, Vec<String>>"
);
}
#[test]
fn test_format_type_ref_named_without_type_paths() {
let ty = TypeRef::Named("Config".to_string());
assert_eq!(format_type_ref(&ty, &HashMap::new()), "Config");
}
#[test]
fn test_format_type_ref_named_with_type_paths() {
let ty = TypeRef::Named("Config".to_string());
let mut paths = HashMap::new();
paths.insert("Config".to_string(), "mylib::Config".to_string());
assert_eq!(format_type_ref(&ty, &paths), "mylib::Config");
}
#[test]
fn test_format_type_ref_named_not_in_type_paths_falls_back_to_name() {
let ty = TypeRef::Named("Unknown".to_string());
let mut paths = HashMap::new();
paths.insert("Other".to_string(), "mylib::Other".to_string());
assert_eq!(format_type_ref(&ty, &paths), "Unknown");
}
#[test]
fn test_format_param_type_string_ref() {
let param = make_param("input", TypeRef::String, true);
assert_eq!(format_param_type(¶m, &HashMap::new()), "&str");
}
#[test]
fn test_format_param_type_string_owned() {
let param = make_param("input", TypeRef::String, false);
assert_eq!(format_param_type(¶m, &HashMap::new()), "String");
}
#[test]
fn test_format_param_type_bytes_ref() {
let param = make_param("data", TypeRef::Bytes, true);
assert_eq!(format_param_type(¶m, &HashMap::new()), "&[u8]");
}
#[test]
fn test_format_param_type_bytes_owned() {
let param = make_param("data", TypeRef::Bytes, false);
assert_eq!(format_param_type(¶m, &HashMap::new()), "Vec<u8>");
}
#[test]
fn test_format_param_type_path_ref() {
let param = make_param("path", TypeRef::Path, true);
assert_eq!(format_param_type(¶m, &HashMap::new()), "&std::path::Path");
}
#[test]
fn test_format_param_type_path_owned() {
let param = make_param("path", TypeRef::Path, false);
assert_eq!(format_param_type(¶m, &HashMap::new()), "std::path::PathBuf");
}
#[test]
fn test_format_param_type_vec_ref() {
let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), true);
assert_eq!(format_param_type(¶m, &HashMap::new()), "&[String]");
}
#[test]
fn test_format_param_type_vec_owned() {
let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), false);
assert_eq!(format_param_type(¶m, &HashMap::new()), "Vec<String>");
}
#[test]
fn test_format_param_type_named_ref_with_type_paths() {
let mut paths = HashMap::new();
paths.insert("Config".to_string(), "mylib::Config".to_string());
let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
assert_eq!(format_param_type(¶m, &paths), "&mylib::Config");
}
#[test]
fn test_format_param_type_named_ref_without_type_paths() {
let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
assert_eq!(format_param_type(¶m, &HashMap::new()), "&Config");
}
#[test]
fn test_format_param_type_primitive_ref_passes_by_value() {
let param = make_param("count", TypeRef::Primitive(PrimitiveType::U32), true);
assert_eq!(format_param_type(¶m, &HashMap::new()), "u32");
}
#[test]
fn test_format_param_type_unit_ref_passes_by_value() {
let param = make_param("nothing", TypeRef::Unit, true);
assert_eq!(format_param_type(¶m, &HashMap::new()), "()");
}
#[test]
fn test_format_return_type_without_error() {
let result = format_return_type(&TypeRef::String, None, &HashMap::new(), false);
assert_eq!(result, "String");
}
#[test]
fn test_format_return_type_with_error() {
let result = format_return_type(&TypeRef::String, Some("MyError"), &HashMap::new(), false);
assert_eq!(result, "std::result::Result<String, MyError>");
}
#[test]
fn test_format_return_type_unit_with_error() {
let result = format_return_type(
&TypeRef::Unit,
Some("Box<dyn std::error::Error>"),
&HashMap::new(),
false,
);
assert_eq!(result, "std::result::Result<(), Box<dyn std::error::Error>>");
}
#[test]
fn test_format_return_type_named_with_type_paths_and_error() {
let mut paths = HashMap::new();
paths.insert("Output".to_string(), "mylib::Output".to_string());
let result = format_return_type(
&TypeRef::Named("Output".to_string()),
Some("mylib::MyError"),
&paths,
false,
);
assert_eq!(result, "std::result::Result<mylib::Output, mylib::MyError>");
}
#[test]
fn test_format_return_type_vec_string_with_returns_ref() {
let result = format_return_type(&TypeRef::Vec(Box::new(TypeRef::String)), None, &HashMap::new(), true);
assert_eq!(result, "&[&str]", "Vec<String> + returns_ref must yield &[&str]");
}
#[test]
fn test_format_return_type_vec_no_returns_ref_unchanged() {
let result = format_return_type(&TypeRef::Vec(Box::new(TypeRef::String)), None, &HashMap::new(), false);
assert_eq!(
result, "Vec<String>",
"Vec<String> without returns_ref must stay Vec<String>"
);
}
#[test]
fn test_gen_bridge_wrapper_struct_contains_struct_name() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_wrapper_struct(&spec, &generator);
assert!(
result.contains("pub struct PyOcrBackendBridge"),
"missing struct declaration in:\n{result}"
);
}
#[test]
fn test_gen_bridge_wrapper_struct_contains_inner_field() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_wrapper_struct(&spec, &generator);
assert!(result.contains("inner: Py<PyAny>"), "missing inner field in:\n{result}");
}
#[test]
fn test_gen_bridge_wrapper_struct_contains_cached_name() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_wrapper_struct(&spec, &generator);
assert!(
result.contains("cached_name: String"),
"missing cached_name field in:\n{result}"
);
}
#[test]
fn test_gen_bridge_plugin_impl_returns_none_when_no_super_trait() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
assert!(gen_bridge_plugin_impl(&spec, &generator).is_none());
}
#[test]
fn test_gen_bridge_plugin_impl_returns_some_when_super_trait_configured() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(Some("Plugin"), None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
assert!(gen_bridge_plugin_impl(&spec, &generator).is_some());
}
#[test]
fn test_gen_bridge_plugin_impl_uses_qualified_super_trait_path() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(Some("Plugin"), None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
assert!(
result.contains("impl mylib::Plugin for PyOcrBackendBridge"),
"missing qualified super-trait path in:\n{result}"
);
}
#[test]
fn test_gen_bridge_plugin_impl_uses_already_qualified_super_trait_path() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(Some("other_crate::Plugin"), None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
assert!(
result.contains("impl other_crate::Plugin for PyOcrBackendBridge"),
"wrong super-trait path in:\n{result}"
);
}
#[test]
fn test_gen_bridge_plugin_impl_contains_name_fn() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(Some("Plugin"), None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
assert!(
result.contains("fn name(") && result.contains("cached_name"),
"missing name() using cached_name in:\n{result}"
);
}
#[test]
fn test_gen_bridge_plugin_impl_contains_version_fn() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(Some("Plugin"), None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
assert!(result.contains("fn version("), "missing version() in:\n{result}");
}
#[test]
fn test_gen_bridge_plugin_impl_contains_initialize_fn() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(Some("Plugin"), None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
assert!(result.contains("fn initialize("), "missing initialize() in:\n{result}");
}
#[test]
fn test_gen_bridge_plugin_impl_contains_shutdown_fn() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(Some("Plugin"), None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
assert!(result.contains("fn shutdown("), "missing shutdown() in:\n{result}");
}
#[test]
fn test_gen_bridge_trait_impl_includes_impl_header() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_trait_impl(&spec, &generator);
assert!(
result.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
"missing impl header in:\n{result}"
);
}
#[test]
fn test_gen_bridge_trait_impl_includes_method_signatures() {
let methods = vec![make_method(
"process",
vec![],
TypeRef::String,
false,
false,
None,
None,
)];
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_trait_impl(&spec, &generator);
assert!(result.contains("fn process("), "missing method signature in:\n{result}");
}
#[test]
fn test_gen_bridge_trait_impl_includes_method_body_from_generator() {
let methods = vec![make_method(
"process",
vec![],
TypeRef::String,
false,
false,
None,
None,
)];
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_trait_impl(&spec, &generator);
assert!(
result.contains("// sync body for process"),
"missing sync method body in:\n{result}"
);
}
#[test]
fn test_gen_bridge_trait_impl_async_method_uses_async_body() {
let methods = vec![make_method(
"process_async",
vec![],
TypeRef::String,
true,
false,
None,
None,
)];
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_trait_impl(&spec, &generator);
assert!(
result.contains("// async body for process_async"),
"missing async method body in:\n{result}"
);
assert!(
result.contains("async fn process_async("),
"missing async keyword in method signature in:\n{result}"
);
}
#[test]
fn test_gen_bridge_trait_impl_filters_trait_source_methods() {
let methods = vec![
make_method("own_method", vec![], TypeRef::String, false, false, None, None),
make_method(
"inherited_method",
vec![],
TypeRef::String,
false,
false,
Some("other_crate::OtherTrait"),
None,
),
];
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_trait_impl(&spec, &generator);
assert!(
result.contains("fn own_method("),
"own method should be present in:\n{result}"
);
assert!(
!result.contains("fn inherited_method("),
"inherited method should be filtered out in:\n{result}"
);
}
#[test]
fn test_gen_bridge_trait_impl_method_with_params() {
let params = vec![
make_param("input", TypeRef::String, true),
make_param("count", TypeRef::Primitive(PrimitiveType::U32), false),
];
let methods = vec![make_method(
"process",
params,
TypeRef::String,
false,
false,
None,
None,
)];
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_trait_impl(&spec, &generator);
assert!(result.contains("input: &str"), "missing &str param in:\n{result}");
assert!(result.contains("count: u32"), "missing u32 param in:\n{result}");
}
#[test]
fn test_gen_bridge_trait_impl_return_type_with_error() {
let methods = vec![make_method(
"process",
vec![],
TypeRef::String,
false,
false,
None,
Some("MyError"),
)];
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_trait_impl(&spec, &generator);
assert!(
result.contains("-> std::result::Result<String, mylib::MyError>"),
"missing std::result::Result return type in:\n{result}"
);
}
#[test]
fn test_gen_bridge_registration_fn_returns_none_without_register_fn() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
assert!(gen_bridge_registration_fn(&spec, &generator).is_none());
}
#[test]
fn test_gen_bridge_registration_fn_returns_some_with_register_fn() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let result = gen_bridge_registration_fn(&spec, &generator);
assert!(result.is_some());
let code = result.unwrap();
assert!(
code.contains("register_ocr_backend"),
"missing register fn name in:\n{code}"
);
}
#[test]
fn test_gen_bridge_all_includes_imports() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let output = gen_bridge_all(&spec, &generator);
assert!(output.imports.contains(&"pyo3::prelude::*".to_string()));
assert!(output.imports.contains(&"pyo3::types::PyString".to_string()));
}
#[test]
fn test_gen_bridge_all_includes_wrapper_struct() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let output = gen_bridge_all(&spec, &generator);
assert!(
output.code.contains("pub struct PyOcrBackendBridge"),
"missing struct in:\n{}",
output.code
);
}
#[test]
fn test_gen_bridge_all_includes_constructor() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let output = gen_bridge_all(&spec, &generator);
assert!(
output.code.contains("pub fn new("),
"missing constructor in:\n{}",
output.code
);
}
#[test]
fn test_gen_bridge_all_includes_trait_impl() {
let methods = vec![make_method(
"process",
vec![],
TypeRef::String,
false,
false,
None,
None,
)];
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let output = gen_bridge_all(&spec, &generator);
assert!(
output.code.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
"missing trait impl in:\n{}",
output.code
);
}
#[test]
fn test_gen_bridge_all_includes_plugin_impl_when_super_trait_set() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(Some("Plugin"), None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let output = gen_bridge_all(&spec, &generator);
assert!(
output.code.contains("impl mylib::Plugin for PyOcrBackendBridge"),
"missing plugin impl in:\n{}",
output.code
);
}
#[test]
fn test_gen_bridge_all_no_plugin_impl_when_no_super_trait() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let output = gen_bridge_all(&spec, &generator);
assert!(
!output.code.contains("fn name(") || !output.code.contains("cached_name"),
"unexpected plugin impl present without super_trait"
);
}
#[test]
fn test_gen_bridge_all_includes_registration_fn_when_configured() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let output = gen_bridge_all(&spec, &generator);
assert!(
output.code.contains("register_ocr_backend"),
"missing registration fn in:\n{}",
output.code
);
}
#[test]
fn test_gen_bridge_all_no_registration_fn_when_absent() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let output = gen_bridge_all(&spec, &generator);
assert!(
!output.code.contains("register_ocr_backend"),
"unexpected registration fn present:\n{}",
output.code
);
}
#[test]
fn test_gen_bridge_all_ordering_struct_before_trait_impl() {
let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
let config = make_trait_bridge_config(None, None);
let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
let generator = MockBridgeGenerator;
let output = gen_bridge_all(&spec, &generator);
let struct_pos = output.code.find("pub struct PyOcrBackendBridge").unwrap();
let impl_pos = output
.code
.find("impl mylib::OcrBackend for PyOcrBackendBridge")
.unwrap();
assert!(struct_pos < impl_pos, "struct should appear before trait impl");
}
fn make_bridge(
type_alias: Option<&str>,
param_name: Option<&str>,
bind_via: BridgeBinding,
options_type: Option<&str>,
options_field: Option<&str>,
) -> TraitBridgeConfig {
TraitBridgeConfig {
trait_name: "HtmlVisitor".to_string(),
super_trait: None,
registry_getter: None,
register_fn: None,
unregister_fn: None,
clear_fn: None,
type_alias: type_alias.map(str::to_string),
param_name: param_name.map(str::to_string),
register_extra_args: None,
exclude_languages: vec![],
bind_via,
options_type: options_type.map(str::to_string),
options_field: options_field.map(str::to_string),
}
}
#[test]
fn find_bridge_param_returns_first_param_match_in_function_param_mode() {
let func = make_func(
"convert",
vec![
make_param("html", TypeRef::String, true),
make_param("visitor", TypeRef::Named("VisitorHandle".to_string()), false),
],
);
let bridges = vec![make_bridge(
Some("VisitorHandle"),
Some("visitor"),
BridgeBinding::FunctionParam,
None,
None,
)];
let result = find_bridge_param(&func, &bridges).expect("bridge match");
assert_eq!(result.0, 1);
}
#[test]
fn find_bridge_param_skips_options_field_bridges() {
let func = make_func(
"convert",
vec![
make_param("html", TypeRef::String, true),
make_param("visitor", TypeRef::Named("VisitorHandle".to_string()), false),
],
);
let bridges = vec![make_bridge(
Some("VisitorHandle"),
Some("visitor"),
BridgeBinding::OptionsField,
Some("ConversionOptions"),
Some("visitor"),
)];
assert!(
find_bridge_param(&func, &bridges).is_none(),
"bridges configured with bind_via=options_field must not be returned by find_bridge_param"
);
}
#[test]
fn find_bridge_field_detects_field_via_alias() {
let opts_type = TypeDef {
name: "ConversionOptions".to_string(),
rust_path: "mylib::ConversionOptions".to_string(),
original_rust_path: String::new(),
fields: vec![
make_field("debug", TypeRef::Primitive(PrimitiveType::Bool)),
make_field(
"visitor",
TypeRef::Optional(Box::new(TypeRef::Named("VisitorHandle".to_string()))),
),
],
methods: vec![],
is_opaque: false,
is_clone: true,
is_copy: false,
doc: String::new(),
cfg: None,
is_trait: false,
has_default: true,
has_stripped_cfg_fields: false,
is_return_type: false,
serde_rename_all: None,
has_serde: false,
super_traits: vec![],
};
let func = make_func(
"convert",
vec![
make_param("html", TypeRef::String, true),
make_param(
"options",
TypeRef::Optional(Box::new(TypeRef::Named("ConversionOptions".to_string()))),
false,
),
],
);
let bridges = vec![make_bridge(
Some("VisitorHandle"),
Some("visitor"),
BridgeBinding::OptionsField,
Some("ConversionOptions"),
None,
)];
let m = find_bridge_field(&func, std::slice::from_ref(&opts_type), &bridges).expect("bridge field match");
assert_eq!(m.param_index, 1);
assert_eq!(m.param_name, "options");
assert_eq!(m.options_type, "ConversionOptions");
assert!(m.param_is_optional);
assert_eq!(m.field_name, "visitor");
}
#[test]
fn find_bridge_field_returns_none_for_function_param_bridge() {
let opts_type = TypeDef {
name: "ConversionOptions".to_string(),
rust_path: "mylib::ConversionOptions".to_string(),
original_rust_path: String::new(),
fields: vec![make_field(
"visitor",
TypeRef::Optional(Box::new(TypeRef::Named("VisitorHandle".to_string()))),
)],
methods: vec![],
is_opaque: false,
is_clone: true,
is_copy: false,
doc: String::new(),
cfg: None,
is_trait: false,
has_default: true,
has_stripped_cfg_fields: false,
is_return_type: false,
serde_rename_all: None,
has_serde: false,
super_traits: vec![],
};
let func = make_func(
"convert",
vec![make_param(
"options",
TypeRef::Named("ConversionOptions".to_string()),
false,
)],
);
let bridges = vec![make_bridge(
Some("VisitorHandle"),
Some("visitor"),
BridgeBinding::FunctionParam,
None,
None,
)];
assert!(find_bridge_field(&func, std::slice::from_ref(&opts_type), &bridges).is_none());
}
}