Skip to main content

pharmsol_macros/
lib.rs

1//! Procedural macros for [`pharmsol`](https://crates.io/crates/pharmsol).
2//!
3//! This crate is not intended to be used directly. Use the re-exports from the
4//! `pharmsol` crate instead.
5
6use pharmsol_dsl::{
7    AnalyticalKernel as ResolverAnalyticalKernel, AnalyticalStructureInputKind,
8    AnalyticalStructureInputPlan, AnalyticalStructureInputSource,
9};
10use proc_macro::TokenStream;
11use proc_macro2::{Span, TokenStream as TokenStream2};
12use quote::{quote, ToTokens};
13use std::collections::{HashMap, HashSet};
14use syn::{
15    parse::{Parse, ParseStream, Parser},
16    punctuated::Punctuated,
17    token,
18    visit::Visit,
19    visit_mut::VisitMut,
20    Expr, ExprClosure, Ident, Lit, LitInt, LitStr, Pat, Stmt, Token,
21};
22
23// ---------------------------------------------------------------------------
24// Macro input parsing
25// ---------------------------------------------------------------------------
26
27struct OdeInput {
28    name: LitStr,
29    params: Vec<Ident>,
30    covariates: Vec<Ident>,
31    states: Vec<Ident>,
32    outputs: Vec<SymbolicIndex>,
33    routes: Vec<OdeRouteDecl>,
34    diffeq: ExprClosure,
35    lag: Option<ExprClosure>,
36    fa: Option<ExprClosure>,
37    init: Option<ExprClosure>,
38    out: ExprClosure,
39}
40
41struct AnalyticalInput {
42    name: LitStr,
43    params: Vec<Ident>,
44    derived: Vec<Ident>,
45    covariates: Vec<Ident>,
46    states: Vec<Ident>,
47    outputs: Vec<SymbolicIndex>,
48    routes: Vec<OdeRouteDecl>,
49    structure: Ident,
50    derive: Option<ExprClosure>,
51    lag: Option<ExprClosure>,
52    fa: Option<ExprClosure>,
53    init: Option<ExprClosure>,
54    out: ExprClosure,
55}
56
57struct SdeInput {
58    name: LitStr,
59    params: Vec<Ident>,
60    covariates: Vec<Ident>,
61    states: Vec<Ident>,
62    outputs: Vec<SymbolicIndex>,
63    routes: Vec<OdeRouteDecl>,
64    particles: Expr,
65    drift: ExprClosure,
66    diffusion: ExprClosure,
67    lag: Option<ExprClosure>,
68    fa: Option<ExprClosure>,
69    init: Option<ExprClosure>,
70    out: ExprClosure,
71}
72
73struct OdeRouteDecl {
74    kind: OdeRouteKind,
75    input: SymbolicIndex,
76    destination: Ident,
77}
78
79#[derive(Clone, Copy)]
80enum OdeRouteKind {
81    Bolus,
82    Infusion,
83}
84
85struct AnalyticalKernelSpec {
86    kernel: ResolverAnalyticalKernel,
87    runtime_path: TokenStream2,
88    metadata_kernel: TokenStream2,
89    state_count: usize,
90}
91
92struct RoutePropertyEntry {
93    route: SymbolicIndex,
94    value: Expr,
95}
96
97#[derive(Clone)]
98enum SymbolicIndex {
99    Ident(Ident),
100    Int(LitInt),
101}
102
103impl SymbolicIndex {
104    fn name(&self) -> String {
105        match self {
106            Self::Ident(ident) => ident.to_string(),
107            Self::Int(lit) => lit.base10_digits().to_string(),
108        }
109    }
110
111    fn ident(&self) -> Option<&Ident> {
112        match self {
113            Self::Ident(ident) => Some(ident),
114            Self::Int(_) => None,
115        }
116    }
117
118    fn numeric_value(&self) -> Option<usize> {
119        match self {
120            Self::Ident(_) => None,
121            Self::Int(lit) => Some(
122                lit.base10_parse::<usize>()
123                    .expect("validated numeric label should fit usize"),
124            ),
125        }
126    }
127
128    fn numeric(value: usize) -> Self {
129        Self::Int(LitInt::new(&value.to_string(), Span::call_site()))
130    }
131}
132
133impl Parse for SymbolicIndex {
134    fn parse(input: ParseStream) -> syn::Result<Self> {
135        if input.peek(LitInt) {
136            let lit: LitInt = input.parse()?;
137            lit.base10_parse::<usize>().map_err(|_| {
138                syn::Error::new_spanned(
139                    &lit,
140                    "numeric declaration-first labels must be non-negative base-10 integers that fit in usize",
141                )
142            })?;
143            Ok(Self::Int(lit))
144        } else {
145            Ok(Self::Ident(input.parse()?))
146        }
147    }
148}
149
150impl ToTokens for SymbolicIndex {
151    fn to_tokens(&self, tokens: &mut TokenStream2) {
152        match self {
153            Self::Ident(ident) => ident.to_tokens(tokens),
154            Self::Int(lit) => lit.to_tokens(tokens),
155        }
156    }
157}
158
159impl std::fmt::Display for SymbolicIndex {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        f.write_str(&self.name())
162    }
163}
164
165impl Parse for OdeRouteDecl {
166    fn parse(input: ParseStream) -> syn::Result<Self> {
167        let kind_ident: Ident = input.parse()?;
168        let kind = match kind_ident.to_string().as_str() {
169            "bolus" => OdeRouteKind::Bolus,
170            "infusion" => OdeRouteKind::Infusion,
171            other => {
172                return Err(syn::Error::new_spanned(
173                    &kind_ident,
174                    format!("unknown route kind `{other}`, expected `bolus` or `infusion`"),
175                ));
176            }
177        };
178
179        let content;
180        syn::parenthesized!(content in input);
181        let route_input: SymbolicIndex = content.parse()?;
182        if !content.is_empty() {
183            return Err(content.error("expected a single route input name inside `(...)`"));
184        }
185
186        if !input.peek(Token![->]) {
187            return Err(
188                input.error("expected `->` followed by a destination state in route declaration")
189            );
190        }
191        input.parse::<Token![->]>()?;
192        let destination: Ident = input.parse()?;
193
194        if input.peek(token::Brace) {
195            return Err(
196                input.error("route properties are not supported in declaration-first `ode!` yet")
197            );
198        }
199
200        Ok(Self {
201            kind,
202            input: route_input,
203            destination,
204        })
205    }
206}
207
208impl Parse for OdeInput {
209    fn parse(input: ParseStream) -> syn::Result<Self> {
210        let mut name = None;
211        let mut params = None;
212        let mut covariates = None;
213        let mut states = None;
214        let mut outputs = None;
215        let mut routes = None;
216        let mut diffeq = None;
217        let mut lag = None;
218        let mut fa = None;
219        let mut init = None;
220        let mut out = None;
221
222        while !input.is_empty() {
223            let key: Ident = input.parse()?;
224            input.parse::<Token![:]>()?;
225
226            match key.to_string().as_str() {
227                "name" => set_once_ode(&mut name, input.parse()?, &key, "name")?,
228                "params" => set_once_ode(&mut params, parse_ident_list(input)?, &key, "params")?,
229                "covariates" => set_once_ode(
230                    &mut covariates,
231                    parse_ident_list(input)?,
232                    &key,
233                    "covariates",
234                )?,
235                "states" => set_once_ode(&mut states, parse_ident_list(input)?, &key, "states")?,
236                "outputs" => set_once_ode(
237                    &mut outputs,
238                    parse_symbolic_index_list(input)?,
239                    &key,
240                    "outputs",
241                )?,
242                "routes" => set_once_ode(&mut routes, parse_route_list(input)?, &key, "routes")?,
243                "diffeq" => set_once_ode(&mut diffeq, input.parse()?, &key, "diffeq")?,
244                "lag" => set_once_ode(&mut lag, input.parse()?, &key, "lag")?,
245                "fa" => set_once_ode(&mut fa, input.parse()?, &key, "fa")?,
246                "init" => set_once_ode(&mut init, input.parse()?, &key, "init")?,
247                "out" => set_once_ode(&mut out, input.parse()?, &key, "out")?,
248                other => {
249                    return Err(syn::Error::new_spanned(
250                        &key,
251                        format!(
252                            "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, diffeq, lag, fa, init, out"
253                        ),
254                    ));
255                }
256            }
257
258            if !input.is_empty() {
259                input.parse::<Token![,]>()?;
260            }
261        }
262
263        let name = name.ok_or_else(|| {
264            syn::Error::new(
265                Span::call_site(),
266                "declaration-first `ode!` requires `name`, `params`, `states`, `outputs`, and `routes`; the old inferred-dimensions form has been removed",
267            )
268        })?;
269        let params = params.ok_or_else(|| missing_required_ode_field("params"))?;
270        let covariates = covariates.unwrap_or_default();
271        let states = states.ok_or_else(|| missing_required_ode_field("states"))?;
272        let outputs = outputs.ok_or_else(|| missing_required_ode_field("outputs"))?;
273        let routes = routes.ok_or_else(|| missing_required_ode_field("routes"))?;
274        let diffeq = diffeq.ok_or_else(|| missing_required_ode_field("diffeq"))?;
275        let out = out.ok_or_else(|| missing_required_ode_field("out"))?;
276        validate_ode_diffeq_uses_automatic_injection(&diffeq, &routes)?;
277
278        validate_unique_idents("parameter", &params, "ode!")?;
279        validate_unique_idents("covariate", &covariates, "ode!")?;
280        validate_unique_idents("state", &states, "ode!")?;
281        let output_idents = symbolic_index_idents(&outputs);
282
283        validate_unique_symbolic_indices("output", &outputs, "ode!")?;
284        validate_routes(&routes, &states, "ode!")?;
285        validate_named_binding_compatibility(
286            NamedBindingSets {
287                params: &params,
288                derived: &[],
289                covariates: &covariates,
290                states: &states,
291                outputs: &output_idents,
292                routes: &routes,
293            },
294            OdeBindingClosures {
295                diffeq: &diffeq,
296                common: CommonBindingClosures {
297                    lag: lag.as_ref(),
298                    fa: fa.as_ref(),
299                    init: init.as_ref(),
300                    out: &out,
301                },
302            },
303        )?;
304
305        Ok(Self {
306            name,
307            params,
308            covariates,
309            states,
310            outputs,
311            routes,
312            diffeq,
313            lag,
314            fa,
315            init,
316            out,
317        })
318    }
319}
320
321impl Parse for RoutePropertyEntry {
322    fn parse(input: ParseStream) -> syn::Result<Self> {
323        let route: SymbolicIndex = input.parse()?;
324        input.parse::<Token![=>]>()?;
325        let value: Expr = input.parse()?;
326        Ok(Self { route, value })
327    }
328}
329
330impl Parse for AnalyticalInput {
331    fn parse(input: ParseStream) -> syn::Result<Self> {
332        let mut name = None;
333        let mut params = None;
334        let mut derived = None;
335        let mut covariates = None;
336        let mut states = None;
337        let mut outputs = None;
338        let mut routes = None;
339        let mut structure = None;
340        let mut derive = None;
341        let mut lag = None;
342        let mut fa = None;
343        let mut init = None;
344        let mut out = None;
345
346        while !input.is_empty() {
347            let key: Ident = input.parse()?;
348            input.parse::<Token![:]>()?;
349
350            match key.to_string().as_str() {
351                "name" => set_once_analytical(&mut name, input.parse()?, &key, "name")?,
352                "params" => {
353                    set_once_analytical(&mut params, parse_ident_list(input)?, &key, "params")?
354                }
355                "derived" => {
356                    set_once_analytical(&mut derived, parse_ident_list(input)?, &key, "derived")?
357                }
358                "covariates" => set_once_analytical(
359                    &mut covariates,
360                    parse_ident_list(input)?,
361                    &key,
362                    "covariates",
363                )?,
364                "states" => {
365                    set_once_analytical(&mut states, parse_ident_list(input)?, &key, "states")?
366                }
367                "outputs" => set_once_analytical(
368                    &mut outputs,
369                    parse_symbolic_index_list(input)?,
370                    &key,
371                    "outputs",
372                )?,
373                "routes" => {
374                    set_once_analytical(&mut routes, parse_route_list(input)?, &key, "routes")?
375                }
376                "structure" => {
377                    set_once_analytical(&mut structure, input.parse()?, &key, "structure")?
378                }
379                "derive" => set_once_analytical(&mut derive, input.parse()?, &key, "derive")?,
380                "sec" => {
381                    return Err(syn::Error::new_spanned(
382                        &key,
383                        "built-in `analytical!` no longer supports `sec`; use `derived: [...]` plus `derive: ...`",
384                    ));
385                }
386                "lag" => set_once_analytical(&mut lag, input.parse()?, &key, "lag")?,
387                "fa" => set_once_analytical(&mut fa, input.parse()?, &key, "fa")?,
388                "init" => set_once_analytical(&mut init, input.parse()?, &key, "init")?,
389                "out" => set_once_analytical(&mut out, input.parse()?, &key, "out")?,
390                other => {
391                    return Err(syn::Error::new_spanned(
392                        &key,
393                        format!(
394                            "unknown field `{other}`, expected one of: name, params, derived, covariates, states, outputs, routes, structure, derive, lag, fa, init, out"
395                        ),
396                    ));
397                }
398            }
399
400            if !input.is_empty() {
401                input.parse::<Token![,]>()?;
402            }
403        }
404
405        let name = name.ok_or_else(|| missing_required_analytical_field("name"))?;
406        let params = params.ok_or_else(|| missing_required_analytical_field("params"))?;
407        let derived = derived.unwrap_or_default();
408        let covariates = covariates.unwrap_or_default();
409        let states = states.ok_or_else(|| missing_required_analytical_field("states"))?;
410        let outputs = outputs.ok_or_else(|| missing_required_analytical_field("outputs"))?;
411        let routes = routes.ok_or_else(|| missing_required_analytical_field("routes"))?;
412        let structure = structure.ok_or_else(|| missing_required_analytical_field("structure"))?;
413        let out = out.ok_or_else(|| missing_required_analytical_field("out"))?;
414
415        validate_unique_idents("covariate", &covariates, "analytical!")?;
416        validate_unique_idents("state", &states, "analytical!")?;
417        let output_idents = symbolic_index_idents(&outputs);
418
419        validate_unique_symbolic_indices("output", &outputs, "analytical!")?;
420        validate_routes(&routes, &states, "analytical!")?;
421
422        let kernel_spec = resolve_analytical_structure(&structure)?;
423        validate_analytical_structure_inputs(&structure, kernel_spec.kernel, &params, &derived)?;
424        if states.len() != kernel_spec.state_count {
425            return Err(syn::Error::new_spanned(
426                &structure,
427                format!(
428                    "analytical structure `{}` expects {} state value(s), but `states` declares {}",
429                    structure,
430                    kernel_spec.state_count,
431                    states.len()
432                ),
433            ));
434        }
435
436        validate_analytical_named_binding_compatibility(
437            NamedBindingSets {
438                params: &params,
439                derived: &derived,
440                covariates: &covariates,
441                states: &states,
442                outputs: &output_idents,
443                routes: &routes,
444            },
445            AnalyticalBindingClosures {
446                derive: derive.as_ref(),
447                common: CommonBindingClosures {
448                    lag: lag.as_ref(),
449                    fa: fa.as_ref(),
450                    init: init.as_ref(),
451                    out: &out,
452                },
453            },
454        )?;
455
456        validate_analytical_derive_contract(
457            kernel_spec.kernel,
458            &params,
459            &derived,
460            &covariates,
461            derive.as_ref(),
462        )?;
463
464        if let Some(lag) = lag.as_ref() {
465            let lag_routes =
466                extract_route_property_routes("built-in `analytical!`", "lag", lag, &routes)?;
467            validate_route_property_kinds("built-in `analytical!`", "lag", &routes, &lag_routes)?;
468        }
469
470        if let Some(fa) = fa.as_ref() {
471            let fa_routes =
472                extract_route_property_routes("built-in `analytical!`", "fa", fa, &routes)?;
473            validate_route_property_kinds("built-in `analytical!`", "fa", &routes, &fa_routes)?;
474        }
475
476        Ok(Self {
477            name,
478            params,
479            derived,
480            covariates,
481            states,
482            outputs,
483            routes,
484            structure,
485            derive,
486            lag,
487            fa,
488            init,
489            out,
490        })
491    }
492}
493
494impl Parse for SdeInput {
495    fn parse(input: ParseStream) -> syn::Result<Self> {
496        let mut name = None;
497        let mut params = None;
498        let mut covariates = None;
499        let mut states = None;
500        let mut outputs = None;
501        let mut routes = None;
502        let mut particles = None;
503        let mut drift = None;
504        let mut diffusion = None;
505        let mut lag = None;
506        let mut fa = None;
507        let mut init = None;
508        let mut out = None;
509
510        while !input.is_empty() {
511            let key: Ident = input.parse()?;
512            input.parse::<Token![:]>()?;
513
514            match key.to_string().as_str() {
515                "name" => set_once_sde(&mut name, input.parse()?, &key, "name")?,
516                "params" => set_once_sde(&mut params, parse_ident_list(input)?, &key, "params")?,
517                "covariates" => set_once_sde(
518                    &mut covariates,
519                    parse_ident_list(input)?,
520                    &key,
521                    "covariates",
522                )?,
523                "states" => set_once_sde(&mut states, parse_ident_list(input)?, &key, "states")?,
524                "outputs" => set_once_sde(
525                    &mut outputs,
526                    parse_symbolic_index_list(input)?,
527                    &key,
528                    "outputs",
529                )?,
530                "routes" => set_once_sde(&mut routes, parse_route_list(input)?, &key, "routes")?,
531                "particles" => set_once_sde(&mut particles, input.parse()?, &key, "particles")?,
532                "drift" => set_once_sde(&mut drift, input.parse()?, &key, "drift")?,
533                "diffusion" => set_once_sde(&mut diffusion, input.parse()?, &key, "diffusion")?,
534                "lag" => set_once_sde(&mut lag, input.parse()?, &key, "lag")?,
535                "fa" => set_once_sde(&mut fa, input.parse()?, &key, "fa")?,
536                "init" => set_once_sde(&mut init, input.parse()?, &key, "init")?,
537                "out" => set_once_sde(&mut out, input.parse()?, &key, "out")?,
538                other => {
539                    return Err(syn::Error::new_spanned(
540                        &key,
541                        format!(
542                            "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, particles, drift, diffusion, lag, fa, init, out"
543                        ),
544                    ));
545                }
546            }
547
548            if !input.is_empty() {
549                input.parse::<Token![,]>()?;
550            }
551        }
552
553        let name = name.ok_or_else(|| missing_required_sde_field("name"))?;
554        let params = params.ok_or_else(|| missing_required_sde_field("params"))?;
555        let covariates = covariates.unwrap_or_default();
556        let states = states.ok_or_else(|| missing_required_sde_field("states"))?;
557        let outputs = outputs.ok_or_else(|| missing_required_sde_field("outputs"))?;
558        let routes = routes.ok_or_else(|| missing_required_sde_field("routes"))?;
559        let particles = particles.ok_or_else(|| missing_required_sde_field("particles"))?;
560        let drift = drift.ok_or_else(|| missing_required_sde_field("drift"))?;
561        let diffusion = diffusion.ok_or_else(|| missing_required_sde_field("diffusion"))?;
562        let out = out.ok_or_else(|| missing_required_sde_field("out"))?;
563
564        validate_unique_idents("parameter", &params, "sde!")?;
565        validate_unique_idents("covariate", &covariates, "sde!")?;
566        validate_unique_idents("state", &states, "sde!")?;
567        let output_idents = symbolic_index_idents(&outputs);
568
569        validate_unique_symbolic_indices("output", &outputs, "sde!")?;
570        validate_routes(&routes, &states, "sde!")?;
571        validate_sde_named_binding_compatibility(
572            NamedBindingSets {
573                params: &params,
574                derived: &[],
575                covariates: &covariates,
576                states: &states,
577                outputs: &output_idents,
578                routes: &routes,
579            },
580            SdeBindingClosures {
581                drift: &drift,
582                diffusion: &diffusion,
583                common: CommonBindingClosures {
584                    lag: lag.as_ref(),
585                    fa: fa.as_ref(),
586                    init: init.as_ref(),
587                    out: &out,
588                },
589            },
590        )?;
591
592        if let Some(lag) = lag.as_ref() {
593            let lag_routes =
594                extract_route_property_routes("declaration-first `sde!`", "lag", lag, &routes)?;
595            validate_route_property_kinds("declaration-first `sde!`", "lag", &routes, &lag_routes)?;
596        }
597
598        if let Some(fa) = fa.as_ref() {
599            let fa_routes =
600                extract_route_property_routes("declaration-first `sde!`", "fa", fa, &routes)?;
601            validate_route_property_kinds("declaration-first `sde!`", "fa", &routes, &fa_routes)?;
602        }
603
604        Ok(Self {
605            name,
606            params,
607            covariates,
608            states,
609            outputs,
610            routes,
611            particles,
612            drift,
613            diffusion,
614            lag,
615            fa,
616            init,
617            out,
618        })
619    }
620}
621
622// ---------------------------------------------------------------------------
623// Helpers
624// ---------------------------------------------------------------------------
625
626fn missing_required_ode_field(name: &str) -> syn::Error {
627    syn::Error::new(
628        Span::call_site(),
629        format!("missing required field `{name}` in declaration-first `ode!`"),
630    )
631}
632
633fn missing_required_analytical_field(name: &str) -> syn::Error {
634    syn::Error::new(
635        Span::call_site(),
636        format!("missing required field `{name}` in built-in `analytical!`"),
637    )
638}
639
640fn missing_required_sde_field(name: &str) -> syn::Error {
641    syn::Error::new(
642        Span::call_site(),
643        format!("missing required field `{name}` in declaration-first `sde!`"),
644    )
645}
646
647fn set_once_ode<T>(slot: &mut Option<T>, value: T, key: &Ident, name: &str) -> syn::Result<()> {
648    if slot.is_some() {
649        Err(syn::Error::new_spanned(
650            key,
651            format!("duplicate field `{name}` in `ode!`"),
652        ))
653    } else {
654        *slot = Some(value);
655        Ok(())
656    }
657}
658
659fn set_once_analytical<T>(
660    slot: &mut Option<T>,
661    value: T,
662    key: &Ident,
663    name: &str,
664) -> syn::Result<()> {
665    if slot.is_some() {
666        Err(syn::Error::new_spanned(
667            key,
668            format!("duplicate field `{name}` in `analytical!`"),
669        ))
670    } else {
671        *slot = Some(value);
672        Ok(())
673    }
674}
675
676fn set_once_sde<T>(slot: &mut Option<T>, value: T, key: &Ident, name: &str) -> syn::Result<()> {
677    if slot.is_some() {
678        Err(syn::Error::new_spanned(
679            key,
680            format!("duplicate field `{name}` in `sde!`"),
681        ))
682    } else {
683        *slot = Some(value);
684        Ok(())
685    }
686}
687
688fn parse_ident_list(input: ParseStream) -> syn::Result<Vec<Ident>> {
689    let content;
690    syn::bracketed!(content in input);
691    Ok(Punctuated::<Ident, Token![,]>::parse_terminated(&content)?
692        .into_iter()
693        .collect())
694}
695
696fn parse_symbolic_index_list(input: ParseStream) -> syn::Result<Vec<SymbolicIndex>> {
697    let content;
698    syn::bracketed!(content in input);
699    Ok(
700        Punctuated::<SymbolicIndex, Token![,]>::parse_terminated(&content)?
701            .into_iter()
702            .collect(),
703    )
704}
705
706fn parse_route_list(input: ParseStream) -> syn::Result<Vec<OdeRouteDecl>> {
707    if input.peek(token::Brace) {
708        return Err(input.error("declaration-first macro `routes` must use `[...]`, not `{...}`"));
709    }
710
711    if !input.peek(token::Bracket) {
712        return Err(
713            input.error("expected a bracketed route list like `routes: [infusion(iv) -> central]`")
714        );
715    }
716
717    let content;
718    syn::bracketed!(content in input);
719    Ok(
720        Punctuated::<OdeRouteDecl, Token![,]>::parse_terminated(&content)?
721            .into_iter()
722            .collect(),
723    )
724}
725
726fn param_name(pat: &Pat) -> String {
727    match pat {
728        Pat::Ident(p) => p.ident.to_string(),
729        _ => String::new(),
730    }
731}
732
733fn closure_param_names(c: &ExprClosure) -> Vec<String> {
734    c.inputs.iter().map(param_name).collect()
735}
736
737fn closure_param_ident(c: &ExprClosure, index: usize) -> Option<Ident> {
738    c.inputs.get(index).and_then(|pat| match pat {
739        Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()),
740        _ => None,
741    })
742}
743
744fn generated_ident(name: &str) -> Ident {
745    Ident::new(name, Span::call_site())
746}
747
748fn symbolic_index_idents(labels: &[SymbolicIndex]) -> Vec<Ident> {
749    labels
750        .iter()
751        .filter_map(|label| label.ident().cloned())
752        .collect()
753}
754
755fn symbolic_index_bindings(labels: &[SymbolicIndex]) -> Vec<(SymbolicIndex, usize)> {
756    labels
757        .iter()
758        .cloned()
759        .enumerate()
760        .map(|(index, label)| (label, index))
761        .collect()
762}
763
764fn symbolic_numeric_binding_map(bindings: &[(SymbolicIndex, usize)]) -> HashMap<usize, usize> {
765    bindings
766        .iter()
767        .filter_map(|(label, index)| label.numeric_value().map(|value| (value, *index)))
768        .collect()
769}
770
771#[derive(Default)]
772struct ClosureBodyUsage {
773    idents: HashSet<String>,
774    indexed_idents: HashSet<String>,
775    assigned_indexed_idents: HashSet<String>,
776    contains_macro: bool,
777}
778
779impl ClosureBodyUsage {
780    fn analyze(expr: &Expr) -> Self {
781        let mut usage = Self::default();
782        usage.visit_expr(expr);
783        usage
784    }
785
786    fn uses(&self, ident: &Ident) -> bool {
787        self.contains_macro || self.idents.contains(&ident.to_string())
788    }
789
790    fn mentions(&self, ident: &Ident) -> bool {
791        self.idents.contains(&ident.to_string())
792    }
793
794    fn indexes(&self, ident: &Ident) -> bool {
795        self.indexed_idents.contains(&ident.to_string())
796    }
797
798    fn assigns_index(&self, ident: &Ident) -> bool {
799        self.assigned_indexed_idents.contains(&ident.to_string())
800    }
801}
802
803impl<'ast> Visit<'ast> for ClosureBodyUsage {
804    fn visit_expr_path(&mut self, expr_path: &'ast syn::ExprPath) {
805        if expr_path.qself.is_none()
806            && expr_path.path.leading_colon.is_none()
807            && expr_path.path.segments.len() == 1
808        {
809            self.idents
810                .insert(expr_path.path.segments[0].ident.to_string());
811        }
812
813        syn::visit::visit_expr_path(self, expr_path);
814    }
815
816    fn visit_expr_macro(&mut self, expr_macro: &'ast syn::ExprMacro) {
817        self.contains_macro = true;
818        syn::visit::visit_expr_macro(self, expr_macro);
819    }
820
821    fn visit_stmt_macro(&mut self, stmt_macro: &'ast syn::StmtMacro) {
822        self.contains_macro = true;
823        syn::visit::visit_stmt_macro(self, stmt_macro);
824    }
825
826    fn visit_expr_index(&mut self, expr_index: &'ast syn::ExprIndex) {
827        if let Expr::Path(expr_path) = expr_index.expr.as_ref() {
828            if expr_path.qself.is_none()
829                && expr_path.path.leading_colon.is_none()
830                && expr_path.path.segments.len() == 1
831            {
832                self.indexed_idents
833                    .insert(expr_path.path.segments[0].ident.to_string());
834            }
835        }
836
837        syn::visit::visit_expr_index(self, expr_index);
838    }
839
840    fn visit_expr_assign(&mut self, expr_assign: &'ast syn::ExprAssign) {
841        if let Expr::Index(expr_index) = expr_assign.left.as_ref() {
842            if let Expr::Path(expr_path) = expr_index.expr.as_ref() {
843                if expr_path.qself.is_none()
844                    && expr_path.path.leading_colon.is_none()
845                    && expr_path.path.segments.len() == 1
846                {
847                    self.assigned_indexed_idents
848                        .insert(expr_path.path.segments[0].ident.to_string());
849                }
850            }
851        }
852
853        syn::visit::visit_expr_assign(self, expr_assign);
854    }
855}
856
857struct IndexRewriteTarget {
858    container: Ident,
859    labels: HashMap<usize, usize>,
860}
861
862impl IndexRewriteTarget {
863    fn new(container: Ident, labels: HashMap<usize, usize>) -> Self {
864        Self { container, labels }
865    }
866}
867
868struct NumericLabelRewriter {
869    index_targets: Vec<IndexRewriteTarget>,
870    route_labels: Option<HashMap<usize, usize>>,
871}
872
873impl NumericLabelRewriter {
874    fn rewrite(
875        expr: &Expr,
876        index_targets: Vec<IndexRewriteTarget>,
877        route_labels: Option<HashMap<usize, usize>>,
878    ) -> Expr {
879        let mut rewritten = expr.clone();
880        let mut rewriter = Self {
881            index_targets,
882            route_labels,
883        };
884        rewriter.visit_expr_mut(&mut rewritten);
885        rewritten
886    }
887
888    fn target_labels(&self, path: &syn::ExprPath) -> Option<&HashMap<usize, usize>> {
889        if path.qself.is_some()
890            || path.path.leading_colon.is_some()
891            || path.path.segments.len() != 1
892        {
893            return None;
894        }
895
896        let ident = &path.path.segments[0].ident;
897        self.index_targets
898            .iter()
899            .find(|target| target.container == *ident)
900            .map(|target| &target.labels)
901    }
902
903    fn rewrite_route_macro(&self, mac: &mut syn::Macro) {
904        let Some(route_labels) = self.route_labels.as_ref() else {
905            return;
906        };
907        if !(mac.path.is_ident("lag") || mac.path.is_ident("fa")) {
908            return;
909        }
910
911        let Ok(entries) = Punctuated::<RoutePropertyEntry, Token![,]>::parse_terminated
912            .parse2(mac.tokens.clone())
913        else {
914            return;
915        };
916
917        let entries = entries.into_iter().map(|mut entry| {
918            if let Some(value) = entry.route.numeric_value() {
919                if let Some(internal_index) = route_labels.get(&value) {
920                    entry.route = SymbolicIndex::numeric(*internal_index);
921                }
922            }
923            entry
924        });
925
926        let tokens = entries.map(|entry| {
927            let route = entry.route;
928            let value = entry.value;
929            quote! { #route => #value }
930        });
931        mac.tokens = quote! { #(#tokens),* };
932    }
933}
934
935impl VisitMut for NumericLabelRewriter {
936    fn visit_expr_index_mut(&mut self, expr_index: &mut syn::ExprIndex) {
937        syn::visit_mut::visit_expr_index_mut(self, expr_index);
938
939        let Expr::Path(expr_path) = expr_index.expr.as_ref() else {
940            return;
941        };
942        let Some(labels) = self.target_labels(expr_path) else {
943            return;
944        };
945        let Expr::Lit(expr_lit) = expr_index.index.as_ref() else {
946            return;
947        };
948        let Lit::Int(lit) = &expr_lit.lit else {
949            return;
950        };
951        let Ok(external_index) = lit.base10_parse::<usize>() else {
952            return;
953        };
954        let Some(internal_index) = labels.get(&external_index) else {
955            return;
956        };
957
958        *expr_index.index = Expr::Lit(syn::ExprLit {
959            attrs: Vec::new(),
960            lit: Lit::Int(LitInt::new(&internal_index.to_string(), lit.span())),
961        });
962    }
963
964    fn visit_expr_macro_mut(&mut self, expr_macro: &mut syn::ExprMacro) {
965        self.rewrite_route_macro(&mut expr_macro.mac);
966        syn::visit_mut::visit_expr_macro_mut(self, expr_macro);
967    }
968
969    fn visit_stmt_macro_mut(&mut self, stmt_macro: &mut syn::StmtMacro) {
970        self.rewrite_route_macro(&mut stmt_macro.mac);
971        syn::visit_mut::visit_stmt_macro_mut(self, stmt_macro);
972    }
973}
974
975fn generate_closure_input_aliases(
976    closure: &ExprClosure,
977    internal_names: &[Ident],
978) -> syn::Result<TokenStream2> {
979    if closure.inputs.len() != internal_names.len() {
980        return Err(syn::Error::new_spanned(
981            closure,
982            "internal named binding generation error: closure arity mismatch",
983        ));
984    }
985
986    let aliases =
987        closure
988            .inputs
989            .iter()
990            .zip(internal_names.iter())
991            .map(|(pattern, internal_name)| {
992                quote! {
993                    let #pattern = #internal_name;
994                }
995            });
996
997    Ok(quote! {
998        #(#aliases)*
999    })
1000}
1001
1002fn generate_supported_input_aliases(
1003    closure: &ExprClosure,
1004    supported_internal_names: &[&[Ident]],
1005    error_message: &str,
1006) -> syn::Result<TokenStream2> {
1007    for internal_names in supported_internal_names {
1008        if closure.inputs.len() == internal_names.len() {
1009            return generate_closure_input_aliases(closure, internal_names);
1010        }
1011    }
1012
1013    Err(syn::Error::new_spanned(closure, error_message))
1014}
1015
1016fn generate_parameter_bindings(
1017    params: &[Ident],
1018    closure: &ExprClosure,
1019    parameter_vector: &Ident,
1020) -> TokenStream2 {
1021    let usage = ClosureBodyUsage::analyze(closure.body.as_ref());
1022    let bindings = params
1023        .iter()
1024        .enumerate()
1025        .filter(|(_, ident)| usage.uses(ident))
1026        .map(|(index, ident)| {
1027            quote! {
1028                #[allow(unused_variables)]
1029                let #ident = #parameter_vector[#index];
1030            }
1031        });
1032
1033    quote! {
1034        #(#bindings)*
1035    }
1036}
1037
1038fn generate_derived_bindings(
1039    derived: &[Ident],
1040    closure: &ExprClosure,
1041    derived_values: &Ident,
1042) -> TokenStream2 {
1043    let usage = ClosureBodyUsage::analyze(closure.body.as_ref());
1044    let bindings = derived
1045        .iter()
1046        .enumerate()
1047        .filter(|(_, ident)| usage.uses(ident))
1048        .map(|(index, ident)| {
1049            quote! {
1050                #[allow(unused_variables)]
1051                let #ident = #derived_values[#index];
1052            }
1053        });
1054
1055    quote! {
1056        #(#bindings)*
1057    }
1058}
1059
1060fn generate_covariate_bindings(
1061    covariates: &[Ident],
1062    closure: &ExprClosure,
1063    covariate_map: &Ident,
1064    time: &Ident,
1065) -> TokenStream2 {
1066    let usage = ClosureBodyUsage::analyze(closure.body.as_ref());
1067    let used_covariates = covariates
1068        .iter()
1069        .filter(|ident| usage.uses(ident))
1070        .collect::<Vec<_>>();
1071
1072    if used_covariates.is_empty() {
1073        quote! {}
1074    } else {
1075        quote! {
1076            ::pharmsol::fetch_cov!(#covariate_map, #time, #(#used_covariates),*);
1077        }
1078    }
1079}
1080
1081fn analytical_error_span<'a>(names: &'a [Ident], target: &str) -> Option<&'a Ident> {
1082    names.iter().find(|ident| *ident == target)
1083}
1084
1085fn validate_analytical_structure_inputs(
1086    structure: &Ident,
1087    kernel: ResolverAnalyticalKernel,
1088    params: &[Ident],
1089    derived: &[Ident],
1090) -> syn::Result<AnalyticalStructureInputPlan> {
1091    let primary_names = params.iter().map(Ident::to_string).collect::<Vec<_>>();
1092    let derived_names = derived.iter().map(Ident::to_string).collect::<Vec<_>>();
1093    AnalyticalStructureInputPlan::for_kernel(kernel, &primary_names, &derived_names).map_err(
1094        |error| match error {
1095            pharmsol_dsl::AnalyticalStructureInputError::DuplicatePrimary { name } => {
1096                let span = analytical_error_span(params, &name).unwrap_or(structure);
1097                syn::Error::new_spanned(span, format!("duplicate primary parameter `{name}`"))
1098            }
1099            pharmsol_dsl::AnalyticalStructureInputError::DuplicateDerived { name } => {
1100                let span = analytical_error_span(derived, &name).unwrap_or(structure);
1101                syn::Error::new_spanned(span, format!("duplicate derived parameter `{name}`"))
1102            }
1103            pharmsol_dsl::AnalyticalStructureInputError::ConflictingName { name } => {
1104                let span = analytical_error_span(derived, &name)
1105                    .or_else(|| analytical_error_span(params, &name))
1106                    .unwrap_or(structure);
1107                syn::Error::new_spanned(
1108                    span,
1109                    format!("`{name}` is declared in both `params` and `derived`"),
1110                )
1111            }
1112            pharmsol_dsl::AnalyticalStructureInputError::MissingRequiredName {
1113                structure,
1114                name,
1115                suggestion,
1116            } => {
1117                let message = if let Some(candidate) = suggestion {
1118                    format!(
1119                        "analytical structure `{structure}` requires `{name}`; did you mean `{candidate}`? declare it in `params: [...]` or `derived: [...]`"
1120                    )
1121                } else {
1122                    format!(
1123                        "analytical structure `{structure}` requires `{name}`; declare it in `params: [...]` or `derived: [...]`"
1124                    )
1125                };
1126                syn::Error::new_spanned(structure, message)
1127            }
1128        },
1129    )
1130}
1131
1132#[derive(Clone)]
1133struct DeriveValidationContext {
1134    params: HashSet<String>,
1135    covariates: HashSet<String>,
1136    derived: HashSet<String>,
1137}
1138
1139impl DeriveValidationContext {
1140    fn new(params: &[Ident], covariates: &[Ident], derived: &[Ident]) -> Self {
1141        Self {
1142            params: params.iter().map(Ident::to_string).collect(),
1143            covariates: covariates.iter().map(Ident::to_string).collect(),
1144            derived: derived.iter().map(Ident::to_string).collect(),
1145        }
1146    }
1147
1148    fn invalid_target_error(&self, ident: &Ident) -> syn::Error {
1149        let name = ident.to_string();
1150        let message = if self.params.contains(&name) {
1151            format!(
1152                "`derive` cannot assign to `{name}`; only names declared in `derived: [...]` are valid derive targets"
1153            )
1154        } else if self.covariates.contains(&name) {
1155            format!(
1156                "`derive` cannot assign to covariate `{name}`; only names declared in `derived: [...]` are valid derive targets"
1157            )
1158        } else {
1159            format!(
1160                "`derive` cannot assign to `{name}`; declare it in `derived: [...]` before assigning to it"
1161            )
1162        };
1163        syn::Error::new_spanned(ident, message)
1164    }
1165}
1166
1167fn bound_local_names(pat: &Pat) -> Vec<String> {
1168    struct BoundNames {
1169        names: Vec<String>,
1170    }
1171
1172    impl<'ast> Visit<'ast> for BoundNames {
1173        fn visit_pat_ident(&mut self, pat_ident: &'ast syn::PatIdent) {
1174            self.names.push(pat_ident.ident.to_string());
1175        }
1176    }
1177
1178    let mut bound = BoundNames { names: Vec::new() };
1179    bound.visit_pat(pat);
1180    bound.names
1181}
1182
1183fn analyze_derive_block(
1184    block: &syn::Block,
1185    context: &DeriveValidationContext,
1186    locals: &mut HashSet<String>,
1187    assigned: &HashSet<String>,
1188) -> syn::Result<HashSet<String>> {
1189    let mut assigned_now = assigned.clone();
1190    for stmt in &block.stmts {
1191        assigned_now = analyze_derive_stmt(stmt, context, locals, &assigned_now)?;
1192    }
1193    Ok(assigned_now)
1194}
1195
1196fn analyze_derive_stmt(
1197    stmt: &Stmt,
1198    context: &DeriveValidationContext,
1199    locals: &mut HashSet<String>,
1200    assigned: &HashSet<String>,
1201) -> syn::Result<HashSet<String>> {
1202    match stmt {
1203        Stmt::Local(local) => {
1204            if let Some(init) = &local.init {
1205                let _ = analyze_derive_expr(&init.expr, context, &mut locals.clone(), assigned)?;
1206            }
1207            for name in bound_local_names(&local.pat) {
1208                locals.insert(name);
1209            }
1210            Ok(assigned.clone())
1211        }
1212        Stmt::Expr(expr, _) => analyze_derive_expr(expr, context, locals, assigned),
1213        Stmt::Macro(stmt_macro) => Err(syn::Error::new_spanned(
1214            stmt_macro,
1215            "`derive` only supports assignments, `if`, `if` / `else`, `for`, and local `let` bindings",
1216        )),
1217        _ => Ok(assigned.clone()),
1218    }
1219}
1220
1221fn analyze_derive_expr(
1222    expr: &Expr,
1223    context: &DeriveValidationContext,
1224    locals: &mut HashSet<String>,
1225    assigned: &HashSet<String>,
1226) -> syn::Result<HashSet<String>> {
1227    match expr {
1228        Expr::Assign(assign) => {
1229            if let Expr::Path(path) = assign.left.as_ref() {
1230                if path.qself.is_none()
1231                    && path.path.leading_colon.is_none()
1232                    && path.path.segments.len() == 1
1233                {
1234                    let ident = &path.path.segments[0].ident;
1235                    let name = ident.to_string();
1236                    if context.derived.contains(&name) {
1237                        let mut next = assigned.clone();
1238                        next.insert(name);
1239                        return Ok(next);
1240                    }
1241                    if locals.contains(&name) {
1242                        return Ok(assigned.clone());
1243                    }
1244                    return Err(context.invalid_target_error(ident));
1245                }
1246            }
1247            Err(syn::Error::new_spanned(
1248                &assign.left,
1249                "`derive` assignments must target a name declared in `derived: [...]`",
1250            ))
1251        }
1252        Expr::If(expr_if) => {
1253            let mut then_locals = locals.clone();
1254            let then_assigned = analyze_derive_block(
1255                &expr_if.then_branch,
1256                context,
1257                &mut then_locals,
1258                assigned,
1259            )?;
1260
1261            if let Some((_, else_branch)) = &expr_if.else_branch {
1262                let mut else_locals = locals.clone();
1263                let else_assigned = analyze_derive_expr(
1264                    else_branch,
1265                    context,
1266                    &mut else_locals,
1267                    assigned,
1268                )?;
1269                Ok(then_assigned
1270                    .intersection(&else_assigned)
1271                    .cloned()
1272                    .collect::<HashSet<_>>())
1273            } else {
1274                Ok(assigned.clone())
1275            }
1276        }
1277        Expr::ForLoop(expr_for) => {
1278            let mut loop_locals = locals.clone();
1279            for name in bound_local_names(&expr_for.pat) {
1280                loop_locals.insert(name);
1281            }
1282            let _ = analyze_derive_block(&expr_for.body, context, &mut loop_locals, assigned)?;
1283            Ok(assigned.clone())
1284        }
1285        Expr::Block(expr_block) => analyze_derive_block(&expr_block.block, context, locals, assigned),
1286        Expr::While(expr_while) => Err(syn::Error::new_spanned(
1287            expr_while,
1288            "`derive` does not support `while`; use straight-line code, `if`, `if` / `else`, or `for`",
1289        )),
1290        Expr::Loop(expr_loop) => Err(syn::Error::new_spanned(
1291            expr_loop,
1292            "`derive` does not support `loop`; use straight-line code, `if`, `if` / `else`, or `for`",
1293        )),
1294        Expr::Match(expr_match) => Err(syn::Error::new_spanned(
1295            expr_match,
1296            "`derive` does not support `match`; use straight-line code, `if`, `if` / `else`, or `for`",
1297        )),
1298        _ => Ok(assigned.clone()),
1299    }
1300}
1301
1302fn validate_analytical_derive_contract(
1303    kernel: ResolverAnalyticalKernel,
1304    params: &[Ident],
1305    derived: &[Ident],
1306    covariates: &[Ident],
1307    derive: Option<&ExprClosure>,
1308) -> syn::Result<()> {
1309    if derived.is_empty() {
1310        if let Some(derive) = derive {
1311            return Err(syn::Error::new_spanned(
1312                derive,
1313                "built-in `analytical!` `derive` requires `derived: [...]`",
1314            ));
1315        }
1316        return Ok(());
1317    }
1318
1319    let derive = derive.ok_or_else(|| {
1320        syn::Error::new_spanned(
1321            &derived[0],
1322            "built-in `analytical!` declares `derived: [...]` but is missing `derive: ...`",
1323        )
1324    })?;
1325
1326    let p = generated_ident("__pharmsol_p");
1327    let t = generated_ident("__pharmsol_t");
1328    let cov = generated_ident("__pharmsol_cov");
1329    let full_inputs = [p, t.clone(), cov];
1330    let reduced_inputs = [t];
1331    generate_supported_input_aliases(
1332        derive,
1333        &[&full_inputs, &reduced_inputs],
1334        "built-in `analytical!` requires `derive` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|",
1335    )?;
1336
1337    let context = DeriveValidationContext::new(params, covariates, derived);
1338    let mut locals = HashSet::new();
1339    let assigned = match derive.body.as_ref() {
1340        Expr::Block(expr_block) => {
1341            analyze_derive_block(&expr_block.block, &context, &mut locals, &HashSet::new())?
1342        }
1343        expr => analyze_derive_expr(expr, &context, &mut locals, &HashSet::new())?,
1344    };
1345
1346    let required_derived = match validate_analytical_structure_inputs(
1347        &Ident::new(kernel.name(), Span::call_site()),
1348        kernel,
1349        params,
1350        derived,
1351    ) {
1352        Ok(plan) => match plan.kind() {
1353            AnalyticalStructureInputKind::AllPrimary { .. } => HashSet::new(),
1354            AnalyticalStructureInputKind::AllDerived { indices, .. } => indices
1355                .iter()
1356                .map(|index| derived[*index].to_string())
1357                .collect::<HashSet<_>>(),
1358            AnalyticalStructureInputKind::Mixed { bindings } => bindings
1359                .iter()
1360                .filter_map(|binding| match binding.source {
1361                    AnalyticalStructureInputSource::Primary => None,
1362                    AnalyticalStructureInputSource::Derived => {
1363                        Some(derived[binding.index].to_string())
1364                    }
1365                })
1366                .collect::<HashSet<_>>(),
1367        },
1368        Err(_) => HashSet::new(),
1369    };
1370
1371    for ident in derived {
1372        let name = ident.to_string();
1373        if !assigned.contains(&name) {
1374            let message = if required_derived.contains(&name) {
1375                format!(
1376                    "derived parameter `{name}` is not definitely assigned on every path before analytical structure `{}` uses it",
1377                    kernel.name()
1378                )
1379            } else {
1380                format!(
1381                    "derived parameter `{name}` is declared in `derived: [...]` but is not definitely assigned in `derive`"
1382                )
1383            };
1384            return Err(syn::Error::new_spanned(ident, message));
1385        }
1386    }
1387
1388    Ok(())
1389}
1390
1391fn validate_ode_diffeq_uses_automatic_injection(
1392    diffeq: &ExprClosure,
1393    routes: &[OdeRouteDecl],
1394) -> syn::Result<()> {
1395    match closure_param_names(diffeq).len() {
1396        3 => Ok(()),
1397        5 => {
1398            let usage = ClosureBodyUsage::analyze(diffeq.body.as_ref());
1399            let route_inputs = route_input_idents(routes);
1400            let fourth_param = closure_param_ident(diffeq, 3);
1401            let fifth_param = closure_param_ident(diffeq, 4);
1402            let mentions_route_inputs = route_inputs.iter().any(|route| usage.mentions(route));
1403            let indexes_fifth_param = fifth_param.as_ref().is_some_and(|ident| usage.indexes(ident));
1404            let reads_fourth_param_as_input = fourth_param
1405                .as_ref()
1406                .is_some_and(|ident| usage.indexes(ident) && !usage.assigns_index(ident));
1407
1408            if mentions_route_inputs || indexes_fifth_param || reads_fourth_param_as_input {
1409                Err(syn::Error::new_spanned(
1410                    diffeq,
1411                    "declaration-first `ode!` only supports automatic route injection in `diffeq`; use either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx| and remove manual `bolus[...]` / `rateiv[...]` terms",
1412                ))
1413            } else {
1414                Ok(())
1415            }
1416        }
1417        _ => Err(syn::Error::new_spanned(
1418            diffeq,
1419            "declaration-first `ode!` only supports automatic route injection in `diffeq`; use either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|",
1420        )),
1421    }
1422}
1423
1424fn route_input_idents(routes: &[OdeRouteDecl]) -> Vec<Ident> {
1425    routes
1426        .iter()
1427        .filter_map(|route| route.input.ident().cloned())
1428        .collect()
1429}
1430
1431fn route_input_names(routes: &[OdeRouteDecl]) -> Vec<String> {
1432    routes.iter().map(|route| route.input.name()).collect()
1433}
1434
1435fn ode_route_input_bindings(routes: &[OdeRouteDecl]) -> Vec<(SymbolicIndex, usize)> {
1436    let mut next_bolus_index = 0usize;
1437    let mut next_infusion_index = 0usize;
1438
1439    routes
1440        .iter()
1441        .map(|route| {
1442            let index = match route.kind {
1443                OdeRouteKind::Bolus => {
1444                    let index = next_bolus_index;
1445                    next_bolus_index += 1;
1446                    index
1447                }
1448                OdeRouteKind::Infusion => {
1449                    let index = next_infusion_index;
1450                    next_infusion_index += 1;
1451                    index
1452                }
1453            };
1454            (route.input.clone(), index)
1455        })
1456        .collect()
1457}
1458
1459fn dense_index_len(bindings: &[(SymbolicIndex, usize)]) -> usize {
1460    bindings
1461        .iter()
1462        .map(|(_, index)| index + 1)
1463        .max()
1464        .unwrap_or(0)
1465}
1466
1467fn validate_binding_conflicts(
1468    left_label: &str,
1469    left: &[Ident],
1470    right_label: &str,
1471    right: &[Ident],
1472    context: &str,
1473) -> syn::Result<()> {
1474    let right_names = right.iter().map(Ident::to_string).collect::<HashSet<_>>();
1475
1476    for ident in left {
1477        let name = ident.to_string();
1478        if right_names.contains(&name) {
1479            return Err(syn::Error::new_spanned(
1480                ident,
1481                format!(
1482                    "named {left_label} binding `{name}` conflicts with named {right_label} binding in {context}"
1483                ),
1484            ));
1485        }
1486    }
1487
1488    Ok(())
1489}
1490
1491fn validate_closure_param_conflicts(
1492    closure_label: &str,
1493    closure: &ExprClosure,
1494    bindings: &[Ident],
1495    binding_label: &str,
1496) -> syn::Result<()> {
1497    let parameter_names = closure_param_names(closure)
1498        .into_iter()
1499        .filter(|name| !name.is_empty())
1500        .collect::<HashSet<_>>();
1501
1502    for ident in bindings {
1503        let name = ident.to_string();
1504        if parameter_names.contains(&name) {
1505            return Err(syn::Error::new_spanned(
1506                ident,
1507                format!(
1508                    "named {binding_label} binding `{name}` conflicts with `{closure_label}` closure parameter `{name}`"
1509                ),
1510            ));
1511        }
1512    }
1513
1514    Ok(())
1515}
1516
1517#[derive(Clone, Copy)]
1518struct NamedBindingSets<'a> {
1519    params: &'a [Ident],
1520    derived: &'a [Ident],
1521    covariates: &'a [Ident],
1522    states: &'a [Ident],
1523    outputs: &'a [Ident],
1524    routes: &'a [OdeRouteDecl],
1525}
1526
1527#[derive(Clone, Copy)]
1528struct CommonBindingClosures<'a> {
1529    lag: Option<&'a ExprClosure>,
1530    fa: Option<&'a ExprClosure>,
1531    init: Option<&'a ExprClosure>,
1532    out: &'a ExprClosure,
1533}
1534
1535#[derive(Clone, Copy)]
1536struct AnalyticalBindingClosures<'a> {
1537    derive: Option<&'a ExprClosure>,
1538    common: CommonBindingClosures<'a>,
1539}
1540
1541#[derive(Clone, Copy)]
1542struct OdeBindingClosures<'a> {
1543    diffeq: &'a ExprClosure,
1544    common: CommonBindingClosures<'a>,
1545}
1546
1547#[derive(Clone, Copy)]
1548struct SdeBindingClosures<'a> {
1549    drift: &'a ExprClosure,
1550    diffusion: &'a ExprClosure,
1551    common: CommonBindingClosures<'a>,
1552}
1553
1554fn validate_named_binding_compatibility(
1555    bindings: NamedBindingSets<'_>,
1556    closures: OdeBindingClosures<'_>,
1557) -> syn::Result<()> {
1558    let NamedBindingSets {
1559        params,
1560        derived: _,
1561        covariates,
1562        states,
1563        outputs,
1564        routes,
1565    } = bindings;
1566    let OdeBindingClosures {
1567        diffeq,
1568        common: CommonBindingClosures { lag, fa, init, out },
1569    } = closures;
1570    let route_inputs = route_input_idents(routes);
1571
1572    validate_binding_conflicts(
1573        "parameter",
1574        params,
1575        "covariate",
1576        covariates,
1577        "declaration-first `ode!` named binding generation",
1578    )?;
1579    validate_binding_conflicts(
1580        "parameter",
1581        params,
1582        "state",
1583        states,
1584        "`diffeq` and `out` named binding generation",
1585    )?;
1586    validate_binding_conflicts(
1587        "parameter",
1588        params,
1589        "output",
1590        outputs,
1591        "`out` named binding generation",
1592    )?;
1593    validate_binding_conflicts(
1594        "state",
1595        states,
1596        "output",
1597        outputs,
1598        "`out` named binding generation",
1599    )?;
1600    validate_binding_conflicts(
1601        "covariate",
1602        covariates,
1603        "state",
1604        states,
1605        "declaration-first `ode!` named binding generation",
1606    )?;
1607    validate_binding_conflicts(
1608        "covariate",
1609        covariates,
1610        "output",
1611        outputs,
1612        "declaration-first `ode!` named binding generation",
1613    )?;
1614
1615    validate_closure_param_conflicts("diffeq", diffeq, params, "parameter")?;
1616    validate_closure_param_conflicts("diffeq", diffeq, covariates, "covariate")?;
1617    validate_closure_param_conflicts("diffeq", diffeq, states, "state")?;
1618
1619    if let Some(lag) = lag {
1620        validate_binding_conflicts(
1621            "covariate",
1622            covariates,
1623            "route",
1624            &route_inputs,
1625            "`lag` named binding generation",
1626        )?;
1627        validate_closure_param_conflicts("lag", lag, params, "parameter")?;
1628        validate_closure_param_conflicts("lag", lag, covariates, "covariate")?;
1629        validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?;
1630    }
1631
1632    if let Some(fa) = fa {
1633        validate_binding_conflicts(
1634            "covariate",
1635            covariates,
1636            "route",
1637            &route_inputs,
1638            "`fa` named binding generation",
1639        )?;
1640        validate_closure_param_conflicts("fa", fa, params, "parameter")?;
1641        validate_closure_param_conflicts("fa", fa, covariates, "covariate")?;
1642        validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?;
1643    }
1644
1645    if let Some(init) = init {
1646        validate_closure_param_conflicts("init", init, params, "parameter")?;
1647        validate_closure_param_conflicts("init", init, covariates, "covariate")?;
1648        validate_closure_param_conflicts("init", init, states, "state")?;
1649    }
1650
1651    validate_closure_param_conflicts("out", out, params, "parameter")?;
1652    validate_closure_param_conflicts("out", out, covariates, "covariate")?;
1653    validate_closure_param_conflicts("out", out, states, "state")?;
1654    validate_closure_param_conflicts("out", out, outputs, "output")?;
1655
1656    Ok(())
1657}
1658
1659fn validate_analytical_named_binding_compatibility(
1660    bindings: NamedBindingSets<'_>,
1661    closures: AnalyticalBindingClosures<'_>,
1662) -> syn::Result<()> {
1663    let NamedBindingSets {
1664        params,
1665        derived,
1666        covariates,
1667        states,
1668        outputs,
1669        routes,
1670    } = bindings;
1671    let AnalyticalBindingClosures {
1672        derive,
1673        common: CommonBindingClosures { lag, fa, init, out },
1674    } = closures;
1675    let route_inputs = route_input_idents(routes);
1676
1677    validate_binding_conflicts(
1678        "parameter",
1679        params,
1680        "covariate",
1681        covariates,
1682        "`analytical!` named binding generation",
1683    )?;
1684    validate_binding_conflicts(
1685        "derived parameter",
1686        derived,
1687        "covariate",
1688        covariates,
1689        "`analytical!` named binding generation",
1690    )?;
1691    validate_binding_conflicts(
1692        "parameter",
1693        params,
1694        "state",
1695        states,
1696        "`analytical!` named binding generation",
1697    )?;
1698    validate_binding_conflicts(
1699        "derived parameter",
1700        derived,
1701        "state",
1702        states,
1703        "`analytical!` named binding generation",
1704    )?;
1705    validate_binding_conflicts(
1706        "parameter",
1707        params,
1708        "output",
1709        outputs,
1710        "`analytical!` named binding generation",
1711    )?;
1712    validate_binding_conflicts(
1713        "derived parameter",
1714        derived,
1715        "output",
1716        outputs,
1717        "`analytical!` named binding generation",
1718    )?;
1719    validate_binding_conflicts(
1720        "covariate",
1721        covariates,
1722        "state",
1723        states,
1724        "`analytical!` named binding generation",
1725    )?;
1726    validate_binding_conflicts(
1727        "covariate",
1728        covariates,
1729        "output",
1730        outputs,
1731        "`analytical!` named binding generation",
1732    )?;
1733    validate_binding_conflicts(
1734        "covariate",
1735        covariates,
1736        "route",
1737        &route_inputs,
1738        "`analytical!` named binding generation",
1739    )?;
1740    validate_binding_conflicts(
1741        "parameter",
1742        params,
1743        "route",
1744        &route_inputs,
1745        "`analytical!` named binding generation",
1746    )?;
1747    validate_binding_conflicts(
1748        "derived parameter",
1749        derived,
1750        "route",
1751        &route_inputs,
1752        "`analytical!` named binding generation",
1753    )?;
1754    validate_binding_conflicts(
1755        "state",
1756        states,
1757        "output",
1758        outputs,
1759        "`analytical!` named binding generation",
1760    )?;
1761    validate_binding_conflicts(
1762        "state",
1763        states,
1764        "route",
1765        &route_inputs,
1766        "`analytical!` named binding generation",
1767    )?;
1768    validate_binding_conflicts(
1769        "output",
1770        outputs,
1771        "route",
1772        &route_inputs,
1773        "`analytical!` named binding generation",
1774    )?;
1775
1776    if let Some(derive) = derive {
1777        validate_closure_param_conflicts("derive", derive, params, "parameter")?;
1778        validate_closure_param_conflicts("derive", derive, derived, "derived parameter")?;
1779        validate_closure_param_conflicts("derive", derive, covariates, "covariate")?;
1780    }
1781
1782    if let Some(lag) = lag {
1783        validate_closure_param_conflicts("lag", lag, params, "parameter")?;
1784        validate_closure_param_conflicts("lag", lag, derived, "derived parameter")?;
1785        validate_closure_param_conflicts("lag", lag, covariates, "covariate")?;
1786        validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?;
1787    }
1788
1789    if let Some(fa) = fa {
1790        validate_closure_param_conflicts("fa", fa, params, "parameter")?;
1791        validate_closure_param_conflicts("fa", fa, derived, "derived parameter")?;
1792        validate_closure_param_conflicts("fa", fa, covariates, "covariate")?;
1793        validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?;
1794    }
1795
1796    if let Some(init) = init {
1797        validate_closure_param_conflicts("init", init, params, "parameter")?;
1798        validate_closure_param_conflicts("init", init, derived, "derived parameter")?;
1799        validate_closure_param_conflicts("init", init, covariates, "covariate")?;
1800        validate_closure_param_conflicts("init", init, states, "state")?;
1801    }
1802
1803    validate_closure_param_conflicts("out", out, params, "parameter")?;
1804    validate_closure_param_conflicts("out", out, derived, "derived parameter")?;
1805    validate_closure_param_conflicts("out", out, covariates, "covariate")?;
1806    validate_closure_param_conflicts("out", out, states, "state")?;
1807    validate_closure_param_conflicts("out", out, outputs, "output")?;
1808
1809    Ok(())
1810}
1811
1812fn validate_sde_named_binding_compatibility(
1813    bindings: NamedBindingSets<'_>,
1814    closures: SdeBindingClosures<'_>,
1815) -> syn::Result<()> {
1816    let NamedBindingSets {
1817        params,
1818        derived: _,
1819        covariates,
1820        states,
1821        outputs,
1822        routes,
1823    } = bindings;
1824    let SdeBindingClosures {
1825        drift,
1826        diffusion,
1827        common: CommonBindingClosures { lag, fa, init, out },
1828    } = closures;
1829    let route_inputs = route_input_idents(routes);
1830
1831    validate_binding_conflicts(
1832        "parameter",
1833        params,
1834        "covariate",
1835        covariates,
1836        "`sde!` named binding generation",
1837    )?;
1838    validate_binding_conflicts(
1839        "parameter",
1840        params,
1841        "state",
1842        states,
1843        "`sde!` named binding generation",
1844    )?;
1845    validate_binding_conflicts(
1846        "parameter",
1847        params,
1848        "output",
1849        outputs,
1850        "`sde!` named binding generation",
1851    )?;
1852    validate_binding_conflicts(
1853        "covariate",
1854        covariates,
1855        "state",
1856        states,
1857        "`sde!` named binding generation",
1858    )?;
1859    validate_binding_conflicts(
1860        "covariate",
1861        covariates,
1862        "output",
1863        outputs,
1864        "`sde!` named binding generation",
1865    )?;
1866    validate_binding_conflicts(
1867        "covariate",
1868        covariates,
1869        "route",
1870        &route_inputs,
1871        "`sde!` named binding generation",
1872    )?;
1873    validate_binding_conflicts(
1874        "parameter",
1875        params,
1876        "route",
1877        &route_inputs,
1878        "`sde!` named binding generation",
1879    )?;
1880    validate_binding_conflicts(
1881        "state",
1882        states,
1883        "output",
1884        outputs,
1885        "`sde!` named binding generation",
1886    )?;
1887    validate_binding_conflicts(
1888        "state",
1889        states,
1890        "route",
1891        &route_inputs,
1892        "`sde!` named binding generation",
1893    )?;
1894    validate_binding_conflicts(
1895        "output",
1896        outputs,
1897        "route",
1898        &route_inputs,
1899        "`sde!` named binding generation",
1900    )?;
1901
1902    validate_closure_param_conflicts("drift", drift, params, "parameter")?;
1903    validate_closure_param_conflicts("drift", drift, covariates, "covariate")?;
1904    validate_closure_param_conflicts("drift", drift, states, "state")?;
1905    validate_closure_param_conflicts("diffusion", diffusion, params, "parameter")?;
1906    validate_closure_param_conflicts("diffusion", diffusion, states, "state")?;
1907
1908    if let Some(lag) = lag {
1909        validate_closure_param_conflicts("lag", lag, params, "parameter")?;
1910        validate_closure_param_conflicts("lag", lag, covariates, "covariate")?;
1911        validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?;
1912    }
1913
1914    if let Some(fa) = fa {
1915        validate_closure_param_conflicts("fa", fa, params, "parameter")?;
1916        validate_closure_param_conflicts("fa", fa, covariates, "covariate")?;
1917        validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?;
1918    }
1919
1920    if let Some(init) = init {
1921        validate_closure_param_conflicts("init", init, params, "parameter")?;
1922        validate_closure_param_conflicts("init", init, covariates, "covariate")?;
1923        validate_closure_param_conflicts("init", init, states, "state")?;
1924    }
1925
1926    validate_closure_param_conflicts("out", out, params, "parameter")?;
1927    validate_closure_param_conflicts("out", out, covariates, "covariate")?;
1928    validate_closure_param_conflicts("out", out, states, "state")?;
1929    validate_closure_param_conflicts("out", out, outputs, "output")?;
1930
1931    Ok(())
1932}
1933
1934fn generate_index_consts(idents: &[Ident]) -> TokenStream2 {
1935    let bindings = idents.iter().enumerate().map(|(index, ident)| {
1936        quote! {
1937            #[allow(non_upper_case_globals, dead_code)]
1938            const #ident: usize = #index;
1939        }
1940    });
1941
1942    quote! {
1943        #(#bindings)*
1944    }
1945}
1946
1947fn generate_mapped_index_consts(bindings: &[(SymbolicIndex, usize)]) -> TokenStream2 {
1948    let bindings = bindings.iter().filter_map(|(label, index)| {
1949        label.ident().map(|ident| {
1950            quote! {
1951                #[allow(non_upper_case_globals, dead_code)]
1952                const #ident: usize = #index;
1953            }
1954        })
1955    });
1956
1957    quote! {
1958        #(#bindings)*
1959    }
1960}
1961
1962fn expand_out(
1963    out: &ExprClosure,
1964    params: &[Ident],
1965    covariates: &[Ident],
1966    states: &[Ident],
1967    outputs: &[SymbolicIndex],
1968) -> syn::Result<TokenStream2> {
1969    let state_consts = generate_index_consts(states);
1970    let output_bindings = symbolic_index_bindings(outputs);
1971    let output_consts = generate_mapped_index_consts(&output_bindings);
1972    let x = generated_ident("__pharmsol_x");
1973    let p = generated_ident("__pharmsol_p");
1974    let t = generated_ident("__pharmsol_t");
1975    let cov = generated_ident("__pharmsol_cov");
1976    let y = generated_ident("__pharmsol_y");
1977    let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()];
1978    let reduced_inputs = [x.clone(), t.clone(), y.clone()];
1979    let input_aliases = generate_supported_input_aliases(
1980        out,
1981        &[&full_inputs, &reduced_inputs],
1982        "declaration-first `ode!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|",
1983    )?;
1984    let parameter_bindings = generate_parameter_bindings(params, out, &p);
1985    let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t);
1986    let y_binding = if out.inputs.len() == full_inputs.len() {
1987        closure_param_ident(out, 4).unwrap_or_else(|| y.clone())
1988    } else {
1989        closure_param_ident(out, 2).unwrap_or_else(|| y.clone())
1990    };
1991    let body = NumericLabelRewriter::rewrite(
1992        out.body.as_ref(),
1993        vec![IndexRewriteTarget::new(
1994            y_binding,
1995            symbolic_numeric_binding_map(&output_bindings),
1996        )],
1997        None,
1998    );
1999
2000    Ok(quote! {{
2001        let __pharmsol_out: fn(
2002            &::pharmsol::simulator::V,
2003            &::pharmsol::simulator::V,
2004            f64,
2005            &::pharmsol::data::Covariates,
2006            &mut ::pharmsol::simulator::V,
2007        ) = |#x: &::pharmsol::simulator::V,
2008             #p: &::pharmsol::simulator::V,
2009             #t: f64,
2010             #cov: &::pharmsol::data::Covariates,
2011             #y: &mut ::pharmsol::simulator::V| {
2012            #input_aliases
2013            #state_consts
2014            #output_consts
2015            #parameter_bindings
2016            #covariate_bindings
2017            #body
2018        };
2019        __pharmsol_out
2020    }})
2021}
2022
2023fn route_property_error<T: ToTokens>(macro_name: &str, label: &str, node: T) -> syn::Error {
2024    syn::Error::new_spanned(
2025        node,
2026        format!(
2027            "{macro_name} requires `{label}` to return `{label}! {{ ... }}` so route-property metadata can be synthesized"
2028        ),
2029    )
2030}
2031
2032fn find_terminal_macro_invocation(
2033    macro_name: &str,
2034    label: &str,
2035    closure: &ExprClosure,
2036) -> syn::Result<syn::Macro> {
2037    match closure.body.as_ref() {
2038        Expr::Macro(expr_macro) if expr_macro.mac.path.is_ident(label) => {
2039            Ok(expr_macro.mac.clone())
2040        }
2041        Expr::Macro(expr_macro) => Err(route_property_error(macro_name, label, expr_macro)),
2042        Expr::Block(expr_block) => {
2043            for stmt in expr_block.block.stmts.iter().rev() {
2044                match stmt {
2045                    Stmt::Expr(Expr::Macro(expr_macro), _)
2046                        if expr_macro.mac.path.is_ident(label) =>
2047                    {
2048                        return Ok(expr_macro.mac.clone());
2049                    }
2050                    Stmt::Expr(Expr::Macro(expr_macro), _) => {
2051                        return Err(route_property_error(macro_name, label, expr_macro));
2052                    }
2053                    Stmt::Expr(other, _) => {
2054                        return Err(route_property_error(macro_name, label, other));
2055                    }
2056                    Stmt::Macro(stmt_macro) if stmt_macro.mac.path.is_ident(label) => {
2057                        return Ok(stmt_macro.mac.clone());
2058                    }
2059                    Stmt::Macro(stmt_macro) => {
2060                        return Err(route_property_error(macro_name, label, stmt_macro));
2061                    }
2062                    _ => continue,
2063                }
2064            }
2065
2066            Err(route_property_error(macro_name, label, expr_block))
2067        }
2068        other => Err(route_property_error(macro_name, label, other)),
2069    }
2070}
2071
2072fn extract_route_property_routes(
2073    macro_name: &str,
2074    label: &str,
2075    closure: &ExprClosure,
2076    routes: &[OdeRouteDecl],
2077) -> syn::Result<HashSet<String>> {
2078    let macro_expr = find_terminal_macro_invocation(macro_name, label, closure)?;
2079    let entries = Punctuated::<RoutePropertyEntry, Token![,]>::parse_terminated
2080        .parse2(macro_expr.tokens.clone())?;
2081    let known_routes = route_input_names(routes)
2082        .into_iter()
2083        .collect::<HashSet<_>>();
2084    let mut seen = HashSet::new();
2085
2086    for entry in entries {
2087        let route_name = entry.route.name();
2088        if !known_routes.contains(&route_name) {
2089            return Err(syn::Error::new_spanned(
2090                &entry.route,
2091                format!(
2092                    "route `{route_name}` in `{label}!` is not declared in the `routes` section"
2093                ),
2094            ));
2095        }
2096        if !seen.insert(route_name.clone()) {
2097            return Err(syn::Error::new_spanned(
2098                &entry.route,
2099                format!("duplicate route `{route_name}` in `{label}!`"),
2100            ));
2101        }
2102        let _ = entry.value;
2103    }
2104
2105    Ok(seen)
2106}
2107
2108fn validate_route_property_kinds(
2109    macro_name: &str,
2110    label: &str,
2111    routes: &[OdeRouteDecl],
2112    property_routes: &HashSet<String>,
2113) -> syn::Result<()> {
2114    for route in routes {
2115        if property_routes.contains(&route.input.name())
2116            && matches!(route.kind, OdeRouteKind::Infusion)
2117        {
2118            return Err(syn::Error::new_spanned(
2119                &route.input,
2120                format!(
2121                    "{macro_name} does not allow `{label}` on infusion route `{}`",
2122                    route.input
2123                ),
2124            ));
2125        }
2126    }
2127
2128    Ok(())
2129}
2130
2131fn expand_ode_route_map(
2132    label: &str,
2133    closure: &ExprClosure,
2134    params: &[Ident],
2135    covariates: &[Ident],
2136    route_bindings: &[(SymbolicIndex, usize)],
2137) -> syn::Result<TokenStream2> {
2138    let route_consts = generate_mapped_index_consts(route_bindings);
2139    let p = generated_ident("__pharmsol_p");
2140    let t = generated_ident("__pharmsol_t");
2141    let cov = generated_ident("__pharmsol_cov");
2142    let full_inputs = [p.clone(), t.clone(), cov.clone()];
2143    let reduced_inputs = [t.clone()];
2144    let input_aliases = generate_supported_input_aliases(
2145        closure,
2146        &[&full_inputs, &reduced_inputs],
2147        &format!(
2148            "declaration-first `ode!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|"
2149        ),
2150    )?;
2151    let parameter_bindings = generate_parameter_bindings(params, closure, &p);
2152    let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t);
2153    let body = NumericLabelRewriter::rewrite(
2154        closure.body.as_ref(),
2155        Vec::new(),
2156        Some(symbolic_numeric_binding_map(route_bindings)),
2157    );
2158
2159    Ok(quote! {{
2160        let __pharmsol_route_map: fn(
2161            &::pharmsol::simulator::V,
2162            f64,
2163            &::pharmsol::data::Covariates,
2164        ) -> ::std::collections::HashMap<usize, f64> = |#p: &::pharmsol::simulator::V,
2165             #t: f64,
2166             #cov: &::pharmsol::data::Covariates| {
2167            #input_aliases
2168            #route_consts
2169            #parameter_bindings
2170            #covariate_bindings
2171            #body
2172        };
2173        __pharmsol_route_map
2174    }})
2175}
2176
2177fn expand_ode_init(
2178    init: &ExprClosure,
2179    params: &[Ident],
2180    covariates: &[Ident],
2181    states: &[Ident],
2182) -> syn::Result<TokenStream2> {
2183    let state_consts = generate_index_consts(states);
2184    let p = generated_ident("__pharmsol_p");
2185    let t = generated_ident("__pharmsol_t");
2186    let cov = generated_ident("__pharmsol_cov");
2187    let x = generated_ident("__pharmsol_x");
2188    let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()];
2189    let reduced_inputs = [t.clone(), x.clone()];
2190    let input_aliases = generate_supported_input_aliases(
2191        init,
2192        &[&full_inputs, &reduced_inputs],
2193        "declaration-first `ode!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|",
2194    )?;
2195    let parameter_bindings = generate_parameter_bindings(params, init, &p);
2196    let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t);
2197    let body = &init.body;
2198
2199    Ok(quote! {{
2200        let __pharmsol_init: fn(
2201            &::pharmsol::simulator::V,
2202            f64,
2203            &::pharmsol::data::Covariates,
2204            &mut ::pharmsol::simulator::V,
2205        ) = |#p: &::pharmsol::simulator::V,
2206             #t: f64,
2207             #cov: &::pharmsol::data::Covariates,
2208             #x: &mut ::pharmsol::simulator::V| {
2209            #input_aliases
2210            #state_consts
2211            #parameter_bindings
2212            #covariate_bindings
2213            #body
2214        };
2215        __pharmsol_init
2216    }})
2217}
2218
2219fn expand_route_metadata(
2220    routes: &[OdeRouteDecl],
2221    lag_routes: &HashSet<String>,
2222    fa_routes: &HashSet<String>,
2223) -> Vec<TokenStream2> {
2224    routes
2225        .iter()
2226        .map(|route| {
2227            let input = &route.input;
2228            let destination = &route.destination;
2229            let route_name = route.input.name();
2230            let route_builder = match route.kind {
2231                OdeRouteKind::Bolus => {
2232                    quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) }
2233                }
2234                OdeRouteKind::Infusion => {
2235                    quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) }
2236                }
2237            };
2238            let lag_flag = if lag_routes.contains(&route_name) {
2239                quote! { .with_lag() }
2240            } else {
2241                quote! {}
2242            };
2243            let fa_flag = if fa_routes.contains(&route_name) {
2244                quote! { .with_bioavailability() }
2245            } else {
2246                quote! {}
2247            };
2248
2249            quote! {
2250                #route_builder
2251                    .to_state(stringify!(#destination))
2252                    #lag_flag
2253                    #fa_flag
2254                    .inject_input_to_destination()
2255            }
2256        })
2257        .collect()
2258}
2259
2260fn expand_analytical_route_metadata(
2261    routes: &[OdeRouteDecl],
2262    lag_routes: &HashSet<String>,
2263    fa_routes: &HashSet<String>,
2264) -> Vec<TokenStream2> {
2265    routes
2266        .iter()
2267        .map(|route| {
2268            let input = &route.input;
2269            let destination = &route.destination;
2270            let route_name = route.input.name();
2271            let route_builder = match route.kind {
2272                OdeRouteKind::Bolus => {
2273                    quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) }
2274                }
2275                OdeRouteKind::Infusion => {
2276                    quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) }
2277                }
2278            };
2279            let lag_flag = if lag_routes.contains(&route_name) {
2280                quote! { .with_lag() }
2281            } else {
2282                quote! {}
2283            };
2284            let fa_flag = if fa_routes.contains(&route_name) {
2285                quote! { .with_bioavailability() }
2286            } else {
2287                quote! {}
2288            };
2289
2290            quote! {
2291                #route_builder
2292                    .to_state(stringify!(#destination))
2293                    #lag_flag
2294                    #fa_flag
2295            }
2296        })
2297        .collect()
2298}
2299
2300fn expand_sde_route_metadata(
2301    routes: &[OdeRouteDecl],
2302    lag_routes: &HashSet<String>,
2303    fa_routes: &HashSet<String>,
2304) -> Vec<TokenStream2> {
2305    routes
2306        .iter()
2307        .map(|route| {
2308            let input = &route.input;
2309            let destination = &route.destination;
2310            let route_name = route.input.name();
2311            let route_builder = match route.kind {
2312                OdeRouteKind::Bolus => {
2313                    quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) }
2314                }
2315                OdeRouteKind::Infusion => {
2316                    quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) }
2317                }
2318            };
2319            let lag_flag = if lag_routes.contains(&route_name) {
2320                quote! { .with_lag() }
2321            } else {
2322                quote! {}
2323            };
2324            let fa_flag = if fa_routes.contains(&route_name) {
2325                quote! { .with_bioavailability() }
2326            } else {
2327                quote! {}
2328            };
2329
2330            quote! {
2331                #route_builder
2332                    .to_state(stringify!(#destination))
2333                    .inject_input_to_destination()
2334                    #lag_flag
2335                    #fa_flag
2336            }
2337        })
2338        .collect()
2339}
2340
2341fn route_destination_index(route: &OdeRouteDecl, states: &[Ident]) -> usize {
2342    states
2343        .iter()
2344        .position(|state| state == &route.destination)
2345        .expect("validated route destination should exist")
2346}
2347
2348fn expand_injected_ode_route_terms(
2349    routes: &[OdeRouteDecl],
2350    states: &[Ident],
2351    route_bindings: &[(SymbolicIndex, usize)],
2352    dx: &Ident,
2353    bolus: &Ident,
2354    rateiv: &Ident,
2355) -> TokenStream2 {
2356    let terms = routes
2357        .iter()
2358        .zip(route_bindings.iter())
2359        .map(|(route, (_, input_index))| {
2360            let destination = route_destination_index(route, states);
2361            match route.kind {
2362                OdeRouteKind::Bolus => quote! {
2363                    #dx[#destination] += #bolus[#input_index];
2364                },
2365                OdeRouteKind::Infusion => quote! {
2366                    #dx[#destination] += #rateiv[#input_index];
2367                },
2368            }
2369        });
2370
2371    quote! {
2372        #(#terms)*
2373    }
2374}
2375
2376fn expand_injected_sde_rate_terms(
2377    routes: &[OdeRouteDecl],
2378    states: &[Ident],
2379    route_bindings: &[(SymbolicIndex, usize)],
2380    dx: &Ident,
2381    rateiv: &Ident,
2382) -> TokenStream2 {
2383    let terms = routes
2384        .iter()
2385        .zip(route_bindings.iter())
2386        .filter_map(|(route, (_, input_index))| match route.kind {
2387            OdeRouteKind::Bolus => None,
2388            OdeRouteKind::Infusion => {
2389                let destination = route_destination_index(route, states);
2390                Some(quote! {
2391                    #dx[#destination] += #rateiv[#input_index];
2392                })
2393            }
2394        });
2395
2396    quote! {
2397        #(#terms)*
2398    }
2399}
2400
2401fn expand_injected_sde_bolus_mappings(
2402    routes: &[OdeRouteDecl],
2403    states: &[Ident],
2404    route_bindings: &[(SymbolicIndex, usize)],
2405) -> TokenStream2 {
2406    let mut destinations = vec![quote! { None }; dense_index_len(route_bindings)];
2407
2408    for (route, (_, input_index)) in routes.iter().zip(route_bindings.iter()) {
2409        if let OdeRouteKind::Bolus = route.kind {
2410            let destination = route_destination_index(route, states);
2411            destinations[*input_index] = quote! { Some(#destination) };
2412        }
2413    }
2414
2415    quote! {
2416        .with_injected_bolus_inputs(&[#(#destinations),*])
2417    }
2418}
2419
2420fn validate_unique_idents(kind: &str, idents: &[Ident], macro_name: &str) -> syn::Result<()> {
2421    let mut seen = HashSet::new();
2422    for ident in idents {
2423        let name = ident.to_string();
2424        if !seen.insert(name.clone()) {
2425            return Err(syn::Error::new_spanned(
2426                ident,
2427                format!("duplicate {kind} `{name}` in declaration-first `{macro_name}`"),
2428            ));
2429        }
2430    }
2431    Ok(())
2432}
2433
2434fn validate_unique_symbolic_indices(
2435    kind: &str,
2436    labels: &[SymbolicIndex],
2437    macro_name: &str,
2438) -> syn::Result<()> {
2439    let mut seen = HashSet::new();
2440    for label in labels {
2441        let name = label.name();
2442        if !seen.insert(name.clone()) {
2443            return Err(syn::Error::new_spanned(
2444                label,
2445                format!("duplicate {kind} `{name}` in declaration-first `{macro_name}`"),
2446            ));
2447        }
2448    }
2449    Ok(())
2450}
2451
2452fn validate_routes(routes: &[OdeRouteDecl], states: &[Ident], macro_name: &str) -> syn::Result<()> {
2453    let known_states = states.iter().map(Ident::to_string).collect::<HashSet<_>>();
2454    let mut seen_routes = HashSet::new();
2455
2456    for route in routes {
2457        let route_name = route.input.name();
2458        if !seen_routes.insert(route_name.clone()) {
2459            return Err(syn::Error::new_spanned(
2460                &route.input,
2461                format!("duplicate route `{route_name}` in declaration-first `{macro_name}`"),
2462            ));
2463        }
2464
2465        if !known_states.contains(&route.destination.to_string()) {
2466            return Err(syn::Error::new_spanned(
2467                &route.destination,
2468                format!(
2469                    "route destination `{}` is not declared in the `states` section",
2470                    route.destination
2471                ),
2472            ));
2473        }
2474    }
2475
2476    Ok(())
2477}
2478
2479fn expand_diffeq(
2480    diffeq: &ExprClosure,
2481    params: &[Ident],
2482    covariates: &[Ident],
2483    states: &[Ident],
2484    routes: &[OdeRouteDecl],
2485    route_bindings: &[(SymbolicIndex, usize)],
2486) -> syn::Result<TokenStream2> {
2487    let state_consts = generate_index_consts(states);
2488    let x = generated_ident("__pharmsol_x");
2489    let p = generated_ident("__pharmsol_p");
2490    let t = generated_ident("__pharmsol_t");
2491    let dx = generated_ident("__pharmsol_dx");
2492    let bolus = generated_ident("__pharmsol_bolus");
2493    let rateiv = generated_ident("__pharmsol_rateiv");
2494    let cov = generated_ident("__pharmsol_cov");
2495    let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()];
2496    let reduced_inputs = [x.clone(), t.clone(), dx.clone()];
2497    let input_aliases = generate_supported_input_aliases(
2498        diffeq,
2499        &[&full_inputs, &reduced_inputs],
2500        "declaration-first `ode!` injected-route `diffeq` requires either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|",
2501    )?;
2502    let parameter_bindings = generate_parameter_bindings(params, diffeq, &p);
2503    let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t);
2504    let body = &diffeq.body;
2505    let dx_binding = if diffeq.inputs.len() == full_inputs.len() {
2506        closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone())
2507    } else {
2508        closure_param_ident(diffeq, 2).unwrap_or_else(|| dx.clone())
2509    };
2510    let route_terms = expand_injected_ode_route_terms(
2511        routes,
2512        states,
2513        route_bindings,
2514        &dx_binding,
2515        &bolus,
2516        &rateiv,
2517    );
2518
2519    Ok(quote! {{
2520        let __pharmsol_diffeq: fn(
2521            &::pharmsol::simulator::V,
2522            &::pharmsol::simulator::V,
2523            f64,
2524            &mut ::pharmsol::simulator::V,
2525            &::pharmsol::simulator::V,
2526            &::pharmsol::simulator::V,
2527            &::pharmsol::data::Covariates,
2528        ) = |#x: &::pharmsol::simulator::V,
2529             #p: &::pharmsol::simulator::V,
2530             #t: f64,
2531             #dx: &mut ::pharmsol::simulator::V,
2532             #bolus: &::pharmsol::simulator::V,
2533             #rateiv: &::pharmsol::simulator::V,
2534             #cov: &::pharmsol::data::Covariates| {
2535            #input_aliases
2536            #state_consts
2537            #parameter_bindings
2538            #covariate_bindings
2539            #body
2540            #route_terms
2541        };
2542        __pharmsol_diffeq
2543    }})
2544}
2545
2546fn resolve_analytical_structure(structure: &Ident) -> syn::Result<AnalyticalKernelSpec> {
2547    let structure_name = structure.to_string();
2548    let (kernel, runtime_path, metadata_kernel, state_count) = match structure_name.as_str() {
2549        "one_compartment" => (
2550            ResolverAnalyticalKernel::OneCompartment,
2551            quote! { ::pharmsol::equation::one_compartment },
2552            quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartment },
2553            1,
2554        ),
2555        "one_compartment_cl" => (
2556            ResolverAnalyticalKernel::OneCompartmentCl,
2557            quote! { ::pharmsol::equation::one_compartment_cl },
2558            quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentCl },
2559            1,
2560        ),
2561        "one_compartment_cl_with_absorption" => (
2562            ResolverAnalyticalKernel::OneCompartmentClWithAbsorption,
2563            quote! { ::pharmsol::equation::one_compartment_cl_with_absorption },
2564            quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentClWithAbsorption },
2565            2,
2566        ),
2567        "one_compartment_with_absorption" => (
2568            ResolverAnalyticalKernel::OneCompartmentWithAbsorption,
2569            quote! { ::pharmsol::equation::one_compartment_with_absorption },
2570            quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentWithAbsorption },
2571            2,
2572        ),
2573        "two_compartments" => (
2574            ResolverAnalyticalKernel::TwoCompartments,
2575            quote! { ::pharmsol::equation::two_compartments },
2576            quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartments },
2577            2,
2578        ),
2579        "two_compartments_cl" => (
2580            ResolverAnalyticalKernel::TwoCompartmentsCl,
2581            quote! { ::pharmsol::equation::two_compartments_cl },
2582            quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsCl },
2583            2,
2584        ),
2585        "two_compartments_cl_with_absorption" => (
2586            ResolverAnalyticalKernel::TwoCompartmentsClWithAbsorption,
2587            quote! { ::pharmsol::equation::two_compartments_cl_with_absorption },
2588            quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsClWithAbsorption },
2589            3,
2590        ),
2591        "two_compartments_with_absorption" => (
2592            ResolverAnalyticalKernel::TwoCompartmentsWithAbsorption,
2593            quote! { ::pharmsol::equation::two_compartments_with_absorption },
2594            quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsWithAbsorption },
2595            3,
2596        ),
2597        "three_compartments" => (
2598            ResolverAnalyticalKernel::ThreeCompartments,
2599            quote! { ::pharmsol::equation::three_compartments },
2600            quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartments },
2601            3,
2602        ),
2603        "three_compartments_cl" => (
2604            ResolverAnalyticalKernel::ThreeCompartmentsCl,
2605            quote! { ::pharmsol::equation::three_compartments_cl },
2606            quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsCl },
2607            3,
2608        ),
2609        "three_compartments_cl_with_absorption" => (
2610            ResolverAnalyticalKernel::ThreeCompartmentsClWithAbsorption,
2611            quote! { ::pharmsol::equation::three_compartments_cl_with_absorption },
2612            quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsClWithAbsorption },
2613            4,
2614        ),
2615        "three_compartments_with_absorption" => (
2616            ResolverAnalyticalKernel::ThreeCompartmentsWithAbsorption,
2617            quote! { ::pharmsol::equation::three_compartments_with_absorption },
2618            quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsWithAbsorption },
2619            4,
2620        ),
2621        _ => {
2622            return Err(syn::Error::new_spanned(
2623                structure,
2624                format!("unknown analytical structure `{structure_name}`"),
2625            ));
2626        }
2627    };
2628
2629    Ok(AnalyticalKernelSpec {
2630        kernel,
2631        runtime_path,
2632        metadata_kernel,
2633        state_count,
2634    })
2635}
2636
2637fn expand_analytical_route_map(
2638    label: &str,
2639    closure: &ExprClosure,
2640    params: &[Ident],
2641    derived: &[Ident],
2642    covariates: &[Ident],
2643    route_bindings: &[(SymbolicIndex, usize)],
2644) -> syn::Result<TokenStream2> {
2645    let route_consts = generate_mapped_index_consts(route_bindings);
2646    let p = generated_ident("__pharmsol_p");
2647    let t = generated_ident("__pharmsol_t");
2648    let cov = generated_ident("__pharmsol_cov");
2649    let full_inputs = [p.clone(), t.clone(), cov.clone()];
2650    let reduced_inputs = [t.clone()];
2651    let input_aliases = generate_supported_input_aliases(
2652        closure,
2653        &[&full_inputs, &reduced_inputs],
2654        &format!(
2655            "built-in `analytical!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|"
2656        ),
2657    )?;
2658    let parameter_bindings = generate_parameter_bindings(params, closure, &p);
2659    let derived_values = generated_ident("__pharmsol_derived");
2660    let derived_bindings = generate_derived_bindings(derived, closure, &derived_values);
2661    let derive_values = if derived_bindings.is_empty() {
2662        quote! {}
2663    } else {
2664        quote! {
2665            let #derived_values = __pharmsol_derive(#p, #t, #cov);
2666        }
2667    };
2668    let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t);
2669    let body = NumericLabelRewriter::rewrite(
2670        closure.body.as_ref(),
2671        Vec::new(),
2672        Some(symbolic_numeric_binding_map(route_bindings)),
2673    );
2674
2675    Ok(quote! {{
2676        let __pharmsol_route_map: fn(
2677            &::pharmsol::simulator::V,
2678            f64,
2679            &::pharmsol::data::Covariates,
2680        ) -> ::std::collections::HashMap<usize, f64> = |#p: &::pharmsol::simulator::V,
2681             #t: f64,
2682             #cov: &::pharmsol::data::Covariates| {
2683            #input_aliases
2684            #route_consts
2685            #parameter_bindings
2686            #derive_values
2687            #derived_bindings
2688            #covariate_bindings
2689            #body
2690        };
2691        __pharmsol_route_map
2692    }})
2693}
2694
2695fn expand_analytical_derive(
2696    derive: Option<&ExprClosure>,
2697    params: &[Ident],
2698    covariates: &[Ident],
2699    derived: &[Ident],
2700) -> syn::Result<TokenStream2> {
2701    let p = generated_ident("__pharmsol_p");
2702    let t = generated_ident("__pharmsol_t");
2703    let cov = generated_ident("__pharmsol_cov");
2704    let derived_len = syn::LitInt::new(&derived.len().to_string(), Span::call_site());
2705
2706    if let Some(derive) = derive {
2707        let full_inputs = [p.clone(), t.clone(), cov.clone()];
2708        let reduced_inputs = [t.clone()];
2709        let input_aliases = generate_supported_input_aliases(
2710            derive,
2711            &[&full_inputs, &reduced_inputs],
2712            "built-in `analytical!` requires `derive` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|",
2713        )?;
2714        let parameter_bindings = generate_parameter_bindings(params, derive, &p);
2715        let covariate_bindings = generate_covariate_bindings(covariates, derive, &cov, &t);
2716        let derived_decls = derived.iter().map(|ident| {
2717            quote! {
2718                #[allow(unused_mut)]
2719                let mut #ident: f64;
2720            }
2721        });
2722        let body = &derive.body;
2723
2724        Ok(quote! {
2725            fn __pharmsol_derive(
2726                #p: &::pharmsol::simulator::V,
2727                #t: f64,
2728                #cov: &::pharmsol::data::Covariates,
2729            ) -> [f64; #derived_len] {
2730                #input_aliases
2731                #parameter_bindings
2732                #covariate_bindings
2733                #(#derived_decls)*
2734                #body
2735                [#(#derived),*]
2736            }
2737        })
2738    } else {
2739        let zeros = derived.iter().map(|_| quote! { 0.0 });
2740        Ok(quote! {
2741            fn __pharmsol_derive(
2742                _: &::pharmsol::simulator::V,
2743                _: f64,
2744                _: &::pharmsol::data::Covariates,
2745            ) -> [f64; #derived_len] {
2746                [#(#zeros),*]
2747            }
2748        })
2749    }
2750}
2751
2752fn expand_analytical_runtime(
2753    runtime_path: &TokenStream2,
2754    projection: &AnalyticalStructureInputKind,
2755) -> TokenStream2 {
2756    match projection {
2757        AnalyticalStructureInputKind::AllPrimary { identity: true, .. } => runtime_path.clone(),
2758        AnalyticalStructureInputKind::AllPrimary { indices, .. } => {
2759            let projected = indices.iter().map(|index| quote! { __pharmsol_p[#index] });
2760            quote! {{
2761                let __pharmsol_eq: fn(
2762                    &::pharmsol::simulator::V,
2763                    &::pharmsol::simulator::V,
2764                    f64,
2765                    &::pharmsol::simulator::V,
2766                    &::pharmsol::data::Covariates,
2767                ) -> ::pharmsol::simulator::V = |
2768                    __pharmsol_x: &::pharmsol::simulator::V,
2769                    __pharmsol_p: &::pharmsol::simulator::V,
2770                    __pharmsol_t: f64,
2771                    __pharmsol_rateiv: &::pharmsol::simulator::V,
2772                    __pharmsol_cov: &::pharmsol::data::Covariates,
2773                | {
2774                    let __pharmsol_projected = ::pharmsol::__macro_support::vector_from_values(vec![#(#projected),*]);
2775                    #runtime_path(__pharmsol_x, &__pharmsol_projected, __pharmsol_t, __pharmsol_rateiv, __pharmsol_cov)
2776                };
2777                __pharmsol_eq
2778            }}
2779        }
2780        AnalyticalStructureInputKind::AllDerived { indices, .. } => {
2781            let projected = indices
2782                .iter()
2783                .map(|index| quote! { __pharmsol_derived[#index] });
2784            quote! {{
2785                let __pharmsol_eq: fn(
2786                    &::pharmsol::simulator::V,
2787                    &::pharmsol::simulator::V,
2788                    f64,
2789                    &::pharmsol::simulator::V,
2790                    &::pharmsol::data::Covariates,
2791                ) -> ::pharmsol::simulator::V = |
2792                    __pharmsol_x: &::pharmsol::simulator::V,
2793                    __pharmsol_p: &::pharmsol::simulator::V,
2794                    __pharmsol_t: f64,
2795                    __pharmsol_rateiv: &::pharmsol::simulator::V,
2796                    __pharmsol_cov: &::pharmsol::data::Covariates,
2797                | {
2798                    let __pharmsol_derived = __pharmsol_derive(__pharmsol_p, __pharmsol_t, __pharmsol_cov);
2799                    let __pharmsol_projected = ::pharmsol::__macro_support::vector_from_values(vec![#(#projected),*]);
2800                    #runtime_path(__pharmsol_x, &__pharmsol_projected, __pharmsol_t, __pharmsol_rateiv, __pharmsol_cov)
2801                };
2802                __pharmsol_eq
2803            }}
2804        }
2805        AnalyticalStructureInputKind::Mixed { bindings } => {
2806            let projected = bindings.iter().map(|binding| match binding.source {
2807                AnalyticalStructureInputSource::Primary => {
2808                    let index = binding.index;
2809                    quote! { __pharmsol_p[#index] }
2810                }
2811                AnalyticalStructureInputSource::Derived => {
2812                    let index = binding.index;
2813                    quote! { __pharmsol_derived[#index] }
2814                }
2815            });
2816            quote! {{
2817                let __pharmsol_eq: fn(
2818                    &::pharmsol::simulator::V,
2819                    &::pharmsol::simulator::V,
2820                    f64,
2821                    &::pharmsol::simulator::V,
2822                    &::pharmsol::data::Covariates,
2823                ) -> ::pharmsol::simulator::V = |
2824                    __pharmsol_x: &::pharmsol::simulator::V,
2825                    __pharmsol_p: &::pharmsol::simulator::V,
2826                    __pharmsol_t: f64,
2827                    __pharmsol_rateiv: &::pharmsol::simulator::V,
2828                    __pharmsol_cov: &::pharmsol::data::Covariates,
2829                | {
2830                    let __pharmsol_derived = __pharmsol_derive(__pharmsol_p, __pharmsol_t, __pharmsol_cov);
2831                    let __pharmsol_projected = ::pharmsol::__macro_support::vector_from_values(vec![#(#projected),*]);
2832                    #runtime_path(__pharmsol_x, &__pharmsol_projected, __pharmsol_t, __pharmsol_rateiv, __pharmsol_cov)
2833                };
2834                __pharmsol_eq
2835            }}
2836        }
2837    }
2838}
2839
2840fn expand_analytical_init(
2841    init: &ExprClosure,
2842    params: &[Ident],
2843    derived: &[Ident],
2844    covariates: &[Ident],
2845    states: &[Ident],
2846) -> syn::Result<TokenStream2> {
2847    let state_consts = generate_index_consts(states);
2848    let p = generated_ident("__pharmsol_p");
2849    let t = generated_ident("__pharmsol_t");
2850    let cov = generated_ident("__pharmsol_cov");
2851    let x = generated_ident("__pharmsol_x");
2852    let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()];
2853    let reduced_inputs = [t.clone(), x.clone()];
2854    let input_aliases = generate_supported_input_aliases(
2855        init,
2856        &[&full_inputs, &reduced_inputs],
2857        "built-in `analytical!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|",
2858    )?;
2859    let parameter_bindings = generate_parameter_bindings(params, init, &p);
2860    let derived_values = generated_ident("__pharmsol_derived");
2861    let derived_bindings = generate_derived_bindings(derived, init, &derived_values);
2862    let derive_values = if derived_bindings.is_empty() {
2863        quote! {}
2864    } else {
2865        quote! {
2866            let #derived_values = __pharmsol_derive(#p, #t, #cov);
2867        }
2868    };
2869    let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t);
2870    let body = &init.body;
2871
2872    Ok(quote! {{
2873        let __pharmsol_init: fn(
2874            &::pharmsol::simulator::V,
2875            f64,
2876            &::pharmsol::data::Covariates,
2877            &mut ::pharmsol::simulator::V,
2878        ) = |#p: &::pharmsol::simulator::V,
2879             #t: f64,
2880             #cov: &::pharmsol::data::Covariates,
2881             #x: &mut ::pharmsol::simulator::V| {
2882            #input_aliases
2883            #state_consts
2884            #parameter_bindings
2885            #derive_values
2886            #derived_bindings
2887            #covariate_bindings
2888            #body
2889        };
2890        __pharmsol_init
2891    }})
2892}
2893
2894fn expand_analytical_out(
2895    out: &ExprClosure,
2896    params: &[Ident],
2897    derived: &[Ident],
2898    covariates: &[Ident],
2899    states: &[Ident],
2900    outputs: &[SymbolicIndex],
2901) -> syn::Result<TokenStream2> {
2902    let state_consts = generate_index_consts(states);
2903    let output_bindings = symbolic_index_bindings(outputs);
2904    let output_consts = generate_mapped_index_consts(&output_bindings);
2905    let x = generated_ident("__pharmsol_x");
2906    let p = generated_ident("__pharmsol_p");
2907    let t = generated_ident("__pharmsol_t");
2908    let cov = generated_ident("__pharmsol_cov");
2909    let y = generated_ident("__pharmsol_y");
2910    let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()];
2911    let reduced_inputs = [x.clone(), t.clone(), y.clone()];
2912    let input_aliases = generate_supported_input_aliases(
2913        out,
2914        &[&full_inputs, &reduced_inputs],
2915        "built-in `analytical!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|",
2916    )?;
2917    let parameter_bindings = generate_parameter_bindings(params, out, &p);
2918    let derived_values = generated_ident("__pharmsol_derived");
2919    let derived_bindings = generate_derived_bindings(derived, out, &derived_values);
2920    let derive_values = if derived_bindings.is_empty() {
2921        quote! {}
2922    } else {
2923        quote! {
2924            let #derived_values = __pharmsol_derive(#p, #t, #cov);
2925        }
2926    };
2927    let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t);
2928    let y_binding = if out.inputs.len() == full_inputs.len() {
2929        closure_param_ident(out, 4).unwrap_or_else(|| y.clone())
2930    } else {
2931        closure_param_ident(out, 2).unwrap_or_else(|| y.clone())
2932    };
2933    let body = NumericLabelRewriter::rewrite(
2934        out.body.as_ref(),
2935        vec![IndexRewriteTarget::new(
2936            y_binding,
2937            symbolic_numeric_binding_map(&output_bindings),
2938        )],
2939        None,
2940    );
2941
2942    Ok(quote! {{
2943        let __pharmsol_out: fn(
2944            &::pharmsol::simulator::V,
2945            &::pharmsol::simulator::V,
2946            f64,
2947            &::pharmsol::data::Covariates,
2948            &mut ::pharmsol::simulator::V,
2949        ) = |#x: &::pharmsol::simulator::V,
2950             #p: &::pharmsol::simulator::V,
2951             #t: f64,
2952             #cov: &::pharmsol::data::Covariates,
2953             #y: &mut ::pharmsol::simulator::V| {
2954            #input_aliases
2955            #state_consts
2956            #output_consts
2957            #parameter_bindings
2958            #derive_values
2959            #derived_bindings
2960            #covariate_bindings
2961            #body
2962        };
2963        __pharmsol_out
2964    }})
2965}
2966
2967fn expand_sde_drift(
2968    drift: &ExprClosure,
2969    params: &[Ident],
2970    covariates: &[Ident],
2971    states: &[Ident],
2972    routes: &[OdeRouteDecl],
2973    route_bindings: &[(SymbolicIndex, usize)],
2974) -> syn::Result<TokenStream2> {
2975    let state_consts = generate_index_consts(states);
2976    let x = generated_ident("__pharmsol_x");
2977    let p = generated_ident("__pharmsol_p");
2978    let t = generated_ident("__pharmsol_t");
2979    let dx = generated_ident("__pharmsol_dx");
2980    let rateiv = generated_ident("__pharmsol_rateiv");
2981    let cov = generated_ident("__pharmsol_cov");
2982    let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()];
2983    let reduced_inputs = [x.clone(), t.clone(), dx.clone()];
2984    let input_aliases = generate_supported_input_aliases(
2985        drift,
2986        &[&full_inputs, &reduced_inputs],
2987        "declaration-first `sde!` requires `drift` to have either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|",
2988    )?;
2989    let parameter_bindings = generate_parameter_bindings(params, drift, &p);
2990    let covariate_bindings = generate_covariate_bindings(covariates, drift, &cov, &t);
2991    let body = &drift.body;
2992    let dx_binding = if drift.inputs.len() == full_inputs.len() {
2993        closure_param_ident(drift, 3).unwrap_or_else(|| dx.clone())
2994    } else {
2995        closure_param_ident(drift, 2).unwrap_or_else(|| dx.clone())
2996    };
2997    let rate_terms =
2998        expand_injected_sde_rate_terms(routes, states, route_bindings, &dx_binding, &rateiv);
2999
3000    Ok(quote! {{
3001        let __pharmsol_drift: fn(
3002            &::pharmsol::simulator::V,
3003            &::pharmsol::simulator::V,
3004            f64,
3005            &mut ::pharmsol::simulator::V,
3006            &::pharmsol::simulator::V,
3007            &::pharmsol::data::Covariates,
3008        ) = |#x: &::pharmsol::simulator::V,
3009             #p: &::pharmsol::simulator::V,
3010             #t: f64,
3011             #dx: &mut ::pharmsol::simulator::V,
3012             #rateiv: &::pharmsol::simulator::V,
3013             #cov: &::pharmsol::data::Covariates| {
3014            #input_aliases
3015            #state_consts
3016            #parameter_bindings
3017            #covariate_bindings
3018            #body
3019            #rate_terms
3020        };
3021        __pharmsol_drift
3022    }})
3023}
3024
3025fn expand_sde_diffusion(
3026    diffusion: &ExprClosure,
3027    params: &[Ident],
3028    states: &[Ident],
3029) -> syn::Result<TokenStream2> {
3030    let state_consts = generate_index_consts(states);
3031    let p = generated_ident("__pharmsol_p");
3032    let sigma = generated_ident("__pharmsol_sigma");
3033    let full_inputs = [p.clone(), sigma.clone()];
3034    let reduced_inputs = [sigma.clone()];
3035    let input_aliases = generate_supported_input_aliases(
3036        diffusion,
3037        &[&full_inputs, &reduced_inputs],
3038        "declaration-first `sde!` requires `diffusion` to have either 2 parameters: |p, sigma| or 1 parameter: |sigma|",
3039    )?;
3040    let parameter_bindings = generate_parameter_bindings(params, diffusion, &p);
3041    let body = &diffusion.body;
3042
3043    Ok(quote! {{
3044        let __pharmsol_diffusion: fn(
3045            &::pharmsol::simulator::V,
3046            &mut ::pharmsol::simulator::V,
3047        ) = |#p: &::pharmsol::simulator::V,
3048             #sigma: &mut ::pharmsol::simulator::V| {
3049            #input_aliases
3050            #state_consts
3051            #parameter_bindings
3052            #body
3053        };
3054        __pharmsol_diffusion
3055    }})
3056}
3057
3058fn expand_sde_route_map(
3059    label: &str,
3060    closure: &ExprClosure,
3061    params: &[Ident],
3062    covariates: &[Ident],
3063    route_bindings: &[(SymbolicIndex, usize)],
3064) -> syn::Result<TokenStream2> {
3065    let route_consts = generate_mapped_index_consts(route_bindings);
3066    let p = generated_ident("__pharmsol_p");
3067    let t = generated_ident("__pharmsol_t");
3068    let cov = generated_ident("__pharmsol_cov");
3069    let full_inputs = [p.clone(), t.clone(), cov.clone()];
3070    let reduced_inputs = [t.clone()];
3071    let input_aliases = generate_supported_input_aliases(
3072        closure,
3073        &[&full_inputs, &reduced_inputs],
3074        &format!(
3075            "declaration-first `sde!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|"
3076        ),
3077    )?;
3078    let parameter_bindings = generate_parameter_bindings(params, closure, &p);
3079    let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t);
3080    let body = NumericLabelRewriter::rewrite(
3081        closure.body.as_ref(),
3082        Vec::new(),
3083        Some(symbolic_numeric_binding_map(route_bindings)),
3084    );
3085
3086    Ok(quote! {{
3087        let __pharmsol_route_map: fn(
3088            &::pharmsol::simulator::V,
3089            f64,
3090            &::pharmsol::data::Covariates,
3091        ) -> ::std::collections::HashMap<usize, f64> = |#p: &::pharmsol::simulator::V,
3092             #t: f64,
3093             #cov: &::pharmsol::data::Covariates| {
3094            #input_aliases
3095            #route_consts
3096            #parameter_bindings
3097            #covariate_bindings
3098            #body
3099        };
3100        __pharmsol_route_map
3101    }})
3102}
3103
3104fn expand_sde_init(
3105    init: &ExprClosure,
3106    params: &[Ident],
3107    covariates: &[Ident],
3108    states: &[Ident],
3109) -> syn::Result<TokenStream2> {
3110    let state_consts = generate_index_consts(states);
3111    let p = generated_ident("__pharmsol_p");
3112    let t = generated_ident("__pharmsol_t");
3113    let cov = generated_ident("__pharmsol_cov");
3114    let x = generated_ident("__pharmsol_x");
3115    let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()];
3116    let reduced_inputs = [t.clone(), x.clone()];
3117    let input_aliases = generate_supported_input_aliases(
3118        init,
3119        &[&full_inputs, &reduced_inputs],
3120        "declaration-first `sde!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|",
3121    )?;
3122    let parameter_bindings = generate_parameter_bindings(params, init, &p);
3123    let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t);
3124    let body = &init.body;
3125
3126    Ok(quote! {{
3127        let __pharmsol_init: fn(
3128            &::pharmsol::simulator::V,
3129            f64,
3130            &::pharmsol::data::Covariates,
3131            &mut ::pharmsol::simulator::V,
3132        ) = |#p: &::pharmsol::simulator::V,
3133             #t: f64,
3134             #cov: &::pharmsol::data::Covariates,
3135             #x: &mut ::pharmsol::simulator::V| {
3136            #input_aliases
3137            #state_consts
3138            #parameter_bindings
3139            #covariate_bindings
3140            #body
3141        };
3142        __pharmsol_init
3143    }})
3144}
3145
3146fn expand_sde_out(
3147    out: &ExprClosure,
3148    params: &[Ident],
3149    covariates: &[Ident],
3150    states: &[Ident],
3151    outputs: &[SymbolicIndex],
3152) -> syn::Result<TokenStream2> {
3153    let state_consts = generate_index_consts(states);
3154    let output_bindings = symbolic_index_bindings(outputs);
3155    let output_consts = generate_mapped_index_consts(&output_bindings);
3156    let x = generated_ident("__pharmsol_x");
3157    let p = generated_ident("__pharmsol_p");
3158    let t = generated_ident("__pharmsol_t");
3159    let cov = generated_ident("__pharmsol_cov");
3160    let y = generated_ident("__pharmsol_y");
3161    let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()];
3162    let reduced_inputs = [x.clone(), t.clone(), y.clone()];
3163    let input_aliases = generate_supported_input_aliases(
3164        out,
3165        &[&full_inputs, &reduced_inputs],
3166        "declaration-first `sde!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|",
3167    )?;
3168    let parameter_bindings = generate_parameter_bindings(params, out, &p);
3169    let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t);
3170    let y_binding = if out.inputs.len() == full_inputs.len() {
3171        closure_param_ident(out, 4).unwrap_or_else(|| y.clone())
3172    } else {
3173        closure_param_ident(out, 2).unwrap_or_else(|| y.clone())
3174    };
3175    let body = NumericLabelRewriter::rewrite(
3176        out.body.as_ref(),
3177        vec![IndexRewriteTarget::new(
3178            y_binding,
3179            symbolic_numeric_binding_map(&output_bindings),
3180        )],
3181        None,
3182    );
3183
3184    Ok(quote! {{
3185        let __pharmsol_out: fn(
3186            &::pharmsol::simulator::V,
3187            &::pharmsol::simulator::V,
3188            f64,
3189            &::pharmsol::data::Covariates,
3190            &mut ::pharmsol::simulator::V,
3191        ) = |#x: &::pharmsol::simulator::V,
3192             #p: &::pharmsol::simulator::V,
3193             #t: f64,
3194             #cov: &::pharmsol::data::Covariates,
3195             #y: &mut ::pharmsol::simulator::V| {
3196            #input_aliases
3197            #state_consts
3198            #output_consts
3199            #parameter_bindings
3200            #covariate_bindings
3201            #body
3202        };
3203        __pharmsol_out
3204    }})
3205}
3206
3207// ---------------------------------------------------------------------------
3208// Proc macros
3209// ---------------------------------------------------------------------------
3210
3211#[proc_macro]
3212pub fn ode(input: TokenStream) -> TokenStream {
3213    let input = syn::parse_macro_input!(input as OdeInput);
3214
3215    let route_bindings = ode_route_input_bindings(&input.routes);
3216
3217    let lag_routes = match input.lag.as_ref() {
3218        Some(closure) => match extract_route_property_routes(
3219            "declaration-first `ode!`",
3220            "lag",
3221            closure,
3222            &input.routes,
3223        ) {
3224            Ok(routes) => {
3225                if let Err(error) = validate_route_property_kinds(
3226                    "declaration-first `ode!`",
3227                    "lag",
3228                    &input.routes,
3229                    &routes,
3230                ) {
3231                    return error.to_compile_error().into();
3232                }
3233                routes
3234            }
3235            Err(error) => return error.to_compile_error().into(),
3236        },
3237        None => HashSet::new(),
3238    };
3239
3240    let fa_routes = match input.fa.as_ref() {
3241        Some(closure) => match extract_route_property_routes(
3242            "declaration-first `ode!`",
3243            "fa",
3244            closure,
3245            &input.routes,
3246        ) {
3247            Ok(routes) => {
3248                if let Err(error) = validate_route_property_kinds(
3249                    "declaration-first `ode!`",
3250                    "fa",
3251                    &input.routes,
3252                    &routes,
3253                ) {
3254                    return error.to_compile_error().into();
3255                }
3256                routes
3257            }
3258            Err(error) => return error.to_compile_error().into(),
3259        },
3260        None => HashSet::new(),
3261    };
3262
3263    let diffeq = match expand_diffeq(
3264        &input.diffeq,
3265        &input.params,
3266        &input.covariates,
3267        &input.states,
3268        &input.routes,
3269        &route_bindings,
3270    ) {
3271        Ok(diffeq) => diffeq,
3272        Err(error) => return error.to_compile_error().into(),
3273    };
3274
3275    let out = match expand_out(
3276        &input.out,
3277        &input.params,
3278        &input.covariates,
3279        &input.states,
3280        &input.outputs,
3281    ) {
3282        Ok(out) => out,
3283        Err(error) => return error.to_compile_error().into(),
3284    };
3285
3286    let nstates = input.states.len();
3287    let ndrugs = dense_index_len(&route_bindings);
3288    let nout = input.outputs.len();
3289
3290    let name = &input.name;
3291    let params = &input.params;
3292    let covariates = &input.covariates;
3293    let states = &input.states;
3294    let outputs = &input.outputs;
3295    let routes = expand_route_metadata(&input.routes, &lag_routes, &fa_routes);
3296    let covariate_metadata = if covariates.is_empty() {
3297        quote! {}
3298    } else {
3299        quote! {
3300            .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*])
3301        }
3302    };
3303
3304    let lag = match input.lag.as_ref() {
3305        Some(closure) => match expand_ode_route_map(
3306            "lag",
3307            closure,
3308            &input.params,
3309            &input.covariates,
3310            &route_bindings,
3311        ) {
3312            Ok(lag) => lag,
3313            Err(error) => return error.to_compile_error().into(),
3314        },
3315        None => quote! { |_, _, _| ::std::collections::HashMap::new() },
3316    };
3317
3318    let fa = match input.fa.as_ref() {
3319        Some(closure) => {
3320            match expand_ode_route_map(
3321                "fa",
3322                closure,
3323                &input.params,
3324                &input.covariates,
3325                &route_bindings,
3326            ) {
3327                Ok(fa) => fa,
3328                Err(error) => return error.to_compile_error().into(),
3329            }
3330        }
3331        None => quote! { |_, _, _| ::std::collections::HashMap::new() },
3332    };
3333
3334    let init = match input.init.as_ref() {
3335        Some(closure) => {
3336            match expand_ode_init(closure, &input.params, &input.covariates, &input.states) {
3337                Ok(init) => init,
3338                Err(error) => return error.to_compile_error().into(),
3339            }
3340        }
3341        None => quote! { |_, _, _, _| {} },
3342    };
3343
3344    quote! {{
3345        let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name)
3346            .parameters([#(stringify!(#params)),*])
3347            #covariate_metadata
3348            .states([#(stringify!(#states)),*])
3349            .outputs([#(stringify!(#outputs)),*])
3350            #(.route(#routes))*;
3351
3352        ::pharmsol::equation::ODE::new(
3353            #diffeq,
3354            #lag,
3355            #fa,
3356            #init,
3357            #out,
3358        )
3359        .with_nstates(#nstates)
3360        .with_ndrugs(#ndrugs)
3361        .with_nout(#nout)
3362        .with_metadata(__pharmsol_metadata)
3363        .expect("declaration-first `ode!` generated invalid metadata")
3364    }}
3365    .into()
3366}
3367
3368#[proc_macro]
3369pub fn analytical(input: TokenStream) -> TokenStream {
3370    let input = syn::parse_macro_input!(input as AnalyticalInput);
3371    let route_bindings = ode_route_input_bindings(&input.routes);
3372
3373    let kernel_spec = match resolve_analytical_structure(&input.structure) {
3374        Ok(spec) => spec,
3375        Err(error) => return error.to_compile_error().into(),
3376    };
3377    let projection = match validate_analytical_structure_inputs(
3378        &input.structure,
3379        kernel_spec.kernel,
3380        &input.params,
3381        &input.derived,
3382    ) {
3383        Ok(plan) => plan,
3384        Err(error) => return error.to_compile_error().into(),
3385    };
3386
3387    let lag_routes = match input.lag.as_ref() {
3388        Some(closure) => match extract_route_property_routes(
3389            "built-in `analytical!`",
3390            "lag",
3391            closure,
3392            &input.routes,
3393        ) {
3394            Ok(routes) => {
3395                if let Err(error) = validate_route_property_kinds(
3396                    "built-in `analytical!`",
3397                    "lag",
3398                    &input.routes,
3399                    &routes,
3400                ) {
3401                    return error.to_compile_error().into();
3402                }
3403                routes
3404            }
3405            Err(error) => return error.to_compile_error().into(),
3406        },
3407        None => HashSet::new(),
3408    };
3409
3410    let fa_routes = match input.fa.as_ref() {
3411        Some(closure) => match extract_route_property_routes(
3412            "built-in `analytical!`",
3413            "fa",
3414            closure,
3415            &input.routes,
3416        ) {
3417            Ok(routes) => {
3418                if let Err(error) = validate_route_property_kinds(
3419                    "built-in `analytical!`",
3420                    "fa",
3421                    &input.routes,
3422                    &routes,
3423                ) {
3424                    return error.to_compile_error().into();
3425                }
3426                routes
3427            }
3428            Err(error) => return error.to_compile_error().into(),
3429        },
3430        None => HashSet::new(),
3431    };
3432
3433    let derive = match expand_analytical_derive(
3434        input.derive.as_ref(),
3435        &input.params,
3436        &input.covariates,
3437        &input.derived,
3438    ) {
3439        Ok(derive) => derive,
3440        Err(error) => return error.to_compile_error().into(),
3441    };
3442    let eq = expand_analytical_runtime(&kernel_spec.runtime_path, projection.kind());
3443
3444    let out = match expand_analytical_out(
3445        &input.out,
3446        &input.params,
3447        &input.derived,
3448        &input.covariates,
3449        &input.states,
3450        &input.outputs,
3451    ) {
3452        Ok(out) => out,
3453        Err(error) => return error.to_compile_error().into(),
3454    };
3455
3456    let lag = match input.lag.as_ref() {
3457        Some(closure) => {
3458            match expand_analytical_route_map(
3459                "lag",
3460                closure,
3461                &input.params,
3462                &input.derived,
3463                &input.covariates,
3464                &route_bindings,
3465            ) {
3466                Ok(lag) => lag,
3467                Err(error) => return error.to_compile_error().into(),
3468            }
3469        }
3470        None => quote! { |_, _, _| ::std::collections::HashMap::new() },
3471    };
3472
3473    let fa = match input.fa.as_ref() {
3474        Some(closure) => {
3475            match expand_analytical_route_map(
3476                "fa",
3477                closure,
3478                &input.params,
3479                &input.derived,
3480                &input.covariates,
3481                &route_bindings,
3482            ) {
3483                Ok(fa) => fa,
3484                Err(error) => return error.to_compile_error().into(),
3485            }
3486        }
3487        None => quote! { |_, _, _| ::std::collections::HashMap::new() },
3488    };
3489
3490    let init = match input.init.as_ref() {
3491        Some(closure) => {
3492            match expand_analytical_init(
3493                closure,
3494                &input.params,
3495                &input.derived,
3496                &input.covariates,
3497                &input.states,
3498            ) {
3499                Ok(init) => init,
3500                Err(error) => return error.to_compile_error().into(),
3501            }
3502        }
3503        None => quote! { |_, _, _, _| {} },
3504    };
3505
3506    let nstates = input.states.len();
3507    let ndrugs = dense_index_len(&route_bindings);
3508    let nout = input.outputs.len();
3509
3510    let name = &input.name;
3511    let params = &input.params;
3512    let covariates = &input.covariates;
3513    let states = &input.states;
3514    let outputs = &input.outputs;
3515    let routes = expand_analytical_route_metadata(&input.routes, &lag_routes, &fa_routes);
3516    let metadata_kernel = kernel_spec.metadata_kernel;
3517    let covariate_metadata = if covariates.is_empty() {
3518        quote! {}
3519    } else {
3520        quote! {
3521            .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*])
3522        }
3523    };
3524
3525    quote! {{
3526        #derive
3527        let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name)
3528            .kind(::pharmsol::equation::ModelKind::Analytical)
3529            .parameters([#(stringify!(#params)),*])
3530            #covariate_metadata
3531            .states([#(stringify!(#states)),*])
3532            .outputs([#(stringify!(#outputs)),*])
3533            #(.route(#routes))*
3534            .analytical_kernel(#metadata_kernel);
3535
3536        ::pharmsol::equation::Analytical::new(
3537            #eq,
3538            |_, _, _| {},
3539            #lag,
3540            #fa,
3541            #init,
3542            #out,
3543        )
3544        .with_nstates(#nstates)
3545        .with_ndrugs(#ndrugs)
3546        .with_nout(#nout)
3547        .with_metadata(__pharmsol_metadata)
3548        .expect("built-in `analytical!` generated invalid metadata")
3549    }}
3550    .into()
3551}
3552
3553#[proc_macro]
3554pub fn sde(input: TokenStream) -> TokenStream {
3555    let input = syn::parse_macro_input!(input as SdeInput);
3556    let route_bindings = ode_route_input_bindings(&input.routes);
3557
3558    let lag_routes = match input.lag.as_ref() {
3559        Some(closure) => match extract_route_property_routes(
3560            "declaration-first `sde!`",
3561            "lag",
3562            closure,
3563            &input.routes,
3564        ) {
3565            Ok(routes) => {
3566                if let Err(error) = validate_route_property_kinds(
3567                    "declaration-first `sde!`",
3568                    "lag",
3569                    &input.routes,
3570                    &routes,
3571                ) {
3572                    return error.to_compile_error().into();
3573                }
3574                routes
3575            }
3576            Err(error) => return error.to_compile_error().into(),
3577        },
3578        None => HashSet::new(),
3579    };
3580
3581    let fa_routes = match input.fa.as_ref() {
3582        Some(closure) => match extract_route_property_routes(
3583            "declaration-first `sde!`",
3584            "fa",
3585            closure,
3586            &input.routes,
3587        ) {
3588            Ok(routes) => {
3589                if let Err(error) = validate_route_property_kinds(
3590                    "declaration-first `sde!`",
3591                    "fa",
3592                    &input.routes,
3593                    &routes,
3594                ) {
3595                    return error.to_compile_error().into();
3596                }
3597                routes
3598            }
3599            Err(error) => return error.to_compile_error().into(),
3600        },
3601        None => HashSet::new(),
3602    };
3603
3604    let drift = match expand_sde_drift(
3605        &input.drift,
3606        &input.params,
3607        &input.covariates,
3608        &input.states,
3609        &input.routes,
3610        &route_bindings,
3611    ) {
3612        Ok(drift) => drift,
3613        Err(error) => return error.to_compile_error().into(),
3614    };
3615
3616    let diffusion = match expand_sde_diffusion(&input.diffusion, &input.params, &input.states) {
3617        Ok(diffusion) => diffusion,
3618        Err(error) => return error.to_compile_error().into(),
3619    };
3620
3621    let lag = match input.lag.as_ref() {
3622        Some(closure) => match expand_sde_route_map(
3623            "lag",
3624            closure,
3625            &input.params,
3626            &input.covariates,
3627            &route_bindings,
3628        ) {
3629            Ok(lag) => lag,
3630            Err(error) => return error.to_compile_error().into(),
3631        },
3632        None => quote! { |_, _, _| ::std::collections::HashMap::new() },
3633    };
3634
3635    let fa = match input.fa.as_ref() {
3636        Some(closure) => {
3637            match expand_sde_route_map(
3638                "fa",
3639                closure,
3640                &input.params,
3641                &input.covariates,
3642                &route_bindings,
3643            ) {
3644                Ok(fa) => fa,
3645                Err(error) => return error.to_compile_error().into(),
3646            }
3647        }
3648        None => quote! { |_, _, _| ::std::collections::HashMap::new() },
3649    };
3650
3651    let init = match input.init.as_ref() {
3652        Some(closure) => {
3653            match expand_sde_init(closure, &input.params, &input.covariates, &input.states) {
3654                Ok(init) => init,
3655                Err(error) => return error.to_compile_error().into(),
3656            }
3657        }
3658        None => quote! { |_, _, _, _| {} },
3659    };
3660
3661    let out = match expand_sde_out(
3662        &input.out,
3663        &input.params,
3664        &input.covariates,
3665        &input.states,
3666        &input.outputs,
3667    ) {
3668        Ok(out) => out,
3669        Err(error) => return error.to_compile_error().into(),
3670    };
3671
3672    let nstates = input.states.len();
3673    let ndrugs = dense_index_len(&route_bindings);
3674    let nout = input.outputs.len();
3675
3676    let name = &input.name;
3677    let params = &input.params;
3678    let covariates = &input.covariates;
3679    let states = &input.states;
3680    let outputs = &input.outputs;
3681    let particles = &input.particles;
3682    let routes = expand_sde_route_metadata(&input.routes, &lag_routes, &fa_routes);
3683    let bolus_mappings =
3684        expand_injected_sde_bolus_mappings(&input.routes, &input.states, &route_bindings);
3685    let covariate_metadata = if covariates.is_empty() {
3686        quote! {}
3687    } else {
3688        quote! {
3689            .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*])
3690        }
3691    };
3692
3693    quote! {{
3694        let __pharmsol_particles: usize = #particles;
3695        let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name)
3696            .kind(::pharmsol::equation::ModelKind::Sde)
3697            .parameters([#(stringify!(#params)),*])
3698            #covariate_metadata
3699            .states([#(stringify!(#states)),*])
3700            .outputs([#(stringify!(#outputs)),*])
3701            #(.route(#routes))*
3702            .particles(__pharmsol_particles);
3703
3704        ::pharmsol::equation::SDE::new(
3705            #drift,
3706            #diffusion,
3707            #lag,
3708            #fa,
3709            #init,
3710            #out,
3711            __pharmsol_particles,
3712        )
3713        .with_nstates(#nstates)
3714        .with_ndrugs(#ndrugs)
3715        .with_nout(#nout)
3716        #bolus_mappings
3717        .with_metadata(__pharmsol_metadata)
3718        .expect("declaration-first `sde!` generated invalid metadata")
3719    }}
3720    .into()
3721}
3722
3723#[cfg(test)]
3724mod tests {
3725    use super::*;
3726
3727    #[test]
3728    fn rejects_removed_legacy_form() {
3729        let error = syn::parse_str::<OdeInput>(
3730            "diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}",
3731        )
3732        .err()
3733        .expect("legacy macro form must fail");
3734
3735        assert!(error
3736            .to_string()
3737            .contains("requires `name`, `params`, `states`, `outputs`, and `routes`"));
3738        assert!(error
3739            .to_string()
3740            .contains("old inferred-dimensions form has been removed"));
3741    }
3742
3743    #[test]
3744    fn validates_route_destinations() {
3745        let error = syn::parse_str::<OdeInput>(
3746            "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: [infusion(iv) -> peripheral], diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}",
3747        )
3748        .err()
3749        .expect("unknown route destination must fail");
3750
3751        assert!(error
3752            .to_string()
3753            .contains("route destination `peripheral` is not declared in the `states` section"));
3754    }
3755
3756    #[test]
3757    fn rejects_named_binding_collisions() {
3758        let error = syn::parse_str::<OdeInput>(
3759            "name: \"demo\", params: [central, v], states: [central], outputs: [cp], routes: [infusion(iv) -> central], diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}",
3760        )
3761        .err()
3762        .expect("parameter/state binding collisions must fail");
3763
3764        assert!(error
3765            .to_string()
3766            .contains("named parameter binding `central` conflicts with named state binding"));
3767    }
3768
3769    #[test]
3770    fn ode_route_bindings_share_inputs_by_kind_local_ordinal() {
3771        let input = syn::parse_str::<OdeInput>(
3772            "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: [bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot], diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}",
3773        )
3774        .expect("declaration-first ode input should parse");
3775
3776        let bindings = ode_route_input_bindings(&input.routes);
3777
3778        assert_eq!(dense_index_len(&bindings), 2);
3779        assert_eq!(bindings[0].0.name(), "oral");
3780        assert_eq!(bindings[0].1, 0);
3781        assert_eq!(bindings[1].0.name(), "iv");
3782        assert_eq!(bindings[1].1, 0);
3783        assert_eq!(bindings[2].0.name(), "sc");
3784        assert_eq!(bindings[2].1, 1);
3785    }
3786
3787    #[test]
3788    fn generated_parameter_bindings_only_include_referenced_locals_in_hot_closures() {
3789        let params = vec![generated_ident("ke"), generated_ident("v")];
3790        let closure = syn::parse_str::<ExprClosure>(
3791            "|x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }",
3792        )
3793        .expect("closure should parse");
3794
3795        let bindings =
3796            generate_parameter_bindings(&params, &closure, &generated_ident("__pharmsol_p"))
3797                .to_string();
3798
3799        assert!(
3800            bindings.contains("let ke = __pharmsol_p [0usize] ;")
3801                || bindings.contains("let ke = __pharmsol_p [ 0 ] ;")
3802        );
3803        assert!(!bindings.contains("let v ="));
3804    }
3805
3806    #[test]
3807    fn generated_parameter_bindings_fall_back_to_all_params_for_stmt_macros() {
3808        let params = vec![generated_ident("ka"), generated_ident("tlag")];
3809        let closure = syn::parse_str::<ExprClosure>("|_p, _t, _cov| { lag! { oral => tlag } }")
3810            .expect("closure should parse");
3811
3812        let bindings =
3813            generate_parameter_bindings(&params, &closure, &generated_ident("__pharmsol_p"))
3814                .to_string();
3815
3816        assert!(bindings.contains("let ka ="));
3817        assert!(bindings.contains("let tlag ="));
3818    }
3819
3820    #[test]
3821    fn analytical_accepts_extra_parameters_beyond_kernel_arity() {
3822        let input = syn::parse_str::<AnalyticalInput>(
3823            "name: \"demo\", params: [ka, ke0, v, tlag, tvke], derived: [ke], covariates: [wt, renal], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { ke = tvke; }, out: |x, p, t, cov, y| {}",
3824        )
3825        .expect("extra declared parameters should be allowed");
3826
3827        assert_eq!(input.params.len(), 5);
3828        assert_eq!(input.derived.len(), 1);
3829        assert_eq!(input.covariates.len(), 2);
3830        assert!(input.derive.is_some());
3831        assert_eq!(input.states.len(), 2);
3832    }
3833
3834    #[test]
3835    fn analytical_rejects_legacy_sec_with_migration_message() {
3836        let error = syn::parse_str::<AnalyticalInput>(
3837            "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, sec: |_t| { ke = 1.0; }, out: |x, p, t, cov, y| {}",
3838        )
3839        .err()
3840        .expect("legacy sec must fail");
3841
3842        assert!(error
3843            .to_string()
3844            .contains("no longer supports `sec`; use `derived: [...]` plus `derive: ...`"));
3845    }
3846
3847    #[test]
3848    fn analytical_rejects_unknown_structure() {
3849        let error = syn::parse_str::<AnalyticalInput>(
3850            "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: [infusion(iv) -> central], structure: mystery, out: |x, p, t, cov, y| {}",
3851        )
3852        .err()
3853        .expect("unknown analytical structure must fail");
3854
3855        assert!(error
3856            .to_string()
3857            .contains("unknown analytical structure `mystery`"));
3858    }
3859
3860    #[test]
3861    fn analytical_rejects_missing_required_structure_name() {
3862        let error = syn::parse_str::<AnalyticalInput>(
3863            "name: \"demo\", params: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}",
3864        )
3865        .err()
3866        .expect("missing required structure name must fail");
3867
3868        assert!(error.to_string().contains("requires `ka`"));
3869    }
3870
3871    #[test]
3872    fn analytical_rejects_overlap_between_params_and_derived() {
3873        let error = syn::parse_str::<AnalyticalInput>(
3874            "name: \"demo\", params: [ka, ke, v], derived: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { ke = 1.0; }, out: |x, p, t, cov, y| {}",
3875        )
3876        .err()
3877        .expect("overlap must fail");
3878
3879        assert!(error
3880            .to_string()
3881            .contains("`ke` is declared in both `params` and `derived`"));
3882    }
3883
3884    #[test]
3885    fn analytical_rejects_invalid_derive_target() {
3886        let error = syn::parse_str::<AnalyticalInput>(
3887            "name: \"demo\", params: [ka, ke0, v], derived: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { ke0 = 1.0; ke = 0.1; }, out: |x, p, t, cov, y| {}",
3888        )
3889        .err()
3890        .expect("invalid derive target must fail");
3891
3892        assert!(error
3893            .to_string()
3894            .contains("`derive` cannot assign to `ke0`"));
3895    }
3896
3897    #[test]
3898    fn analytical_rejects_if_only_assignment_for_required_derived_name() {
3899        let error = syn::parse_str::<AnalyticalInput>(
3900            "name: \"demo\", params: [ka, ke0, v], derived: [ke], covariates: [wt], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { if wt > 70.0 { ke = ke0; } }, out: |x, p, t, cov, y| {}",
3901        )
3902        .err()
3903        .expect("bare if must fail");
3904
3905        assert!(error
3906            .to_string()
3907            .contains("not definitely assigned on every path"));
3908    }
3909
3910    #[test]
3911    fn analytical_accepts_if_else_assignment_for_required_derived_name() {
3912        syn::parse_str::<AnalyticalInput>(
3913            "name: \"demo\", params: [ka, ke0, v], derived: [ke], covariates: [wt], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { if wt > 70.0 { ke = ke0; } else { ke = ke0 * 0.5; } }, out: |x, p, t, cov, y| {}",
3914        )
3915        .expect("if / else should establish derived assignment");
3916    }
3917
3918    #[test]
3919    fn analytical_rejects_loop_only_assignment_for_required_derived_name() {
3920        let error = syn::parse_str::<AnalyticalInput>(
3921            "name: \"demo\", params: [ka, ke0, v], derived: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { for i in 0..1 { let _ = i; ke = ke0; } }, out: |x, p, t, cov, y| {}",
3922        )
3923        .err()
3924        .expect("loop-only assignment must fail");
3925
3926        assert!(error
3927            .to_string()
3928            .contains("not definitely assigned on every path"));
3929    }
3930
3931    #[test]
3932    fn analytical_accepts_initial_assignment_followed_by_loop_updates() {
3933        syn::parse_str::<AnalyticalInput>(
3934            "name: \"demo\", params: [ka, ke0, v], derived: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, derive: |_t| { ke = ke0; for i in 0..2 { let _ = i; ke = ke + 1.0; } }, out: |x, p, t, cov, y| {}",
3935        )
3936        .expect("initial assignment plus loop updates should pass");
3937    }
3938
3939    #[test]
3940    fn analytical_rejects_unknown_route_property_binding() {
3941        let error = syn::parse_str::<AnalyticalInput>(
3942            "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { iv => 1.0 } }, out: |x, p, t, cov, y| {}",
3943        )
3944        .err()
3945        .expect("unknown lag route must fail");
3946
3947        assert!(error
3948            .to_string()
3949            .contains("route `iv` in `lag!` is not declared in the `routes` section"));
3950    }
3951
3952    #[test]
3953    fn analytical_rejects_infusion_lag_binding() {
3954        let error = syn::parse_str::<AnalyticalInput>(
3955            "name: \"demo\", params: [ke, v, tlag], states: [central], outputs: [cp], routes: [infusion(iv) -> central], structure: one_compartment, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}",
3956        )
3957        .err()
3958        .expect("infusion lag must fail");
3959
3960        assert!(error
3961            .to_string()
3962            .contains("built-in `analytical!` does not allow `lag` on infusion route `iv`"));
3963    }
3964
3965    #[test]
3966    fn sde_requires_particles() {
3967        let error = syn::parse_str::<SdeInput>(
3968            "name: \"demo\", params: [ke, theta], states: [central], outputs: [cp], routes: [infusion(iv) -> central], drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, out: |x, p, t, cov, y| {}",
3969        )
3970        .err()
3971        .expect("missing particles must fail");
3972
3973        assert!(error
3974            .to_string()
3975            .contains("missing required field `particles` in declaration-first `sde!`"));
3976    }
3977
3978    #[test]
3979    fn sde_rejects_unknown_route_property_binding() {
3980        let error = syn::parse_str::<SdeInput>(
3981            "name: \"demo\", params: [ke, sigma_ke], states: [central], outputs: [cp], routes: [infusion(iv) -> central], particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { oral => 1.0 } }, out: |x, p, t, cov, y| {}",
3982        )
3983        .err()
3984        .expect("unknown lag route must fail");
3985
3986        assert!(error
3987            .to_string()
3988            .contains("route `oral` in `lag!` is not declared in the `routes` section"));
3989    }
3990
3991    #[test]
3992    fn sde_rejects_infusion_lag_binding() {
3993        let error = syn::parse_str::<SdeInput>(
3994            "name: \"demo\", params: [ke, sigma_ke, tlag], states: [central], outputs: [cp], routes: [infusion(iv) -> central], particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}",
3995        )
3996        .err()
3997        .expect("infusion lag must fail");
3998
3999        assert!(error
4000            .to_string()
4001            .contains("declaration-first `sde!` does not allow `lag` on infusion route `iv`"));
4002    }
4003
4004    #[test]
4005    fn rejects_braced_route_lists() {
4006        let error = syn::parse_str::<OdeInput>(
4007            "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}",
4008        )
4009        .err()
4010        .expect("braced route lists must fail");
4011
4012        assert!(error
4013            .to_string()
4014            .contains("declaration-first macro `routes` must use `[...]`, not `{...}`"));
4015    }
4016}