1use alef_core::config::{BridgeBinding, TraitBridgeConfig};
9use alef_core::ir::{FieldDef, FunctionDef, MethodDef, ParamDef, PrimitiveType, TypeDef, TypeRef};
10use heck::ToSnakeCase;
11use std::collections::HashMap;
12
13pub struct TraitBridgeSpec<'a> {
15 pub trait_def: &'a TypeDef,
17 pub bridge_config: &'a TraitBridgeConfig,
19 pub core_import: &'a str,
21 pub wrapper_prefix: &'a str,
23 pub type_paths: HashMap<String, String>,
25 pub error_type: String,
27 pub error_constructor: String,
29}
30
31impl<'a> TraitBridgeSpec<'a> {
32 pub fn error_path(&self) -> String {
39 if self.error_type.contains("::") || self.error_type.contains('<') {
40 self.error_type.clone()
41 } else {
42 format!("{}::{}", self.core_import, self.error_type)
43 }
44 }
45
46 pub fn make_error(&self, msg_expr: &str) -> String {
48 self.error_constructor.replace("{msg}", msg_expr)
49 }
50
51 pub fn wrapper_name(&self) -> String {
53 format!("{}{}Bridge", self.wrapper_prefix, self.trait_def.name)
54 }
55
56 pub fn trait_snake(&self) -> String {
58 self.trait_def.name.to_snake_case()
59 }
60
61 pub fn trait_path(&self) -> String {
63 self.trait_def.rust_path.replace('-', "_")
64 }
65
66 pub fn required_methods(&self) -> Vec<&'a MethodDef> {
68 self.trait_def.methods.iter().filter(|m| !m.has_default_impl).collect()
69 }
70
71 pub fn optional_methods(&self) -> Vec<&'a MethodDef> {
73 self.trait_def.methods.iter().filter(|m| m.has_default_impl).collect()
74 }
75}
76
77pub trait TraitBridgeGenerator {
83 fn foreign_object_type(&self) -> &str;
85
86 fn bridge_imports(&self) -> Vec<String>;
88
89 fn gen_sync_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
94
95 fn gen_async_method_body(&self, method: &MethodDef, spec: &TraitBridgeSpec) -> String;
99
100 fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String;
105
106 fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String;
112
113 fn gen_unregistration_fn(&self, _spec: &TraitBridgeSpec) -> String {
120 String::new()
121 }
122
123 fn gen_clear_fn(&self, _spec: &TraitBridgeSpec) -> String {
130 String::new()
131 }
132
133 fn async_trait_is_send(&self) -> bool {
138 true
139 }
140}
141
142pub fn gen_bridge_wrapper_struct(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
156 let wrapper = spec.wrapper_name();
157 let foreign_type = generator.foreign_object_type();
158
159 crate::template_env::render(
160 "generators/trait_bridge/wrapper_struct.jinja",
161 minijinja::context! {
162 wrapper_prefix => spec.wrapper_prefix,
163 trait_name => &spec.trait_def.name,
164 wrapper_name => wrapper,
165 foreign_type => foreign_type,
166 },
167 )
168}
169
170fn gen_bridge_debug_impl(spec: &TraitBridgeSpec) -> String {
176 let wrapper = spec.wrapper_name();
177 format!(
178 "impl std::fmt::Debug for {wrapper} {{\n fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{\n write!(f, \"{wrapper}\")\n }}\n}}"
179 )
180}
181
182pub fn gen_bridge_plugin_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
190 let super_trait_name = spec.bridge_config.super_trait.as_deref()?;
191
192 let wrapper = spec.wrapper_name();
193 let core_import = spec.core_import;
194
195 let super_trait_path = if super_trait_name.contains("::") {
197 super_trait_name.to_string()
198 } else {
199 format!("{core_import}::{super_trait_name}")
200 };
201
202 let error_path = spec.error_path();
206
207 let version_method = MethodDef {
209 name: "version".to_string(),
210 params: vec![],
211 return_type: alef_core::ir::TypeRef::String,
212 is_async: false,
213 is_static: false,
214 error_type: None,
215 doc: String::new(),
216 receiver: Some(alef_core::ir::ReceiverKind::Ref),
217 sanitized: false,
218 trait_source: None,
219 returns_ref: false,
220 returns_cow: false,
221 return_newtype_wrapper: None,
222 has_default_impl: false,
223 };
224 let version_body = generator.gen_sync_method_body(&version_method, spec);
225
226 let init_method = MethodDef {
228 name: "initialize".to_string(),
229 params: vec![],
230 return_type: alef_core::ir::TypeRef::Unit,
231 is_async: false,
232 is_static: false,
233 error_type: Some(error_path.clone()),
234 doc: String::new(),
235 receiver: Some(alef_core::ir::ReceiverKind::Ref),
236 sanitized: false,
237 trait_source: None,
238 returns_ref: false,
239 returns_cow: false,
240 return_newtype_wrapper: None,
241 has_default_impl: true,
242 };
243 let init_body = generator.gen_sync_method_body(&init_method, spec);
244
245 let shutdown_method = MethodDef {
247 name: "shutdown".to_string(),
248 params: vec![],
249 return_type: alef_core::ir::TypeRef::Unit,
250 is_async: false,
251 is_static: false,
252 error_type: Some(error_path.clone()),
253 doc: String::new(),
254 receiver: Some(alef_core::ir::ReceiverKind::Ref),
255 sanitized: false,
256 trait_source: None,
257 returns_ref: false,
258 returns_cow: false,
259 return_newtype_wrapper: None,
260 has_default_impl: true,
261 };
262 let shutdown_body = generator.gen_sync_method_body(&shutdown_method, spec);
263
264 let version_lines: Vec<&str> = version_body.lines().collect();
266 let init_lines: Vec<&str> = init_body.lines().collect();
267 let shutdown_lines: Vec<&str> = shutdown_body.lines().collect();
268
269 Some(crate::template_env::render(
270 "generators/trait_bridge/plugin_impl.jinja",
271 minijinja::context! {
272 super_trait_path => super_trait_path,
273 wrapper_name => wrapper,
274 error_path => error_path,
275 version_lines => version_lines,
276 init_lines => init_lines,
277 shutdown_lines => shutdown_lines,
278 },
279 ))
280}
281
282pub fn gen_bridge_trait_impl(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> String {
287 let wrapper = spec.wrapper_name();
288 let trait_path = spec.trait_path();
289
290 let has_async_methods = spec
292 .trait_def
293 .methods
294 .iter()
295 .any(|m| m.is_async && m.trait_source.is_none());
296 let async_trait_is_send = generator.async_trait_is_send();
297
298 let own_methods: Vec<_> = spec
300 .trait_def
301 .methods
302 .iter()
303 .filter(|m| m.trait_source.is_none())
304 .collect();
305
306 let mut methods_code = String::with_capacity(1024);
308 for (i, method) in own_methods.iter().enumerate() {
309 if i > 0 {
310 methods_code.push_str("\n\n");
311 }
312
313 let async_kw = if method.is_async { "async " } else { "" };
315 let receiver = match &method.receiver {
316 Some(alef_core::ir::ReceiverKind::Ref) => "&self",
317 Some(alef_core::ir::ReceiverKind::RefMut) => "&mut self",
318 Some(alef_core::ir::ReceiverKind::Owned) => "self",
319 None => "",
320 };
321
322 let params: Vec<String> = method
324 .params
325 .iter()
326 .map(|p| format!("{}: {}", p.name, format_param_type(p, &spec.type_paths)))
327 .collect();
328
329 let all_params = if receiver.is_empty() {
330 params.join(", ")
331 } else if params.is_empty() {
332 receiver.to_string()
333 } else {
334 format!("{}, {}", receiver, params.join(", "))
335 };
336
337 let error_override = method.error_type.as_ref().map(|_| spec.error_path());
341 let ret = format_return_type(&method.return_type, error_override.as_deref(), &spec.type_paths);
342
343 let body = if method.is_async {
345 generator.gen_async_method_body(method, spec)
346 } else {
347 generator.gen_sync_method_body(method, spec)
348 };
349
350 let indented_body = body
352 .lines()
353 .map(|line| format!(" {line}"))
354 .collect::<Vec<_>>()
355 .join("\n");
356
357 methods_code.push_str(&crate::template_env::render(
358 "generators/trait_bridge/trait_method.jinja",
359 minijinja::context! {
360 async_kw => async_kw,
361 method_name => &method.name,
362 all_params => all_params,
363 ret => ret,
364 indented_body => &indented_body,
365 },
366 ));
367 }
368
369 crate::template_env::render(
370 "generators/trait_bridge/trait_impl.jinja",
371 minijinja::context! {
372 has_async_methods => has_async_methods,
373 async_trait_is_send => async_trait_is_send,
374 trait_path => trait_path,
375 wrapper_name => wrapper,
376 methods_code => methods_code,
377 },
378 )
379}
380
381pub fn gen_bridge_registration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
388 spec.bridge_config.register_fn.as_deref()?;
389 Some(generator.gen_registration_fn(spec))
390}
391
392pub fn gen_bridge_unregistration_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
399 spec.bridge_config.unregister_fn.as_deref()?;
400 let body = generator.gen_unregistration_fn(spec);
401 if body.is_empty() { None } else { Some(body) }
402}
403
404pub fn gen_bridge_clear_fn(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> Option<String> {
411 spec.bridge_config.clear_fn.as_deref()?;
412 let body = generator.gen_clear_fn(spec);
413 if body.is_empty() { None } else { Some(body) }
414}
415
416pub fn host_function_path(spec: &TraitBridgeSpec, fn_name: &str) -> String {
434 if let Some(getter) = spec.bridge_config.registry_getter.as_deref() {
435 let last = getter.rsplit("::").next().unwrap_or("");
436 if let Some(sub) = last.strip_prefix("get_").and_then(|s| s.strip_suffix("_registry")) {
437 let prefix_end = getter.len() - last.len();
438 let prefix = &getter[..prefix_end];
439 let prefix = prefix.trim_end_matches("registry::");
440 return format!("{prefix}{sub}::{fn_name}");
441 }
442 }
443 format!("{}::plugins::{}", spec.core_import, fn_name)
444}
445
446pub struct BridgeOutput {
449 pub imports: Vec<String>,
451 pub code: String,
453}
454
455pub fn gen_bridge_all(spec: &TraitBridgeSpec, generator: &dyn TraitBridgeGenerator) -> BridgeOutput {
461 let imports = generator.bridge_imports();
462 let mut out = String::with_capacity(4096);
463
464 out.push_str(&gen_bridge_wrapper_struct(spec, generator));
466 out.push_str("\n\n");
467
468 out.push_str(&gen_bridge_debug_impl(spec));
470 out.push_str("\n\n");
471
472 out.push_str(&generator.gen_constructor(spec));
474 out.push_str("\n\n");
475
476 if let Some(plugin_impl) = gen_bridge_plugin_impl(spec, generator) {
478 out.push_str(&plugin_impl);
479 out.push_str("\n\n");
480 }
481
482 out.push_str(&gen_bridge_trait_impl(spec, generator));
484
485 if let Some(reg_fn_code) = gen_bridge_registration_fn(spec, generator) {
487 out.push_str("\n\n");
488 out.push_str(®_fn_code);
489 }
490
491 if let Some(unreg_fn_code) = gen_bridge_unregistration_fn(spec, generator) {
494 out.push_str("\n\n");
495 out.push_str(&unreg_fn_code);
496 }
497
498 if let Some(clear_fn_code) = gen_bridge_clear_fn(spec, generator) {
501 out.push_str("\n\n");
502 out.push_str(&clear_fn_code);
503 }
504
505 BridgeOutput { imports, code: out }
506}
507
508pub fn format_type_ref(ty: &alef_core::ir::TypeRef, type_paths: &HashMap<String, String>) -> String {
517 use alef_core::ir::{PrimitiveType, TypeRef};
518 match ty {
519 TypeRef::Primitive(p) => match p {
520 PrimitiveType::Bool => "bool",
521 PrimitiveType::U8 => "u8",
522 PrimitiveType::U16 => "u16",
523 PrimitiveType::U32 => "u32",
524 PrimitiveType::U64 => "u64",
525 PrimitiveType::I8 => "i8",
526 PrimitiveType::I16 => "i16",
527 PrimitiveType::I32 => "i32",
528 PrimitiveType::I64 => "i64",
529 PrimitiveType::F32 => "f32",
530 PrimitiveType::F64 => "f64",
531 PrimitiveType::Usize => "usize",
532 PrimitiveType::Isize => "isize",
533 }
534 .to_string(),
535 TypeRef::String => "String".to_string(),
536 TypeRef::Char => "char".to_string(),
537 TypeRef::Bytes => "Vec<u8>".to_string(),
538 TypeRef::Optional(inner) => format!("Option<{}>", format_type_ref(inner, type_paths)),
539 TypeRef::Vec(inner) => format!("Vec<{}>", format_type_ref(inner, type_paths)),
540 TypeRef::Map(k, v) => format!(
541 "std::collections::HashMap<{}, {}>",
542 format_type_ref(k, type_paths),
543 format_type_ref(v, type_paths)
544 ),
545 TypeRef::Named(name) => type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone()),
546 TypeRef::Path => "std::path::PathBuf".to_string(),
547 TypeRef::Unit => "()".to_string(),
548 TypeRef::Json => "serde_json::Value".to_string(),
549 TypeRef::Duration => "std::time::Duration".to_string(),
550 }
551}
552
553pub fn format_return_type(
555 ty: &alef_core::ir::TypeRef,
556 error_type: Option<&str>,
557 type_paths: &HashMap<String, String>,
558) -> String {
559 let inner = format_type_ref(ty, type_paths);
560 match error_type {
561 Some(err) => format!("std::result::Result<{inner}, {err}>"),
562 None => inner,
563 }
564}
565
566pub fn format_param_type(param: &ParamDef, type_paths: &HashMap<String, String>) -> String {
578 use alef_core::ir::TypeRef;
579 let base = if param.is_ref {
580 let mutability = if param.is_mut { "mut " } else { "" };
581 match ¶m.ty {
582 TypeRef::String => format!("&{mutability}str"),
583 TypeRef::Bytes => format!("&{mutability}[u8]"),
584 TypeRef::Path => format!("&{mutability}std::path::Path"),
585 TypeRef::Vec(inner) => format!("&{mutability}[{}]", format_type_ref(inner, type_paths)),
586 TypeRef::Named(name) => {
587 let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
588 format!("&{mutability}{qualified}")
589 }
590 TypeRef::Optional(inner) => {
591 let inner_type_str = match inner.as_ref() {
595 TypeRef::String => format!("&{mutability}str"),
596 TypeRef::Bytes => format!("&{mutability}[u8]"),
597 TypeRef::Path => format!("&{mutability}std::path::Path"),
598 TypeRef::Vec(v) => format!("&{mutability}[{}]", format_type_ref(v, type_paths)),
599 TypeRef::Named(name) => {
600 let qualified = type_paths.get(name.as_str()).cloned().unwrap_or_else(|| name.clone());
601 format!("&{mutability}{qualified}")
602 }
603 other => format_type_ref(other, type_paths),
605 };
606 return format!("Option<{inner_type_str}>");
608 }
609 other => format_type_ref(other, type_paths),
611 }
612 } else {
613 format_type_ref(¶m.ty, type_paths)
614 };
615
616 if param.optional {
620 format!("Option<{base}>")
621 } else {
622 base
623 }
624}
625
626pub fn prim(p: &PrimitiveType) -> &'static str {
632 use PrimitiveType::*;
633 match p {
634 Bool => "bool",
635 U8 => "u8",
636 U16 => "u16",
637 U32 => "u32",
638 U64 => "u64",
639 I8 => "i8",
640 I16 => "i16",
641 I32 => "i32",
642 I64 => "i64",
643 F32 => "f32",
644 F64 => "f64",
645 Usize => "usize",
646 Isize => "isize",
647 }
648}
649
650pub fn bridge_param_type(ty: &TypeRef, ci: &str, is_ref: bool, tp: &HashMap<String, String>) -> String {
654 match ty {
655 TypeRef::Bytes if is_ref => "&[u8]".into(),
656 TypeRef::Bytes => "Vec<u8>".into(),
657 TypeRef::String if is_ref => "&str".into(),
658 TypeRef::String => "String".into(),
659 TypeRef::Path if is_ref => "&std::path::Path".into(),
660 TypeRef::Path => "std::path::PathBuf".into(),
661 TypeRef::Named(n) => {
662 let qualified = tp.get(n).cloned().unwrap_or_else(|| format!("{ci}::{n}"));
663 if is_ref { format!("&{qualified}") } else { qualified }
664 }
665 TypeRef::Vec(inner) => format!("Vec<{}>", bridge_param_type(inner, ci, false, tp)),
666 TypeRef::Optional(inner) => format!("Option<{}>", bridge_param_type(inner, ci, false, tp)),
667 TypeRef::Primitive(p) => prim(p).into(),
668 TypeRef::Unit => "()".into(),
669 TypeRef::Char => "char".into(),
670 TypeRef::Map(k, v) => format!(
671 "std::collections::HashMap<{}, {}>",
672 bridge_param_type(k, ci, false, tp),
673 bridge_param_type(v, ci, false, tp)
674 ),
675 TypeRef::Json => "serde_json::Value".into(),
676 TypeRef::Duration => "std::time::Duration".into(),
677 }
678}
679
680pub fn visitor_param_type(ty: &TypeRef, is_ref: bool, optional: bool, tp: &HashMap<String, String>) -> String {
686 if optional && matches!(ty, TypeRef::String) && is_ref {
687 return "Option<&str>".to_string();
688 }
689 if is_ref {
690 if let TypeRef::Vec(inner) = ty {
691 let inner_str = bridge_param_type(inner, "", false, tp);
692 return format!("&[{inner_str}]");
693 }
694 }
695 bridge_param_type(ty, "", is_ref, tp)
696}
697
698pub fn find_bridge_param<'a>(
705 func: &FunctionDef,
706 bridges: &'a [TraitBridgeConfig],
707) -> Option<(usize, &'a TraitBridgeConfig)> {
708 for (idx, param) in func.params.iter().enumerate() {
709 let named = match ¶m.ty {
710 TypeRef::Named(n) => Some(n.as_str()),
711 TypeRef::Optional(inner) => {
712 if let TypeRef::Named(n) = inner.as_ref() {
713 Some(n.as_str())
714 } else {
715 None
716 }
717 }
718 _ => None,
719 };
720 for bridge in bridges {
721 if bridge.bind_via != BridgeBinding::FunctionParam {
722 continue;
723 }
724 if let Some(type_name) = named {
725 if bridge.type_alias.as_deref() == Some(type_name) {
726 return Some((idx, bridge));
727 }
728 }
729 if bridge.param_name.as_deref() == Some(param.name.as_str()) {
730 return Some((idx, bridge));
731 }
732 }
733 }
734 None
735}
736
737#[derive(Debug, Clone)]
740pub struct BridgeFieldMatch<'a> {
741 pub param_index: usize,
743 pub param_name: String,
745 pub options_type: String,
747 pub param_is_optional: bool,
749 pub field_name: String,
751 pub field: &'a FieldDef,
753 pub bridge: &'a TraitBridgeConfig,
755}
756
757pub fn find_bridge_field<'a>(
768 func: &FunctionDef,
769 types: &'a [TypeDef],
770 bridges: &'a [TraitBridgeConfig],
771) -> Option<BridgeFieldMatch<'a>> {
772 fn unwrap_named(ty: &TypeRef) -> Option<(&str, bool)> {
773 match ty {
774 TypeRef::Named(n) => Some((n.as_str(), false)),
775 TypeRef::Optional(inner) => {
776 if let TypeRef::Named(n) = inner.as_ref() {
777 Some((n.as_str(), true))
778 } else {
779 None
780 }
781 }
782 _ => None,
783 }
784 }
785
786 for (idx, param) in func.params.iter().enumerate() {
787 let Some((type_name, is_optional)) = unwrap_named(¶m.ty) else {
788 continue;
789 };
790 let Some(type_def) = types.iter().find(|t| t.name == type_name) else {
791 continue;
792 };
793 for bridge in bridges {
794 if bridge.bind_via != BridgeBinding::OptionsField {
795 continue;
796 }
797 if bridge.options_type.as_deref() != Some(type_name) {
798 continue;
799 }
800 let field_name = bridge.resolved_options_field();
801 for field in &type_def.fields {
802 let matches_name = field_name.is_some_and(|n| field.name == n);
803 let matches_alias = bridge
804 .type_alias
805 .as_deref()
806 .is_some_and(|alias| field_type_matches_alias(&field.ty, alias));
807 if matches_name || matches_alias {
808 return Some(BridgeFieldMatch {
809 param_index: idx,
810 param_name: param.name.clone(),
811 options_type: type_name.to_string(),
812 param_is_optional: is_optional,
813 field_name: field.name.clone(),
814 field,
815 bridge,
816 });
817 }
818 }
819 }
820 }
821 None
822}
823
824fn field_type_matches_alias(field_ty: &TypeRef, alias: &str) -> bool {
827 match field_ty {
828 TypeRef::Named(n) => n == alias,
829 TypeRef::Optional(inner) | TypeRef::Vec(inner) => field_type_matches_alias(inner, alias),
830 _ => false,
831 }
832}
833
834pub fn to_camel_case(s: &str) -> String {
836 let mut result = String::new();
837 let mut capitalize_next = false;
838 for ch in s.chars() {
839 if ch == '_' {
840 capitalize_next = true;
841 } else if capitalize_next {
842 result.push(ch.to_ascii_uppercase());
843 capitalize_next = false;
844 } else {
845 result.push(ch);
846 }
847 }
848 result
849}
850
851#[cfg(test)]
852mod tests {
853 use super::*;
854 use alef_core::config::TraitBridgeConfig;
855 use alef_core::ir::{MethodDef, ParamDef, PrimitiveType, ReceiverKind, TypeDef, TypeRef};
856
857 fn make_trait_bridge_config(super_trait: Option<&str>, register_fn: Option<&str>) -> TraitBridgeConfig {
862 TraitBridgeConfig {
863 trait_name: "OcrBackend".to_string(),
864 super_trait: super_trait.map(str::to_string),
865 registry_getter: None,
866 register_fn: register_fn.map(str::to_string),
867 unregister_fn: None,
868 clear_fn: None,
869 type_alias: None,
870 param_name: None,
871 register_extra_args: None,
872 exclude_languages: Vec::new(),
873 bind_via: BridgeBinding::FunctionParam,
874 options_type: None,
875 options_field: None,
876 }
877 }
878
879 fn make_type_def(name: &str, rust_path: &str, methods: Vec<MethodDef>) -> TypeDef {
880 TypeDef {
881 name: name.to_string(),
882 rust_path: rust_path.to_string(),
883 original_rust_path: rust_path.to_string(),
884 fields: vec![],
885 methods,
886 is_opaque: true,
887 is_clone: false,
888 is_copy: false,
889 doc: String::new(),
890 cfg: None,
891 is_trait: true,
892 has_default: false,
893 has_stripped_cfg_fields: false,
894 is_return_type: false,
895 serde_rename_all: None,
896 has_serde: false,
897 super_traits: vec![],
898 }
899 }
900
901 fn make_method(
902 name: &str,
903 params: Vec<ParamDef>,
904 return_type: TypeRef,
905 is_async: bool,
906 has_default_impl: bool,
907 trait_source: Option<&str>,
908 error_type: Option<&str>,
909 ) -> MethodDef {
910 MethodDef {
911 name: name.to_string(),
912 params,
913 return_type,
914 is_async,
915 is_static: false,
916 error_type: error_type.map(str::to_string),
917 doc: String::new(),
918 receiver: Some(ReceiverKind::Ref),
919 sanitized: false,
920 trait_source: trait_source.map(str::to_string),
921 returns_ref: false,
922 returns_cow: false,
923 return_newtype_wrapper: None,
924 has_default_impl,
925 }
926 }
927
928 fn make_func(name: &str, params: Vec<ParamDef>) -> FunctionDef {
929 FunctionDef {
930 name: name.to_string(),
931 rust_path: format!("mylib::{name}"),
932 original_rust_path: String::new(),
933 params,
934 return_type: TypeRef::Unit,
935 is_async: false,
936 error_type: None,
937 doc: String::new(),
938 cfg: None,
939 sanitized: false,
940 return_sanitized: false,
941 returns_ref: false,
942 returns_cow: false,
943 return_newtype_wrapper: None,
944 }
945 }
946
947 fn make_field(name: &str, ty: TypeRef) -> FieldDef {
948 FieldDef {
949 name: name.to_string(),
950 ty,
951 optional: false,
952 default: None,
953 doc: String::new(),
954 sanitized: false,
955 is_boxed: false,
956 type_rust_path: None,
957 cfg: None,
958 typed_default: None,
959 core_wrapper: Default::default(),
960 vec_inner_core_wrapper: Default::default(),
961 newtype_wrapper: None,
962 }
963 }
964
965 fn make_param(name: &str, ty: TypeRef, is_ref: bool) -> ParamDef {
966 ParamDef {
967 name: name.to_string(),
968 ty,
969 optional: false,
970 default: None,
971 sanitized: false,
972 typed_default: None,
973 is_ref,
974 is_mut: false,
975 newtype_wrapper: None,
976 original_type: None,
977 }
978 }
979
980 fn make_spec<'a>(
981 trait_def: &'a TypeDef,
982 bridge_config: &'a TraitBridgeConfig,
983 wrapper_prefix: &'a str,
984 type_paths: HashMap<String, String>,
985 ) -> TraitBridgeSpec<'a> {
986 TraitBridgeSpec {
987 trait_def,
988 bridge_config,
989 core_import: "mylib",
990 wrapper_prefix,
991 type_paths,
992 error_type: "MyError".to_string(),
993 error_constructor: "MyError::from({msg})".to_string(),
994 }
995 }
996
997 struct MockBridgeGenerator;
1002
1003 impl TraitBridgeGenerator for MockBridgeGenerator {
1004 fn foreign_object_type(&self) -> &str {
1005 "Py<PyAny>"
1006 }
1007
1008 fn bridge_imports(&self) -> Vec<String> {
1009 vec!["pyo3::prelude::*".to_string(), "pyo3::types::PyString".to_string()]
1010 }
1011
1012 fn gen_sync_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
1013 format!("// sync body for {}", method.name)
1014 }
1015
1016 fn gen_async_method_body(&self, method: &MethodDef, _spec: &TraitBridgeSpec) -> String {
1017 format!("// async body for {}", method.name)
1018 }
1019
1020 fn gen_constructor(&self, spec: &TraitBridgeSpec) -> String {
1021 format!(
1022 "impl {} {{\n pub fn new(obj: Py<PyAny>) -> Self {{ Self {{ inner: obj, cached_name: String::new() }} }}\n}}",
1023 spec.wrapper_name()
1024 )
1025 }
1026
1027 fn gen_registration_fn(&self, spec: &TraitBridgeSpec) -> String {
1028 let fn_name = spec.bridge_config.register_fn.as_deref().unwrap_or("register");
1029 format!("pub fn {fn_name}(obj: Py<PyAny>) {{ /* register */ }}")
1030 }
1031 }
1032
1033 #[test]
1038 fn test_wrapper_name() {
1039 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1040 let config = make_trait_bridge_config(None, None);
1041 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1042 assert_eq!(spec.wrapper_name(), "PyOcrBackendBridge");
1043 }
1044
1045 #[test]
1046 fn test_trait_snake() {
1047 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1048 let config = make_trait_bridge_config(None, None);
1049 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1050 assert_eq!(spec.trait_snake(), "ocr_backend");
1051 }
1052
1053 #[test]
1054 fn test_trait_path_replaces_hyphens() {
1055 let trait_def = make_type_def("OcrBackend", "my-lib::OcrBackend", vec![]);
1056 let config = make_trait_bridge_config(None, None);
1057 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1058 assert_eq!(spec.trait_path(), "my_lib::OcrBackend");
1059 }
1060
1061 #[test]
1062 fn test_required_methods_filters_no_default_impl() {
1063 let methods = vec![
1064 make_method("process", vec![], TypeRef::String, false, false, None, None),
1065 make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
1066 make_method("detect", vec![], TypeRef::String, false, false, None, None),
1067 ];
1068 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1069 let config = make_trait_bridge_config(None, None);
1070 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1071 let required = spec.required_methods();
1072 assert_eq!(required.len(), 2);
1073 assert!(required.iter().any(|m| m.name == "process"));
1074 assert!(required.iter().any(|m| m.name == "detect"));
1075 }
1076
1077 #[test]
1078 fn test_optional_methods_filters_has_default_impl() {
1079 let methods = vec![
1080 make_method("process", vec![], TypeRef::String, false, false, None, None),
1081 make_method("initialize", vec![], TypeRef::Unit, false, true, None, None),
1082 make_method("shutdown", vec![], TypeRef::Unit, false, true, None, None),
1083 ];
1084 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1085 let config = make_trait_bridge_config(None, None);
1086 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1087 let optional = spec.optional_methods();
1088 assert_eq!(optional.len(), 2);
1089 assert!(optional.iter().any(|m| m.name == "initialize"));
1090 assert!(optional.iter().any(|m| m.name == "shutdown"));
1091 }
1092
1093 #[test]
1094 fn test_error_path() {
1095 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1096 let config = make_trait_bridge_config(None, None);
1097 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1098 assert_eq!(spec.error_path(), "mylib::MyError");
1099 }
1100
1101 #[test]
1106 fn test_format_type_ref_primitives() {
1107 let paths = HashMap::new();
1108 let cases: Vec<(TypeRef, &str)> = vec![
1109 (TypeRef::Primitive(PrimitiveType::Bool), "bool"),
1110 (TypeRef::Primitive(PrimitiveType::U8), "u8"),
1111 (TypeRef::Primitive(PrimitiveType::U16), "u16"),
1112 (TypeRef::Primitive(PrimitiveType::U32), "u32"),
1113 (TypeRef::Primitive(PrimitiveType::U64), "u64"),
1114 (TypeRef::Primitive(PrimitiveType::I8), "i8"),
1115 (TypeRef::Primitive(PrimitiveType::I16), "i16"),
1116 (TypeRef::Primitive(PrimitiveType::I32), "i32"),
1117 (TypeRef::Primitive(PrimitiveType::I64), "i64"),
1118 (TypeRef::Primitive(PrimitiveType::F32), "f32"),
1119 (TypeRef::Primitive(PrimitiveType::F64), "f64"),
1120 (TypeRef::Primitive(PrimitiveType::Usize), "usize"),
1121 (TypeRef::Primitive(PrimitiveType::Isize), "isize"),
1122 ];
1123 for (ty, expected) in cases {
1124 assert_eq!(format_type_ref(&ty, &paths), expected, "mismatch for {expected}");
1125 }
1126 }
1127
1128 #[test]
1129 fn test_format_type_ref_string() {
1130 assert_eq!(format_type_ref(&TypeRef::String, &HashMap::new()), "String");
1131 }
1132
1133 #[test]
1134 fn test_format_type_ref_char() {
1135 assert_eq!(format_type_ref(&TypeRef::Char, &HashMap::new()), "char");
1136 }
1137
1138 #[test]
1139 fn test_format_type_ref_bytes() {
1140 assert_eq!(format_type_ref(&TypeRef::Bytes, &HashMap::new()), "Vec<u8>");
1141 }
1142
1143 #[test]
1144 fn test_format_type_ref_path() {
1145 assert_eq!(format_type_ref(&TypeRef::Path, &HashMap::new()), "std::path::PathBuf");
1146 }
1147
1148 #[test]
1149 fn test_format_type_ref_unit() {
1150 assert_eq!(format_type_ref(&TypeRef::Unit, &HashMap::new()), "()");
1151 }
1152
1153 #[test]
1154 fn test_format_type_ref_json() {
1155 assert_eq!(format_type_ref(&TypeRef::Json, &HashMap::new()), "serde_json::Value");
1156 }
1157
1158 #[test]
1159 fn test_format_type_ref_duration() {
1160 assert_eq!(
1161 format_type_ref(&TypeRef::Duration, &HashMap::new()),
1162 "std::time::Duration"
1163 );
1164 }
1165
1166 #[test]
1167 fn test_format_type_ref_optional() {
1168 let ty = TypeRef::Optional(Box::new(TypeRef::String));
1169 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<String>");
1170 }
1171
1172 #[test]
1173 fn test_format_type_ref_optional_nested() {
1174 let ty = TypeRef::Optional(Box::new(TypeRef::Optional(Box::new(TypeRef::Primitive(
1175 PrimitiveType::U32,
1176 )))));
1177 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Option<Option<u32>>");
1178 }
1179
1180 #[test]
1181 fn test_format_type_ref_vec() {
1182 let ty = TypeRef::Vec(Box::new(TypeRef::Primitive(PrimitiveType::U8)));
1183 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<u8>");
1184 }
1185
1186 #[test]
1187 fn test_format_type_ref_vec_nested() {
1188 let ty = TypeRef::Vec(Box::new(TypeRef::Vec(Box::new(TypeRef::String))));
1189 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Vec<Vec<String>>");
1190 }
1191
1192 #[test]
1193 fn test_format_type_ref_map() {
1194 let ty = TypeRef::Map(
1195 Box::new(TypeRef::String),
1196 Box::new(TypeRef::Primitive(PrimitiveType::I64)),
1197 );
1198 assert_eq!(
1199 format_type_ref(&ty, &HashMap::new()),
1200 "std::collections::HashMap<String, i64>"
1201 );
1202 }
1203
1204 #[test]
1205 fn test_format_type_ref_map_nested_value() {
1206 let ty = TypeRef::Map(
1207 Box::new(TypeRef::String),
1208 Box::new(TypeRef::Vec(Box::new(TypeRef::String))),
1209 );
1210 assert_eq!(
1211 format_type_ref(&ty, &HashMap::new()),
1212 "std::collections::HashMap<String, Vec<String>>"
1213 );
1214 }
1215
1216 #[test]
1217 fn test_format_type_ref_named_without_type_paths() {
1218 let ty = TypeRef::Named("Config".to_string());
1219 assert_eq!(format_type_ref(&ty, &HashMap::new()), "Config");
1220 }
1221
1222 #[test]
1223 fn test_format_type_ref_named_with_type_paths() {
1224 let ty = TypeRef::Named("Config".to_string());
1225 let mut paths = HashMap::new();
1226 paths.insert("Config".to_string(), "mylib::Config".to_string());
1227 assert_eq!(format_type_ref(&ty, &paths), "mylib::Config");
1228 }
1229
1230 #[test]
1231 fn test_format_type_ref_named_not_in_type_paths_falls_back_to_name() {
1232 let ty = TypeRef::Named("Unknown".to_string());
1233 let mut paths = HashMap::new();
1234 paths.insert("Other".to_string(), "mylib::Other".to_string());
1235 assert_eq!(format_type_ref(&ty, &paths), "Unknown");
1236 }
1237
1238 #[test]
1243 fn test_format_param_type_string_ref() {
1244 let param = make_param("input", TypeRef::String, true);
1245 assert_eq!(format_param_type(¶m, &HashMap::new()), "&str");
1246 }
1247
1248 #[test]
1249 fn test_format_param_type_string_owned() {
1250 let param = make_param("input", TypeRef::String, false);
1251 assert_eq!(format_param_type(¶m, &HashMap::new()), "String");
1252 }
1253
1254 #[test]
1255 fn test_format_param_type_bytes_ref() {
1256 let param = make_param("data", TypeRef::Bytes, true);
1257 assert_eq!(format_param_type(¶m, &HashMap::new()), "&[u8]");
1258 }
1259
1260 #[test]
1261 fn test_format_param_type_bytes_owned() {
1262 let param = make_param("data", TypeRef::Bytes, false);
1263 assert_eq!(format_param_type(¶m, &HashMap::new()), "Vec<u8>");
1264 }
1265
1266 #[test]
1267 fn test_format_param_type_path_ref() {
1268 let param = make_param("path", TypeRef::Path, true);
1269 assert_eq!(format_param_type(¶m, &HashMap::new()), "&std::path::Path");
1270 }
1271
1272 #[test]
1273 fn test_format_param_type_path_owned() {
1274 let param = make_param("path", TypeRef::Path, false);
1275 assert_eq!(format_param_type(¶m, &HashMap::new()), "std::path::PathBuf");
1276 }
1277
1278 #[test]
1279 fn test_format_param_type_vec_ref() {
1280 let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), true);
1281 assert_eq!(format_param_type(¶m, &HashMap::new()), "&[String]");
1282 }
1283
1284 #[test]
1285 fn test_format_param_type_vec_owned() {
1286 let param = make_param("items", TypeRef::Vec(Box::new(TypeRef::String)), false);
1287 assert_eq!(format_param_type(¶m, &HashMap::new()), "Vec<String>");
1288 }
1289
1290 #[test]
1291 fn test_format_param_type_named_ref_with_type_paths() {
1292 let mut paths = HashMap::new();
1293 paths.insert("Config".to_string(), "mylib::Config".to_string());
1294 let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
1295 assert_eq!(format_param_type(¶m, &paths), "&mylib::Config");
1296 }
1297
1298 #[test]
1299 fn test_format_param_type_named_ref_without_type_paths() {
1300 let param = make_param("cfg", TypeRef::Named("Config".to_string()), true);
1301 assert_eq!(format_param_type(¶m, &HashMap::new()), "&Config");
1302 }
1303
1304 #[test]
1305 fn test_format_param_type_primitive_ref_passes_by_value() {
1306 let param = make_param("count", TypeRef::Primitive(PrimitiveType::U32), true);
1308 assert_eq!(format_param_type(¶m, &HashMap::new()), "u32");
1309 }
1310
1311 #[test]
1312 fn test_format_param_type_unit_ref_passes_by_value() {
1313 let param = make_param("nothing", TypeRef::Unit, true);
1314 assert_eq!(format_param_type(¶m, &HashMap::new()), "()");
1315 }
1316
1317 #[test]
1322 fn test_format_return_type_without_error() {
1323 let result = format_return_type(&TypeRef::String, None, &HashMap::new());
1324 assert_eq!(result, "String");
1325 }
1326
1327 #[test]
1328 fn test_format_return_type_with_error() {
1329 let result = format_return_type(&TypeRef::String, Some("MyError"), &HashMap::new());
1330 assert_eq!(result, "std::result::Result<String, MyError>");
1331 }
1332
1333 #[test]
1334 fn test_format_return_type_unit_with_error() {
1335 let result = format_return_type(&TypeRef::Unit, Some("Box<dyn std::error::Error>"), &HashMap::new());
1336 assert_eq!(result, "std::result::Result<(), Box<dyn std::error::Error>>");
1337 }
1338
1339 #[test]
1340 fn test_format_return_type_named_with_type_paths_and_error() {
1341 let mut paths = HashMap::new();
1342 paths.insert("Output".to_string(), "mylib::Output".to_string());
1343 let result = format_return_type(&TypeRef::Named("Output".to_string()), Some("mylib::MyError"), &paths);
1344 assert_eq!(result, "std::result::Result<mylib::Output, mylib::MyError>");
1345 }
1346
1347 #[test]
1352 fn test_gen_bridge_wrapper_struct_contains_struct_name() {
1353 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1354 let config = make_trait_bridge_config(None, None);
1355 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1356 let generator = MockBridgeGenerator;
1357 let result = gen_bridge_wrapper_struct(&spec, &generator);
1358 assert!(
1359 result.contains("pub struct PyOcrBackendBridge"),
1360 "missing struct declaration in:\n{result}"
1361 );
1362 }
1363
1364 #[test]
1365 fn test_gen_bridge_wrapper_struct_contains_inner_field() {
1366 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1367 let config = make_trait_bridge_config(None, None);
1368 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1369 let generator = MockBridgeGenerator;
1370 let result = gen_bridge_wrapper_struct(&spec, &generator);
1371 assert!(result.contains("inner: Py<PyAny>"), "missing inner field in:\n{result}");
1372 }
1373
1374 #[test]
1375 fn test_gen_bridge_wrapper_struct_contains_cached_name() {
1376 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1377 let config = make_trait_bridge_config(None, None);
1378 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1379 let generator = MockBridgeGenerator;
1380 let result = gen_bridge_wrapper_struct(&spec, &generator);
1381 assert!(
1382 result.contains("cached_name: String"),
1383 "missing cached_name field in:\n{result}"
1384 );
1385 }
1386
1387 #[test]
1392 fn test_gen_bridge_plugin_impl_returns_none_when_no_super_trait() {
1393 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1394 let config = make_trait_bridge_config(None, None);
1395 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1396 let generator = MockBridgeGenerator;
1397 assert!(gen_bridge_plugin_impl(&spec, &generator).is_none());
1398 }
1399
1400 #[test]
1401 fn test_gen_bridge_plugin_impl_returns_some_when_super_trait_configured() {
1402 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1403 let config = make_trait_bridge_config(Some("Plugin"), None);
1404 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1405 let generator = MockBridgeGenerator;
1406 assert!(gen_bridge_plugin_impl(&spec, &generator).is_some());
1407 }
1408
1409 #[test]
1410 fn test_gen_bridge_plugin_impl_uses_qualified_super_trait_path() {
1411 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1412 let config = make_trait_bridge_config(Some("Plugin"), None);
1413 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1414 let generator = MockBridgeGenerator;
1415 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1416 assert!(
1417 result.contains("impl mylib::Plugin for PyOcrBackendBridge"),
1418 "missing qualified super-trait path in:\n{result}"
1419 );
1420 }
1421
1422 #[test]
1423 fn test_gen_bridge_plugin_impl_uses_already_qualified_super_trait_path() {
1424 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1425 let config = make_trait_bridge_config(Some("other_crate::Plugin"), None);
1426 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1427 let generator = MockBridgeGenerator;
1428 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1429 assert!(
1430 result.contains("impl other_crate::Plugin for PyOcrBackendBridge"),
1431 "wrong super-trait path in:\n{result}"
1432 );
1433 }
1434
1435 #[test]
1436 fn test_gen_bridge_plugin_impl_contains_name_fn() {
1437 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1438 let config = make_trait_bridge_config(Some("Plugin"), None);
1439 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1440 let generator = MockBridgeGenerator;
1441 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1442 assert!(
1443 result.contains("fn name(") && result.contains("cached_name"),
1444 "missing name() using cached_name in:\n{result}"
1445 );
1446 }
1447
1448 #[test]
1449 fn test_gen_bridge_plugin_impl_contains_version_fn() {
1450 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1451 let config = make_trait_bridge_config(Some("Plugin"), None);
1452 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1453 let generator = MockBridgeGenerator;
1454 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1455 assert!(result.contains("fn version("), "missing version() in:\n{result}");
1456 }
1457
1458 #[test]
1459 fn test_gen_bridge_plugin_impl_contains_initialize_fn() {
1460 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1461 let config = make_trait_bridge_config(Some("Plugin"), None);
1462 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1463 let generator = MockBridgeGenerator;
1464 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1465 assert!(result.contains("fn initialize("), "missing initialize() in:\n{result}");
1466 }
1467
1468 #[test]
1469 fn test_gen_bridge_plugin_impl_contains_shutdown_fn() {
1470 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1471 let config = make_trait_bridge_config(Some("Plugin"), None);
1472 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1473 let generator = MockBridgeGenerator;
1474 let result = gen_bridge_plugin_impl(&spec, &generator).unwrap();
1475 assert!(result.contains("fn shutdown("), "missing shutdown() in:\n{result}");
1476 }
1477
1478 #[test]
1483 fn test_gen_bridge_trait_impl_includes_impl_header() {
1484 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1485 let config = make_trait_bridge_config(None, None);
1486 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1487 let generator = MockBridgeGenerator;
1488 let result = gen_bridge_trait_impl(&spec, &generator);
1489 assert!(
1490 result.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
1491 "missing impl header in:\n{result}"
1492 );
1493 }
1494
1495 #[test]
1496 fn test_gen_bridge_trait_impl_includes_method_signatures() {
1497 let methods = vec![make_method(
1498 "process",
1499 vec![],
1500 TypeRef::String,
1501 false,
1502 false,
1503 None,
1504 None,
1505 )];
1506 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1507 let config = make_trait_bridge_config(None, None);
1508 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1509 let generator = MockBridgeGenerator;
1510 let result = gen_bridge_trait_impl(&spec, &generator);
1511 assert!(result.contains("fn process("), "missing method signature in:\n{result}");
1512 }
1513
1514 #[test]
1515 fn test_gen_bridge_trait_impl_includes_method_body_from_generator() {
1516 let methods = vec![make_method(
1517 "process",
1518 vec![],
1519 TypeRef::String,
1520 false,
1521 false,
1522 None,
1523 None,
1524 )];
1525 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1526 let config = make_trait_bridge_config(None, None);
1527 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1528 let generator = MockBridgeGenerator;
1529 let result = gen_bridge_trait_impl(&spec, &generator);
1530 assert!(
1531 result.contains("// sync body for process"),
1532 "missing sync method body in:\n{result}"
1533 );
1534 }
1535
1536 #[test]
1537 fn test_gen_bridge_trait_impl_async_method_uses_async_body() {
1538 let methods = vec![make_method(
1539 "process_async",
1540 vec![],
1541 TypeRef::String,
1542 true,
1543 false,
1544 None,
1545 None,
1546 )];
1547 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1548 let config = make_trait_bridge_config(None, None);
1549 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1550 let generator = MockBridgeGenerator;
1551 let result = gen_bridge_trait_impl(&spec, &generator);
1552 assert!(
1553 result.contains("// async body for process_async"),
1554 "missing async method body in:\n{result}"
1555 );
1556 assert!(
1557 result.contains("async fn process_async("),
1558 "missing async keyword in method signature in:\n{result}"
1559 );
1560 }
1561
1562 #[test]
1563 fn test_gen_bridge_trait_impl_filters_trait_source_methods() {
1564 let methods = vec![
1566 make_method("own_method", vec![], TypeRef::String, false, false, None, None),
1567 make_method(
1568 "inherited_method",
1569 vec![],
1570 TypeRef::String,
1571 false,
1572 false,
1573 Some("other_crate::OtherTrait"),
1574 None,
1575 ),
1576 ];
1577 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1578 let config = make_trait_bridge_config(None, None);
1579 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1580 let generator = MockBridgeGenerator;
1581 let result = gen_bridge_trait_impl(&spec, &generator);
1582 assert!(
1583 result.contains("fn own_method("),
1584 "own method should be present in:\n{result}"
1585 );
1586 assert!(
1587 !result.contains("fn inherited_method("),
1588 "inherited method should be filtered out in:\n{result}"
1589 );
1590 }
1591
1592 #[test]
1593 fn test_gen_bridge_trait_impl_method_with_params() {
1594 let params = vec![
1595 make_param("input", TypeRef::String, true),
1596 make_param("count", TypeRef::Primitive(PrimitiveType::U32), false),
1597 ];
1598 let methods = vec![make_method(
1599 "process",
1600 params,
1601 TypeRef::String,
1602 false,
1603 false,
1604 None,
1605 None,
1606 )];
1607 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1608 let config = make_trait_bridge_config(None, None);
1609 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1610 let generator = MockBridgeGenerator;
1611 let result = gen_bridge_trait_impl(&spec, &generator);
1612 assert!(result.contains("input: &str"), "missing &str param in:\n{result}");
1613 assert!(result.contains("count: u32"), "missing u32 param in:\n{result}");
1614 }
1615
1616 #[test]
1617 fn test_gen_bridge_trait_impl_return_type_with_error() {
1618 let methods = vec![make_method(
1619 "process",
1620 vec![],
1621 TypeRef::String,
1622 false,
1623 false,
1624 None,
1625 Some("MyError"),
1626 )];
1627 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1628 let config = make_trait_bridge_config(None, None);
1629 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1630 let generator = MockBridgeGenerator;
1631 let result = gen_bridge_trait_impl(&spec, &generator);
1632 assert!(
1633 result.contains("-> std::result::Result<String, mylib::MyError>"),
1634 "missing std::result::Result return type in:\n{result}"
1635 );
1636 }
1637
1638 #[test]
1643 fn test_gen_bridge_registration_fn_returns_none_without_register_fn() {
1644 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1645 let config = make_trait_bridge_config(None, None);
1646 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1647 let generator = MockBridgeGenerator;
1648 assert!(gen_bridge_registration_fn(&spec, &generator).is_none());
1649 }
1650
1651 #[test]
1652 fn test_gen_bridge_registration_fn_returns_some_with_register_fn() {
1653 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1654 let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
1655 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1656 let generator = MockBridgeGenerator;
1657 let result = gen_bridge_registration_fn(&spec, &generator);
1658 assert!(result.is_some());
1659 let code = result.unwrap();
1660 assert!(
1661 code.contains("register_ocr_backend"),
1662 "missing register fn name in:\n{code}"
1663 );
1664 }
1665
1666 #[test]
1671 fn test_gen_bridge_all_includes_imports() {
1672 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1673 let config = make_trait_bridge_config(None, None);
1674 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1675 let generator = MockBridgeGenerator;
1676 let output = gen_bridge_all(&spec, &generator);
1677 assert!(output.imports.contains(&"pyo3::prelude::*".to_string()));
1678 assert!(output.imports.contains(&"pyo3::types::PyString".to_string()));
1679 }
1680
1681 #[test]
1682 fn test_gen_bridge_all_includes_wrapper_struct() {
1683 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1684 let config = make_trait_bridge_config(None, None);
1685 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1686 let generator = MockBridgeGenerator;
1687 let output = gen_bridge_all(&spec, &generator);
1688 assert!(
1689 output.code.contains("pub struct PyOcrBackendBridge"),
1690 "missing struct in:\n{}",
1691 output.code
1692 );
1693 }
1694
1695 #[test]
1696 fn test_gen_bridge_all_includes_constructor() {
1697 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1698 let config = make_trait_bridge_config(None, None);
1699 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1700 let generator = MockBridgeGenerator;
1701 let output = gen_bridge_all(&spec, &generator);
1702 assert!(
1703 output.code.contains("pub fn new("),
1704 "missing constructor in:\n{}",
1705 output.code
1706 );
1707 }
1708
1709 #[test]
1710 fn test_gen_bridge_all_includes_trait_impl() {
1711 let methods = vec![make_method(
1712 "process",
1713 vec![],
1714 TypeRef::String,
1715 false,
1716 false,
1717 None,
1718 None,
1719 )];
1720 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", methods);
1721 let config = make_trait_bridge_config(None, None);
1722 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1723 let generator = MockBridgeGenerator;
1724 let output = gen_bridge_all(&spec, &generator);
1725 assert!(
1726 output.code.contains("impl mylib::OcrBackend for PyOcrBackendBridge"),
1727 "missing trait impl in:\n{}",
1728 output.code
1729 );
1730 }
1731
1732 #[test]
1733 fn test_gen_bridge_all_includes_plugin_impl_when_super_trait_set() {
1734 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1735 let config = make_trait_bridge_config(Some("Plugin"), None);
1736 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1737 let generator = MockBridgeGenerator;
1738 let output = gen_bridge_all(&spec, &generator);
1739 assert!(
1740 output.code.contains("impl mylib::Plugin for PyOcrBackendBridge"),
1741 "missing plugin impl in:\n{}",
1742 output.code
1743 );
1744 }
1745
1746 #[test]
1747 fn test_gen_bridge_all_no_plugin_impl_when_no_super_trait() {
1748 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1749 let config = make_trait_bridge_config(None, None);
1750 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1751 let generator = MockBridgeGenerator;
1752 let output = gen_bridge_all(&spec, &generator);
1753 assert!(
1754 !output.code.contains("fn name(") || !output.code.contains("cached_name"),
1755 "unexpected plugin impl present without super_trait"
1756 );
1757 }
1758
1759 #[test]
1760 fn test_gen_bridge_all_includes_registration_fn_when_configured() {
1761 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1762 let config = make_trait_bridge_config(None, Some("register_ocr_backend"));
1763 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1764 let generator = MockBridgeGenerator;
1765 let output = gen_bridge_all(&spec, &generator);
1766 assert!(
1767 output.code.contains("register_ocr_backend"),
1768 "missing registration fn in:\n{}",
1769 output.code
1770 );
1771 }
1772
1773 #[test]
1774 fn test_gen_bridge_all_no_registration_fn_when_absent() {
1775 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1776 let config = make_trait_bridge_config(None, None);
1777 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1778 let generator = MockBridgeGenerator;
1779 let output = gen_bridge_all(&spec, &generator);
1780 assert!(
1781 !output.code.contains("register_ocr_backend"),
1782 "unexpected registration fn present:\n{}",
1783 output.code
1784 );
1785 }
1786
1787 #[test]
1788 fn test_gen_bridge_all_ordering_struct_before_trait_impl() {
1789 let trait_def = make_type_def("OcrBackend", "mylib::OcrBackend", vec![]);
1790 let config = make_trait_bridge_config(None, None);
1791 let spec = make_spec(&trait_def, &config, "Py", HashMap::new());
1792 let generator = MockBridgeGenerator;
1793 let output = gen_bridge_all(&spec, &generator);
1794 let struct_pos = output.code.find("pub struct PyOcrBackendBridge").unwrap();
1795 let impl_pos = output
1796 .code
1797 .find("impl mylib::OcrBackend for PyOcrBackendBridge")
1798 .unwrap();
1799 assert!(struct_pos < impl_pos, "struct should appear before trait impl");
1800 }
1801
1802 fn make_bridge(
1807 type_alias: Option<&str>,
1808 param_name: Option<&str>,
1809 bind_via: BridgeBinding,
1810 options_type: Option<&str>,
1811 options_field: Option<&str>,
1812 ) -> TraitBridgeConfig {
1813 TraitBridgeConfig {
1814 trait_name: "HtmlVisitor".to_string(),
1815 super_trait: None,
1816 registry_getter: None,
1817 register_fn: None,
1818 unregister_fn: None,
1819 clear_fn: None,
1820 type_alias: type_alias.map(str::to_string),
1821 param_name: param_name.map(str::to_string),
1822 register_extra_args: None,
1823 exclude_languages: vec![],
1824 bind_via,
1825 options_type: options_type.map(str::to_string),
1826 options_field: options_field.map(str::to_string),
1827 }
1828 }
1829
1830 #[test]
1831 fn find_bridge_param_returns_first_param_match_in_function_param_mode() {
1832 let func = make_func(
1833 "convert",
1834 vec![
1835 make_param("html", TypeRef::String, true),
1836 make_param("visitor", TypeRef::Named("VisitorHandle".to_string()), false),
1837 ],
1838 );
1839 let bridges = vec![make_bridge(
1840 Some("VisitorHandle"),
1841 Some("visitor"),
1842 BridgeBinding::FunctionParam,
1843 None,
1844 None,
1845 )];
1846 let result = find_bridge_param(&func, &bridges).expect("bridge match");
1847 assert_eq!(result.0, 1);
1848 }
1849
1850 #[test]
1851 fn find_bridge_param_skips_options_field_bridges() {
1852 let func = make_func(
1853 "convert",
1854 vec![
1855 make_param("html", TypeRef::String, true),
1856 make_param("visitor", TypeRef::Named("VisitorHandle".to_string()), false),
1857 ],
1858 );
1859 let bridges = vec![make_bridge(
1860 Some("VisitorHandle"),
1861 Some("visitor"),
1862 BridgeBinding::OptionsField,
1863 Some("ConversionOptions"),
1864 Some("visitor"),
1865 )];
1866 assert!(
1867 find_bridge_param(&func, &bridges).is_none(),
1868 "bridges configured with bind_via=options_field must not be returned by find_bridge_param"
1869 );
1870 }
1871
1872 #[test]
1873 fn find_bridge_field_detects_field_via_alias() {
1874 let opts_type = TypeDef {
1875 name: "ConversionOptions".to_string(),
1876 rust_path: "mylib::ConversionOptions".to_string(),
1877 original_rust_path: String::new(),
1878 fields: vec![
1879 make_field("debug", TypeRef::Primitive(PrimitiveType::Bool)),
1880 make_field(
1881 "visitor",
1882 TypeRef::Optional(Box::new(TypeRef::Named("VisitorHandle".to_string()))),
1883 ),
1884 ],
1885 methods: vec![],
1886 is_opaque: false,
1887 is_clone: true,
1888 is_copy: false,
1889 doc: String::new(),
1890 cfg: None,
1891 is_trait: false,
1892 has_default: true,
1893 has_stripped_cfg_fields: false,
1894 is_return_type: false,
1895 serde_rename_all: None,
1896 has_serde: false,
1897 super_traits: vec![],
1898 };
1899 let func = make_func(
1900 "convert",
1901 vec![
1902 make_param("html", TypeRef::String, true),
1903 make_param(
1904 "options",
1905 TypeRef::Optional(Box::new(TypeRef::Named("ConversionOptions".to_string()))),
1906 false,
1907 ),
1908 ],
1909 );
1910 let bridges = vec![make_bridge(
1911 Some("VisitorHandle"),
1912 Some("visitor"),
1913 BridgeBinding::OptionsField,
1914 Some("ConversionOptions"),
1915 None,
1916 )];
1917 let m = find_bridge_field(&func, std::slice::from_ref(&opts_type), &bridges).expect("bridge field match");
1918 assert_eq!(m.param_index, 1);
1919 assert_eq!(m.param_name, "options");
1920 assert_eq!(m.options_type, "ConversionOptions");
1921 assert!(m.param_is_optional);
1922 assert_eq!(m.field_name, "visitor");
1923 }
1924
1925 #[test]
1926 fn find_bridge_field_returns_none_for_function_param_bridge() {
1927 let opts_type = TypeDef {
1928 name: "ConversionOptions".to_string(),
1929 rust_path: "mylib::ConversionOptions".to_string(),
1930 original_rust_path: String::new(),
1931 fields: vec![make_field(
1932 "visitor",
1933 TypeRef::Optional(Box::new(TypeRef::Named("VisitorHandle".to_string()))),
1934 )],
1935 methods: vec![],
1936 is_opaque: false,
1937 is_clone: true,
1938 is_copy: false,
1939 doc: String::new(),
1940 cfg: None,
1941 is_trait: false,
1942 has_default: true,
1943 has_stripped_cfg_fields: false,
1944 is_return_type: false,
1945 serde_rename_all: None,
1946 has_serde: false,
1947 super_traits: vec![],
1948 };
1949 let func = make_func(
1950 "convert",
1951 vec![make_param(
1952 "options",
1953 TypeRef::Named("ConversionOptions".to_string()),
1954 false,
1955 )],
1956 );
1957 let bridges = vec![make_bridge(
1958 Some("VisitorHandle"),
1959 Some("visitor"),
1960 BridgeBinding::FunctionParam,
1961 None,
1962 None,
1963 )];
1964 assert!(find_bridge_field(&func, std::slice::from_ref(&opts_type), &bridges).is_none());
1965 }
1966}