1#![expect(clippy::useless_format)]
6
7use std::collections::HashSet;
8use std::env;
9
10use proc_macro::Delimiter;
11use virtue::generate::FnSelfArg;
12use virtue::parse::{Attribute, AttributeLocation, EnumBody, StructBody};
13use virtue::prelude::*;
14use virtue::utils::{parse_tagged_attribute, ParsedAttribute};
15
16const ENV_SSHWIRE_DEBUG: &str = "SSHWIRE_DEBUG";
17
18#[proc_macro_derive(SSHEncode, attributes(sshwire))]
19pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
20 encode_inner(input).unwrap_or_else(|e| e.into_token_stream())
21}
22
23#[proc_macro_derive(SSHDecode, attributes(sshwire))]
24pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
25 decode_inner(input).unwrap_or_else(|e| e.into_token_stream())
26}
27
28fn encode_inner(input: TokenStream) -> Result<TokenStream> {
29 let parse = Parse::new(input)?;
30 let (mut gen, att, body) = parse.into_generator();
31 match body {
33 Body::Struct(body) => {
34 encode_struct(&mut gen, body)?;
35 }
36 Body::Enum(body) => {
37 encode_enum(&mut gen, &att, body)?;
38 }
39 }
40 if env::var(ENV_SSHWIRE_DEBUG).is_ok() {
41 gen.export_to_file("sshwire", "SSHEncode");
42 }
43 gen.finish()
44}
45
46fn decode_inner(input: TokenStream) -> Result<TokenStream> {
47 let parse = Parse::new(input)?;
48 let (mut gen, att, body) = parse.into_generator();
49 match body {
51 Body::Struct(body) => {
52 decode_struct(&mut gen, body)?;
53 }
54 Body::Enum(body) => {
55 decode_enum(&mut gen, &att, body)?;
56 }
57 }
58 if env::var(ENV_SSHWIRE_DEBUG).is_ok() {
59 gen.export_to_file("sshwire", "SSHDecode");
60 }
61 gen.finish()
62}
63
64#[derive(Debug)]
65enum ContainerAtt {
66 VariantPrefix,
69
70 NoNames,
73}
74
75#[derive(Debug)]
76enum FieldAtt {
77 VariantName(Ident),
80
81 CaptureUnknown,
85
86 Variant(TokenTree),
91}
92
93fn take_cont_atts(atts: &[Attribute]) -> Result<Vec<ContainerAtt>> {
94 let x = atts
95 .iter()
96 .filter_map(|a| parse_tagged_attribute(&a.tokens, "sshwire").transpose());
97
98 let mut ret = vec![];
99 for a in x {
101 for a in a? {
102 let l = match a {
103 ParsedAttribute::Tag(l) if l.to_string() == "no_variant_names" => {
104 Ok(ContainerAtt::NoNames)
105 }
106 ParsedAttribute::Tag(l) if l.to_string() == "variant_prefix" => {
107 Ok(ContainerAtt::VariantPrefix)
108 }
109 _ => Err(Error::Custom {
110 error: "Unknown sshwire atttribute".into(),
111 span: None,
112 }),
113 }?;
114 ret.push(l);
115 }
116 }
117 Ok(ret)
118}
119
120fn take_field_atts(atts: &[Attribute]) -> Result<Vec<FieldAtt>> {
122 atts.iter()
123 .filter_map(|a| {
124 match a.location {
125 AttributeLocation::Field | AttributeLocation::Variant => {
126 let mut s = a.tokens.stream().into_iter();
127 if &s.next().expect("missing attribute name").to_string()
128 != "sshwire"
129 {
130 return None;
132 }
133 Some(if let Some(TokenTree::Group(g)) = s.next() {
134 let mut g = g.stream().into_iter();
135 let f = match g.next() {
136 Some(TokenTree::Ident(l))
137 if l.to_string() == "variant_name" =>
138 {
139 match g.next() {
141 Some(TokenTree::Punct(p)) if p == '=' => (),
142 _ => {
143 return Some(Err(Error::Custom {
144 error: "Missing '='".into(),
145 span: Some(a.tokens.span()),
146 }))
147 }
148 }
149 match g.next() {
150 Some(TokenTree::Ident(i)) => {
151 Ok(FieldAtt::VariantName(i))
152 }
153 _ => Err(Error::ExpectedIdent(a.tokens.span())),
154 }
155 }
156
157 Some(TokenTree::Ident(l))
158 if l.to_string() == "unknown" =>
159 {
160 Ok(FieldAtt::CaptureUnknown)
161 }
162
163 Some(TokenTree::Ident(l))
164 if l.to_string() == "variant" =>
165 {
166 match g.next() {
168 Some(TokenTree::Punct(p)) if p == '=' => (),
169 _ => {
170 return Some(Err(Error::Custom {
171 error: "Missing '='".into(),
172 span: Some(a.tokens.span()),
173 }))
174 }
175 }
176 if let Some(t) = g.next() {
177 Ok(FieldAtt::Variant(t))
178 } else {
179 Err(Error::Custom {
180 error: "Missing expression".into(),
181 span: Some(a.tokens.span()),
182 })
183 }
184 }
185
186 _ => Err(Error::Custom {
187 error: "Unknown sshwire atttribute".into(),
188 span: Some(a.tokens.span()),
189 }),
190 };
191
192 if g.next().is_some() {
193 Err(Error::Custom {
194 error: "Extra unhandled parts".into(),
195 span: Some(a.tokens.span()),
196 })
197 } else {
198 f
199 }
200 } else {
201 Err(Error::Custom {
202 error: "#[sshwire(...)] attribute is missing (...) part"
203 .into(),
204 span: Some(a.tokens.span()),
205 })
206 })
207 }
208 _ => panic!("Non-field attribute for field: {a:#?}"),
209 }
210 })
211 .collect()
212}
213
214fn encode_struct(gen: &mut Generator, body: StructBody) -> Result<()> {
215 gen.impl_for("::sunset::sshwire::SSHEncode")
216 .generate_fn("enc")
217 .with_self_arg(FnSelfArg::RefSelf)
218 .with_arg("s", "&mut dyn ::sunset::sshwire::SSHSink")
219 .with_return_type("::sunset::sshwire::WireResult<()>")
220 .body(|fn_body| {
221 match &body.fields {
222 Some(Fields::Tuple(v)) => {
223 for (fname, f) in v.iter().enumerate() {
224 if !f.attributes.is_empty() {
226 return Err(Error::Custom { error: "Attributes aren't allowed for tuple structs".into(), span: Some(f.span()) })
227 }
228 fn_body.push_parsed(format!("::sunset::sshwire::SSHEncode::enc(&self.{fname}, s)?;"))?;
229 }
230 }
231 Some(Fields::Struct(v)) => {
232 for f in v {
233 let fname = &f.0;
234 let atts = take_field_atts(&f.1.attributes)?;
235 for a in atts {
236 if let FieldAtt::VariantName(enum_field) = a {
237 fn_body.push_parsed(format!("::sunset::sshwire::SSHEncode::enc(&self.{enum_field}.variant_name()?, s)?;"))?;
239 }
240 }
241 fn_body.push_parsed(format!("::sunset::sshwire::SSHEncode::enc(&self.{fname}, s)?;"))?;
242 }
243
244 }
245 None => {
246 }
249
250 }
251 fn_body.push_parsed("Ok(())")?;
252 Ok(())
253 })?;
254 Ok(())
255}
256
257fn encode_enum(
258 gen: &mut Generator,
259 atts: &[Attribute],
260 body: EnumBody,
261) -> Result<()> {
262 let cont_atts = take_cont_atts(atts)?;
263
264 gen.impl_for("::sunset::sshwire::SSHEncode")
265 .generate_fn("enc")
266 .with_self_arg(FnSelfArg::RefSelf)
267 .with_arg("s", "&mut dyn ::sunset::sshwire::SSHSink")
268 .with_return_type("::sunset::sshwire::WireResult<()>")
269 .body(|fn_body| {
270 if cont_atts.iter().any(|c| matches!(c, ContainerAtt::VariantPrefix)) {
271 fn_body.push_parsed("::sunset::sshwire::SSHEncode::enc(&self.variant_name()?, s)?;")?;
272 }
273
274 fn_body.ident_str("match");
275 fn_body.puncts("*");
276 fn_body.ident_str("self");
277 fn_body.group(Delimiter::Brace, |match_arm| {
278 for var in &body.variants {
279 match_arm.ident_str("Self");
280 match_arm.puncts("::");
281 match_arm.ident(var.name.clone());
282
283 let atts = take_field_atts(&var.attributes)?;
284
285 let mut rhs = StreamBuilder::new();
286 if let Some(val) = &var.value {
287 return Err(Error::Custom { error: "sunset_sshwire_derive::SSHEncode currently does not encode enum discriminants.".into(), span: Some(val.span())})
290 }
291 match var.fields {
292 None => {
293 }
295 Some(Fields::Tuple(ref f)) if f.len() == 1 => {
296 match_arm.group(Delimiter::Parenthesis, |item| {
297 item.ident_str("ref");
298 item.ident_str("i");
299 Ok(())
300 })?;
301 if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
302 rhs.push_parsed("return Err(::sunset::sshwire::WireError::UnknownVariant)")?;
303 } else {
304 rhs.push_parsed(format!("::sunset::sshwire::SSHEncode::enc(i, s)?;"))?;
305 }
306
307 }
308 _ => return Err(Error::Custom { error: "sunset_sshwire_derive::SSHEncode currently only implements Unit or single value enum variants.".into(), span: None})
309 }
310
311 match_arm.puncts("=>");
312 match_arm.group(Delimiter::Brace, |var_body| {
313 var_body.append(rhs);
314 Ok(())
315 })?;
316 }
317 Ok(())
318 })?;
319 fn_body.push_parsed("#[allow(unreachable_code)]")?;
321 fn_body.push_parsed("Ok(())")?;
322 Ok(())
323 })?;
324
325 if !cont_atts.iter().any(|c| matches!(c, ContainerAtt::NoNames)) {
326 encode_enum_names(gen, atts, body)?;
327 }
328 Ok(())
329}
330
331fn field_att_var_names(name: &Ident, mut atts: Vec<FieldAtt>) -> Result<TokenTree> {
332 let mut v = vec![];
333 while let Some(p) = atts.pop() {
334 if let FieldAtt::Variant(t) = p {
335 v.push(t);
336 }
337 }
338 if v.len() != 1 {
339 return Err(Error::Custom { error: format!("One #[sshwire(variant = ...)] attribute is required for each enum field, missing for {:?}", name), span: None});
340 }
341 Ok(v.pop().unwrap())
342}
343
344fn encode_enum_names(
345 gen: &mut Generator,
346 _atts: &[Attribute],
347 body: EnumBody,
348) -> Result<()> {
349 gen.impl_for("::sunset::sshwire::SSHEncodeEnum")
350 .generate_fn("variant_name")
351 .with_self_arg(FnSelfArg::RefSelf)
352 .with_return_type("::sunset::sshwire::WireResult<&'static str>")
353 .body(|fn_body| {
354 fn_body.push_parsed("let r = match self")?;
355 fn_body.group(Delimiter::Brace, |match_arm| {
356 for var in &body.variants {
357 match_arm.ident_str("Self");
358 match_arm.puncts("::");
359 match_arm.ident(var.name.clone());
360
361 let mut rhs = StreamBuilder::new();
362 let atts = take_field_atts(&var.attributes)?;
363 if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
364 rhs.push_parsed("return Err(::sunset::sshwire::WireError::UnknownVariant)")?;
365 } else {
366 rhs.push(field_att_var_names(&var.name, atts)?);
367 }
368
369 match var.fields {
370 None => {
371 }
373 Some(Fields::Tuple(ref f)) if f.len() == 1 => {
374 match_arm.group(Delimiter::Parenthesis, |item| {
375 item.ident_str("_");
376 Ok(())
377 })?;
378
379 }
380 _ => return Err(Error::Custom { error: "sunset_sshwire_derive::SSHEncode currently only implements Unit or single value enum variants.".into(), span: None})
381 }
382
383 match_arm.puncts("=>");
384 match_arm.group(Delimiter::Brace, |var_body| {
385 var_body.append(rhs);
386 Ok(())
387 })?;
388 }
389 Ok(())
390 })?;
391 fn_body.push_parsed(";")?;
392 fn_body.push_parsed("#[allow(unreachable_code)]")?;
394 fn_body.push_parsed("Ok(r)")?;
395
396 Ok(())
397 })?;
398
399 Ok(())
400}
401
402fn decode_struct(gen: &mut Generator, body: StructBody) -> Result<()> {
403 gen.impl_for_with_lifetimes("::sunset::sshwire::SSHDecode", ["de"])
404 .modify_generic_constraints(|generics, where_constraints| {
405 for lt in generics.iter_lifetimes() {
406 where_constraints.push_parsed_constraint(format!("'de: '{}", lt.ident))?;
407 }
408 Ok(())
409 })?
410 .generate_fn("dec")
411 .with_generic_deps("S", ["::sunset::sshwire::SSHSource<'de>"])
412 .with_arg("s", "&mut S")
413 .with_return_type("::sunset::sshwire::WireResult<Self>")
414 .body(|fn_body| {
415 let mut named_enums = HashSet::new();
416 if let Some(Fields::Struct(v)) = &body.fields {
417 for f in v {
418 let atts = take_field_atts(&f.1.attributes)?;
419 for a in atts {
420 if let FieldAtt::VariantName(enum_field) = a {
421 named_enums.insert(enum_field.to_string());
423 fn_body.push_parsed(format!("let enum_name_{enum_field}: BinString = ::sunset::sshwire::SSHDecode::dec(s)?;"))?;
424 }
425 }
426 let fname = &f.0;
427 if named_enums.contains(&fname.to_string()) {
428 fn_body.push_parsed(format!("let field_{fname} = ::sunset::sshwire::SSHDecodeEnum::dec_enum(s, enum_name_{fname}.0)?;"))?;
429 } else {
430 fn_body.push_parsed(format!("let field_{fname} = ::sunset::sshwire::SSHDecode::dec(s)?;"))?;
431 }
432 }
433 }
434 fn_body.ident_str("Ok");
435 fn_body.group(Delimiter::Parenthesis, |fn_body| {
436 match &body.fields {
437 Some(Fields::Tuple(f)) => {
438 fn_body.ident_str("Self");
440 fn_body.group(Delimiter::Parenthesis, |args| {
441 for _ in f.iter() {
442 args.push_parsed(format!("::sunset::sshwire::SSHDecode::dec(s)?,"))?;
443 }
444 Ok(())
445 })?;
446 }
447 Some(Fields::Struct(v)) => {
448 fn_body.ident_str("Self");
449 fn_body.group(Delimiter::Brace, |args| {
450 for f in v {
451 let fname = &f.0;
452 args.push_parsed(format!("{fname}: field_{fname},"))?;
453 }
454 Ok(())
455 })?;
456 }
457 None => {
458 fn_body.ident_str("Self");
460 fn_body.group(Delimiter::Brace, |_| Ok(()))?;
461 }
462 }
463 Ok(())
464 })?;
465 Ok(())
466 })?;
467 Ok(())
468}
469
470fn decode_enum(
471 gen: &mut Generator,
472 atts: &[Attribute],
473 body: EnumBody,
474) -> Result<()> {
475 let cont_atts = take_cont_atts(atts)?;
476
477 if cont_atts.iter().any(|c| matches!(c, ContainerAtt::NoNames)) {
478 return Err(Error::Custom {
479 error:
480 "SSHDecode derive can't be used with #[sshwire(no_variant_names)]"
481 .into(),
482 span: None,
483 });
484 }
485
486 if cont_atts.iter().any(|c| matches!(c, ContainerAtt::VariantPrefix)) {
488 decode_enum_variant_prefix(gen, atts, &body)?;
489 }
490
491 decode_enum_names(gen, atts, &body)?;
492 Ok(())
493}
494
495fn decode_enum_variant_prefix(
496 gen: &mut Generator,
497 _atts: &[Attribute],
498 _body: &EnumBody,
499) -> Result<()> {
500 gen.impl_for_with_lifetimes("::sunset::sshwire::SSHDecode", ["de"])
501 .modify_generic_constraints(|generics, where_constraints| {
502 for lt in generics.iter_lifetimes() {
503 where_constraints.push_parsed_constraint(format!("'de: '{}", lt.ident))?;
504 }
505 Ok(())
506 })?
507 .generate_fn("dec")
508 .with_generic_deps("S", ["::sunset::sshwire::SSHSource<'de>"])
509 .with_arg("s", "&mut S")
510 .with_return_type("::sunset::sshwire::WireResult<Self>")
511 .body(|fn_body| {
512 fn_body
513 .push_parsed("let variant: ::sunset::sshwire::BinString = ::sunset::sshwire::SSHDecode::dec(s)?;")?;
514 fn_body.push_parsed(
515 "::sunset::sshwire::SSHDecodeEnum::dec_enum(s, variant.0)",
516 )?;
517 Ok(())
518 })
519}
520
521fn decode_enum_names(
522 gen: &mut Generator,
523 _atts: &[Attribute],
524 body: &EnumBody,
525) -> Result<()> {
526 gen.impl_for_with_lifetimes("::sunset::sshwire::SSHDecodeEnum", ["de"])
527 .modify_generic_constraints(|generics, where_constraints| {
528 for lt in generics.iter_lifetimes() {
529 where_constraints.push_parsed_constraint(format!("'de: '{}", lt.ident))?;
530 }
531 Ok(())
532 })?
533 .generate_fn("dec_enum")
534 .with_generic_deps("S", ["::sunset::sshwire::SSHSource<'de>"])
535 .with_arg("s", "&mut S")
536 .with_arg("variant", "&'de [u8]")
537 .with_return_type("::sunset::sshwire::WireResult<Self>")
538 .body(|fn_body| {
539 fn_body.push_parsed("let var_str = ::sunset::sshwire::try_as_ascii_str(variant).ok();")?;
541
542 fn_body.push_parsed("let r = match var_str")?;
543 fn_body.group(Delimiter::Brace, |match_arm| {
544 let mut unknown_arm = None;
545 for var in &body.variants {
546 let atts = take_field_atts(&var.attributes)?;
547 if atts.iter().any(|a| matches!(a, FieldAtt::CaptureUnknown)) {
548 let mut m = StreamBuilder::new();
550 m.push_parsed(format!("_ => {{ s.ctx().seen_unknown = true; Self::{}(Unknown::new(variant))}}", var.name))?;
551 if unknown_arm.replace(m).is_some() {
552 return Err(Error::Custom { error: "only one variant can have #[sshwire(unknown)]".into(), span: None})
553 }
554 } else {
555 let var_name = field_att_var_names(&var.name, atts)?;
556 match_arm.push_parsed(format!("Some({}) => ", var_name))?;
557 match_arm.group(Delimiter::Brace, |var_body| {
558 match var.fields {
559 None => {
560 var_body.push_parsed(format!("Self::{}", var.name))?;
561 }
562 Some(Fields::Tuple(ref f)) if f.len() == 1 => {
563 var_body.push_parsed(format!("Self::{}(::sunset::sshwire::SSHDecode::dec(s)?)", var.name))?;
564 }
565 _ => return Err(Error::Custom { error: "SSHDecode currently only implements Unit or single value enum variants. ".into(), span: None})
566 }
567 Ok(())
568 })?;
569
570 }
571 if let Some(unk) = unknown_arm.take() {
572 match_arm.append(unk);
573 }
574 }
575 Ok(())
576 })?;
577 fn_body.push_parsed("; Ok(r)")?;
578 Ok(())
579 })?;
580 Ok(())
581}