differential_equations_derive/lib.rs
1//! Derive macros for differential equations crate
2//!
3//! [](https://github.com/Ryan-D-Gast/differential-equations)
4//! [](https://docs.rs/differential-equations)
5//!
6//! # License
7//!
8//! ```text
9//! Copyright 2025 Ryan D. Gast
10//!
11//! Licensed under the Apache License, Version 2.0 (the "License");
12//! you may not use this file except in compliance with the License.
13//! You may obtain a copy of the License at
14//!
15//! http://www.apache.org/licenses/LICENSE-2.0
16//!
17//! Unless required by applicable law or agreed to in writing, software
18//! distributed under the License is distributed on an "AS IS" BASIS,
19//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20//! See the License for the specific language governing permissions and
21//! limitations under the License.
22//! ```
23
24use proc_macro::TokenStream;
25use syn::{parse_macro_input, DeriveInput, Data, Fields};
26
27mod field_analysis;
28mod code_generation;
29mod implementations;
30
31use field_analysis::analyze_field_type;
32use code_generation::*;
33use implementations::*;
34
35/// Derive macro for the `State` trait.
36///
37/// This macro generates implementations of:
38/// - The `State` trait for accessing state elements
39/// - Required arithmetic operations (Add, Sub, AddAssign, Mul<T>, Div<T>)
40/// - Clone, Copy, and Debug
41///
42/// It supports structs with fields of type T, [T; N], SMatrix<T, R, C>, or Complex<T> where T implements the required traits.
43///
44/// # Example
45/// ```rust
46/// use differential_equations_derive::State;
47/// use nalgebra::SMatrix;
48/// use num_complex::Complex;
49///
50/// #[derive(State)]
51/// struct MyState<T> {
52/// x: T,
53/// y: T,
54/// velocities: [T; 3],
55/// transformation: SMatrix<T, 2, 2>,
56/// impedance: Complex<T>,
57/// }
58/// ```
59///
60/// Fields can be:
61/// - Single values of type T
62/// - Arrays of type [T; N] where N is a compile-time constant
63/// - Matrices of type SMatrix<T, R, C> where R and C are compile-time constants
64/// - Complex numbers of type Complex<T>
65///
66#[proc_macro_derive(State)]
67pub fn derive_state(input: TokenStream) -> TokenStream {
68 // Parse the input tokens into a syntax tree
69 let input = parse_macro_input!(input as DeriveInput);
70
71 // Get the struct name
72 let name = input.ident;
73
74 // Extract field information
75 let fields = match input.data {
76 Data::Struct(data) => {
77 match data.fields {
78 Fields::Named(fields) => fields.named,
79 Fields::Unnamed(fields) => fields.unnamed,
80 Fields::Unit => panic!("Unit structs are not supported"),
81 }
82 },
83 _ => panic!("State can only be derived for structs"),
84 };
85
86 // Analyze field types and calculate total element count
87 let mut total_elements = 0usize;
88 let mut field_info = Vec::new();
89
90 for field in &fields {
91 let field_type_info = analyze_field_type(&field.ty);
92 total_elements += field_type_info.element_count();
93 field_info.push(field_type_info);
94 }
95
96 // Generate all the required components
97 let field_get_branches = generate_get_branches(&fields, &field_info);
98 let field_set_branches = generate_set_branches(&fields, &field_info);
99 let zeros_init = generate_zeros_init(&fields, &field_info);
100 let debug_fields = generate_debug_fields(&fields);
101
102 // Generate all trait implementations
103 let clone_impl = generate_clone_impl(&name);
104 let copy_impl = generate_copy_impl(&name);
105 let debug_impl = generate_debug_impl(&name, &debug_fields);
106 let state_impl = generate_state_impl(&name, total_elements, &field_get_branches, &field_set_branches, &zeros_init, &fields);
107 let add_impl = generate_add_impl(&name, &fields, &field_info);
108 let sub_impl = generate_sub_impl(&name, &fields, &field_info);
109 let add_assign_impl = generate_add_assign_impl(&name, &fields, &field_info);
110 let neg_impl = generate_neg_impl(&name, &fields, &field_info);
111 let mul_impl = generate_mul_impl(&name, &fields, &field_info);
112 let div_impl = generate_div_impl(&name, &fields, &field_info);
113
114 // Combine all implementations
115 let expanded = quote::quote! {
116 #clone_impl
117 #copy_impl
118 #debug_impl
119 #state_impl
120 #add_impl
121 #sub_impl
122 #add_assign_impl
123 #mul_impl
124 #div_impl
125 #neg_impl
126 };
127
128 // Return the generated implementation
129 TokenStream::from(expanded)
130}