1use alef_core::config::{BridgeBinding, TraitBridgeConfig};
9use alef_core::ir::{FieldDef, FunctionDef, MethodDef, ParamDef, PrimitiveType, TypeDef, TypeRef};
10use heck::ToSnakeCase;
11use std::collections::HashMap;
12
13pub struct TraitBridgeSpec<'a> {
15 pub trait_def: &'a TypeDef,
17 pub bridge_config: &'a TraitBridgeConfig,
19 pub core_import: &'a str,
21 pub wrapper_prefix: &'a str,
23 pub type_paths: HashMap<String, String>,
25 pub error_type: String,
27 pub error_constructor: String,
29}
30
31impl<'a> TraitBridgeSpec<'a> {
32 pub fn error_path(&self) -> String {
39 if self.error_type.contains("::") || self.error_type.contains('<') {
40 self.error_type.clone()
41 } else {
42 format!("{}::{}", self.core_import, self.error_type)
43 }
44 }
45
46 pub fn make_error(&self, msg_expr: &str) -> String {
48 self.error_constructor.replace("{msg}", msg_expr)
49 }
50
51 pub fn wrapper_name(&self) -> String {
53 format!("{}{}Bridge", self.wrapper_prefix, self.trait_def.name)
54 }
55
56 pub fn trait_snake(&self) -> String {
58 self.trait_def.name.to_snake_case()
59 }
60
61 pub fn trait_path(&self) -> String {
63 self.trait_def.rust_path.replace('-', "_")
64 }
65
66 pub fn required_methods(&self) -> Vec<&'a MethodDef> {
68 self.trait_def.methods.iter().filter(|m| !m.has_default_impl).collect()
69 }
70
71 pub fn optional_methods(&self) -> Vec<&'a MethodDef> {
73 self.trait_def.methods.iter().filter(|m| m.has_default_impl).collect()
74 }
75}
76
77pub trait TraitBridgeGenerator {
83 fn foreign_object_type(&self) -> &str;
85
86 fn bridge_imports(&self) -> Vec<String>;
88
89 fn gen_sync_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
94
95 fn gen_async_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
99
100 fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String;
105
106 fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String;
112
113 fn gen_unregistration_fn(&self, _spec: &TraitBridgeSpec) -> String {
120 String::new()
121 }
122
123 fn gen_clear_fn(&self, _spec: &TraitBridgeSpec) -> String {
130 String::new()
131 }
132
133 fn async_trait_is_send(&self) -> bool {
138 true
139 }
140}
141
142pub fn gen_bridge_wrapper_struct(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
156 let wrapper = spec.wrapper_name();
157 let foreign_type = generator.foreign_object_type();
158
159 crate::template_env::render(
160 "generators/trait_bridge/wrapper_struct.jinja",
161 minijinja::context! {
162 wrapper_prefix => spec.wrapper_prefix,
163 trait_name => &spec.trait_def.name,
164 wrapper_name => wrapper,
165 foreign_type => foreign_type,
166 },
167 )
168}
169
170fn gen_bridge_debug_impl(spec: &TraitBridgeSpec) -> String {
176 let wrapper = spec.wrapper_name();
177 format!(
178 "impl std::fmt::Debug for {wrapper} {{\n fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{\n write!(f, \"{wrapper}\")\n }}\n}}"
179 )
180}
181
182pub fn gen_bridge_plugin_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
190 let super_trait_name = spec.bridge_config.super_trait.as_deref()?;
191
192 let wrapper = spec.wrapper_name();
193 let core_import = spec.core_import;
194
195 let super_trait_path = if super_trait_name.contains("::") {
197 super_trait_name.to_string()
198 } else {
199 format!("{core_import}::{super_trait_name}")
200 };
201
202 let error_path = spec.error_path();
206
207 let version_method = MethodDef {
209 name: "version".to_string(),
210 params: vec![],
211 return_type: alef_core::ir::TypeRef::String,
212 is_async: false,
213 is_static: false,
214 error_type: None,
215 doc: String::new(),
216 receiver: Some(alef_core::ir::ReceiverKind::Ref),
217 sanitized: false,
218 trait_source: None,
219 returns_ref: false,
220 returns_cow: false,
221 return_newtype_wrapper: None,
222 has_default_impl: false,
223 };
224 let version_body = generator.gen_sync_method_body(&version_method, spec);
225
226 let init_method = MethodDef {
228 name: "initialize".to_string(),
229 params: vec![],
230 return_type: alef_core::ir::TypeRef::Unit,
231 is_async: false,
232 is_static: false,
233 error_type: Some(error_path.clone()),
234 doc: String::new(),
235 receiver: Some(alef_core::ir::ReceiverKind::Ref),
236 sanitized: false,
237 trait_source: None,
238 returns_ref: false,
239 returns_cow: false,
240 return_newtype_wrapper: None,
241 has_default_impl: true,
242 };
243 let init_body = generator.gen_sync_method_body(&init_method, spec);
244
245 let shutdown_method = MethodDef {
247 name: "shutdown".to_string(),
248 params: vec![],
249 return_type: alef_core::ir::TypeRef::Unit,
250 is_async: false,
251 is_static: false,
252 error_type: Some(error_path.clone()),
253 doc: String::new(),
254 receiver: Some(alef_core::ir::ReceiverKind::Ref),
255 sanitized: false,
256 trait_source: None,
257 returns_ref: false,
258 returns_cow: false,
259 return_newtype_wrapper: None,
260 has_default_impl: true,
261 };
262 let shutdown_body = generator.gen_sync_method_body(&shutdown_method, spec);
263
264 let version_lines: Vec<&str> = version_body.lines().collect();
266 let init_lines: Vec<&str> = init_body.lines().collect();
267 let shutdown_lines: Vec<&str> = shutdown_body.lines().collect();
268
269 Some(crate::template_env::render(
270 "generators/trait_bridge/plugin_impl.jinja",
271 minijinja::context! {
272 super_trait_path => super_trait_path,
273 wrapper_name => wrapper,
274 error_path => error_path,
275 version_lines => version_lines,
276 init_lines => init_lines,
277 shutdown_lines => shutdown_lines,
278 },
279 ))
280}
281
282pub fn gen_bridge_trait_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
288 let wrapper = spec.wrapper_name();
289 let trait_path = spec.trait_path();
290
291 let has_async_methods = spec
294 .trait_def
295 .methods
296 .iter()
297 .any(|m| m.is_async && m.trait_source.is_none() && !m.has_default_impl);
298 let async_trait_is_send = generator.async_trait_is_send();
299
300 let own_methods: Vec<_> = spec
304 .trait_def
305 .methods
306 .iter()
307 .filter(|m| m.trait_source.is_none() && !m.has_default_impl)
308 .collect();
309
310 let mut methods_code = String::with_capacity(1024);
312 for (i, method) in own_methods.iter().enumerate() {
313 if i > 0 {
314 methods_code.push_str("\n\n");
315 }
316
317 let async_kw = if method.is_async { "async " } else { "" };
319 let receiver = match &method.receiver {
320 Some(alef_core::ir::ReceiverKind::Ref) => "&self",
321 Some(alef_core::ir::ReceiverKind::RefMut) => "&mut self",
322 Some(alef_core::ir::ReceiverKind::Owned) => "self",
323 None => "",
324 };
325
326 let params: Vec<String> = method
328 .params
329 .iter()
330 .map(|p| format!("{}: {}", p.name, format_param_type(p, &spec.type_paths)))
331 .collect();
332
333 let all_params = if receiver.is_empty() {
334 params.join(", ")
335 } else if params.is_empty() {
336 receiver.to_string()
337 } else {
338 format!("{}, {}", receiver, params.join(", "))
339 };
340
341 let error_override = method.error_type.as_ref().map(|_| spec.error_path());
346 let ret = format_return_type(
347 &method.return_type,
348 error_override.as_deref(),
349 &spec.type_paths,
350 method.returns_ref,
351 );
352
353 let body = if method.is_async {
355 generator.gen_async_method_body(method, spec)
356 } else {
357 generator.gen_sync_method_body(method, spec)
358 };
359
360 let indented_body = body
362 .lines()
363 .map(|line| format!(" {line}"))
364 .collect::<Vec<_>>()
365 .join("\n");
366
367 methods_code.push_str(&crate::template_env::render(
368 "generators/trait_bridge/trait_method.jinja",
369 minijinja::context! {
370 async_kw => async_kw,
371 method_name => &method.name,
372 all_params => all_params,
373 ret => ret,
374 indented_body => &indented_body,
375 },
376 ));
377 }
378
379 crate::template_env::render(
380 "generators/trait_bridge/trait_impl.jinja",
381 minijinja::context! {
382 has_async_methods => has_async_methods,
383 async_trait_is_send => async_trait_is_send,
384 trait_path => trait_path,
385 wrapper_name => wrapper,
386 methods_code => methods_code,
387 },
388 )
389}
390
391pub fn gen_bridge_registration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
398 spec.bridge_config.register_fn.as_deref()?;
399 Some(generator.gen_registration_fn(spec))
400}
401
402pub fn gen_bridge_unregistration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
409 spec.bridge_config.unregister_fn.as_deref()?;
410 let body = generator.gen_unregistration_fn(spec);
411 if body.is_empty() { None } else { Some(body) }
412}
413
414pub fn gen_bridge_clear_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
421 spec.bridge_config.clear_fn.as_deref()?;
422 let body = generator.gen_clear_fn(spec);
423 if body.is_empty() { None } else { Some(body) }
424}
425
426pub fn host_function_path(spec: &TraitBridgeSpec, fn_name: &str) -> String {
444 if let Some(getter) = spec.bridge_config.registry_getter.as_deref() {
445 let last = getter.rsplit("::").next().unwrap_or("");
446 if let Some(sub) = last.strip_prefix("get_").and_then(|s| s.strip_suffix("_registry")) {
447 let prefix_end = getter.len() - last.len();
448 let prefix = &getter[..prefix_end];
449 let prefix = prefix.trim_end_matches("registry::");
450 return format!("{prefix}{sub}::{fn_name}");
451 }
452 }
453 format!("{}::plugins::{}", spec.core_import, fn_name)
454}
455
456pub struct BridgeOutput {
459 pub imports: Vec<String>,
461 pub code: String,
463}
464
465pub fn gen_bridge_all(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> BridgeOutput {
471 let imports = generator.bridge_imports();
472 let mut out = String::with_capacity(4096);
473
474 out.push_str(&gen_bridge_wrapper_struct(spec, generator));
476 out.push_str("\n\n");
477
478 out.push_str(&gen_bridge_debug_impl(spec));
480 out.push_str("\n\n");
481
482 out.push_str(&generator.gen_constructor(spec));
484 out.push_str("\n\n");
485
486 if let Some(plugin_impl) = gen_bridge_plugin_impl(spec, generator) {
488 out.push_str(&plugin_impl);
489 out.push_str("\n\n");
490 }
491
492 out.push_str(&gen_bridge_trait_impl(spec, generator));
494
495 if let Some(reg_fn_code) = gen_bridge_registration_fn(spec, generator) {
497 out.push_str("\n\n");
498 out.push_str(®_fn_code);
499 }
500
501 if let Some(unreg_fn_code) = gen_bridge_unregistration_fn(spec, generator) {
504 out.push_str("\n\n");
505 out.push_str(&unreg_fn_code);
506 }
507
508 if let Some(clear_fn_code) = gen_bridge_clear_fn(spec, generator) {
511 out.push_str("\n\n");
512 out.push_str(&clear_fn_code);
513 }
514
515 BridgeOutput { imports, code: out }
516}
517
518pub fn format_type_ref(ty: &alef_core::ir::TypeRef, type_paths: &HashMap<String, String>) -> String {
527 use alef_core::ir::{PrimitiveType, TypeRef};
528 match ty {
529 TypeRef::Primitive(p) => match p {
530 PrimitiveType::Bool => "bool",
531 PrimitiveType::U8 => "u8",
532 PrimitiveType::U16 => "u16",
533 PrimitiveType::U32 => "u32",
534 PrimitiveType::U64 => "u64",
535 PrimitiveType::I8 => "i8",
536 PrimitiveType::I16 => "i16",
537 PrimitiveType::I32 => "i32",
538 PrimitiveType::I64 => "i64",
539 PrimitiveType::F32 => "f32",
540 PrimitiveType::F64 => "f64",
541 PrimitiveType::Usize => "usize",
542 PrimitiveType::Isize => "isize",
543 }
544 .to_string(),
545 TypeRef::String => "String".to_string(),
546 TypeRef::Char => "char".to_string(),
547 TypeRef::Bytes => "Vec<u8>".to_string(),
548 TypeRef::Optional(inner) => format!("Option<{}>", format_type_ref(inner, type_paths)),
549 TypeRef::Vec(inner) => format!("Vec<{}>", format_type_ref(inner, type_paths)),
550 TypeRef::Map(k, v) => format!(
551 "std::collections::HashMap<{}, {}>",
552 format_type_ref(k, type_paths),
553 format_type_ref(v, type_paths)
554 ),
555 TypeRef::Named(name) => type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone()),
556 TypeRef::Path => "std::path::PathBuf".to_string(),
557 TypeRef::Unit => "()".to_string(),
558 TypeRef::Json => "serde_json::Value".to_string(),
559 TypeRef::Duration => "std::time::Duration".to_string(),
560 }
561}
562
563pub fn format_return_type(
573 ty: &alef_core::ir::TypeRef,
574 error_type: Option<&str>,
575 type_paths: &HashMap<String, String>,
576 returns_ref: bool,
577) -> String {
578 let inner = if returns_ref {
579 if let alef_core::ir::TypeRef::Vec(elem) = ty {
581 let elem_str = match elem.as_ref() {
582 alef_core::ir::TypeRef::String => "&str".to_string(),
583 alef_core::ir::TypeRef::Bytes => "&[u8]".to_string(),
584 alef_core::ir::TypeRef::Named(name) => {
585 let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
586 format!("&{qualified}")
587 }
588 other => format_type_ref(other, type_paths),
589 };
590 format!("&[{elem_str}]")
591 } else {
592 format_type_ref(ty, type_paths)
593 }
594 } else {
595 format_type_ref(ty, type_paths)
596 };
597 match error_type {
598 Some(err) => format!("std::result::Result<{inner}, {err}>"),
599 None => inner,
600 }
601}
602
603pub fn format_param_type(param: &ParamDef, type_paths: &HashMap<String, String>) -> String {
615 use alef_core::ir::TypeRef;
616 let base = if param.is_ref {
617 let mutability = if param.is_mut { "mut " } else { "" };
618 match ¶m.ty {
619 TypeRef::String => format!("&{mutability}str"),
620 TypeRef::Bytes => format!("&{mutability}[u8]"),
621 TypeRef::Path => format!("&{mutability}std::path::Path"),
622 TypeRef::Vec(inner) => format!("&{mutability}[{}]", format_type_ref(inner, type_paths)),
623 TypeRef::Named(name) => {
624 let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
625 format!("&{mutability}{qualified}")
626 }
627 TypeRef::Optional(inner) => {
628 let inner_type_str = match inner.as_ref() {
632 TypeRef::String => format!("&{mutability}str"),
633 TypeRef::Bytes => format!("&{mutability}[u8]"),
634 TypeRef::Path => format!("&{mutability}std::path::Path"),
635 TypeRef::Vec(v) => format!("&{mutability}[{}]", format_type_ref(v, type_paths)),
636 TypeRef::Named(name) => {
637 let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
638 format!("&{mutability}{qualified}")
639 }
640 other => format_type_ref(other, type_paths),
642 };
643 return format!("Option<{inner_type_str}>");
645 }
646 other => format_type_ref(other, type_paths),
648 }
649 } else {
650 format_type_ref(¶m.ty, type_paths)
651 };
652
653 if param.optional {
657 format!("Option<{base}>")
658 } else {
659 base
660 }
661}
662
663pub fn prim(p: &PrimitiveType) -> &'static str {
669 use PrimitiveType::*;
670 match p {
671 Bool => "bool",
672 U8 => "u8",
673 U16 => "u16",
674 U32 => "u32",
675 U64 => "u64",
676 I8 => "i8",
677 I16 => "i16",
678 I32 => "i32",
679 I64 => "i64",
680 F32 => "f32",
681 F64 => "f64",
682 Usize => "usize",
683 Isize => "isize",
684 }
685}
686
687pub fn bridge_param_type(ty: &TypeRef, ci: &str, is_ref: bool, tp: &HashMap<String, String>) -> String {
691 match ty {
692 TypeRef::Bytes if is_ref => "&[u8]".into(),
693 TypeRef::Bytes => "Vec<u8>".into(),
694 TypeRef::String if is_ref => "&str".into(),
695 TypeRef::String => "String".into(),
696 TypeRef::Path if is_ref => "&std::path::Path".into(),
697 TypeRef::Path => "std::path::PathBuf".into(),
698 TypeRef::Named(n) => {
699 let qualified = tp.get(n).cloned().unwrap_or_else(|| format!("{ci}::{n}"));
700 if is_ref { format!("&{qualified}") } else { qualified }
701 }
702 TypeRef::Vec(inner) => format!("Vec<{}>", bridge_param_type(inner, ci, false, tp)),
703 TypeRef::Optional(inner) => format!("Option<{}>", bridge_param_type(inner, ci, false, tp)),
704 TypeRef::Primitive(p) => prim(p).into(),
705 TypeRef::Unit => "()".into(),
706 TypeRef::Char => "char".into(),
707 TypeRef::Map(k, v) => format!(
708 "std::collections::HashMap<{}, {}>",
709 bridge_param_type(k, ci, false, tp),
710 bridge_param_type(v, ci, false, tp)
711 ),
712 TypeRef::Json => "serde_json::Value".into(),
713 TypeRef::Duration => "std::time::Duration".into(),
714 }
715}
716
717pub fn visitor_param_type(ty: &TypeRef, is_ref: bool, optional: bool, tp: &HashMap<String, String>) -> String {
723 if optional && matches!(ty, TypeRef::String) && is_ref {
724 return "Option<&str>".to_string();
725 }
726 if is_ref {
727 if let TypeRef::Vec(inner) = ty {
728 let inner_str = bridge_param_type(inner, "", false, tp);
729 return format!("&[{inner_str}]");
730 }
731 }
732 bridge_param_type(ty, "", is_ref, tp)
733}
734
735pub fn find_bridge_param<'a>(
742 func: &FunctionDef,
743 bridges: &'a [TraitBridgeConfig],
744) -> Option<(usize, &'a TraitBridgeConfig)> {
745 for (idx, param) in func.params.iter().enumerate() {
746 let named = match ¶m.ty {
747 TypeRef::Named(n) => Some(n.as_str()),
748 TypeRef::Optional(inner) => {
749 if let TypeRef::Named(n) = inner.as_ref() {
750 Some(n.as_str())
751 } else {
752 None
753 }
754 }
755 _ => None,
756 };
757 for bridge in bridges {
758 if bridge.bind_via != BridgeBinding::FunctionParam {
759 continue;
760 }
761 if let Some(type_name) = named {
762 if bridge.type_alias.as_deref() == Some(type_name) {
763 return Some((idx, bridge));
764 }
765 }
766 if bridge.param_name.as_deref() == Some(param.name.as_str()) {
767 return Some((idx, bridge));
768 }
769 }
770 }
771 None
772}
773
774#[derive(Debug, Clone)]
777pub struct BridgeFieldMatch<'a> {
778 pub param_index: usize,
780 pub param_name: String,
782 pub options_type: String,
784 pub param_is_optional: bool,
786 pub field_name: String,
788 pub field: &'a FieldDef,
790 pub bridge: &'a TraitBridgeConfig,
792}
793
794pub fn find_bridge_field<'a>(
805 func: &FunctionDef,
806 types: &'a [TypeDef],
807 bridges: &'a [TraitBridgeConfig],
808) -> Option<BridgeFieldMatch<'a>> {
809 fn unwrap_named(ty: &TypeRef) -> Option<(&str, bool)> {
810 match ty {
811 TypeRef::Named(n) => Some((n.as_str(), false)),
812 TypeRef::Optional(inner) => {
813 if let TypeRef::Named(n) = inner.as_ref() {
814 Some((n.as_str(), true))
815 } else {
816 None
817 }
818 }
819 _ => None,
820 }
821 }
822
823 for (idx, param) in func.params.iter().enumerate() {
824 let Some((type_name, is_optional)) = unwrap_named(¶m.ty) else {
825 continue;
826 };
827 let Some(type_def) = types.iter().find(|t| t.name == type_name) else {
828 continue;
829 };
830 for bridge in bridges {
831 if bridge.bind_via != BridgeBinding::OptionsField {
832 continue;
833 }
834 if bridge.options_type.as_deref() != Some(type_name) {
835 continue;
836 }
837 let field_name = bridge.resolved_options_field();
838 for field in &type_def.fields {
839 let matches_name = field_name.is_some_and(|n| field.name == n);
840 let matches_alias = bridge
841 .type_alias
842 .as_deref()
843 .is_some_and(|alias| field_type_matches_alias(&field.ty, alias));
844 if matches_name || matches_alias {
845 return Some(BridgeFieldMatch {
846 param_index: idx,
847 param_name: param.name.clone(),
848 options_type: type_name.to_string(),
849 param_is_optional: is_optional,
850 field_name: field.name.clone(),
851 field,
852 bridge,
853 });
854 }
855 }
856 }
857 }
858 None
859}
860
861fn field_type_matches_alias(field_ty: &TypeRef, alias: &str) -> bool {
864 match field_ty {
865 TypeRef::Named(n) => n == alias,
866 TypeRef::Optional(inner) | TypeRef::Vec(inner) => field_type_matches_alias(inner, alias),
867 _ => false,
868 }
869}
870
871pub fn to_camel_case(s: &str) -> String {
873 let mut result = String::new();
874 let mut capitalize_next = false;
875 for ch in s.chars() {
876 if ch == '_' {
877 capitalize_next = true;
878 } else if capitalize_next {
879 result.push(ch.to_ascii_uppercase());
880 capitalize_next = false;
881 } else {
882 result.push(ch);
883 }
884 }
885 result
886}
887
888#[cfg(test)]
889mod tests {
890 use super::*;
891 use alef_core::config::TraitBridgeConfig;
892 use alef_core::ir::{MethodDef, ParamDef, PrimitiveType, ReceiverKind, TypeDef, TypeRef};
893
894 fn make_trait_bridge_config(super_trait: Option<&str>, register_fn: Option<&str>) -> TraitBridgeConfig {
899 TraitBridgeConfig {
900 trait_name: "OcrBackend".to_string(),
901 super_trait: super_trait.map(str::to_string),
902 registry_getter: None,
903 register_fn: register_fn.map(str::to_string),
904 unregister_fn: None,
905 clear_fn: None,
906 type_alias: None,
907 param_name: None,
908 register_extra_args: None,
909 exclude_languages: Vec::new(),
910 bind_via: BridgeBinding::FunctionParam,
911 options_type: None,
912 options_field: None,
913 }
914 }
915
916 fn make_type_def(name: &str, rust_path: &str, methods: Vec<MethodDef>) -> TypeDef {
917 TypeDef {
918 name: name.to_string(),
919 rust_path: rust_path.to_string(),
920 original_rust_path: rust_path.to_string(),
921 fields: vec![],
922 methods,
923 is_opaque: true,
924 is_clone: false,
925 is_copy: false,
926 doc: String::new(),
927 cfg: None,
928 is_trait: true,
929 has_default: false,
930 has_stripped_cfg_fields: false,
931 is_return_type: false,
932 serde_rename_all: None,
933 has_serde: false,
934 super_traits: vec![],
935 }
936 }
937
938 fn make_method(
939 name: &str,
940 params: Vec<ParamDef>,
941 return_type: TypeRef,
942 is_async: bool,
943 has_default_impl: bool,
944 trait_source: Option<&str>,
945 error_type: Option<&str>,
946 ) -> MethodDef {
947 MethodDef {
948 name: name.to_string(),
949 params,
950 return_type,
951 is_async,
952 is_static: false,
953 error_type: error_type.map(str::to_string),
954 doc: String::new(),
955 receiver: Some(ReceiverKind::Ref),
956 sanitized: false,
957 trait_source: trait_source.map(str::to_string),
958 returns_ref: false,
959 returns_cow: false,
960 return_newtype_wrapper: None,
961 has_default_impl,
962 }
963 }
964
965 fn make_func(name: &str, params: Vec<ParamDef>) -> FunctionDef {
966 FunctionDef {
967 name: name.to_string(),
968 rust_path: format!("mylib::{name}"),
969 original_rust_path: String::new(),
970 params,
971 return_type: TypeRef::Unit,
972 is_async: false,
973 error_type: None,
974 doc: String::new(),
975 cfg: None,
976 sanitized: false,
977 return_sanitized: false,
978 returns_ref: false,
979 returns_cow: false,
980 return_newtype_wrapper: None,
981 }
982 }
983
984 fn make_field(name: &str, ty: TypeRef) -> FieldDef {
985 FieldDef {
986 name: name.to_string(),
987 ty,
988 optional: false,
989 default: None,
990 doc: String::new(),
991 sanitized: false,
992 is_boxed: false,
993 type_rust_path: None,
994 cfg: None,
995 typed_default: None,
996 core_wrapper: Default::default(),
997 vec_inner_core_wrapper: Default::default(),
998 newtype_wrapper: None,
999 serde_rename: None,
1000 serde_flatten: false,
1001 }
1002 }
1003
1004 fn make_param(name: &str, ty: TypeRef, is_ref: bool) -> ParamDef {
1005 ParamDef {
1006 name: name.to_string(),
1007 ty,
1008 optional: false,
1009 default: None,
1010 sanitized: false,
1011 typed_default: None,
1012 is_ref,
1013 is_mut: false,
1014 newtype_wrapper: None,
1015 original_type: None,
1016 }
1017 }
1018
1019 fn make_spec<'a>(
1020 trait_def: &'a TypeDef,
1021 bridge_config: &'a TraitBridgeConfig,
1022 wrapper_prefix: &'a str,
1023 type_paths: HashMap<String, String>,
1024 ) -> TraitBridgeSpec<'a> {
1025 TraitBridgeSpec {
1026 trait_def,
1027 bridge_config,
1028 core_import: "mylib",
1029 wrapper_prefix,
1030 type_paths,
1031 error_type: "MyError".to_string(),
1032 error_constructor: "MyError::from({msg})".to_string(),
1033 }
1034 }
1035
1036 struct MockBridgeGenerator;
1041
1042 impl TraitBridgeGenerator for MockBridgeGenerator {
1043 fn foreign_object_type(&self) -> &str {
1044 "Py<PyAny>"
1045 }
1046
1047 fn bridge_imports(&self) -> Vec<String> {
1048 vec!["pyo3::prelude::*".to_string(), "pyo3::types::PyString".to_string()]
1049 }
1050
1051 fn gen_sync_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
1052 format!("// sync body for {}", method.name)
1053 }
1054
1055 fn gen_async_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
1056 format!("// async body for {}", method.name)
1057 }
1058
1059 fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String {
1060 format!(
1061 "impl {} {{\n pub fn new(obj: Py<PyAny>) -> Self {{ Self {{ inner: obj, cached_name: String::new() }} }}\n}}",
1062 spec.wrapper_name()
1063 )
1064 }
1065
1066 fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String {
1067 let fn_name = spec.bridge_config.register_fn.as_deref().unwrap_or("register");
1068 format!("pub fn {fn_name}(obj: Py<PyAny>) {{ /* register */ }}")
1069 }
1070 }
1071
1072 #[test]
1077 fn test_wrapper_name() {
1078 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1079 let config = make_trait_bridge_config(None, None);
1080 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1081 assert_eq!(spec.wrapper_name(), "PyOcrBackendBridge");
1082 }
1083
1084 #[test]
1085 fn test_trait_snake() {
1086 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1087 let config = make_trait_bridge_config(None, None);
1088 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1089 assert_eq!(spec.trait_snake(), "ocr_backend");
1090 }
1091
1092 #[test]
1093 fn test_trait_path_replaces_hyphens() {
1094 let trait_def = make_type_def("OcrBackend", "my-lib::OcrBackend", vec![]);
1095 let config = make_trait_bridge_config(None, None);
1096 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1097 assert_eq!(spec.trait_path(), "my_lib::OcrBackend");
1098 }
1099
1100 #[test]
1101 fn test_required_methods_filters_no_default_impl() {
1102 let methods = vec![
1103 make_method("process", vec![], TypeRef::String, false, false, None, None),
1104 make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
1105 make_method("detect", vec![], TypeRef::String, false, false, None, None),
1106 ];
1107 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1108 let config = make_trait_bridge_config(None, None);
1109 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1110 let required = spec.required_methods();
1111 assert_eq!(required.len(), 2);
1112 assert!(required.iter().any(|m| m.name == "process"));
1113 assert!(required.iter().any(|m| m.name == "detect"));
1114 }
1115
1116 #[test]
1117 fn test_optional_methods_filters_has_default_impl() {
1118 let methods = vec![
1119 make_method("process", vec![], TypeRef::String, false, false, None, None),
1120 make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
1121 make_method("shutdown", vec![], TypeRef::Unit, false, true, None, None),
1122 ];
1123 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1124 let config = make_trait_bridge_config(None, None);
1125 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1126 let optional = spec.optional_methods();
1127 assert_eq!(optional.len(), 2);
1128 assert!(optional.iter().any(|m| m.name == "initialize"));
1129 assert!(optional.iter().any(|m| m.name == "shutdown"));
1130 }
1131
1132 #[test]
1133 fn test_error_path() {
1134 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1135 let config = make_trait_bridge_config(None, None);
1136 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1137 assert_eq!(spec.error_path(), "mylib::MyError");
1138 }
1139
1140 #[test]
1145 fn test_format_type_ref_primitives() {
1146 let paths = HashMap::new();
1147 let cases: Vec<(TypeRef, &str)> = vec![
1148 (TypeRef::Primitive(PrimitiveType::Bool), "bool"),
1149 (TypeRef::Primitive(PrimitiveType::U8), "u8"),
1150 (TypeRef::Primitive(PrimitiveType::U16), "u16"),
1151 (TypeRef::Primitive(PrimitiveType::U32), "u32"),
1152 (TypeRef::Primitive(PrimitiveType::U64), "u64"),
1153 (TypeRef::Primitive(PrimitiveType::I8), "i8"),
1154 (TypeRef::Primitive(PrimitiveType::I16), "i16"),
1155 (TypeRef::Primitive(PrimitiveType::I32), "i32"),
1156 (TypeRef::Primitive(PrimitiveType::I64), "i64"),
1157 (TypeRef::Primitive(PrimitiveType::F32), "f32"),
1158 (TypeRef::Primitive(PrimitiveType::F64), "f64"),
1159 (TypeRef::Primitive(PrimitiveType::Usize), "usize"),
1160 (TypeRef::Primitive(PrimitiveType::Isize), "isize"),
1161 ];
1162 for (ty, expected) in cases {
1163 assert_eq!(format_type_ref(&ty, &paths), expected, "mismatch for {expected}");
1164 }
1165 }
1166
1167 #[test]
1168 fn test_format_type_ref_string() {
1169 assert_eq!(format_type_ref(&TypeRef::String, &HashMap::new()), "String");
1170 }
1171
1172 #[test]
1173 fn test_format_type_ref_char() {
1174 assert_eq!(format_type_ref(&TypeRef::Char, &HashMap::new()), "char");
1175 }
1176
1177 #[test]
1178 fn test_format_type_ref_bytes() {
1179 assert_eq!(format_type_ref(&TypeRef::Bytes, &HashMap::new()), "Vec<u8>");
1180 }
1181
1182 #[test]
1183 fn test_format_type_ref_path() {
1184 assert_eq!(format_type_ref(&TypeRef::Path, &HashMap::new()), "std::path::PathBuf");
1185 }
1186
1187 #[test]
1188 fn test_format_type_ref_unit() {
1189 assert_eq!(format_type_ref(&TypeRef::Unit, &HashMap::new()), "()");
1190 }
1191
1192 #[test]
1193 fn test_format_type_ref_json() {
1194 assert_eq!(format_type_ref(&TypeRef::Json, &HashMap::new()), "serde_json::Value");
1195 }
1196
1197 #[test]
1198 fn test_format_type_ref_duration() {
1199 assert_eq!(
1200 format_type_ref(&TypeRef::Duration, &HashMap::new()),
1201 "std::time::Duration"
1202 );
1203 }
1204
1205 #[test]
1206 fn test_format_type_ref_optional() {
1207 let ty = TypeRef::Optional(Box::new(TypeRef::String));
1208 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<String>");
1209 }
1210
1211 #[test]
1212 fn test_format_type_ref_optional_nested() {
1213 let ty = TypeRef::Optional(Box::new(TypeRef::Optional(Box::new(TypeRef::Primitive(
1214 PrimitiveType::U32,
1215 )))));
1216 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<Option<u32>>");
1217 }
1218
1219 #[test]
1220 fn test_format_type_ref_vec() {
1221 let ty = TypeRef::Vec(Box::new(TypeRef::Primitive(PrimitiveType::U8)));
1222 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<u8>");
1223 }
1224
1225 #[test]
1226 fn test_format_type_ref_vec_nested() {
1227 let ty = TypeRef::Vec(Box::new(TypeRef::Vec(Box::new(TypeRef::String))));
1228 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<Vec<String>>");
1229 }
1230
1231 #[test]
1232 fn test_format_type_ref_map() {
1233 let ty = TypeRef::Map(
1234 Box::new(TypeRef::String),
1235 Box::new(TypeRef::Primitive(PrimitiveType::I64)),
1236 );
1237 assert_eq!(
1238 format_type_ref(&ty, &HashMap::new()),
1239 "std::collections::HashMap<String, i64>"
1240 );
1241 }
1242
1243 #[test]
1244 fn test_format_type_ref_map_nested_value() {
1245 let ty = TypeRef::Map(
1246 Box::new(TypeRef::String),
1247 Box::new(TypeRef::Vec(Box::new(TypeRef::String))),
1248 );
1249 assert_eq!(
1250 format_type_ref(&ty, &HashMap::new()),
1251 "std::collections::HashMap<String, Vec<String>>"
1252 );
1253 }
1254
1255 #[test]
1256 fn test_format_type_ref_named_without_type_paths() {
1257 let ty = TypeRef::Named("Config".to_string());
1258 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Config");
1259 }
1260
1261 #[test]
1262 fn test_format_type_ref_named_with_type_paths() {
1263 let ty = TypeRef::Named("Config".to_string());
1264 let mut paths = HashMap::new();
1265 paths.insert("Config".to_string(), "mylib::Config".to_string());
1266 assert_eq!(format_type_ref(&ty, &paths), "mylib::Config");
1267 }
1268
1269 #[test]
1270 fn test_format_type_ref_named_not_in_type_paths_falls_back_to_name() {
1271 let ty = TypeRef::Named("Unknown".to_string());
1272 let mut paths = HashMap::new();
1273 paths.insert("Other".to_string(), "mylib::Other".to_string());
1274 assert_eq!(format_type_ref(&ty, &paths), "Unknown");
1275 }
1276
1277 #[test]
1282 fn test_format_param_type_string_ref() {
1283 let param = make_param("input", TypeRef::String, true);
1284 assert_eq!(format_param_type(¶m, &HashMap::new()), "&str");
1285 }
1286
1287 #[test]
1288 fn test_format_param_type_string_owned() {
1289 let param = make_param("input", TypeRef::String, false);
1290 assert_eq!(format_param_type(¶m, &HashMap::new()), "String");
1291 }
1292
1293 #[test]
1294 fn test_format_param_type_bytes_ref() {
1295 let param = make_param("data", TypeRef::Bytes, true);
1296 assert_eq!(format_param_type(¶m, &HashMap::new()), "&[u8]");
1297 }
1298
1299 #[test]
1300 fn test_format_param_type_bytes_owned() {
1301 let param = make_param("data", TypeRef::Bytes, false);
1302 assert_eq!(format_param_type(¶m, &HashMap::new()), "Vec<u8>");
1303 }
1304
1305 #[test]
1306 fn test_format_param_type_path_ref() {
1307 let param = make_param("path", TypeRef::Path, true);
1308 assert_eq!(format_param_type(¶m, &HashMap::new()), "&std::path::Path");
1309 }
1310
1311 #[test]
1312 fn test_format_param_type_path_owned() {
1313 let param = make_param("path", TypeRef::Path, false);
1314 assert_eq!(format_param_type(¶m, &HashMap::new()), "std::path::PathBuf");
1315 }
1316
1317 #[test]
1318 fn test_format_param_type_vec_ref() {
1319 let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), true);
1320 assert_eq!(format_param_type(¶m, &HashMap::new()), "&[String]");
1321 }
1322
1323 #[test]
1324 fn test_format_param_type_vec_owned() {
1325 let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), false);
1326 assert_eq!(format_param_type(¶m, &HashMap::new()), "Vec<String>");
1327 }
1328
1329 #[test]
1330 fn test_format_param_type_named_ref_with_type_paths() {
1331 let mut paths = HashMap::new();
1332 paths.insert("Config".to_string(), "mylib::Config".to_string());
1333 let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
1334 assert_eq!(format_param_type(¶m, &paths), "&mylib::Config");
1335 }
1336
1337 #[test]
1338 fn test_format_param_type_named_ref_without_type_paths() {
1339 let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
1340 assert_eq!(format_param_type(¶m, &HashMap::new()), "&Config");
1341 }
1342
1343 #[test]
1344 fn test_format_param_type_primitive_ref_passes_by_value() {
1345 let param = make_param("count", TypeRef::Primitive(PrimitiveType::U32), true);
1347 assert_eq!(format_param_type(¶m, &HashMap::new()), "u32");
1348 }
1349
1350 #[test]
1351 fn test_format_param_type_unit_ref_passes_by_value() {
1352 let param = make_param("nothing", TypeRef::Unit, true);
1353 assert_eq!(format_param_type(¶m, &HashMap::new()), "()");
1354 }
1355
1356 #[test]
1361 fn test_format_return_type_without_error() {
1362 let result = format_return_type(&TypeRef::String, None, &HashMap::new(), false);
1363 assert_eq!(result, "String");
1364 }
1365
1366 #[test]
1367 fn test_format_return_type_with_error() {
1368 let result = format_return_type(&TypeRef::String, Some("MyError"), &HashMap::new(), false);
1369 assert_eq!(result, "std::result::Result<String, MyError>");
1370 }
1371
1372 #[test]
1373 fn test_format_return_type_unit_with_error() {
1374 let result = format_return_type(
1375 &TypeRef::Unit,
1376 Some("Box<dyn std::error::Error>"),
1377 &HashMap::new(),
1378 false,
1379 );
1380 assert_eq!(result, "std::result::Result<(), Box<dyn std::error::Error>>");
1381 }
1382
1383 #[test]
1384 fn test_format_return_type_named_with_type_paths_and_error() {
1385 let mut paths = HashMap::new();
1386 paths.insert("Output".to_string(), "mylib::Output".to_string());
1387 let result = format_return_type(
1388 &TypeRef::Named("Output".to_string()),
1389 Some("mylib::MyError"),
1390 &paths,
1391 false,
1392 );
1393 assert_eq!(result, "std::result::Result<mylib::Output, mylib::MyError>");
1394 }
1395
1396 #[test]
1397 fn test_format_return_type_vec_string_with_returns_ref() {
1398 let result = format_return_type(&TypeRef::Vec(Box::new(TypeRef::String)), None, &HashMap::new(), true);
1402 assert_eq!(result, "&[&str]", "Vec<String> + returns_ref must yield &[&str]");
1403 }
1404
1405 #[test]
1406 fn test_format_return_type_vec_no_returns_ref_unchanged() {
1407 let result = format_return_type(&TypeRef::Vec(Box::new(TypeRef::String)), None, &HashMap::new(), false);
1409 assert_eq!(
1410 result, "Vec<String>",
1411 "Vec<String> without returns_ref must stay Vec<String>"
1412 );
1413 }
1414
1415 #[test]
1420 fn test_gen_bridge_wrapper_struct_contains_struct_name() {
1421 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1422 let config = make_trait_bridge_config(None, None);
1423 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1424 let generator = MockBridgeGenerator;
1425 let result = gen_bridge_wrapper_struct(&spec, &generator);
1426 assert!(
1427 result.contains("pub struct PyOcrBackendBridge"),
1428 "missing struct declaration in:\n{result}"
1429 );
1430 }
1431
1432 #[test]
1433 fn test_gen_bridge_wrapper_struct_contains_inner_field() {
1434 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1435 let config = make_trait_bridge_config(None, None);
1436 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1437 let generator = MockBridgeGenerator;
1438 let result = gen_bridge_wrapper_struct(&spec, &generator);
1439 assert!(result.contains("inner: Py<PyAny>"), "missing inner field in:\n{result}");
1440 }
1441
1442 #[test]
1443 fn test_gen_bridge_wrapper_struct_contains_cached_name() {
1444 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1445 let config = make_trait_bridge_config(None, None);
1446 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1447 let generator = MockBridgeGenerator;
1448 let result = gen_bridge_wrapper_struct(&spec, &generator);
1449 assert!(
1450 result.contains("cached_name: String"),
1451 "missing cached_name field in:\n{result}"
1452 );
1453 }
1454
1455 #[test]
1460 fn test_gen_bridge_plugin_impl_returns_none_when_no_super_trait() {
1461 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1462 let config = make_trait_bridge_config(None, None);
1463 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1464 let generator = MockBridgeGenerator;
1465 assert!(gen_bridge_plugin_impl(&spec, &generator).is_none());
1466 }
1467
1468 #[test]
1469 fn test_gen_bridge_plugin_impl_returns_some_when_super_trait_configured() {
1470 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1471 let config = make_trait_bridge_config(Some("Plugin"), None);
1472 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1473 let generator = MockBridgeGenerator;
1474 assert!(gen_bridge_plugin_impl(&spec, &generator).is_some());
1475 }
1476
1477 #[test]
1478 fn test_gen_bridge_plugin_impl_uses_qualified_super_trait_path() {
1479 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1480 let config = make_trait_bridge_config(Some("Plugin"), None);
1481 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1482 let generator = MockBridgeGenerator;
1483 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1484 assert!(
1485 result.contains("impl mylib::Plugin for PyOcrBackendBridge"),
1486 "missing qualified super-trait path in:\n{result}"
1487 );
1488 }
1489
1490 #[test]
1491 fn test_gen_bridge_plugin_impl_uses_already_qualified_super_trait_path() {
1492 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1493 let config = make_trait_bridge_config(Some("other_crate::Plugin"), None);
1494 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1495 let generator = MockBridgeGenerator;
1496 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1497 assert!(
1498 result.contains("impl other_crate::Plugin for PyOcrBackendBridge"),
1499 "wrong super-trait path in:\n{result}"
1500 );
1501 }
1502
1503 #[test]
1504 fn test_gen_bridge_plugin_impl_contains_name_fn() {
1505 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1506 let config = make_trait_bridge_config(Some("Plugin"), None);
1507 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1508 let generator = MockBridgeGenerator;
1509 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1510 assert!(
1511 result.contains("fn name(") && result.contains("cached_name"),
1512 "missing name() using cached_name in:\n{result}"
1513 );
1514 }
1515
1516 #[test]
1517 fn test_gen_bridge_plugin_impl_contains_version_fn() {
1518 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1519 let config = make_trait_bridge_config(Some("Plugin"), None);
1520 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1521 let generator = MockBridgeGenerator;
1522 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1523 assert!(result.contains("fn version("), "missing version() in:\n{result}");
1524 }
1525
1526 #[test]
1527 fn test_gen_bridge_plugin_impl_contains_initialize_fn() {
1528 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1529 let config = make_trait_bridge_config(Some("Plugin"), None);
1530 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1531 let generator = MockBridgeGenerator;
1532 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1533 assert!(result.contains("fn initialize("), "missing initialize() in:\n{result}");
1534 }
1535
1536 #[test]
1537 fn test_gen_bridge_plugin_impl_contains_shutdown_fn() {
1538 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1539 let config = make_trait_bridge_config(Some("Plugin"), None);
1540 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1541 let generator = MockBridgeGenerator;
1542 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1543 assert!(result.contains("fn shutdown("), "missing shutdown() in:\n{result}");
1544 }
1545
1546 #[test]
1551 fn test_gen_bridge_trait_impl_includes_impl_header() {
1552 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1553 let config = make_trait_bridge_config(None, None);
1554 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1555 let generator = MockBridgeGenerator;
1556 let result = gen_bridge_trait_impl(&spec, &generator);
1557 assert!(
1558 result.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
1559 "missing impl header in:\n{result}"
1560 );
1561 }
1562
1563 #[test]
1564 fn test_gen_bridge_trait_impl_includes_method_signatures() {
1565 let methods = vec![make_method(
1566 "process",
1567 vec![],
1568 TypeRef::String,
1569 false,
1570 false,
1571 None,
1572 None,
1573 )];
1574 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1575 let config = make_trait_bridge_config(None, None);
1576 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1577 let generator = MockBridgeGenerator;
1578 let result = gen_bridge_trait_impl(&spec, &generator);
1579 assert!(result.contains("fn process("), "missing method signature in:\n{result}");
1580 }
1581
1582 #[test]
1583 fn test_gen_bridge_trait_impl_includes_method_body_from_generator() {
1584 let methods = vec![make_method(
1585 "process",
1586 vec![],
1587 TypeRef::String,
1588 false,
1589 false,
1590 None,
1591 None,
1592 )];
1593 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1594 let config = make_trait_bridge_config(None, None);
1595 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1596 let generator = MockBridgeGenerator;
1597 let result = gen_bridge_trait_impl(&spec, &generator);
1598 assert!(
1599 result.contains("// sync body for process"),
1600 "missing sync method body in:\n{result}"
1601 );
1602 }
1603
1604 #[test]
1605 fn test_gen_bridge_trait_impl_async_method_uses_async_body() {
1606 let methods = vec![make_method(
1607 "process_async",
1608 vec![],
1609 TypeRef::String,
1610 true,
1611 false,
1612 None,
1613 None,
1614 )];
1615 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1616 let config = make_trait_bridge_config(None, None);
1617 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1618 let generator = MockBridgeGenerator;
1619 let result = gen_bridge_trait_impl(&spec, &generator);
1620 assert!(
1621 result.contains("// async body for process_async"),
1622 "missing async method body in:\n{result}"
1623 );
1624 assert!(
1625 result.contains("async fn process_async("),
1626 "missing async keyword in method signature in:\n{result}"
1627 );
1628 }
1629
1630 #[test]
1631 fn test_gen_bridge_trait_impl_filters_trait_source_methods() {
1632 let methods = vec![
1634 make_method("own_method", vec![], TypeRef::String, false, false, None, None),
1635 make_method(
1636 "inherited_method",
1637 vec![],
1638 TypeRef::String,
1639 false,
1640 false,
1641 Some("other_crate::OtherTrait"),
1642 None,
1643 ),
1644 ];
1645 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1646 let config = make_trait_bridge_config(None, None);
1647 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1648 let generator = MockBridgeGenerator;
1649 let result = gen_bridge_trait_impl(&spec, &generator);
1650 assert!(
1651 result.contains("fn own_method("),
1652 "own method should be present in:\n{result}"
1653 );
1654 assert!(
1655 !result.contains("fn inherited_method("),
1656 "inherited method should be filtered out in:\n{result}"
1657 );
1658 }
1659
1660 #[test]
1661 fn test_gen_bridge_trait_impl_method_with_params() {
1662 let params = vec![
1663 make_param("input", TypeRef::String, true),
1664 make_param("count", TypeRef::Primitive(PrimitiveType::U32), false),
1665 ];
1666 let methods = vec![make_method(
1667 "process",
1668 params,
1669 TypeRef::String,
1670 false,
1671 false,
1672 None,
1673 None,
1674 )];
1675 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1676 let config = make_trait_bridge_config(None, None);
1677 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1678 let generator = MockBridgeGenerator;
1679 let result = gen_bridge_trait_impl(&spec, &generator);
1680 assert!(result.contains("input: &str"), "missing &str param in:\n{result}");
1681 assert!(result.contains("count: u32"), "missing u32 param in:\n{result}");
1682 }
1683
1684 #[test]
1685 fn test_gen_bridge_trait_impl_return_type_with_error() {
1686 let methods = vec![make_method(
1687 "process",
1688 vec![],
1689 TypeRef::String,
1690 false,
1691 false,
1692 None,
1693 Some("MyError"),
1694 )];
1695 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1696 let config = make_trait_bridge_config(None, None);
1697 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1698 let generator = MockBridgeGenerator;
1699 let result = gen_bridge_trait_impl(&spec, &generator);
1700 assert!(
1701 result.contains("-> std::result::Result<String, mylib::MyError>"),
1702 "missing std::result::Result return type in:\n{result}"
1703 );
1704 }
1705
1706 #[test]
1711 fn test_gen_bridge_registration_fn_returns_none_without_register_fn() {
1712 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1713 let config = make_trait_bridge_config(None, None);
1714 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1715 let generator = MockBridgeGenerator;
1716 assert!(gen_bridge_registration_fn(&spec, &generator).is_none());
1717 }
1718
1719 #[test]
1720 fn test_gen_bridge_registration_fn_returns_some_with_register_fn() {
1721 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1722 let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
1723 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1724 let generator = MockBridgeGenerator;
1725 let result = gen_bridge_registration_fn(&spec, &generator);
1726 assert!(result.is_some());
1727 let code = result.unwrap();
1728 assert!(
1729 code.contains("register_ocr_backend"),
1730 "missing register fn name in:\n{code}"
1731 );
1732 }
1733
1734 #[test]
1739 fn test_gen_bridge_all_includes_imports() {
1740 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1741 let config = make_trait_bridge_config(None, None);
1742 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1743 let generator = MockBridgeGenerator;
1744 let output = gen_bridge_all(&spec, &generator);
1745 assert!(output.imports.contains(&"pyo3::prelude::*".to_string()));
1746 assert!(output.imports.contains(&"pyo3::types::PyString".to_string()));
1747 }
1748
1749 #[test]
1750 fn test_gen_bridge_all_includes_wrapper_struct() {
1751 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1752 let config = make_trait_bridge_config(None, None);
1753 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1754 let generator = MockBridgeGenerator;
1755 let output = gen_bridge_all(&spec, &generator);
1756 assert!(
1757 output.code.contains("pub struct PyOcrBackendBridge"),
1758 "missing struct in:\n{}",
1759 output.code
1760 );
1761 }
1762
1763 #[test]
1764 fn test_gen_bridge_all_includes_constructor() {
1765 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1766 let config = make_trait_bridge_config(None, None);
1767 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1768 let generator = MockBridgeGenerator;
1769 let output = gen_bridge_all(&spec, &generator);
1770 assert!(
1771 output.code.contains("pub fn new("),
1772 "missing constructor in:\n{}",
1773 output.code
1774 );
1775 }
1776
1777 #[test]
1778 fn test_gen_bridge_all_includes_trait_impl() {
1779 let methods = vec![make_method(
1780 "process",
1781 vec![],
1782 TypeRef::String,
1783 false,
1784 false,
1785 None,
1786 None,
1787 )];
1788 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1789 let config = make_trait_bridge_config(None, None);
1790 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1791 let generator = MockBridgeGenerator;
1792 let output = gen_bridge_all(&spec, &generator);
1793 assert!(
1794 output.code.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
1795 "missing trait impl in:\n{}",
1796 output.code
1797 );
1798 }
1799
1800 #[test]
1801 fn test_gen_bridge_all_includes_plugin_impl_when_super_trait_set() {
1802 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1803 let config = make_trait_bridge_config(Some("Plugin"), None);
1804 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1805 let generator = MockBridgeGenerator;
1806 let output = gen_bridge_all(&spec, &generator);
1807 assert!(
1808 output.code.contains("impl mylib::Plugin for PyOcrBackendBridge"),
1809 "missing plugin impl in:\n{}",
1810 output.code
1811 );
1812 }
1813
1814 #[test]
1815 fn test_gen_bridge_all_no_plugin_impl_when_no_super_trait() {
1816 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1817 let config = make_trait_bridge_config(None, None);
1818 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1819 let generator = MockBridgeGenerator;
1820 let output = gen_bridge_all(&spec, &generator);
1821 assert!(
1822 !output.code.contains("fn name(") || !output.code.contains("cached_name"),
1823 "unexpected plugin impl present without super_trait"
1824 );
1825 }
1826
1827 #[test]
1828 fn test_gen_bridge_all_includes_registration_fn_when_configured() {
1829 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1830 let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
1831 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1832 let generator = MockBridgeGenerator;
1833 let output = gen_bridge_all(&spec, &generator);
1834 assert!(
1835 output.code.contains("register_ocr_backend"),
1836 "missing registration fn in:\n{}",
1837 output.code
1838 );
1839 }
1840
1841 #[test]
1842 fn test_gen_bridge_all_no_registration_fn_when_absent() {
1843 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1844 let config = make_trait_bridge_config(None, None);
1845 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1846 let generator = MockBridgeGenerator;
1847 let output = gen_bridge_all(&spec, &generator);
1848 assert!(
1849 !output.code.contains("register_ocr_backend"),
1850 "unexpected registration fn present:\n{}",
1851 output.code
1852 );
1853 }
1854
1855 #[test]
1856 fn test_gen_bridge_all_ordering_struct_before_trait_impl() {
1857 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1858 let config = make_trait_bridge_config(None, None);
1859 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1860 let generator = MockBridgeGenerator;
1861 let output = gen_bridge_all(&spec, &generator);
1862 let struct_pos = output.code.find("pub struct PyOcrBackendBridge").unwrap();
1863 let impl_pos = output
1864 .code
1865 .find("impl mylib::OcrBackend for PyOcrBackendBridge")
1866 .unwrap();
1867 assert!(struct_pos < impl_pos, "struct should appear before trait impl");
1868 }
1869
1870 fn make_bridge(
1875 type_alias: Option<&str>,
1876 param_name: Option<&str>,
1877 bind_via: BridgeBinding,
1878 options_type: Option<&str>,
1879 options_field: Option<&str>,
1880 ) -> TraitBridgeConfig {
1881 TraitBridgeConfig {
1882 trait_name: "HtmlVisitor".to_string(),
1883 super_trait: None,
1884 registry_getter: None,
1885 register_fn: None,
1886 unregister_fn: None,
1887 clear_fn: None,
1888 type_alias: type_alias.map(str::to_string),
1889 param_name: param_name.map(str::to_string),
1890 register_extra_args: None,
1891 exclude_languages: vec![],
1892 bind_via,
1893 options_type: options_type.map(str::to_string),
1894 options_field: options_field.map(str::to_string),
1895 }
1896 }
1897
1898 #[test]
1899 fn find_bridge_param_returns_first_param_match_in_function_param_mode() {
1900 let func = make_func(
1901 "convert",
1902 vec![
1903 make_param("html", TypeRef::String, true),
1904 make_param("visitor", TypeRef::Named("VisitorHandle".to_string()), false),
1905 ],
1906 );
1907 let bridges = vec![make_bridge(
1908 Some("VisitorHandle"),
1909 Some("visitor"),
1910 BridgeBinding::FunctionParam,
1911 None,
1912 None,
1913 )];
1914 let result = find_bridge_param(&func, &bridges).expect("bridge match");
1915 assert_eq!(result.0, 1);
1916 }
1917
1918 #[test]
1919 fn find_bridge_param_skips_options_field_bridges() {
1920 let func = make_func(
1921 "convert",
1922 vec![
1923 make_param("html", TypeRef::String, true),
1924 make_param("visitor", TypeRef::Named("VisitorHandle".to_string()), false),
1925 ],
1926 );
1927 let bridges = vec![make_bridge(
1928 Some("VisitorHandle"),
1929 Some("visitor"),
1930 BridgeBinding::OptionsField,
1931 Some("ConversionOptions"),
1932 Some("visitor"),
1933 )];
1934 assert!(
1935 find_bridge_param(&func, &bridges).is_none(),
1936 "bridges configured with bind_via=options_field must not be returned by find_bridge_param"
1937 );
1938 }
1939
1940 #[test]
1941 fn find_bridge_field_detects_field_via_alias() {
1942 let opts_type = TypeDef {
1943 name: "ConversionOptions".to_string(),
1944 rust_path: "mylib::ConversionOptions".to_string(),
1945 original_rust_path: String::new(),
1946 fields: vec![
1947 make_field("debug", TypeRef::Primitive(PrimitiveType::Bool)),
1948 make_field(
1949 "visitor",
1950 TypeRef::Optional(Box::new(TypeRef::Named("VisitorHandle".to_string()))),
1951 ),
1952 ],
1953 methods: vec![],
1954 is_opaque: false,
1955 is_clone: true,
1956 is_copy: false,
1957 doc: String::new(),
1958 cfg: None,
1959 is_trait: false,
1960 has_default: true,
1961 has_stripped_cfg_fields: false,
1962 is_return_type: false,
1963 serde_rename_all: None,
1964 has_serde: false,
1965 super_traits: vec![],
1966 };
1967 let func = make_func(
1968 "convert",
1969 vec![
1970 make_param("html", TypeRef::String, true),
1971 make_param(
1972 "options",
1973 TypeRef::Optional(Box::new(TypeRef::Named("ConversionOptions".to_string()))),
1974 false,
1975 ),
1976 ],
1977 );
1978 let bridges = vec![make_bridge(
1979 Some("VisitorHandle"),
1980 Some("visitor"),
1981 BridgeBinding::OptionsField,
1982 Some("ConversionOptions"),
1983 None,
1984 )];
1985 let m = find_bridge_field(&func, std::slice::from_ref(&opts_type), &bridges).expect("bridge field match");
1986 assert_eq!(m.param_index, 1);
1987 assert_eq!(m.param_name, "options");
1988 assert_eq!(m.options_type, "ConversionOptions");
1989 assert!(m.param_is_optional);
1990 assert_eq!(m.field_name, "visitor");
1991 }
1992
1993 #[test]
1994 fn find_bridge_field_returns_none_for_function_param_bridge() {
1995 let opts_type = TypeDef {
1996 name: "ConversionOptions".to_string(),
1997 rust_path: "mylib::ConversionOptions".to_string(),
1998 original_rust_path: String::new(),
1999 fields: vec![make_field(
2000 "visitor",
2001 TypeRef::Optional(Box::new(TypeRef::Named("VisitorHandle".to_string()))),
2002 )],
2003 methods: vec![],
2004 is_opaque: false,
2005 is_clone: true,
2006 is_copy: false,
2007 doc: String::new(),
2008 cfg: None,
2009 is_trait: false,
2010 has_default: true,
2011 has_stripped_cfg_fields: false,
2012 is_return_type: false,
2013 serde_rename_all: None,
2014 has_serde: false,
2015 super_traits: vec![],
2016 };
2017 let func = make_func(
2018 "convert",
2019 vec![make_param(
2020 "options",
2021 TypeRef::Named("ConversionOptions".to_string()),
2022 false,
2023 )],
2024 );
2025 let bridges = vec![make_bridge(
2026 Some("VisitorHandle"),
2027 Some("visitor"),
2028 BridgeBinding::FunctionParam,
2029 None,
2030 None,
2031 )];
2032 assert!(find_bridge_field(&func, std::slice::from_ref(&opts_type), &bridges).is_none());
2033 }
2034}