1use alef_core::config::{BridgeBinding, TraitBridgeConfig};
9use alef_core::ir::{FieldDef, FunctionDef, MethodDef, ParamDef, PrimitiveType, TypeDef, TypeRef};
10use heck::ToSnakeCase;
11use std::collections::HashMap;
12
13pub struct TraitBridgeSpec<'a> {
15 pub trait_def: &'a TypeDef,
17 pub bridge_config: &'a TraitBridgeConfig,
19 pub core_import: &'a str,
21 pub wrapper_prefix: &'a str,
23 pub type_paths: HashMap<String, String>,
25 pub error_type: String,
27 pub error_constructor: String,
29}
30
31impl<'a> TraitBridgeSpec<'a> {
32 pub fn error_path(&self) -> String {
39 if self.error_type.contains("::") || self.error_type.contains('<') {
40 self.error_type.clone()
41 } else {
42 format!("{}::{}", self.core_import, self.error_type)
43 }
44 }
45
46 pub fn make_error(&self, msg_expr: &str) -> String {
48 self.error_constructor.replace("{msg}", msg_expr)
49 }
50
51 pub fn wrapper_name(&self) -> String {
53 format!("{}{}Bridge", self.wrapper_prefix, self.trait_def.name)
54 }
55
56 pub fn trait_snake(&self) -> String {
58 self.trait_def.name.to_snake_case()
59 }
60
61 pub fn trait_path(&self) -> String {
63 self.trait_def.rust_path.replace('-', "_")
64 }
65
66 pub fn required_methods(&self) -> Vec<&'a MethodDef> {
68 self.trait_def.methods.iter().filter(|m| !m.has_default_impl).collect()
69 }
70
71 pub fn optional_methods(&self) -> Vec<&'a MethodDef> {
73 self.trait_def.methods.iter().filter(|m| m.has_default_impl).collect()
74 }
75}
76
77pub trait TraitBridgeGenerator {
83 fn foreign_object_type(&self) -> &str;
85
86 fn bridge_imports(&self) -> Vec<String>;
88
89 fn gen_sync_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
94
95 fn gen_async_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
99
100 fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String;
105
106 fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String;
112
113 fn gen_unregistration_fn(&self, _spec: &TraitBridgeSpec) -> String {
120 String::new()
121 }
122
123 fn gen_clear_fn(&self, _spec: &TraitBridgeSpec) -> String {
130 String::new()
131 }
132
133 fn async_trait_is_send(&self) -> bool {
138 true
139 }
140}
141
142pub fn gen_bridge_wrapper_struct(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
156 let wrapper = spec.wrapper_name();
157 let foreign_type = generator.foreign_object_type();
158
159 crate::template_env::render(
160 "generators/trait_bridge/wrapper_struct.jinja",
161 minijinja::context! {
162 wrapper_prefix => spec.wrapper_prefix,
163 trait_name => &spec.trait_def.name,
164 wrapper_name => wrapper,
165 foreign_type => foreign_type,
166 },
167 )
168}
169
170fn gen_bridge_debug_impl(spec: &TraitBridgeSpec) -> String {
176 let wrapper = spec.wrapper_name();
177 format!(
178 "impl std::fmt::Debug for {wrapper} {{\n fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{\n write!(f, \"{wrapper}\")\n }}\n}}"
179 )
180}
181
182pub fn gen_bridge_plugin_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
190 let super_trait_name = spec.bridge_config.super_trait.as_deref()?;
191
192 let wrapper = spec.wrapper_name();
193 let core_import = spec.core_import;
194
195 let super_trait_path = if super_trait_name.contains("::") {
197 super_trait_name.to_string()
198 } else {
199 format!("{core_import}::{super_trait_name}")
200 };
201
202 let error_path = spec.error_path();
206
207 let version_method = MethodDef {
209 name: "version".to_string(),
210 params: vec![],
211 return_type: alef_core::ir::TypeRef::String,
212 is_async: false,
213 is_static: false,
214 error_type: None,
215 doc: String::new(),
216 receiver: Some(alef_core::ir::ReceiverKind::Ref),
217 sanitized: false,
218 trait_source: None,
219 returns_ref: false,
220 returns_cow: false,
221 return_newtype_wrapper: None,
222 has_default_impl: false,
223 };
224 let version_body = generator.gen_sync_method_body(&version_method, spec);
225
226 let init_method = MethodDef {
228 name: "initialize".to_string(),
229 params: vec![],
230 return_type: alef_core::ir::TypeRef::Unit,
231 is_async: false,
232 is_static: false,
233 error_type: Some(error_path.clone()),
234 doc: String::new(),
235 receiver: Some(alef_core::ir::ReceiverKind::Ref),
236 sanitized: false,
237 trait_source: None,
238 returns_ref: false,
239 returns_cow: false,
240 return_newtype_wrapper: None,
241 has_default_impl: true,
242 };
243 let init_body = generator.gen_sync_method_body(&init_method, spec);
244
245 let shutdown_method = MethodDef {
247 name: "shutdown".to_string(),
248 params: vec![],
249 return_type: alef_core::ir::TypeRef::Unit,
250 is_async: false,
251 is_static: false,
252 error_type: Some(error_path.clone()),
253 doc: String::new(),
254 receiver: Some(alef_core::ir::ReceiverKind::Ref),
255 sanitized: false,
256 trait_source: None,
257 returns_ref: false,
258 returns_cow: false,
259 return_newtype_wrapper: None,
260 has_default_impl: true,
261 };
262 let shutdown_body = generator.gen_sync_method_body(&shutdown_method, spec);
263
264 let version_lines: Vec<&str> = version_body.lines().collect();
266 let init_lines: Vec<&str> = init_body.lines().collect();
267 let shutdown_lines: Vec<&str> = shutdown_body.lines().collect();
268
269 Some(crate::template_env::render(
270 "generators/trait_bridge/plugin_impl.jinja",
271 minijinja::context! {
272 super_trait_path => super_trait_path,
273 wrapper_name => wrapper,
274 error_path => error_path,
275 version_lines => version_lines,
276 init_lines => init_lines,
277 shutdown_lines => shutdown_lines,
278 },
279 ))
280}
281
282pub fn gen_bridge_trait_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
287 let wrapper = spec.wrapper_name();
288 let trait_path = spec.trait_path();
289
290 let has_async_methods = spec
292 .trait_def
293 .methods
294 .iter()
295 .any(|m| m.is_async && m.trait_source.is_none());
296 let async_trait_is_send = generator.async_trait_is_send();
297
298 let own_methods: Vec<_> = spec
300 .trait_def
301 .methods
302 .iter()
303 .filter(|m| m.trait_source.is_none())
304 .collect();
305
306 let mut methods_code = String::with_capacity(1024);
308 for (i, method) in own_methods.iter().enumerate() {
309 if i > 0 {
310 methods_code.push_str("\n\n");
311 }
312
313 let async_kw = if method.is_async { "async " } else { "" };
315 let receiver = match &method.receiver {
316 Some(alef_core::ir::ReceiverKind::Ref) => "&self",
317 Some(alef_core::ir::ReceiverKind::RefMut) => "&mut self",
318 Some(alef_core::ir::ReceiverKind::Owned) => "self",
319 None => "",
320 };
321
322 let params: Vec<String> = method
324 .params
325 .iter()
326 .map(|p| format!("{}: {}", p.name, format_param_type(p, &spec.type_paths)))
327 .collect();
328
329 let all_params = if receiver.is_empty() {
330 params.join(", ")
331 } else if params.is_empty() {
332 receiver.to_string()
333 } else {
334 format!("{}, {}", receiver, params.join(", "))
335 };
336
337 let error_override = method.error_type.as_ref().map(|_| spec.error_path());
341 let ret = format_return_type(&method.return_type, error_override.as_deref(), &spec.type_paths);
342
343 let body = if method.is_async {
345 generator.gen_async_method_body(method, spec)
346 } else {
347 generator.gen_sync_method_body(method, spec)
348 };
349
350 let indented_body = body
352 .lines()
353 .map(|line| format!(" {line}"))
354 .collect::<Vec<_>>()
355 .join("\n");
356
357 methods_code.push_str(&crate::template_env::render(
358 "generators/trait_bridge/trait_method.jinja",
359 minijinja::context! {
360 async_kw => async_kw,
361 method_name => &method.name,
362 all_params => all_params,
363 ret => ret,
364 indented_body => &indented_body,
365 },
366 ));
367 }
368
369 crate::template_env::render(
370 "generators/trait_bridge/trait_impl.jinja",
371 minijinja::context! {
372 has_async_methods => has_async_methods,
373 async_trait_is_send => async_trait_is_send,
374 trait_path => trait_path,
375 wrapper_name => wrapper,
376 methods_code => methods_code,
377 },
378 )
379}
380
381pub fn gen_bridge_registration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
388 spec.bridge_config.register_fn.as_deref()?;
389 Some(generator.gen_registration_fn(spec))
390}
391
392pub fn gen_bridge_unregistration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
399 spec.bridge_config.unregister_fn.as_deref()?;
400 let body = generator.gen_unregistration_fn(spec);
401 if body.is_empty() { None } else { Some(body) }
402}
403
404pub fn gen_bridge_clear_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
411 spec.bridge_config.clear_fn.as_deref()?;
412 let body = generator.gen_clear_fn(spec);
413 if body.is_empty() { None } else { Some(body) }
414}
415
416pub fn host_function_path(spec: &TraitBridgeSpec, fn_name: &str) -> String {
434 if let Some(getter) = spec.bridge_config.registry_getter.as_deref() {
435 let last = getter.rsplit("::").next().unwrap_or("");
436 if let Some(sub) = last.strip_prefix("get_").and_then(|s| s.strip_suffix("_registry")) {
437 let prefix_end = getter.len() - last.len();
438 let prefix = &getter[..prefix_end];
439 let prefix = prefix.trim_end_matches("registry::");
440 return format!("{prefix}{sub}::{fn_name}");
441 }
442 }
443 format!("{}::plugins::{}", spec.core_import, fn_name)
444}
445
446pub struct BridgeOutput {
449 pub imports: Vec<String>,
451 pub code: String,
453}
454
455pub fn gen_bridge_all(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> BridgeOutput {
461 let imports = generator.bridge_imports();
462 let mut out = String::with_capacity(4096);
463
464 out.push_str(&gen_bridge_wrapper_struct(spec, generator));
466 out.push_str("\n\n");
467
468 out.push_str(&gen_bridge_debug_impl(spec));
470 out.push_str("\n\n");
471
472 out.push_str(&generator.gen_constructor(spec));
474 out.push_str("\n\n");
475
476 if let Some(plugin_impl) = gen_bridge_plugin_impl(spec, generator) {
478 out.push_str(&plugin_impl);
479 out.push_str("\n\n");
480 }
481
482 out.push_str(&gen_bridge_trait_impl(spec, generator));
484
485 if let Some(reg_fn_code) = gen_bridge_registration_fn(spec, generator) {
487 out.push_str("\n\n");
488 out.push_str(®_fn_code);
489 }
490
491 if let Some(unreg_fn_code) = gen_bridge_unregistration_fn(spec, generator) {
494 out.push_str("\n\n");
495 out.push_str(&unreg_fn_code);
496 }
497
498 if let Some(clear_fn_code) = gen_bridge_clear_fn(spec, generator) {
501 out.push_str("\n\n");
502 out.push_str(&clear_fn_code);
503 }
504
505 BridgeOutput { imports, code: out }
506}
507
508pub fn format_type_ref(ty: &alef_core::ir::TypeRef, type_paths: &HashMap<String, String>) -> String {
517 use alef_core::ir::{PrimitiveType, TypeRef};
518 match ty {
519 TypeRef::Primitive(p) => match p {
520 PrimitiveType::Bool => "bool",
521 PrimitiveType::U8 => "u8",
522 PrimitiveType::U16 => "u16",
523 PrimitiveType::U32 => "u32",
524 PrimitiveType::U64 => "u64",
525 PrimitiveType::I8 => "i8",
526 PrimitiveType::I16 => "i16",
527 PrimitiveType::I32 => "i32",
528 PrimitiveType::I64 => "i64",
529 PrimitiveType::F32 => "f32",
530 PrimitiveType::F64 => "f64",
531 PrimitiveType::Usize => "usize",
532 PrimitiveType::Isize => "isize",
533 }
534 .to_string(),
535 TypeRef::String => "String".to_string(),
536 TypeRef::Char => "char".to_string(),
537 TypeRef::Bytes => "Vec<u8>".to_string(),
538 TypeRef::Optional(inner) => format!("Option<{}>", format_type_ref(inner, type_paths)),
539 TypeRef::Vec(inner) => format!("Vec<{}>", format_type_ref(inner, type_paths)),
540 TypeRef::Map(k, v) => format!(
541 "std::collections::HashMap<{}, {}>",
542 format_type_ref(k, type_paths),
543 format_type_ref(v, type_paths)
544 ),
545 TypeRef::Named(name) => type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone()),
546 TypeRef::Path => "std::path::PathBuf".to_string(),
547 TypeRef::Unit => "()".to_string(),
548 TypeRef::Json => "serde_json::Value".to_string(),
549 TypeRef::Duration => "std::time::Duration".to_string(),
550 }
551}
552
553pub fn format_return_type(
555 ty: &alef_core::ir::TypeRef,
556 error_type: Option<&str>,
557 type_paths: &HashMap<String, String>,
558) -> String {
559 let inner = format_type_ref(ty, type_paths);
560 match error_type {
561 Some(err) => format!("std::result::Result<{inner}, {err}>"),
562 None => inner,
563 }
564}
565
566pub fn format_param_type(param: &ParamDef, type_paths: &HashMap<String, String>) -> String {
578 use alef_core::ir::TypeRef;
579 let base = if param.is_ref {
580 let mutability = if param.is_mut { "mut " } else { "" };
581 match ¶m.ty {
582 TypeRef::String => format!("&{mutability}str"),
583 TypeRef::Bytes => format!("&{mutability}[u8]"),
584 TypeRef::Path => format!("&{mutability}std::path::Path"),
585 TypeRef::Vec(inner) => format!("&{mutability}[{}]", format_type_ref(inner, type_paths)),
586 TypeRef::Named(name) => {
587 let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
588 format!("&{mutability}{qualified}")
589 }
590 TypeRef::Optional(inner) => {
591 let inner_type_str = match inner.as_ref() {
595 TypeRef::String => format!("&{mutability}str"),
596 TypeRef::Bytes => format!("&{mutability}[u8]"),
597 TypeRef::Path => format!("&{mutability}std::path::Path"),
598 TypeRef::Vec(v) => format!("&{mutability}[{}]", format_type_ref(v, type_paths)),
599 TypeRef::Named(name) => {
600 let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
601 format!("&{mutability}{qualified}")
602 }
603 other => format_type_ref(other, type_paths),
605 };
606 return format!("Option<{inner_type_str}>");
608 }
609 other => format_type_ref(other, type_paths),
611 }
612 } else {
613 format_type_ref(¶m.ty, type_paths)
614 };
615
616 if param.optional {
620 format!("Option<{base}>")
621 } else {
622 base
623 }
624}
625
626pub fn prim(p: &PrimitiveType) -> &'static str {
632 use PrimitiveType::*;
633 match p {
634 Bool => "bool",
635 U8 => "u8",
636 U16 => "u16",
637 U32 => "u32",
638 U64 => "u64",
639 I8 => "i8",
640 I16 => "i16",
641 I32 => "i32",
642 I64 => "i64",
643 F32 => "f32",
644 F64 => "f64",
645 Usize => "usize",
646 Isize => "isize",
647 }
648}
649
650pub fn bridge_param_type(ty: &TypeRef, ci: &str, is_ref: bool, tp: &HashMap<String, String>) -> String {
654 match ty {
655 TypeRef::Bytes if is_ref => "&[u8]".into(),
656 TypeRef::Bytes => "Vec<u8>".into(),
657 TypeRef::String if is_ref => "&str".into(),
658 TypeRef::String => "String".into(),
659 TypeRef::Path if is_ref => "&std::path::Path".into(),
660 TypeRef::Path => "std::path::PathBuf".into(),
661 TypeRef::Named(n) => {
662 let qualified = tp.get(n).cloned().unwrap_or_else(|| format!("{ci}::{n}"));
663 if is_ref { format!("&{qualified}") } else { qualified }
664 }
665 TypeRef::Vec(inner) => format!("Vec<{}>", bridge_param_type(inner, ci, false, tp)),
666 TypeRef::Optional(inner) => format!("Option<{}>", bridge_param_type(inner, ci, false, tp)),
667 TypeRef::Primitive(p) => prim(p).into(),
668 TypeRef::Unit => "()".into(),
669 TypeRef::Char => "char".into(),
670 TypeRef::Map(k, v) => format!(
671 "std::collections::HashMap<{}, {}>",
672 bridge_param_type(k, ci, false, tp),
673 bridge_param_type(v, ci, false, tp)
674 ),
675 TypeRef::Json => "serde_json::Value".into(),
676 TypeRef::Duration => "std::time::Duration".into(),
677 }
678}
679
680pub fn visitor_param_type(ty: &TypeRef, is_ref: bool, optional: bool, tp: &HashMap<String, String>) -> String {
686 if optional && matches!(ty, TypeRef::String) && is_ref {
687 return "Option<&str>".to_string();
688 }
689 if is_ref {
690 if let TypeRef::Vec(inner) = ty {
691 let inner_str = bridge_param_type(inner, "", false, tp);
692 return format!("&[{inner_str}]");
693 }
694 }
695 bridge_param_type(ty, "", is_ref, tp)
696}
697
698pub fn find_bridge_param<'a>(
705 func: &FunctionDef,
706 bridges: &'a [TraitBridgeConfig],
707) -> Option<(usize, &'a TraitBridgeConfig)> {
708 for (idx, param) in func.params.iter().enumerate() {
709 let named = match ¶m.ty {
710 TypeRef::Named(n) => Some(n.as_str()),
711 TypeRef::Optional(inner) => {
712 if let TypeRef::Named(n) = inner.as_ref() {
713 Some(n.as_str())
714 } else {
715 None
716 }
717 }
718 _ => None,
719 };
720 for bridge in bridges {
721 if bridge.bind_via != BridgeBinding::FunctionParam {
722 continue;
723 }
724 if let Some(type_name) = named {
725 if bridge.type_alias.as_deref() == Some(type_name) {
726 return Some((idx, bridge));
727 }
728 }
729 if bridge.param_name.as_deref() == Some(param.name.as_str()) {
730 return Some((idx, bridge));
731 }
732 }
733 }
734 None
735}
736
737#[derive(Debug, Clone)]
740pub struct BridgeFieldMatch<'a> {
741 pub param_index: usize,
743 pub param_name: String,
745 pub options_type: String,
747 pub param_is_optional: bool,
749 pub field_name: String,
751 pub field: &'a FieldDef,
753 pub bridge: &'a TraitBridgeConfig,
755}
756
757pub fn find_bridge_field<'a>(
768 func: &FunctionDef,
769 types: &'a [TypeDef],
770 bridges: &'a [TraitBridgeConfig],
771) -> Option<BridgeFieldMatch<'a>> {
772 fn unwrap_named(ty: &TypeRef) -> Option<(&str, bool)> {
773 match ty {
774 TypeRef::Named(n) => Some((n.as_str(), false)),
775 TypeRef::Optional(inner) => {
776 if let TypeRef::Named(n) = inner.as_ref() {
777 Some((n.as_str(), true))
778 } else {
779 None
780 }
781 }
782 _ => None,
783 }
784 }
785
786 for (idx, param) in func.params.iter().enumerate() {
787 let Some((type_name, is_optional)) = unwrap_named(¶m.ty) else {
788 continue;
789 };
790 let Some(type_def) = types.iter().find(|t| t.name == type_name) else {
791 continue;
792 };
793 for bridge in bridges {
794 if bridge.bind_via != BridgeBinding::OptionsField {
795 continue;
796 }
797 if bridge.options_type.as_deref() != Some(type_name) {
798 continue;
799 }
800 let field_name = bridge.resolved_options_field();
801 for field in &type_def.fields {
802 let matches_name = field_name.is_some_and(|n| field.name == n);
803 let matches_alias = bridge
804 .type_alias
805 .as_deref()
806 .is_some_and(|alias| field_type_matches_alias(&field.ty, alias));
807 if matches_name || matches_alias {
808 return Some(BridgeFieldMatch {
809 param_index: idx,
810 param_name: param.name.clone(),
811 options_type: type_name.to_string(),
812 param_is_optional: is_optional,
813 field_name: field.name.clone(),
814 field,
815 bridge,
816 });
817 }
818 }
819 }
820 }
821 None
822}
823
824fn field_type_matches_alias(field_ty: &TypeRef, alias: &str) -> bool {
827 match field_ty {
828 TypeRef::Named(n) => n == alias,
829 TypeRef::Optional(inner) | TypeRef::Vec(inner) => field_type_matches_alias(inner, alias),
830 _ => false,
831 }
832}
833
834pub fn to_camel_case(s: &str) -> String {
836 let mut result = String::new();
837 let mut capitalize_next = false;
838 for ch in s.chars() {
839 if ch == '_' {
840 capitalize_next = true;
841 } else if capitalize_next {
842 result.push(ch.to_ascii_uppercase());
843 capitalize_next = false;
844 } else {
845 result.push(ch);
846 }
847 }
848 result
849}
850
851#[cfg(test)]
852mod tests {
853 use super::*;
854 use alef_core::config::TraitBridgeConfig;
855 use alef_core::ir::{MethodDef, ParamDef, PrimitiveType, ReceiverKind, TypeDef, TypeRef};
856
857 fn make_trait_bridge_config(super_trait: Option<&str>, register_fn: Option<&str>) -> TraitBridgeConfig {
862 TraitBridgeConfig {
863 trait_name: "OcrBackend".to_string(),
864 super_trait: super_trait.map(str::to_string),
865 registry_getter: None,
866 register_fn: register_fn.map(str::to_string),
867 unregister_fn: None,
868 clear_fn: None,
869 type_alias: None,
870 param_name: None,
871 register_extra_args: None,
872 exclude_languages: Vec::new(),
873 bind_via: BridgeBinding::FunctionParam,
874 options_type: None,
875 options_field: None,
876 }
877 }
878
879 fn make_type_def(name: &str, rust_path: &str, methods: Vec<MethodDef>) -> TypeDef {
880 TypeDef {
881 name: name.to_string(),
882 rust_path: rust_path.to_string(),
883 original_rust_path: rust_path.to_string(),
884 fields: vec![],
885 methods,
886 is_opaque: true,
887 is_clone: false,
888 is_copy: false,
889 doc: String::new(),
890 cfg: None,
891 is_trait: true,
892 has_default: false,
893 has_stripped_cfg_fields: false,
894 is_return_type: false,
895 serde_rename_all: None,
896 has_serde: false,
897 super_traits: vec![],
898 }
899 }
900
901 fn make_method(
902 name: &str,
903 params: Vec<ParamDef>,
904 return_type: TypeRef,
905 is_async: bool,
906 has_default_impl: bool,
907 trait_source: Option<&str>,
908 error_type: Option<&str>,
909 ) -> MethodDef {
910 MethodDef {
911 name: name.to_string(),
912 params,
913 return_type,
914 is_async,
915 is_static: false,
916 error_type: error_type.map(str::to_string),
917 doc: String::new(),
918 receiver: Some(ReceiverKind::Ref),
919 sanitized: false,
920 trait_source: trait_source.map(str::to_string),
921 returns_ref: false,
922 returns_cow: false,
923 return_newtype_wrapper: None,
924 has_default_impl,
925 }
926 }
927
928 fn make_func(name: &str, params: Vec<ParamDef>) -> FunctionDef {
929 FunctionDef {
930 name: name.to_string(),
931 rust_path: format!("mylib::{name}"),
932 original_rust_path: String::new(),
933 params,
934 return_type: TypeRef::Unit,
935 is_async: false,
936 error_type: None,
937 doc: String::new(),
938 cfg: None,
939 sanitized: false,
940 return_sanitized: false,
941 returns_ref: false,
942 returns_cow: false,
943 return_newtype_wrapper: None,
944 }
945 }
946
947 fn make_field(name: &str, ty: TypeRef) -> FieldDef {
948 FieldDef {
949 name: name.to_string(),
950 ty,
951 optional: false,
952 default: None,
953 doc: String::new(),
954 sanitized: false,
955 is_boxed: false,
956 type_rust_path: None,
957 cfg: None,
958 typed_default: None,
959 core_wrapper: Default::default(),
960 vec_inner_core_wrapper: Default::default(),
961 newtype_wrapper: None,
962 serde_rename: None,
963 serde_flatten: false,
964 }
965 }
966
967 fn make_param(name: &str, ty: TypeRef, is_ref: bool) -> ParamDef {
968 ParamDef {
969 name: name.to_string(),
970 ty,
971 optional: false,
972 default: None,
973 sanitized: false,
974 typed_default: None,
975 is_ref,
976 is_mut: false,
977 newtype_wrapper: None,
978 original_type: None,
979 }
980 }
981
982 fn make_spec<'a>(
983 trait_def: &'a TypeDef,
984 bridge_config: &'a TraitBridgeConfig,
985 wrapper_prefix: &'a str,
986 type_paths: HashMap<String, String>,
987 ) -> TraitBridgeSpec<'a> {
988 TraitBridgeSpec {
989 trait_def,
990 bridge_config,
991 core_import: "mylib",
992 wrapper_prefix,
993 type_paths,
994 error_type: "MyError".to_string(),
995 error_constructor: "MyError::from({msg})".to_string(),
996 }
997 }
998
999 struct MockBridgeGenerator;
1004
1005 impl TraitBridgeGenerator for MockBridgeGenerator {
1006 fn foreign_object_type(&self) -> &str {
1007 "Py<PyAny>"
1008 }
1009
1010 fn bridge_imports(&self) -> Vec<String> {
1011 vec!["pyo3::prelude::*".to_string(), "pyo3::types::PyString".to_string()]
1012 }
1013
1014 fn gen_sync_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
1015 format!("// sync body for {}", method.name)
1016 }
1017
1018 fn gen_async_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
1019 format!("// async body for {}", method.name)
1020 }
1021
1022 fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String {
1023 format!(
1024 "impl {} {{\n pub fn new(obj: Py<PyAny>) -> Self {{ Self {{ inner: obj, cached_name: String::new() }} }}\n}}",
1025 spec.wrapper_name()
1026 )
1027 }
1028
1029 fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String {
1030 let fn_name = spec.bridge_config.register_fn.as_deref().unwrap_or("register");
1031 format!("pub fn {fn_name}(obj: Py<PyAny>) {{ /* register */ }}")
1032 }
1033 }
1034
1035 #[test]
1040 fn test_wrapper_name() {
1041 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1042 let config = make_trait_bridge_config(None, None);
1043 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1044 assert_eq!(spec.wrapper_name(), "PyOcrBackendBridge");
1045 }
1046
1047 #[test]
1048 fn test_trait_snake() {
1049 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1050 let config = make_trait_bridge_config(None, None);
1051 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1052 assert_eq!(spec.trait_snake(), "ocr_backend");
1053 }
1054
1055 #[test]
1056 fn test_trait_path_replaces_hyphens() {
1057 let trait_def = make_type_def("OcrBackend", "my-lib::OcrBackend", vec![]);
1058 let config = make_trait_bridge_config(None, None);
1059 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1060 assert_eq!(spec.trait_path(), "my_lib::OcrBackend");
1061 }
1062
1063 #[test]
1064 fn test_required_methods_filters_no_default_impl() {
1065 let methods = vec![
1066 make_method("process", vec![], TypeRef::String, false, false, None, None),
1067 make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
1068 make_method("detect", vec![], TypeRef::String, false, false, None, None),
1069 ];
1070 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1071 let config = make_trait_bridge_config(None, None);
1072 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1073 let required = spec.required_methods();
1074 assert_eq!(required.len(), 2);
1075 assert!(required.iter().any(|m| m.name == "process"));
1076 assert!(required.iter().any(|m| m.name == "detect"));
1077 }
1078
1079 #[test]
1080 fn test_optional_methods_filters_has_default_impl() {
1081 let methods = vec![
1082 make_method("process", vec![], TypeRef::String, false, false, None, None),
1083 make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
1084 make_method("shutdown", vec![], TypeRef::Unit, false, true, None, None),
1085 ];
1086 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1087 let config = make_trait_bridge_config(None, None);
1088 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1089 let optional = spec.optional_methods();
1090 assert_eq!(optional.len(), 2);
1091 assert!(optional.iter().any(|m| m.name == "initialize"));
1092 assert!(optional.iter().any(|m| m.name == "shutdown"));
1093 }
1094
1095 #[test]
1096 fn test_error_path() {
1097 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1098 let config = make_trait_bridge_config(None, None);
1099 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1100 assert_eq!(spec.error_path(), "mylib::MyError");
1101 }
1102
1103 #[test]
1108 fn test_format_type_ref_primitives() {
1109 let paths = HashMap::new();
1110 let cases: Vec<(TypeRef, &str)> = vec![
1111 (TypeRef::Primitive(PrimitiveType::Bool), "bool"),
1112 (TypeRef::Primitive(PrimitiveType::U8), "u8"),
1113 (TypeRef::Primitive(PrimitiveType::U16), "u16"),
1114 (TypeRef::Primitive(PrimitiveType::U32), "u32"),
1115 (TypeRef::Primitive(PrimitiveType::U64), "u64"),
1116 (TypeRef::Primitive(PrimitiveType::I8), "i8"),
1117 (TypeRef::Primitive(PrimitiveType::I16), "i16"),
1118 (TypeRef::Primitive(PrimitiveType::I32), "i32"),
1119 (TypeRef::Primitive(PrimitiveType::I64), "i64"),
1120 (TypeRef::Primitive(PrimitiveType::F32), "f32"),
1121 (TypeRef::Primitive(PrimitiveType::F64), "f64"),
1122 (TypeRef::Primitive(PrimitiveType::Usize), "usize"),
1123 (TypeRef::Primitive(PrimitiveType::Isize), "isize"),
1124 ];
1125 for (ty, expected) in cases {
1126 assert_eq!(format_type_ref(&ty, &paths), expected, "mismatch for {expected}");
1127 }
1128 }
1129
1130 #[test]
1131 fn test_format_type_ref_string() {
1132 assert_eq!(format_type_ref(&TypeRef::String, &HashMap::new()), "String");
1133 }
1134
1135 #[test]
1136 fn test_format_type_ref_char() {
1137 assert_eq!(format_type_ref(&TypeRef::Char, &HashMap::new()), "char");
1138 }
1139
1140 #[test]
1141 fn test_format_type_ref_bytes() {
1142 assert_eq!(format_type_ref(&TypeRef::Bytes, &HashMap::new()), "Vec<u8>");
1143 }
1144
1145 #[test]
1146 fn test_format_type_ref_path() {
1147 assert_eq!(format_type_ref(&TypeRef::Path, &HashMap::new()), "std::path::PathBuf");
1148 }
1149
1150 #[test]
1151 fn test_format_type_ref_unit() {
1152 assert_eq!(format_type_ref(&TypeRef::Unit, &HashMap::new()), "()");
1153 }
1154
1155 #[test]
1156 fn test_format_type_ref_json() {
1157 assert_eq!(format_type_ref(&TypeRef::Json, &HashMap::new()), "serde_json::Value");
1158 }
1159
1160 #[test]
1161 fn test_format_type_ref_duration() {
1162 assert_eq!(
1163 format_type_ref(&TypeRef::Duration, &HashMap::new()),
1164 "std::time::Duration"
1165 );
1166 }
1167
1168 #[test]
1169 fn test_format_type_ref_optional() {
1170 let ty = TypeRef::Optional(Box::new(TypeRef::String));
1171 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<String>");
1172 }
1173
1174 #[test]
1175 fn test_format_type_ref_optional_nested() {
1176 let ty = TypeRef::Optional(Box::new(TypeRef::Optional(Box::new(TypeRef::Primitive(
1177 PrimitiveType::U32,
1178 )))));
1179 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<Option<u32>>");
1180 }
1181
1182 #[test]
1183 fn test_format_type_ref_vec() {
1184 let ty = TypeRef::Vec(Box::new(TypeRef::Primitive(PrimitiveType::U8)));
1185 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<u8>");
1186 }
1187
1188 #[test]
1189 fn test_format_type_ref_vec_nested() {
1190 let ty = TypeRef::Vec(Box::new(TypeRef::Vec(Box::new(TypeRef::String))));
1191 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<Vec<String>>");
1192 }
1193
1194 #[test]
1195 fn test_format_type_ref_map() {
1196 let ty = TypeRef::Map(
1197 Box::new(TypeRef::String),
1198 Box::new(TypeRef::Primitive(PrimitiveType::I64)),
1199 );
1200 assert_eq!(
1201 format_type_ref(&ty, &HashMap::new()),
1202 "std::collections::HashMap<String, i64>"
1203 );
1204 }
1205
1206 #[test]
1207 fn test_format_type_ref_map_nested_value() {
1208 let ty = TypeRef::Map(
1209 Box::new(TypeRef::String),
1210 Box::new(TypeRef::Vec(Box::new(TypeRef::String))),
1211 );
1212 assert_eq!(
1213 format_type_ref(&ty, &HashMap::new()),
1214 "std::collections::HashMap<String, Vec<String>>"
1215 );
1216 }
1217
1218 #[test]
1219 fn test_format_type_ref_named_without_type_paths() {
1220 let ty = TypeRef::Named("Config".to_string());
1221 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Config");
1222 }
1223
1224 #[test]
1225 fn test_format_type_ref_named_with_type_paths() {
1226 let ty = TypeRef::Named("Config".to_string());
1227 let mut paths = HashMap::new();
1228 paths.insert("Config".to_string(), "mylib::Config".to_string());
1229 assert_eq!(format_type_ref(&ty, &paths), "mylib::Config");
1230 }
1231
1232 #[test]
1233 fn test_format_type_ref_named_not_in_type_paths_falls_back_to_name() {
1234 let ty = TypeRef::Named("Unknown".to_string());
1235 let mut paths = HashMap::new();
1236 paths.insert("Other".to_string(), "mylib::Other".to_string());
1237 assert_eq!(format_type_ref(&ty, &paths), "Unknown");
1238 }
1239
1240 #[test]
1245 fn test_format_param_type_string_ref() {
1246 let param = make_param("input", TypeRef::String, true);
1247 assert_eq!(format_param_type(¶m, &HashMap::new()), "&str");
1248 }
1249
1250 #[test]
1251 fn test_format_param_type_string_owned() {
1252 let param = make_param("input", TypeRef::String, false);
1253 assert_eq!(format_param_type(¶m, &HashMap::new()), "String");
1254 }
1255
1256 #[test]
1257 fn test_format_param_type_bytes_ref() {
1258 let param = make_param("data", TypeRef::Bytes, true);
1259 assert_eq!(format_param_type(¶m, &HashMap::new()), "&[u8]");
1260 }
1261
1262 #[test]
1263 fn test_format_param_type_bytes_owned() {
1264 let param = make_param("data", TypeRef::Bytes, false);
1265 assert_eq!(format_param_type(¶m, &HashMap::new()), "Vec<u8>");
1266 }
1267
1268 #[test]
1269 fn test_format_param_type_path_ref() {
1270 let param = make_param("path", TypeRef::Path, true);
1271 assert_eq!(format_param_type(¶m, &HashMap::new()), "&std::path::Path");
1272 }
1273
1274 #[test]
1275 fn test_format_param_type_path_owned() {
1276 let param = make_param("path", TypeRef::Path, false);
1277 assert_eq!(format_param_type(¶m, &HashMap::new()), "std::path::PathBuf");
1278 }
1279
1280 #[test]
1281 fn test_format_param_type_vec_ref() {
1282 let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), true);
1283 assert_eq!(format_param_type(¶m, &HashMap::new()), "&[String]");
1284 }
1285
1286 #[test]
1287 fn test_format_param_type_vec_owned() {
1288 let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), false);
1289 assert_eq!(format_param_type(¶m, &HashMap::new()), "Vec<String>");
1290 }
1291
1292 #[test]
1293 fn test_format_param_type_named_ref_with_type_paths() {
1294 let mut paths = HashMap::new();
1295 paths.insert("Config".to_string(), "mylib::Config".to_string());
1296 let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
1297 assert_eq!(format_param_type(¶m, &paths), "&mylib::Config");
1298 }
1299
1300 #[test]
1301 fn test_format_param_type_named_ref_without_type_paths() {
1302 let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
1303 assert_eq!(format_param_type(¶m, &HashMap::new()), "&Config");
1304 }
1305
1306 #[test]
1307 fn test_format_param_type_primitive_ref_passes_by_value() {
1308 let param = make_param("count", TypeRef::Primitive(PrimitiveType::U32), true);
1310 assert_eq!(format_param_type(¶m, &HashMap::new()), "u32");
1311 }
1312
1313 #[test]
1314 fn test_format_param_type_unit_ref_passes_by_value() {
1315 let param = make_param("nothing", TypeRef::Unit, true);
1316 assert_eq!(format_param_type(¶m, &HashMap::new()), "()");
1317 }
1318
1319 #[test]
1324 fn test_format_return_type_without_error() {
1325 let result = format_return_type(&TypeRef::String, None, &HashMap::new());
1326 assert_eq!(result, "String");
1327 }
1328
1329 #[test]
1330 fn test_format_return_type_with_error() {
1331 let result = format_return_type(&TypeRef::String, Some("MyError"), &HashMap::new());
1332 assert_eq!(result, "std::result::Result<String, MyError>");
1333 }
1334
1335 #[test]
1336 fn test_format_return_type_unit_with_error() {
1337 let result = format_return_type(&TypeRef::Unit, Some("Box<dyn std::error::Error>"), &HashMap::new());
1338 assert_eq!(result, "std::result::Result<(), Box<dyn std::error::Error>>");
1339 }
1340
1341 #[test]
1342 fn test_format_return_type_named_with_type_paths_and_error() {
1343 let mut paths = HashMap::new();
1344 paths.insert("Output".to_string(), "mylib::Output".to_string());
1345 let result = format_return_type(&TypeRef::Named("Output".to_string()), Some("mylib::MyError"), &paths);
1346 assert_eq!(result, "std::result::Result<mylib::Output, mylib::MyError>");
1347 }
1348
1349 #[test]
1354 fn test_gen_bridge_wrapper_struct_contains_struct_name() {
1355 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1356 let config = make_trait_bridge_config(None, None);
1357 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1358 let generator = MockBridgeGenerator;
1359 let result = gen_bridge_wrapper_struct(&spec, &generator);
1360 assert!(
1361 result.contains("pub struct PyOcrBackendBridge"),
1362 "missing struct declaration in:\n{result}"
1363 );
1364 }
1365
1366 #[test]
1367 fn test_gen_bridge_wrapper_struct_contains_inner_field() {
1368 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1369 let config = make_trait_bridge_config(None, None);
1370 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1371 let generator = MockBridgeGenerator;
1372 let result = gen_bridge_wrapper_struct(&spec, &generator);
1373 assert!(result.contains("inner: Py<PyAny>"), "missing inner field in:\n{result}");
1374 }
1375
1376 #[test]
1377 fn test_gen_bridge_wrapper_struct_contains_cached_name() {
1378 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1379 let config = make_trait_bridge_config(None, None);
1380 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1381 let generator = MockBridgeGenerator;
1382 let result = gen_bridge_wrapper_struct(&spec, &generator);
1383 assert!(
1384 result.contains("cached_name: String"),
1385 "missing cached_name field in:\n{result}"
1386 );
1387 }
1388
1389 #[test]
1394 fn test_gen_bridge_plugin_impl_returns_none_when_no_super_trait() {
1395 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1396 let config = make_trait_bridge_config(None, None);
1397 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1398 let generator = MockBridgeGenerator;
1399 assert!(gen_bridge_plugin_impl(&spec, &generator).is_none());
1400 }
1401
1402 #[test]
1403 fn test_gen_bridge_plugin_impl_returns_some_when_super_trait_configured() {
1404 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1405 let config = make_trait_bridge_config(Some("Plugin"), None);
1406 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1407 let generator = MockBridgeGenerator;
1408 assert!(gen_bridge_plugin_impl(&spec, &generator).is_some());
1409 }
1410
1411 #[test]
1412 fn test_gen_bridge_plugin_impl_uses_qualified_super_trait_path() {
1413 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1414 let config = make_trait_bridge_config(Some("Plugin"), None);
1415 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1416 let generator = MockBridgeGenerator;
1417 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1418 assert!(
1419 result.contains("impl mylib::Plugin for PyOcrBackendBridge"),
1420 "missing qualified super-trait path in:\n{result}"
1421 );
1422 }
1423
1424 #[test]
1425 fn test_gen_bridge_plugin_impl_uses_already_qualified_super_trait_path() {
1426 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1427 let config = make_trait_bridge_config(Some("other_crate::Plugin"), None);
1428 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1429 let generator = MockBridgeGenerator;
1430 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1431 assert!(
1432 result.contains("impl other_crate::Plugin for PyOcrBackendBridge"),
1433 "wrong super-trait path in:\n{result}"
1434 );
1435 }
1436
1437 #[test]
1438 fn test_gen_bridge_plugin_impl_contains_name_fn() {
1439 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1440 let config = make_trait_bridge_config(Some("Plugin"), None);
1441 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1442 let generator = MockBridgeGenerator;
1443 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1444 assert!(
1445 result.contains("fn name(") && result.contains("cached_name"),
1446 "missing name() using cached_name in:\n{result}"
1447 );
1448 }
1449
1450 #[test]
1451 fn test_gen_bridge_plugin_impl_contains_version_fn() {
1452 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1453 let config = make_trait_bridge_config(Some("Plugin"), None);
1454 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1455 let generator = MockBridgeGenerator;
1456 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1457 assert!(result.contains("fn version("), "missing version() in:\n{result}");
1458 }
1459
1460 #[test]
1461 fn test_gen_bridge_plugin_impl_contains_initialize_fn() {
1462 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1463 let config = make_trait_bridge_config(Some("Plugin"), None);
1464 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1465 let generator = MockBridgeGenerator;
1466 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1467 assert!(result.contains("fn initialize("), "missing initialize() in:\n{result}");
1468 }
1469
1470 #[test]
1471 fn test_gen_bridge_plugin_impl_contains_shutdown_fn() {
1472 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1473 let config = make_trait_bridge_config(Some("Plugin"), None);
1474 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1475 let generator = MockBridgeGenerator;
1476 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1477 assert!(result.contains("fn shutdown("), "missing shutdown() in:\n{result}");
1478 }
1479
1480 #[test]
1485 fn test_gen_bridge_trait_impl_includes_impl_header() {
1486 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1487 let config = make_trait_bridge_config(None, None);
1488 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1489 let generator = MockBridgeGenerator;
1490 let result = gen_bridge_trait_impl(&spec, &generator);
1491 assert!(
1492 result.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
1493 "missing impl header in:\n{result}"
1494 );
1495 }
1496
1497 #[test]
1498 fn test_gen_bridge_trait_impl_includes_method_signatures() {
1499 let methods = vec![make_method(
1500 "process",
1501 vec![],
1502 TypeRef::String,
1503 false,
1504 false,
1505 None,
1506 None,
1507 )];
1508 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1509 let config = make_trait_bridge_config(None, None);
1510 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1511 let generator = MockBridgeGenerator;
1512 let result = gen_bridge_trait_impl(&spec, &generator);
1513 assert!(result.contains("fn process("), "missing method signature in:\n{result}");
1514 }
1515
1516 #[test]
1517 fn test_gen_bridge_trait_impl_includes_method_body_from_generator() {
1518 let methods = vec![make_method(
1519 "process",
1520 vec![],
1521 TypeRef::String,
1522 false,
1523 false,
1524 None,
1525 None,
1526 )];
1527 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1528 let config = make_trait_bridge_config(None, None);
1529 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1530 let generator = MockBridgeGenerator;
1531 let result = gen_bridge_trait_impl(&spec, &generator);
1532 assert!(
1533 result.contains("// sync body for process"),
1534 "missing sync method body in:\n{result}"
1535 );
1536 }
1537
1538 #[test]
1539 fn test_gen_bridge_trait_impl_async_method_uses_async_body() {
1540 let methods = vec![make_method(
1541 "process_async",
1542 vec![],
1543 TypeRef::String,
1544 true,
1545 false,
1546 None,
1547 None,
1548 )];
1549 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1550 let config = make_trait_bridge_config(None, None);
1551 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1552 let generator = MockBridgeGenerator;
1553 let result = gen_bridge_trait_impl(&spec, &generator);
1554 assert!(
1555 result.contains("// async body for process_async"),
1556 "missing async method body in:\n{result}"
1557 );
1558 assert!(
1559 result.contains("async fn process_async("),
1560 "missing async keyword in method signature in:\n{result}"
1561 );
1562 }
1563
1564 #[test]
1565 fn test_gen_bridge_trait_impl_filters_trait_source_methods() {
1566 let methods = vec![
1568 make_method("own_method", vec![], TypeRef::String, false, false, None, None),
1569 make_method(
1570 "inherited_method",
1571 vec![],
1572 TypeRef::String,
1573 false,
1574 false,
1575 Some("other_crate::OtherTrait"),
1576 None,
1577 ),
1578 ];
1579 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1580 let config = make_trait_bridge_config(None, None);
1581 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1582 let generator = MockBridgeGenerator;
1583 let result = gen_bridge_trait_impl(&spec, &generator);
1584 assert!(
1585 result.contains("fn own_method("),
1586 "own method should be present in:\n{result}"
1587 );
1588 assert!(
1589 !result.contains("fn inherited_method("),
1590 "inherited method should be filtered out in:\n{result}"
1591 );
1592 }
1593
1594 #[test]
1595 fn test_gen_bridge_trait_impl_method_with_params() {
1596 let params = vec![
1597 make_param("input", TypeRef::String, true),
1598 make_param("count", TypeRef::Primitive(PrimitiveType::U32), false),
1599 ];
1600 let methods = vec![make_method(
1601 "process",
1602 params,
1603 TypeRef::String,
1604 false,
1605 false,
1606 None,
1607 None,
1608 )];
1609 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1610 let config = make_trait_bridge_config(None, None);
1611 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1612 let generator = MockBridgeGenerator;
1613 let result = gen_bridge_trait_impl(&spec, &generator);
1614 assert!(result.contains("input: &str"), "missing &str param in:\n{result}");
1615 assert!(result.contains("count: u32"), "missing u32 param in:\n{result}");
1616 }
1617
1618 #[test]
1619 fn test_gen_bridge_trait_impl_return_type_with_error() {
1620 let methods = vec![make_method(
1621 "process",
1622 vec![],
1623 TypeRef::String,
1624 false,
1625 false,
1626 None,
1627 Some("MyError"),
1628 )];
1629 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1630 let config = make_trait_bridge_config(None, None);
1631 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1632 let generator = MockBridgeGenerator;
1633 let result = gen_bridge_trait_impl(&spec, &generator);
1634 assert!(
1635 result.contains("-> std::result::Result<String, mylib::MyError>"),
1636 "missing std::result::Result return type in:\n{result}"
1637 );
1638 }
1639
1640 #[test]
1645 fn test_gen_bridge_registration_fn_returns_none_without_register_fn() {
1646 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1647 let config = make_trait_bridge_config(None, None);
1648 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1649 let generator = MockBridgeGenerator;
1650 assert!(gen_bridge_registration_fn(&spec, &generator).is_none());
1651 }
1652
1653 #[test]
1654 fn test_gen_bridge_registration_fn_returns_some_with_register_fn() {
1655 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1656 let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
1657 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1658 let generator = MockBridgeGenerator;
1659 let result = gen_bridge_registration_fn(&spec, &generator);
1660 assert!(result.is_some());
1661 let code = result.unwrap();
1662 assert!(
1663 code.contains("register_ocr_backend"),
1664 "missing register fn name in:\n{code}"
1665 );
1666 }
1667
1668 #[test]
1673 fn test_gen_bridge_all_includes_imports() {
1674 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1675 let config = make_trait_bridge_config(None, None);
1676 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1677 let generator = MockBridgeGenerator;
1678 let output = gen_bridge_all(&spec, &generator);
1679 assert!(output.imports.contains(&"pyo3::prelude::*".to_string()));
1680 assert!(output.imports.contains(&"pyo3::types::PyString".to_string()));
1681 }
1682
1683 #[test]
1684 fn test_gen_bridge_all_includes_wrapper_struct() {
1685 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1686 let config = make_trait_bridge_config(None, None);
1687 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1688 let generator = MockBridgeGenerator;
1689 let output = gen_bridge_all(&spec, &generator);
1690 assert!(
1691 output.code.contains("pub struct PyOcrBackendBridge"),
1692 "missing struct in:\n{}",
1693 output.code
1694 );
1695 }
1696
1697 #[test]
1698 fn test_gen_bridge_all_includes_constructor() {
1699 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1700 let config = make_trait_bridge_config(None, None);
1701 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1702 let generator = MockBridgeGenerator;
1703 let output = gen_bridge_all(&spec, &generator);
1704 assert!(
1705 output.code.contains("pub fn new("),
1706 "missing constructor in:\n{}",
1707 output.code
1708 );
1709 }
1710
1711 #[test]
1712 fn test_gen_bridge_all_includes_trait_impl() {
1713 let methods = vec![make_method(
1714 "process",
1715 vec![],
1716 TypeRef::String,
1717 false,
1718 false,
1719 None,
1720 None,
1721 )];
1722 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1723 let config = make_trait_bridge_config(None, None);
1724 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1725 let generator = MockBridgeGenerator;
1726 let output = gen_bridge_all(&spec, &generator);
1727 assert!(
1728 output.code.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
1729 "missing trait impl in:\n{}",
1730 output.code
1731 );
1732 }
1733
1734 #[test]
1735 fn test_gen_bridge_all_includes_plugin_impl_when_super_trait_set() {
1736 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1737 let config = make_trait_bridge_config(Some("Plugin"), None);
1738 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1739 let generator = MockBridgeGenerator;
1740 let output = gen_bridge_all(&spec, &generator);
1741 assert!(
1742 output.code.contains("impl mylib::Plugin for PyOcrBackendBridge"),
1743 "missing plugin impl in:\n{}",
1744 output.code
1745 );
1746 }
1747
1748 #[test]
1749 fn test_gen_bridge_all_no_plugin_impl_when_no_super_trait() {
1750 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1751 let config = make_trait_bridge_config(None, None);
1752 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1753 let generator = MockBridgeGenerator;
1754 let output = gen_bridge_all(&spec, &generator);
1755 assert!(
1756 !output.code.contains("fn name(") || !output.code.contains("cached_name"),
1757 "unexpected plugin impl present without super_trait"
1758 );
1759 }
1760
1761 #[test]
1762 fn test_gen_bridge_all_includes_registration_fn_when_configured() {
1763 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1764 let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
1765 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1766 let generator = MockBridgeGenerator;
1767 let output = gen_bridge_all(&spec, &generator);
1768 assert!(
1769 output.code.contains("register_ocr_backend"),
1770 "missing registration fn in:\n{}",
1771 output.code
1772 );
1773 }
1774
1775 #[test]
1776 fn test_gen_bridge_all_no_registration_fn_when_absent() {
1777 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1778 let config = make_trait_bridge_config(None, None);
1779 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1780 let generator = MockBridgeGenerator;
1781 let output = gen_bridge_all(&spec, &generator);
1782 assert!(
1783 !output.code.contains("register_ocr_backend"),
1784 "unexpected registration fn present:\n{}",
1785 output.code
1786 );
1787 }
1788
1789 #[test]
1790 fn test_gen_bridge_all_ordering_struct_before_trait_impl() {
1791 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1792 let config = make_trait_bridge_config(None, None);
1793 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1794 let generator = MockBridgeGenerator;
1795 let output = gen_bridge_all(&spec, &generator);
1796 let struct_pos = output.code.find("pub struct PyOcrBackendBridge").unwrap();
1797 let impl_pos = output
1798 .code
1799 .find("impl mylib::OcrBackend for PyOcrBackendBridge")
1800 .unwrap();
1801 assert!(struct_pos < impl_pos, "struct should appear before trait impl");
1802 }
1803
1804 fn make_bridge(
1809 type_alias: Option<&str>,
1810 param_name: Option<&str>,
1811 bind_via: BridgeBinding,
1812 options_type: Option<&str>,
1813 options_field: Option<&str>,
1814 ) -> TraitBridgeConfig {
1815 TraitBridgeConfig {
1816 trait_name: "HtmlVisitor".to_string(),
1817 super_trait: None,
1818 registry_getter: None,
1819 register_fn: None,
1820 unregister_fn: None,
1821 clear_fn: None,
1822 type_alias: type_alias.map(str::to_string),
1823 param_name: param_name.map(str::to_string),
1824 register_extra_args: None,
1825 exclude_languages: vec![],
1826 bind_via,
1827 options_type: options_type.map(str::to_string),
1828 options_field: options_field.map(str::to_string),
1829 }
1830 }
1831
1832 #[test]
1833 fn find_bridge_param_returns_first_param_match_in_function_param_mode() {
1834 let func = make_func(
1835 "convert",
1836 vec![
1837 make_param("html", TypeRef::String, true),
1838 make_param("visitor", TypeRef::Named("VisitorHandle".to_string()), false),
1839 ],
1840 );
1841 let bridges = vec![make_bridge(
1842 Some("VisitorHandle"),
1843 Some("visitor"),
1844 BridgeBinding::FunctionParam,
1845 None,
1846 None,
1847 )];
1848 let result = find_bridge_param(&func, &bridges).expect("bridge match");
1849 assert_eq!(result.0, 1);
1850 }
1851
1852 #[test]
1853 fn find_bridge_param_skips_options_field_bridges() {
1854 let func = make_func(
1855 "convert",
1856 vec![
1857 make_param("html", TypeRef::String, true),
1858 make_param("visitor", TypeRef::Named("VisitorHandle".to_string()), false),
1859 ],
1860 );
1861 let bridges = vec![make_bridge(
1862 Some("VisitorHandle"),
1863 Some("visitor"),
1864 BridgeBinding::OptionsField,
1865 Some("ConversionOptions"),
1866 Some("visitor"),
1867 )];
1868 assert!(
1869 find_bridge_param(&func, &bridges).is_none(),
1870 "bridges configured with bind_via=options_field must not be returned by find_bridge_param"
1871 );
1872 }
1873
1874 #[test]
1875 fn find_bridge_field_detects_field_via_alias() {
1876 let opts_type = TypeDef {
1877 name: "ConversionOptions".to_string(),
1878 rust_path: "mylib::ConversionOptions".to_string(),
1879 original_rust_path: String::new(),
1880 fields: vec![
1881 make_field("debug", TypeRef::Primitive(PrimitiveType::Bool)),
1882 make_field(
1883 "visitor",
1884 TypeRef::Optional(Box::new(TypeRef::Named("VisitorHandle".to_string()))),
1885 ),
1886 ],
1887 methods: vec![],
1888 is_opaque: false,
1889 is_clone: true,
1890 is_copy: false,
1891 doc: String::new(),
1892 cfg: None,
1893 is_trait: false,
1894 has_default: true,
1895 has_stripped_cfg_fields: false,
1896 is_return_type: false,
1897 serde_rename_all: None,
1898 has_serde: false,
1899 super_traits: vec![],
1900 };
1901 let func = make_func(
1902 "convert",
1903 vec![
1904 make_param("html", TypeRef::String, true),
1905 make_param(
1906 "options",
1907 TypeRef::Optional(Box::new(TypeRef::Named("ConversionOptions".to_string()))),
1908 false,
1909 ),
1910 ],
1911 );
1912 let bridges = vec![make_bridge(
1913 Some("VisitorHandle"),
1914 Some("visitor"),
1915 BridgeBinding::OptionsField,
1916 Some("ConversionOptions"),
1917 None,
1918 )];
1919 let m = find_bridge_field(&func, std::slice::from_ref(&opts_type), &bridges).expect("bridge field match");
1920 assert_eq!(m.param_index, 1);
1921 assert_eq!(m.param_name, "options");
1922 assert_eq!(m.options_type, "ConversionOptions");
1923 assert!(m.param_is_optional);
1924 assert_eq!(m.field_name, "visitor");
1925 }
1926
1927 #[test]
1928 fn find_bridge_field_returns_none_for_function_param_bridge() {
1929 let opts_type = TypeDef {
1930 name: "ConversionOptions".to_string(),
1931 rust_path: "mylib::ConversionOptions".to_string(),
1932 original_rust_path: String::new(),
1933 fields: vec![make_field(
1934 "visitor",
1935 TypeRef::Optional(Box::new(TypeRef::Named("VisitorHandle".to_string()))),
1936 )],
1937 methods: vec![],
1938 is_opaque: false,
1939 is_clone: true,
1940 is_copy: false,
1941 doc: String::new(),
1942 cfg: None,
1943 is_trait: false,
1944 has_default: true,
1945 has_stripped_cfg_fields: false,
1946 is_return_type: false,
1947 serde_rename_all: None,
1948 has_serde: false,
1949 super_traits: vec![],
1950 };
1951 let func = make_func(
1952 "convert",
1953 vec![make_param(
1954 "options",
1955 TypeRef::Named("ConversionOptions".to_string()),
1956 false,
1957 )],
1958 );
1959 let bridges = vec![make_bridge(
1960 Some("VisitorHandle"),
1961 Some("visitor"),
1962 BridgeBinding::FunctionParam,
1963 None,
1964 None,
1965 )];
1966 assert!(find_bridge_field(&func, std::slice::from_ref(&opts_type), &bridges).is_none());
1967 }
1968}