1use proc_macro::TokenStream;
46use proc_macro2::TokenTree;
47use quote::{format_ident, quote};
48use std::collections::hash_map::DefaultHasher;
49use std::fmt::Write as FmtWrite;
50use std::hash::{Hash, Hasher};
51use std::str::FromStr;
52
53#[derive(Clone, Copy, PartialEq)]
58enum ScalarEnc {
59 Byte,
60 Short,
61 Int,
62 Long,
63 Float,
64 Double,
65 Boolean,
66 Char,
67 Str,
68}
69
70impl ScalarEnc {
71 fn dos_write_method(self) -> Option<&'static str> {
72 match self {
73 Self::Byte => Some("writeByte"),
74 Self::Short => Some("writeShort"),
75 Self::Int => Some("writeInt"),
76 Self::Long => Some("writeLong"),
77 Self::Float => Some("writeFloat"),
78 Self::Double => Some("writeDouble"),
79 Self::Boolean => Some("writeBoolean"),
80 Self::Char => Some("writeChar"),
81 Self::Str => None,
82 }
83 }
84}
85
86#[derive(Clone, Copy, PartialEq)]
91enum PrimitiveType {
92 Byte,
93 Short,
94 Int,
95 Long,
96 Float,
97 Double,
98 Boolean,
99 Char,
100}
101
102impl PrimitiveType {
103 fn from_name(s: &str) -> Option<Self> {
104 match s {
105 "byte" => Some(Self::Byte),
106 "short" => Some(Self::Short),
107 "int" => Some(Self::Int),
108 "long" => Some(Self::Long),
109 "float" => Some(Self::Float),
110 "double" => Some(Self::Double),
111 "boolean" => Some(Self::Boolean),
112 "char" => Some(Self::Char),
113 _ => None,
114 }
115 }
116
117 fn java_name(self) -> &'static str {
118 match self {
119 Self::Byte => "byte",
120 Self::Short => "short",
121 Self::Int => "int",
122 Self::Long => "long",
123 Self::Float => "float",
124 Self::Double => "double",
125 Self::Boolean => "boolean",
126 Self::Char => "char",
127 }
128 }
129
130 fn enc(self) -> ScalarEnc {
131 match self {
132 Self::Byte => ScalarEnc::Byte,
133 Self::Short => ScalarEnc::Short,
134 Self::Int => ScalarEnc::Int,
135 Self::Long => ScalarEnc::Long,
136 Self::Float => ScalarEnc::Float,
137 Self::Double => ScalarEnc::Double,
138 Self::Boolean => ScalarEnc::Boolean,
139 Self::Char => ScalarEnc::Char,
140 }
141 }
142
143 fn dos_write_method(self) -> &'static str {
144 self.enc().dos_write_method().unwrap()
145 }
146
147 fn rust_type_ts(self) -> proc_macro2::TokenStream {
148 match self {
149 Self::Byte => quote! { i8 },
150 Self::Short => quote! { i16 },
151 Self::Int => quote! { i32 },
152 Self::Long => quote! { i64 },
153 Self::Float => quote! { f32 },
154 Self::Double => quote! { f64 },
155 Self::Boolean => quote! { bool },
156 Self::Char => quote! { char },
157 }
158 }
159
160 fn java_dis_read(self, param_name: &str) -> String {
161 match self {
162 Self::Byte => format!("byte {param_name} = _dis.readByte();"),
163 Self::Short => format!("short {param_name} = _dis.readShort();"),
164 Self::Int => format!("int {param_name} = _dis.readInt();"),
165 Self::Long => format!("long {param_name} = _dis.readLong();"),
166 Self::Float => format!("float {param_name} = _dis.readFloat();"),
167 Self::Double => format!("double {param_name} = _dis.readDouble();"),
168 Self::Boolean => format!("boolean {param_name} = _dis.readBoolean();"),
169 Self::Char => format!("char {param_name} = _dis.readChar();"),
170 }
171 }
172}
173
174#[derive(Clone, Copy, PartialEq)]
179enum BoxedType {
180 Byte,
181 Short,
182 Integer,
183 Long,
184 Float,
185 Double,
186 Boolean,
187 Character,
188 String,
189}
190
191impl BoxedType {
192 fn from_name(s: &str) -> Option<Self> {
194 match s {
195 "Byte" | "byte" => Some(Self::Byte),
196 "Short" | "short" => Some(Self::Short),
197 "Integer" | "int" => Some(Self::Integer),
198 "Long" | "long" => Some(Self::Long),
199 "Float" | "float" => Some(Self::Float),
200 "Double" | "double" => Some(Self::Double),
201 "Boolean" | "boolean" => Some(Self::Boolean),
202 "Character" | "char" => Some(Self::Character),
203 "String" => Some(Self::String),
204 _ => None,
205 }
206 }
207
208 fn java_name(self) -> &'static str {
209 match self {
210 Self::Byte => "Byte",
211 Self::Short => "Short",
212 Self::Integer => "Integer",
213 Self::Long => "Long",
214 Self::Float => "Float",
215 Self::Double => "Double",
216 Self::Boolean => "Boolean",
217 Self::Character => "Character",
218 Self::String => "String",
219 }
220 }
221
222 fn enc(self) -> ScalarEnc {
223 match self {
224 Self::Byte => ScalarEnc::Byte,
225 Self::Short => ScalarEnc::Short,
226 Self::Integer => ScalarEnc::Int,
227 Self::Long => ScalarEnc::Long,
228 Self::Float => ScalarEnc::Float,
229 Self::Double => ScalarEnc::Double,
230 Self::Boolean => ScalarEnc::Boolean,
231 Self::Character => ScalarEnc::Char,
232 Self::String => ScalarEnc::Str,
233 }
234 }
235
236 fn dos_write_method(self) -> Option<&'static str> {
237 self.enc().dos_write_method()
238 }
239
240 fn rust_type_ts(self) -> proc_macro2::TokenStream {
241 match self {
242 Self::Byte => quote! { i8 },
243 Self::Short => quote! { i16 },
244 Self::Integer => quote! { i32 },
245 Self::Long => quote! { i64 },
246 Self::Float => quote! { f32 },
247 Self::Double => quote! { f64 },
248 Self::Boolean => quote! { bool },
249 Self::Character => quote! { char },
250 Self::String => quote! { ::std::string::String },
251 }
252 }
253
254 fn java_dis_read(self, param_name: &str) -> String {
256 match self {
257 Self::Byte => format!("byte {param_name} = _dis.readByte();"),
258 Self::Short => format!("short {param_name} = _dis.readShort();"),
259 Self::Integer => format!("int {param_name} = _dis.readInt();"),
260 Self::Long => format!("long {param_name} = _dis.readLong();"),
261 Self::Float => format!("float {param_name} = _dis.readFloat();"),
262 Self::Double => format!("double {param_name} = _dis.readDouble();"),
263 Self::Boolean => format!("boolean {param_name} = _dis.readBoolean();"),
264 Self::Character => format!("char {param_name} = _dis.readChar();"),
265 Self::String => format!(
266 "int _len_{param_name} = _dis.readInt();\n\
267 \t\tbyte[] _b_{param_name} = new byte[_len_{param_name}];\n\
268 \t\t_dis.readFully(_b_{param_name});\n\
269 \t\tString {param_name} = new String(_b_{param_name}, java.nio.charset.StandardCharsets.UTF_8);"
270 ),
271 }
272 }
273}
274
275#[derive(Clone, PartialEq)]
279enum JavaType {
280 Primitive(PrimitiveType),
282 Boxed(BoxedType),
284 Array(Box<JavaType>),
286 List(Box<JavaType>),
288 Optional(Box<JavaType>),
290}
291
292impl JavaType {
293 fn java_type_name(&self) -> String {
298 match self {
299 Self::Primitive(p) => p.java_name().to_string(),
300 Self::Boxed(b) => b.java_name().to_string(),
301 Self::Array(inner) => format!("{}[]", inner.java_type_name()),
302 Self::List(inner) => format!("java.util.List<{}>", inner.java_type_name()),
303 Self::Optional(inner) => format!("java.util.Optional<{}>", inner.java_type_name()),
304 }
305 }
306
307 fn rust_return_type_ts(&self) -> proc_macro2::TokenStream {
309 match self {
310 Self::Primitive(p) => p.rust_type_ts(),
311 Self::Boxed(b) => b.rust_type_ts(),
312 Self::Array(inner) | Self::List(inner) => {
313 let inner_ts = inner.rust_return_type_ts();
314 quote! { ::std::vec::Vec<#inner_ts> }
315 }
316 Self::Optional(inner) => {
317 let inner_ts = inner.rust_return_type_ts();
318 quote! { ::std::option::Option<#inner_ts> }
319 }
320 }
321 }
322
323 fn rust_param_type_ts(&self) -> proc_macro2::TokenStream {
326 match self {
327 Self::Primitive(p) => p.rust_type_ts(),
328 Self::Boxed(BoxedType::String) => quote! { &str },
329 Self::Boxed(b) => b.rust_type_ts(),
330 Self::Array(inner) | Self::List(inner) => {
331 let inner_ts = inner.rust_param_type_ts();
332 quote! { &[#inner_ts] }
333 }
334 Self::Optional(inner) => {
335 let inner_ts = inner.rust_param_type_ts();
336 quote! { ::std::option::Option<#inner_ts> }
337 }
338 }
339 }
340
341 fn rust_ser_ts(
345 &self,
346 param_ident: &proc_macro2::TokenStream,
347 depth: usize,
348 ) -> proc_macro2::TokenStream {
349 match self {
350 Self::Primitive(p) => scalar_enc_ser_ts(p.enc(), param_ident),
351 Self::Boxed(b) => scalar_enc_ser_ts(b.enc(), param_ident),
352 Self::Array(inner) | Self::List(inner) => {
353 let item_var = format_ident!("_item{}", depth);
354 let item_expr = quote! { #item_var };
355 let inner_ser = inner.rust_ser_ts(&item_expr, depth + 1);
356 quote! {
357 {
358 _stdin_bytes.extend_from_slice(&(#param_ident.len() as i32).to_be_bytes());
359 for &#item_var in #param_ident {
360 #inner_ser
361 }
362 }
363 }
364 }
365 Self::Optional(inner) => {
366 let inner_var = format_ident!("_inner{}", depth);
367 let inner_expr = quote! { #inner_var };
368 let inner_ser = inner.rust_ser_ts(&inner_expr, depth + 1);
369 quote! {
370 match #param_ident {
371 ::std::option::Option::None => _stdin_bytes.push(0u8),
372 ::std::option::Option::Some(#inner_var) => {
373 _stdin_bytes.push(1u8);
374 #inner_ser
375 }
376 }
377 }
378 }
379 }
380 }
381
382 fn java_dis_read(&self, param_name: &str, depth: usize) -> String {
386 match self {
387 Self::Primitive(p) => p.java_dis_read(param_name),
388 Self::Boxed(b) => b.java_dis_read(param_name),
389 Self::Array(inner) => {
390 let count_var = format!("_count_{param_name}_{depth}");
391 let i_var = format!("_i_{param_name}_{depth}");
392 let elem_var = format!("_elem_{param_name}_{depth}");
393 let inner_java_type = inner.java_type_name();
394 let new_expr = inner.java_new_outer_array(&count_var);
395 let inner_read = inner.java_dis_read_for_elem(&elem_var, depth + 1);
396 format!(
397 "int {count_var} = _dis.readInt();\n\
398 \t\t{inner_java_type}[] {param_name} = {new_expr};\n\
399 \t\tfor (int {i_var} = 0; {i_var} < {count_var}; {i_var}++) {{\n\
400 \t\t\t{inner_read}\n\
401 \t\t\t{param_name}[{i_var}] = {elem_var};\n\
402 \t\t}}"
403 )
404 }
405 Self::List(inner) => {
406 let count_var = format!("_count_{param_name}_{depth}");
407 let i_var = format!("_i_{param_name}_{depth}");
408 let elem_var = format!("_elem_{param_name}_{depth}");
409 let inner_java_type = inner.java_type_name();
410 let inner_read = inner.java_dis_read_for_elem(&elem_var, depth + 1);
411 format!(
412 "int {count_var} = _dis.readInt();\n\
413 \t\tjava.util.List<{inner_java_type}> {param_name} = new java.util.ArrayList<>();\n\
414 \t\tfor (int {i_var} = 0; {i_var} < {count_var}; {i_var}++) {{\n\
415 \t\t\t{inner_read}\n\
416 \t\t\t{param_name}.add(({inner_java_type}) {elem_var});\n\
417 \t\t}}"
418 )
419 }
420 Self::Optional(inner) => {
421 let tag_var = format!("_tag_{param_name}_{depth}");
422 let inner_var = format!("_inner_{param_name}_{depth}");
423 let inner_java = inner.java_type_name();
424 let inner_read = inner.java_dis_read_for_elem(&inner_var, depth + 1);
425 format!(
426 "int {tag_var} = _dis.readUnsignedByte();\n\
427 \t\tjava.util.Optional<{inner_java}> {param_name};\n\
428 \t\tif ({tag_var} != 0) {{\n\
429 \t\t\t{inner_read}\n\
430 \t\t\t{param_name} = java.util.Optional.of({inner_var});\n\
431 \t\t}} else {{\n\
432 \t\t\t{param_name} = java.util.Optional.empty();\n\
433 \t\t}}"
434 )
435 }
436 }
437 }
438
439 fn java_new_outer_array(&self, count_var: &str) -> String {
445 let mut ty = self;
446 let mut extra_dims = 0usize;
447 while let JavaType::Array(inner) = ty {
448 extra_dims += 1;
449 ty = inner;
450 }
451 let base_name = ty.java_type_name(); let trailing = "[]".repeat(extra_dims);
453 format!("new {base_name}[{count_var}]{trailing}")
454 }
455
456 fn java_dis_read_for_elem(&self, elem_name: &str, depth: usize) -> String {
458 match self {
459 Self::Primitive(p) => p.java_dis_read(elem_name),
460 Self::Boxed(b) => b.java_dis_read(elem_name),
461 _ => self.java_dis_read(elem_name, depth),
462 }
463 }
464
465 #[allow(clippy::too_many_lines)]
470 fn java_main(&self, params: &[(JavaType, String)]) -> String {
471 let param_reads = if params.is_empty() {
473 String::new()
474 } else {
475 let mut s = String::from(
476 "\t\tjava.io.DataInputStream _dis = new java.io.DataInputStream(System.in);\n",
477 );
478 for (ty, name) in params {
479 writeln!(s, "\t\t{}", ty.java_dis_read(name, 0)).unwrap();
480 }
481 s
482 };
483
484 let run_args: String = params
486 .iter()
487 .map(|(_, name)| name.as_str())
488 .collect::<Vec<_>>()
489 .join(", ");
490
491 match self {
492 Self::Primitive(p) => {
493 let method = p.dos_write_method();
494 let serialize = format!(
495 "java.io.DataOutputStream _dos = \
496 new java.io.DataOutputStream(System.out);\n\
497 \t\t_dos.{method}(run({run_args}));\n\
498 \t\t_dos.flush();"
499 );
500 format!(
501 "\tpublic static void main(String[] args) throws Exception {{\n\
502 {param_reads}\t\t{serialize}\n\
503 \t}}"
504 )
505 }
506 Self::Boxed(BoxedType::String) => {
507 let serialize = format!(
508 "byte[] _b = run({run_args}).getBytes(java.nio.charset.StandardCharsets.UTF_8);\n\
509 \t\tSystem.out.write(_b);\n\
510 \t\tSystem.out.flush();"
511 );
512 format!(
513 "\tpublic static void main(String[] args) throws Exception {{\n\
514 {param_reads}\t\t{serialize}\n\
515 \t}}"
516 )
517 }
518 Self::Boxed(b) => {
519 let method = b.dos_write_method().unwrap();
520 let serialize = format!(
521 "java.io.DataOutputStream _dos = \
522 new java.io.DataOutputStream(System.out);\n\
523 \t\t_dos.{method}(run({run_args}));\n\
524 \t\t_dos.flush();"
525 );
526 format!(
527 "\tpublic static void main(String[] args) throws Exception {{\n\
528 {param_reads}\t\t{serialize}\n\
529 \t}}"
530 )
531 }
532 Self::Array(inner) => {
533 let elem_java_type = inner.java_type_name();
534 let ser_body = java_ser_element(inner, "_e0", "_dos", 1);
535 format!(
536 "\tpublic static void main(String[] args) throws Exception {{\n\
537 {param_reads}\t\t{elem_java_type}[] _arr = run({run_args});\n\
538 \t\tjava.io.DataOutputStream _dos = new java.io.DataOutputStream(System.out);\n\
539 \t\t_dos.writeInt(_arr.length);\n\
540 \t\tfor ({elem_java_type} _e0 : _arr) {{\n\
541 \t\t\t{ser_body}\n\
542 \t\t}}\n\
543 \t\t_dos.flush();\n\
544 \t}}"
545 )
546 }
547 Self::List(inner) => {
548 let elem_java_type = inner.java_type_name();
549 let ser_body = java_ser_element(inner, "_e0", "_dos", 1);
550 format!(
551 "\tpublic static void main(String[] args) throws Exception {{\n\
552 {param_reads}\t\tjava.util.List<{elem_java_type}> _arr = run({run_args});\n\
553 \t\tjava.io.DataOutputStream _dos = new java.io.DataOutputStream(System.out);\n\
554 \t\t_dos.writeInt(_arr.size());\n\
555 \t\tfor ({elem_java_type} _e0 : _arr) {{\n\
556 \t\t\t{ser_body}\n\
557 \t\t}}\n\
558 \t\t_dos.flush();\n\
559 \t}}"
560 )
561 }
562 Self::Optional(inner) => {
563 let inner_java_type = inner.java_type_name();
564 let present_body = java_ser_element(inner, "_opt.get()", "_dos", 1);
565 format!(
566 "\tpublic static void main(String[] args) throws Exception {{\n\
567 {param_reads}\t\tjava.util.Optional<{inner_java_type}> _opt = run({run_args});\n\
568 \t\tjava.io.DataOutputStream _dos = new java.io.DataOutputStream(System.out);\n\
569 \t\tif (_opt.isPresent()) {{\n\
570 \t\t\t_dos.writeByte(1);\n\
571 \t\t\t{present_body}\n\
572 \t\t}} else {{\n\
573 \t\t\t_dos.writeByte(0);\n\
574 \t\t}}\n\
575 \t\t_dos.flush();\n\
576 \t}}"
577 )
578 }
579 }
580 }
581
582 fn rust_deser(&self) -> proc_macro2::TokenStream {
586 match self {
587 Self::Primitive(p) => scalar_enc_top_deser(p.enc()),
588 Self::Boxed(BoxedType::String) => {
589 quote! { ::std::string::String::from_utf8(_raw)? }
591 }
592 Self::Boxed(b) => scalar_enc_top_deser(b.enc()),
593 _ => {
594 let rust_type = self.rust_return_type_ts();
596 let read_expr = rust_read_element(self, 0);
597 quote! {
598 {
599 let mut _cur = 0usize;
600 let _result: #rust_type = #read_expr;
601 _result
602 }
603 }
604 }
605 }
606 }
607
608 fn ct_java_tokens(&self, bytes: Vec<u8>) -> Result<proc_macro2::TokenStream, String> {
611 match self {
612 Self::Primitive(p) => {
613 let (lit, _) = scalar_enc_ct_lit(p.enc(), &bytes)?;
614 proc_macro2::TokenStream::from_str(&lit)
615 .map_err(|e| format!("ct_java: produced invalid Rust token: {e}"))
616 }
617 Self::Boxed(BoxedType::String) => {
618 let s = String::from_utf8(bytes)
620 .map_err(|_| "ct_java: Java String is not valid UTF-8".to_string())?;
621 let lit = format!("{s:?}");
622 proc_macro2::TokenStream::from_str(&lit)
623 .map_err(|e| format!("ct_java: produced invalid Rust token: {e}"))
624 }
625 Self::Boxed(b) => {
626 let (lit, _) = scalar_enc_ct_lit(b.enc(), &bytes)?;
627 proc_macro2::TokenStream::from_str(&lit)
628 .map_err(|e| format!("ct_java: produced invalid Rust token: {e}"))
629 }
630 _ => {
631 let mut cur = 0usize;
632 let ts = ct_java_tokens_recursive(self, &bytes, &mut cur)?;
633 Ok(ts)
634 }
635 }
636 }
637}
638
639fn scalar_enc_ser_ts(
643 enc: ScalarEnc,
644 param_ident: &proc_macro2::TokenStream,
645) -> proc_macro2::TokenStream {
646 match enc {
647 ScalarEnc::Byte => quote! {
648 _stdin_bytes.extend_from_slice(&(#param_ident as i8).to_be_bytes());
649 },
650 ScalarEnc::Short => quote! {
651 _stdin_bytes.extend_from_slice(&(#param_ident as i16).to_be_bytes());
652 },
653 ScalarEnc::Int => quote! {
654 _stdin_bytes.extend_from_slice(&(#param_ident as i32).to_be_bytes());
655 },
656 ScalarEnc::Long => quote! {
657 _stdin_bytes.extend_from_slice(&(#param_ident as i64).to_be_bytes());
658 },
659 ScalarEnc::Float => quote! {
660 _stdin_bytes.extend_from_slice(&(#param_ident as f32).to_bits().to_be_bytes());
661 },
662 ScalarEnc::Double => quote! {
663 _stdin_bytes.extend_from_slice(&(#param_ident as f64).to_bits().to_be_bytes());
664 },
665 ScalarEnc::Boolean => quote! {
666 _stdin_bytes.push(#param_ident as u8);
667 },
668 ScalarEnc::Char => quote! {
669 {
670 let _c = #param_ident as u32;
671 assert!(_c <= 0xFFFF, "inline_java: char value exceeds u16 range");
672 _stdin_bytes.extend_from_slice(&(_c as u16).to_be_bytes());
673 }
674 },
675 ScalarEnc::Str => quote! {
676 {
677 let _b = #param_ident.as_bytes();
678 let _len = _b.len() as i32;
679 _stdin_bytes.extend_from_slice(&_len.to_be_bytes());
680 _stdin_bytes.extend_from_slice(_b);
681 }
682 },
683 }
684}
685
686fn scalar_enc_top_deser(enc: ScalarEnc) -> proc_macro2::TokenStream {
688 match enc {
689 ScalarEnc::Byte => quote! { i8::from_be_bytes([_raw[0]]) },
690 ScalarEnc::Short => quote! { i16::from_be_bytes([_raw[0], _raw[1]]) },
691 ScalarEnc::Int => {
692 quote! { i32::from_be_bytes([_raw[0], _raw[1], _raw[2], _raw[3]]) }
693 }
694 ScalarEnc::Long => {
695 quote! {
696 i64::from_be_bytes([
697 _raw[0], _raw[1], _raw[2], _raw[3],
698 _raw[4], _raw[5], _raw[6], _raw[7],
699 ])
700 }
701 }
702 ScalarEnc::Float => {
703 quote! { f32::from_bits(u32::from_be_bytes([_raw[0], _raw[1], _raw[2], _raw[3]])) }
704 }
705 ScalarEnc::Double => {
706 quote! {
707 f64::from_bits(u64::from_be_bytes([
708 _raw[0], _raw[1], _raw[2], _raw[3],
709 _raw[4], _raw[5], _raw[6], _raw[7],
710 ]))
711 }
712 }
713 ScalarEnc::Boolean => quote! { _raw[0] != 0 },
714 ScalarEnc::Char => {
715 quote! {
716 ::std::char::from_u32(u16::from_be_bytes([_raw[0], _raw[1]]) as u32)
717 .ok_or(::inline_java::JavaError::InvalidChar)?
718 }
719 }
720 ScalarEnc::Str => {
721 quote! { ::std::string::String::from_utf8(_raw)? }
723 }
724 }
725}
726
727fn scalar_enc_ct_lit(enc: ScalarEnc, bytes: &[u8]) -> Result<(String, usize), String> {
732 match enc {
733 ScalarEnc::Byte => {
734 if bytes.is_empty() {
735 return Err("ct_java: truncated byte element".to_string());
736 }
737 Ok((format!("{}", i8::from_be_bytes([bytes[0]])), 1))
738 }
739 ScalarEnc::Short => {
740 if bytes.len() < 2 {
741 return Err("ct_java: truncated short element".to_string());
742 }
743 Ok((format!("{}", i16::from_be_bytes([bytes[0], bytes[1]])), 2))
744 }
745 ScalarEnc::Int => {
746 let arr: [u8; 4] = bytes[..4]
747 .try_into()
748 .map_err(|_| "ct_java: truncated int element")?;
749 Ok((format!("{}", i32::from_be_bytes(arr)), 4))
750 }
751 ScalarEnc::Long => {
752 let arr: [u8; 8] = bytes[..8]
753 .try_into()
754 .map_err(|_| "ct_java: truncated long element")?;
755 Ok((format!("{}", i64::from_be_bytes(arr)), 8))
756 }
757 ScalarEnc::Float => {
758 let arr: [u8; 4] = bytes[..4]
759 .try_into()
760 .map_err(|_| "ct_java: truncated float element")?;
761 let bits = u32::from_be_bytes(arr);
762 Ok((format!("f32::from_bits(0x{bits:08x}_u32)"), 4))
763 }
764 ScalarEnc::Double => {
765 let arr: [u8; 8] = bytes[..8]
766 .try_into()
767 .map_err(|_| "ct_java: truncated double element")?;
768 let bits = u64::from_be_bytes(arr);
769 Ok((format!("f64::from_bits(0x{bits:016x}_u64)"), 8))
770 }
771 ScalarEnc::Boolean => {
772 if bytes.is_empty() {
773 return Err("ct_java: truncated boolean element".to_string());
774 }
775 Ok((
776 if bytes[0] != 0 {
777 "true".to_string()
778 } else {
779 "false".to_string()
780 },
781 1,
782 ))
783 }
784 ScalarEnc::Char => {
785 if bytes.len() < 2 {
786 return Err("ct_java: truncated char element".to_string());
787 }
788 let code_unit = u16::from_be_bytes([bytes[0], bytes[1]]);
789 let c = char::from_u32(u32::from(code_unit))
790 .ok_or("ct_java: Java char is not a valid Unicode scalar value")?;
791 Ok((format!("{c:?}"), 2))
792 }
793 ScalarEnc::Str => {
794 if bytes.len() < 4 {
795 return Err("ct_java: truncated String length prefix".to_string());
796 }
797 #[allow(clippy::cast_sign_loss)]
798 let len = i32::from_be_bytes(bytes[..4].try_into().unwrap()) as usize;
799 if bytes.len() < 4 + len {
800 return Err(format!(
801 "ct_java: truncated String element (expected {len} bytes)"
802 ));
803 }
804 let s = String::from_utf8(bytes[4..4 + len].to_vec())
805 .map_err(|_| "ct_java: String element is not valid UTF-8".to_string())?;
806 Ok((format!("{s:?}"), 4 + len))
807 }
808 }
809}
810
811fn ct_java_tokens_recursive(
816 ty: &JavaType,
817 bytes: &[u8],
818 cur: &mut usize,
819) -> Result<proc_macro2::TokenStream, String> {
820 match ty {
821 JavaType::Primitive(p) => {
822 let (lit, consumed) = scalar_enc_ct_lit(p.enc(), &bytes[*cur..])?;
823 *cur += consumed;
824 proc_macro2::TokenStream::from_str(&lit)
825 .map_err(|e| format!("ct_java: produced invalid Rust token: {e}"))
826 }
827 JavaType::Boxed(b) => {
828 let (lit, consumed) = scalar_enc_ct_lit(b.enc(), &bytes[*cur..])?;
829 *cur += consumed;
830 proc_macro2::TokenStream::from_str(&lit)
831 .map_err(|e| format!("ct_java: produced invalid Rust token: {e}"))
832 }
833 JavaType::Array(inner) | JavaType::List(inner) => {
834 if bytes[*cur..].len() < 4 {
835 return Err("ct_java: array/list output too short (missing length)".to_string());
836 }
837 #[allow(clippy::cast_sign_loss)]
838 let n = i32::from_be_bytes(bytes[*cur..*cur + 4].try_into().unwrap()) as usize;
839 *cur += 4;
840 let mut lits: Vec<proc_macro2::TokenStream> = Vec::with_capacity(n);
841 for _ in 0..n {
842 lits.push(ct_java_tokens_recursive(inner, bytes, cur)?);
843 }
844 let array_ts = quote! { [#(#lits),*] };
845 Ok(array_ts)
846 }
847 JavaType::Optional(inner) => {
848 if bytes[*cur..].is_empty() {
849 return Err("ct_java: optional output is empty".to_string());
850 }
851 let tag = bytes[*cur];
852 *cur += 1;
853 if tag == 0 {
854 proc_macro2::TokenStream::from_str("::std::option::Option::None")
855 .map_err(|e| format!("ct_java: produced invalid Rust token: {e}"))
856 } else {
857 let inner_ts = ct_java_tokens_recursive(inner, bytes, cur)?;
858 let result = quote! { ::std::option::Option::Some(#inner_ts) };
859 Ok(result)
860 }
861 }
862 }
863}
864
865fn java_ser_element(ty: &JavaType, var: &str, dos: &str, depth: usize) -> String {
870 match ty {
871 JavaType::Primitive(p) => {
872 let method = p.dos_write_method();
873 format!("{dos}.{method}({var});")
874 }
875 JavaType::Boxed(BoxedType::String) => {
876 format!(
877 "{{ byte[] _b{depth} = ({var}).getBytes(java.nio.charset.StandardCharsets.UTF_8);\n\
878 \t\t\t{dos}.writeInt(_b{depth}.length);\n\
879 \t\t\t{dos}.write(_b{depth}, 0, _b{depth}.length); }}"
880 )
881 }
882 JavaType::Boxed(b) => {
883 let method = b.dos_write_method().unwrap();
884 format!("{dos}.{method}({var});")
885 }
886 JavaType::Array(inner) => {
887 let elem_java_type = inner.java_type_name();
888 let elem_var = format!("_e{depth}");
889 let inner_ser = java_ser_element(inner, &elem_var, dos, depth + 1);
890 format!(
891 "{dos}.writeInt(({var}).length);\n\
892 \t\t\tfor ({elem_java_type} {elem_var} : ({var})) {{\n\
893 \t\t\t\t{inner_ser}\n\
894 \t\t\t}}"
895 )
896 }
897 JavaType::List(inner) => {
898 let elem_java_type = inner.java_type_name();
899 let elem_var = format!("_e{depth}");
900 let inner_ser = java_ser_element(inner, &elem_var, dos, depth + 1);
901 format!(
902 "{dos}.writeInt(({var}).size());\n\
903 \t\t\tfor ({elem_java_type} {elem_var} : ({var})) {{\n\
904 \t\t\t\t{inner_ser}\n\
905 \t\t\t}}"
906 )
907 }
908 JavaType::Optional(inner) => {
909 let inner_java_type = inner.java_type_name();
910 let inner_var = format!("_opt_inner{depth}");
911 let inner_ser = java_ser_element(inner, &inner_var, dos, depth + 1);
912 format!(
913 "if (({var}).isPresent()) {{\n\
914 \t\t\t\t{dos}.writeByte(1);\n\
915 \t\t\t\t{inner_java_type} {inner_var} = ({var}).get();\n\
916 \t\t\t\t{inner_ser}\n\
917 \t\t\t}} else {{\n\
918 \t\t\t\t{dos}.writeByte(0);\n\
919 \t\t\t}}"
920 )
921 }
922 }
923}
924
925fn rust_read_element(ty: &JavaType, depth: usize) -> proc_macro2::TokenStream {
932 match ty {
933 JavaType::Primitive(p) => scalar_enc_read_element(p.enc()),
934 JavaType::Boxed(b) => scalar_enc_read_element(b.enc()),
935 JavaType::Array(inner) | JavaType::List(inner) => {
936 let n_var = format_ident!("_n{}", depth);
937 let v_var = format_ident!("_v{}", depth);
938 let inner_rust_type = inner.rust_return_type_ts();
939 let inner_read = rust_read_element(inner, depth + 1);
940 quote! {{
941 let #n_var = i32::from_be_bytes([_raw[_cur], _raw[_cur + 1], _raw[_cur + 2], _raw[_cur + 3]]) as usize;
942 _cur += 4;
943 let mut #v_var: ::std::vec::Vec<#inner_rust_type> = ::std::vec::Vec::with_capacity(#n_var);
944 for _ in 0..#n_var {
945 let _item = #inner_read;
946 #v_var.push(_item);
947 }
948 #v_var
949 }}
950 }
951 JavaType::Optional(inner) => {
952 let inner_rust_type = inner.rust_return_type_ts();
953 let inner_read = rust_read_element(inner, depth + 1);
954 quote! {{
955 let _tag = _raw[_cur];
956 _cur += 1;
957 if _tag == 0 {
958 ::std::option::Option::None::<#inner_rust_type>
959 } else {
960 ::std::option::Option::Some(#inner_read)
961 }
962 }}
963 }
964 }
965}
966
967fn scalar_enc_read_element(enc: ScalarEnc) -> proc_macro2::TokenStream {
969 match enc {
970 ScalarEnc::Byte => quote! {{
971 let _val = i8::from_be_bytes([_raw[_cur]]);
972 _cur += 1;
973 _val
974 }},
975 ScalarEnc::Short => quote! {{
976 let _val = i16::from_be_bytes([_raw[_cur], _raw[_cur + 1]]);
977 _cur += 2;
978 _val
979 }},
980 ScalarEnc::Int => quote! {{
981 let _val = i32::from_be_bytes([_raw[_cur], _raw[_cur + 1], _raw[_cur + 2], _raw[_cur + 3]]);
982 _cur += 4;
983 _val
984 }},
985 ScalarEnc::Long => quote! {{
986 let _val = i64::from_be_bytes([
987 _raw[_cur], _raw[_cur + 1], _raw[_cur + 2], _raw[_cur + 3],
988 _raw[_cur + 4], _raw[_cur + 5], _raw[_cur + 6], _raw[_cur + 7],
989 ]);
990 _cur += 8;
991 _val
992 }},
993 ScalarEnc::Float => quote! {{
994 let _val = f32::from_bits(u32::from_be_bytes([_raw[_cur], _raw[_cur + 1], _raw[_cur + 2], _raw[_cur + 3]]));
995 _cur += 4;
996 _val
997 }},
998 ScalarEnc::Double => quote! {{
999 let _val = f64::from_bits(u64::from_be_bytes([
1000 _raw[_cur], _raw[_cur + 1], _raw[_cur + 2], _raw[_cur + 3],
1001 _raw[_cur + 4], _raw[_cur + 5], _raw[_cur + 6], _raw[_cur + 7],
1002 ]));
1003 _cur += 8;
1004 _val
1005 }},
1006 ScalarEnc::Boolean => quote! {{
1007 let _val = _raw[_cur] != 0;
1008 _cur += 1;
1009 _val
1010 }},
1011 ScalarEnc::Char => quote! {{
1012 let _val = ::std::char::from_u32(u16::from_be_bytes([_raw[_cur], _raw[_cur + 1]]) as u32)
1013 .ok_or(::inline_java::JavaError::InvalidChar)?;
1014 _cur += 2;
1015 _val
1016 }},
1017 ScalarEnc::Str => quote! {{
1019 let _slen = i32::from_be_bytes([_raw[_cur], _raw[_cur + 1], _raw[_cur + 2], _raw[_cur + 3]]) as usize;
1020 _cur += 4;
1021 let _val = ::std::string::String::from_utf8(_raw[_cur.._cur + _slen].to_vec())?;
1022 _cur += _slen;
1023 _val
1024 }},
1025 }
1026}
1027
1028struct ParsedJava {
1032 imports: String,
1034 outer: String,
1037 body: String,
1040 params: Vec<(JavaType, String)>,
1042 java_type: JavaType,
1044}
1045
1046fn parse_java_type(tts: &[TokenTree]) -> Result<(JavaType, usize), String> {
1058 if tts.is_empty() {
1059 return Err("inline_java: unexpected end of tokens while parsing Java type".to_string());
1060 }
1061
1062 let (tts, offset) = if matches!(&tts[0], TokenTree::Ident(id) if id == "java")
1064 && matches!(tts.get(1), Some(TokenTree::Punct(p)) if p.as_char() == '.')
1065 && matches!(tts.get(2), Some(TokenTree::Ident(id)) if id == "util")
1066 && matches!(tts.get(3), Some(TokenTree::Punct(p)) if p.as_char() == '.')
1067 {
1068 (&tts[4..], 4usize)
1069 } else {
1070 (tts, 0usize)
1071 };
1072
1073 match tts.first() {
1074 Some(TokenTree::Ident(id)) => {
1075 let name = id.to_string();
1076 match name.as_str() {
1077 "List" | "Optional" => {
1078 if !matches!(tts.get(1), Some(TokenTree::Punct(p)) if p.as_char() == '<') {
1080 return Err(format!("inline_java: expected `<` after `{name}`"));
1081 }
1082 let (inner_ty, inner_consumed) = parse_java_type_inner(&tts[2..])?;
1084 let close_idx = 2 + inner_consumed;
1086 if !matches!(tts.get(close_idx), Some(TokenTree::Punct(p)) if p.as_char() == '>')
1087 {
1088 return Err(format!("inline_java: expected `>` to close `{name}<...>`"));
1089 }
1090 let total_consumed = offset + close_idx + 1;
1091 if name == "List" {
1092 Ok((JavaType::List(Box::new(inner_ty)), total_consumed))
1093 } else {
1094 Ok((JavaType::Optional(Box::new(inner_ty)), total_consumed))
1095 }
1096 }
1097 _ => {
1098 let mut consumed = offset + 1;
1100 let base_ty = if let Some(p) = PrimitiveType::from_name(&name) {
1101 JavaType::Primitive(p)
1102 } else if name == "String" {
1103 JavaType::Boxed(BoxedType::String)
1104 } else {
1105 return Err(format!(
1106 "inline_java: `{name}` is not a supported Java type; \
1107 scalar types: byte short int long float double boolean char String"
1108 ));
1109 };
1110
1111 let mut ty = base_ty;
1113 while matches!(
1114 tts.get(consumed - offset),
1115 Some(TokenTree::Group(g))
1116 if g.delimiter() == proc_macro2::Delimiter::Bracket
1117 && g.stream().is_empty()
1118 ) {
1119 ty = JavaType::Array(Box::new(ty));
1120 consumed += 1;
1121 }
1122 Ok((ty, consumed))
1123 }
1124 }
1125 }
1126 _ => Err("inline_java: expected a Java type name".to_string()),
1127 }
1128}
1129
1130fn parse_java_type_inner(tts: &[TokenTree]) -> Result<(JavaType, usize), String> {
1134 if tts.is_empty() {
1135 return Err(
1136 "inline_java: unexpected end of tokens while parsing Java type argument".to_string(),
1137 );
1138 }
1139
1140 let (tts, offset) = if matches!(&tts[0], TokenTree::Ident(id) if id == "java")
1142 && matches!(tts.get(1), Some(TokenTree::Punct(p)) if p.as_char() == '.')
1143 && matches!(tts.get(2), Some(TokenTree::Ident(id)) if id == "util")
1144 && matches!(tts.get(3), Some(TokenTree::Punct(p)) if p.as_char() == '.')
1145 {
1146 (&tts[4..], 4usize)
1147 } else {
1148 (tts, 0usize)
1149 };
1150
1151 match tts.first() {
1152 Some(TokenTree::Ident(id)) => {
1153 let name = id.to_string();
1154 match name.as_str() {
1155 "List" | "Optional" => {
1156 if !matches!(tts.get(1), Some(TokenTree::Punct(p)) if p.as_char() == '<') {
1158 return Err(format!("inline_java: expected `<` after `{name}`"));
1159 }
1160 let (inner_ty, inner_consumed) = parse_java_type_inner(&tts[2..])?;
1161 let close_idx = 2 + inner_consumed;
1162 if !matches!(tts.get(close_idx), Some(TokenTree::Punct(p)) if p.as_char() == '>')
1163 {
1164 return Err(format!("inline_java: expected `>` to close `{name}<...>`"));
1165 }
1166 let total_consumed = offset + close_idx + 1;
1167 if name == "List" {
1168 Ok((JavaType::List(Box::new(inner_ty)), total_consumed))
1169 } else {
1170 Ok((JavaType::Optional(Box::new(inner_ty)), total_consumed))
1171 }
1172 }
1173 _ => {
1174 let b = BoxedType::from_name(&name).ok_or_else(|| {
1176 format!(
1177 "inline_java: `{name}` is not a supported Java type argument; \
1178 supported: Byte Short Integer Long Float Double Boolean Character String \
1179 (or primitive names)"
1180 )
1181 })?;
1182 let mut consumed = offset + 1;
1183 let base_ty = JavaType::Boxed(b);
1184
1185 while matches!(
1187 tts.get(consumed - offset),
1188 Some(TokenTree::Group(g))
1189 if g.delimiter() == proc_macro2::Delimiter::Bracket
1190 && g.stream().is_empty()
1191 ) {
1192 consumed += 1;
1193 }
1194 let array_depth = consumed - offset - 1;
1196 let mut ty = base_ty;
1197 for _ in 0..array_depth {
1198 ty = JavaType::Array(Box::new(ty));
1199 }
1200 Ok((ty, consumed))
1201 }
1202 }
1203 }
1204 _ => Err("inline_java: expected a Java type name inside `<>`".to_string()),
1205 }
1206}
1207
1208fn parse_run_return_type(tts: &[TokenTree]) -> Result<(JavaType, usize, usize), String> {
1218 for i in 0..tts.len().saturating_sub(2) {
1219 if !matches!(&tts[i], TokenTree::Ident(id) if id == "static") {
1220 continue;
1221 }
1222
1223 let start = if i > 0
1225 && matches!(&tts[i - 1], TokenTree::Ident(id)
1226 if matches!(id.to_string().as_str(), "public" | "private" | "protected"))
1227 {
1228 i - 1
1229 } else {
1230 i
1231 };
1232
1233 let type_start = i + 1;
1235 if type_start >= tts.len() {
1236 continue;
1237 }
1238
1239 if let Ok((java_type, consumed)) = parse_java_type(&tts[type_start..]) {
1240 let run_idx = type_start + consumed;
1241 if matches!(tts.get(run_idx), Some(TokenTree::Ident(id)) if id == "run") {
1242 return Ok((java_type, start, run_idx));
1243 }
1244 }
1246 }
1248 Err("inline_java: could not find `static <type> run()` in Java body".to_string())
1249}
1250
1251fn parse_run_params(tts: &[TokenTree]) -> Result<Vec<(JavaType, String)>, String> {
1257 let group = match tts.first() {
1259 Some(TokenTree::Group(g)) if g.delimiter() == proc_macro2::Delimiter::Parenthesis => g,
1260 _ => return Ok(vec![]),
1261 };
1262
1263 let inner: Vec<TokenTree> = group.stream().into_iter().collect();
1264 if inner.is_empty() {
1265 return Ok(vec![]);
1266 }
1267
1268 let mut params = Vec::new();
1272 let mut segments: Vec<Vec<TokenTree>> = Vec::new();
1273 let mut current: Vec<TokenTree> = Vec::new();
1274 let mut angle_depth = 0i32;
1275 for tt in inner {
1276 if matches!(&tt, TokenTree::Punct(p) if p.as_char() == '<') {
1277 angle_depth += 1;
1278 current.push(tt);
1279 } else if matches!(&tt, TokenTree::Punct(p) if p.as_char() == '>') {
1280 angle_depth -= 1;
1281 current.push(tt);
1282 } else if matches!(&tt, TokenTree::Punct(p) if p.as_char() == ',') && angle_depth == 0 {
1283 segments.push(std::mem::take(&mut current));
1284 } else {
1285 current.push(tt);
1286 }
1287 }
1288 if !current.is_empty() {
1289 segments.push(current);
1290 }
1291
1292 for seg in segments {
1293 if seg.is_empty() {
1294 continue;
1295 }
1296
1297 let param_name = match seg.last() {
1300 Some(TokenTree::Ident(id)) => id.to_string(),
1301 _ => {
1302 return Err(
1303 "inline_java: unexpected token in run() parameter list: expected a parameter name"
1304 .to_string(),
1305 );
1306 }
1307 };
1308
1309 let type_tts = &seg[..seg.len() - 1];
1311 if type_tts.is_empty() {
1312 return Err(format!(
1313 "inline_java: missing type for parameter `{param_name}`"
1314 ));
1315 }
1316
1317 let (java_type, consumed) = parse_java_type(type_tts).map_err(|e| {
1318 format!("inline_java: error parsing type of parameter `{param_name}`: {e}")
1319 })?;
1320
1321 if consumed != type_tts.len() {
1323 return Err(format!(
1324 "inline_java: unexpected tokens after type of parameter `{param_name}`"
1325 ));
1326 }
1327
1328 params.push((java_type, param_name));
1329 }
1330
1331 Ok(params)
1332}
1333
1334fn parse_java_source(stream: proc_macro2::TokenStream) -> Result<ParsedJava, String> {
1342 let tts: Vec<TokenTree> = stream.into_iter().collect();
1343
1344 let mut first_import_idx: Option<usize> = None;
1346 let mut last_import_end_idx: Option<usize> = None; let mut first_body_idx: Option<usize> = None;
1348 let mut in_imports = true;
1349 let mut i = 0usize;
1350
1351 while i < tts.len() && in_imports {
1352 match &tts[i] {
1353 TokenTree::Ident(id) if id == "import" || id == "package" => {
1354 first_import_idx.get_or_insert(i);
1355 let semi = tts[i + 1..]
1357 .iter()
1358 .position(|t| matches!(t, TokenTree::Punct(p) if p.as_char() == ';'))
1359 .map(|rel| i + 1 + rel);
1360 if let Some(semi_idx) = semi {
1361 last_import_end_idx = Some(semi_idx);
1362 i = semi_idx + 1;
1363 } else {
1364 in_imports = false;
1366 first_body_idx = Some(i);
1367 }
1368 }
1369 _ => {
1370 in_imports = false;
1371 first_body_idx = Some(i);
1372 }
1373 }
1374 }
1375 if first_body_idx.is_none() && i < tts.len() {
1377 first_body_idx = Some(i);
1378 }
1379 let body_start = first_body_idx.unwrap_or(tts.len());
1380
1381 let (java_type, run_rel_idx, run_rel_run_idx) = parse_run_return_type(&tts[body_start..])?;
1383 let run_abs_idx = body_start + run_rel_idx;
1384 let run_token_abs_idx = body_start + run_rel_run_idx;
1385
1386 let params = parse_run_params(&tts[run_token_abs_idx + 1..])?;
1388
1389 let slice_text = |lo: usize, hi: usize| -> String {
1393 if lo >= hi {
1394 return String::new();
1395 }
1396 tts[lo]
1397 .span()
1398 .join(tts[hi - 1].span())
1399 .and_then(|s| s.source_text())
1400 .unwrap_or_else(|| {
1401 tts[lo..hi]
1402 .iter()
1403 .map(std::string::ToString::to_string)
1404 .collect::<Vec<_>>()
1405 .join(" ")
1406 })
1407 };
1408
1409 let imports = match (first_import_idx, last_import_end_idx) {
1411 (Some(fi), Some(le)) => slice_text(fi, le + 1),
1412 _ => String::new(),
1413 };
1414
1415 let outer = slice_text(body_start, run_abs_idx);
1417
1418 let body = if run_abs_idx < tts.len() {
1420 let start_span = tts[run_abs_idx].span();
1421 let end_span = tts.last().unwrap().span();
1422
1423 match start_span.join(end_span).and_then(|s| s.source_text()) {
1424 Some(raw) => raw,
1425 None => tts[run_abs_idx..]
1426 .iter()
1427 .map(std::string::ToString::to_string)
1428 .collect::<Vec<_>>()
1429 .join(" "),
1430 }
1431 } else {
1432 String::new()
1433 };
1434
1435 Ok(ParsedJava {
1436 imports,
1437 outer,
1438 body,
1439 params,
1440 java_type,
1441 })
1442}
1443
1444#[allow(clippy::similar_names)]
1451fn make_runner_fn(parsed: ParsedJava, opts: JavaOpts, prefix: &str) -> proc_macro2::TokenStream {
1452 let ParsedJava {
1453 imports,
1454 outer,
1455 body,
1456 params,
1457 java_type,
1458 } = parsed;
1459
1460 let class_name = make_class_name(prefix, &imports, &outer, &body, &opts);
1461 let filename = format!("{class_name}.java");
1462 let full_class_name = qualify_class_name(&class_name, &imports);
1463
1464 let main_method = java_type.java_main(¶ms);
1465 let java_class = format_java_class(&imports, &outer, &class_name, &body, &main_method);
1466
1467 let javac_raw = opts.javac_args.unwrap_or_default();
1468 let java_raw = opts.java_args.unwrap_or_default();
1469 let deser = java_type.rust_deser();
1470 let ret_ty = java_type.rust_return_type_ts();
1471
1472 let fn_params: Vec<proc_macro2::TokenStream> = params
1475 .iter()
1476 .map(|(ty, name)| {
1477 let ident = proc_macro2::Ident::new(name, proc_macro2::Span::call_site());
1478 let param_ty = ty.rust_param_type_ts();
1479 quote! { #ident: #param_ty }
1480 })
1481 .collect();
1482
1483 let ser_stmts: Vec<proc_macro2::TokenStream> = params
1485 .iter()
1486 .map(|(ty, name)| {
1487 let ident = proc_macro2::Ident::new(name, proc_macro2::Span::call_site());
1488 let ident_ts = quote! { #ident };
1489 ty.rust_ser_ts(&ident_ts, 0)
1490 })
1491 .collect();
1492
1493 quote! {
1494 fn __java_runner(#(#fn_params),*) -> ::std::result::Result<#ret_ty, ::inline_java::JavaError> {
1495 let mut _stdin_bytes: ::std::vec::Vec<u8> = ::std::vec::Vec::new();
1496 #(#ser_stmts)*
1497 let _raw = ::inline_java::run_java(
1498 #class_name,
1499 #filename,
1500 #java_class,
1501 #full_class_name,
1502 #javac_raw,
1503 #java_raw,
1504 &_stdin_bytes,
1505 )?;
1506 ::std::result::Result::Ok(#deser)
1507 }
1508 }
1509}
1510
1511#[proc_macro]
1565#[allow(clippy::similar_names)]
1566pub fn java(input: TokenStream) -> TokenStream {
1567 let input2 = proc_macro2::TokenStream::from(input);
1568
1569 let (opts, input2) = extract_opts(input2);
1571
1572 let parsed = match parse_java_source(input2) {
1573 Ok(p) => p,
1574 Err(msg) => return quote! { compile_error!(#msg) }.into(),
1575 };
1576
1577 let runner_fn = make_runner_fn(parsed, opts, "InlineJava");
1578
1579 let generated = quote! {
1580 {
1581 #runner_fn
1582 __java_runner()
1583 }
1584 };
1585
1586 generated.into()
1587}
1588
1589#[proc_macro]
1643#[allow(clippy::similar_names)]
1644pub fn java_fn(input: TokenStream) -> TokenStream {
1645 let input2 = proc_macro2::TokenStream::from(input);
1646
1647 let (opts, input2) = extract_opts(input2);
1649
1650 let parsed = match parse_java_source(input2) {
1651 Ok(p) => p,
1652 Err(msg) => return quote! { compile_error!(#msg) }.into(),
1653 };
1654
1655 let runner_fn = make_runner_fn(parsed, opts, "InlineJava");
1656
1657 let generated = quote! {
1658 {
1659 #runner_fn
1660 __java_runner
1661 }
1662 };
1663
1664 generated.into()
1665}
1666
1667#[proc_macro]
1705pub fn ct_java(input: TokenStream) -> TokenStream {
1706 match ct_java_impl(proc_macro2::TokenStream::from(input)) {
1707 Ok(ts) => ts.into(),
1708 Err(msg) => quote! { compile_error!(#msg) }.into(),
1709 }
1710}
1711
1712fn ct_java_impl(input: proc_macro2::TokenStream) -> Result<proc_macro2::TokenStream, String> {
1713 let (opts, input) = extract_opts(input);
1714
1715 let ParsedJava {
1716 imports,
1717 outer,
1718 body,
1719 java_type,
1720 ..
1721 } = parse_java_source(input)?;
1722
1723 let class_name = make_class_name("CtJava", &imports, &outer, &body, &opts);
1724 let filename = format!("{class_name}.java");
1725 let full_class_name = qualify_class_name(&class_name, &imports);
1726
1727 let main_method = java_type.java_main(&[]);
1728 let java_class = format_java_class(&imports, &outer, &class_name, &body, &main_method);
1729
1730 let bytes = compile_run_java_now(
1731 &class_name,
1732 &filename,
1733 &java_class,
1734 &full_class_name,
1735 opts.javac_args.as_deref(),
1736 opts.java_args.as_deref(),
1737 )?;
1738 java_type.ct_java_tokens(bytes)
1739}
1740
1741struct JavaOpts {
1744 javac_args: Option<String>,
1746 java_args: Option<String>,
1748}
1749
1750fn extract_opts(input: proc_macro2::TokenStream) -> (JavaOpts, proc_macro2::TokenStream) {
1754 let mut tts: Vec<TokenTree> = input.into_iter().collect();
1755 let mut opts = JavaOpts {
1756 javac_args: None,
1757 java_args: None,
1758 };
1759 let mut cursor = 0;
1760
1761 loop {
1762 match try_parse_opt(&tts[cursor..]) {
1763 None => break,
1764 Some((key, val, consumed)) => {
1765 match key.as_str() {
1766 "javac" => opts.javac_args = Some(val),
1767 "java" => opts.java_args = Some(val),
1768 _ => break,
1769 }
1770 cursor += consumed;
1771 if let Some(TokenTree::Punct(p)) = tts.get(cursor)
1772 && p.as_char() == ','
1773 {
1774 cursor += 1;
1775 }
1776 }
1777 }
1778 }
1779
1780 let rest = tts.drain(cursor..).collect();
1781 (opts, rest)
1782}
1783
1784fn try_parse_opt(tts: &[TokenTree]) -> Option<(String, String, usize)> {
1788 let key = match tts.first() {
1789 Some(TokenTree::Ident(id)) => id.to_string(),
1790 _ => return None,
1791 };
1792 let Some(TokenTree::Punct(eq)) = tts.get(1) else {
1793 return None;
1794 };
1795 if eq.as_char() != '=' {
1796 return None;
1797 }
1798 let Some(TokenTree::Literal(lit)) = tts.get(2) else {
1799 return None;
1800 };
1801 let value = litrs::StringLit::try_from(lit).ok()?.value().to_owned();
1802 Some((key, value, 3))
1803}
1804
1805fn make_class_name(
1810 prefix: &str,
1811 imports: &str,
1812 outer: &str,
1813 body: &str,
1814 opts: &JavaOpts,
1815) -> String {
1816 let mut h = DefaultHasher::new();
1817 imports.hash(&mut h);
1818 outer.hash(&mut h);
1819 body.hash(&mut h);
1820 opts.javac_args.hash(&mut h);
1821 opts.java_args.hash(&mut h);
1822 format!("{prefix}_{:016x}", h.finish())
1823}
1824
1825fn qualify_class_name(class_name: &str, imports: &str) -> String {
1828 match parse_package_name(imports) {
1829 Some(pkg) => format!("{pkg}.{class_name}"),
1830 None => class_name.to_owned(),
1831 }
1832}
1833
1834#[allow(clippy::similar_names)]
1838fn compile_run_java_now(
1839 class_name: &str,
1840 filename: &str,
1841 java_class: &str,
1842 full_class_name: &str,
1843 javac_raw: Option<&str>,
1844 java_raw: Option<&str>,
1845) -> Result<Vec<u8>, String> {
1846 inline_java_core::run_java(
1847 class_name,
1848 filename,
1849 java_class,
1850 full_class_name,
1851 javac_raw.unwrap_or(""),
1852 java_raw.unwrap_or(""),
1853 &[],
1854 )
1855 .map_err(|e| e.to_string())
1856}
1857
1858fn format_java_class(
1860 imports: &str,
1861 outer: &str,
1862 class_name: &str,
1863 body: &str,
1864 main_method: &str,
1865) -> String {
1866 format!("{imports}\n{outer}\npublic class {class_name} {{\n\n{body}\n\n{main_method}\n}}\n")
1867}
1868
1869fn parse_package_name(imports: &str) -> Option<String> {
1876 let marker = "package ";
1877 let i = imports.find(marker)?;
1878 if i > 0 && !imports[..i].ends_with(|c: char| c.is_whitespace()) {
1879 return None;
1880 }
1881 let rest = imports[i + marker.len()..].trim_start();
1882 let semi = rest.find(';')?;
1883 let pkg = rest[..semi].trim().replace(|c: char| c.is_whitespace(), "");
1884 if pkg.is_empty() { None } else { Some(pkg) }
1885}