Skip to main content

asupersync_macros/
lib.rs

1//! Proc macros for asupersync structured concurrency runtime.
2//!
3//! This crate provides procedural macros that simplify working with the asupersync
4//! async runtime's structured concurrency primitives. The macros handle the boilerplate
5//! for creating scopes, spawning tasks, joining results, and racing computations.
6//!
7//! # Available Macros
8//!
9//! - [`scope!`] - Create a structured concurrency scope
10//! - [`spawn!`] - Spawn a task within the current scope
11//! - [`join!`] - Join multiple futures, waiting for all to complete
12//! - [`join_all!`] - Join multiple futures into an array
13//! - [`race!`] - Race multiple futures, returning the first to complete
14//! - [`session_protocol!`] - Generate typestate session protocols
15//! - [`conformance`] - Annotate conformance tests
16//!
17//! # Contract With `asupersync`
18//!
19//! The root `asupersync` crate re-exports only the supported runtime DSL:
20//! `scope!`, `spawn!`, `join!`, `join_all!`, and `race!`, and only when the
21//! `proc-macros` feature is enabled.
22//!
23//! This crate also defines `session_protocol!` and `#[conformance]`, but those
24//! remain explicit-path macros on `asupersync_macros`; they are not part of the
25//! default root macro contract.
26//!
27//! # Example
28//!
29//! ```ignore
30//! use asupersync_macros::{scope, spawn, join, race};
31//!
32//! async fn example(cx: &Cx, state: &mut RuntimeState) {
33//!     scope!(cx, state: state, {
34//!         let handle1 = spawn!(async { compute_a().await });
35//!         let handle2 = spawn!(async { compute_b().await });
36//!
37//!         // Wait for both
38//!         let (result_a, result_b) = join!(handle1, handle2);
39//!     });
40//! }
41//! ```
42
43mod instrument;
44mod join;
45mod race;
46mod scope;
47mod session;
48mod spawn;
49mod util;
50
51use proc_macro::TokenStream;
52
53/// Creates a structured concurrency scope.
54///
55/// The `scope!` macro creates a [`Scope`](asupersync::Scope) binding for the
56/// current `Cx` region and makes it available as `scope` inside the body.
57///
58/// Today this is an ergonomic binding helper, not a fresh child-region
59/// boundary. For actual child-region ownership and quiescence, call
60/// [`Scope::region`](asupersync::Scope::region) explicitly.
61///
62/// # Syntax
63///
64/// ```ignore
65/// scope!(cx, {
66///     // body with spawned tasks
67/// })
68/// scope!(cx, state: &mut state, {
69///     let _child = spawn!(async { work().await });
70/// })
71/// ```
72///
73/// # Arguments
74///
75/// - `cx` - The capability context (`&Cx`)
76/// - `body` - A block containing the scope's work
77/// - `state` - Optional runtime state binding used by nested `spawn!` calls
78///
79/// # Returns
80///
81/// The result of the scope body.
82///
83/// # Example
84///
85/// ```ignore
86/// scope!(cx, state: &mut state, {
87///     spawn!(async { work_a().await });
88///     spawn!(async { work_b().await });
89///     // Both tasks are awaited before scope exits
90/// })
91/// ```
92#[proc_macro]
93pub fn scope(input: TokenStream) -> TokenStream {
94    scope::scope_impl(input)
95}
96
97/// Spawns a task within the current scope.
98///
99/// The `spawn!` macro expands to [`Scope::spawn_registered`], so it requires
100/// ambient `__state` and `__cx` bindings in addition to the target `Scope`.
101///
102/// The easiest supported path is to use it inside `scope!(..., state: ..., { ... })`.
103///
104/// # Syntax
105///
106/// ```ignore
107/// spawn!(async { /* work */ })
108/// spawn!(async move { /* work with captured values */ })
109/// ```
110///
111/// # Returns
112///
113/// A `TaskHandle` that can be awaited to get the task's result.
114///
115/// # Example
116///
117/// ```ignore
118/// let handle = spawn!(async {
119///     expensive_computation().await
120/// });
121/// let result = handle.await;
122/// ```
123#[proc_macro]
124pub fn spawn(input: TokenStream) -> TokenStream {
125    spawn::spawn_impl(input)
126}
127
128/// Joins multiple futures, waiting for all to complete.
129///
130/// The `join!` macro is a supported proc-macro convenience surface, but the
131/// current implementation still awaits branches sequentially. It preserves
132/// left-to-right evaluation and tuple ordering today; parallel polling remains
133/// future work.
134///
135/// # Syntax
136///
137/// ```ignore
138/// join!(future1, future2, ...)
139/// ```
140///
141/// # Returns
142///
143/// A tuple of all the futures' results in the order they were specified.
144///
145/// # Outcome Semantics
146///
147/// The combined outcome follows the severity lattice:
148/// - If all succeed: `Outcome::Ok((r1, r2, ...))`
149/// - If any fails: the most severe outcome is propagated
150///
151/// # Example
152///
153/// ```ignore
154/// let (a, b, c) = join!(
155///     fetch_user().await,
156///     fetch_profile().await,
157///     fetch_settings().await
158/// );
159/// ```
160#[proc_macro]
161pub fn join(input: TokenStream) -> TokenStream {
162    join::join_impl(input)
163}
164
165/// Joins multiple futures into an array, waiting for all to complete.
166///
167/// The `join_all!` macro is like `join!` but returns an array instead of a
168/// tuple. Like `join!`, the current implementation still awaits branches
169/// sequentially.
170///
171/// # Syntax
172///
173/// ```ignore
174/// join_all!(future1, future2, ...)
175/// ```
176///
177/// # Returns
178///
179/// An array of all the futures' results in the order they were specified.
180/// Since all results must be the same type, this enables easier iteration.
181///
182/// # Example
183///
184/// ```ignore
185/// let results: [i32; 3] = join_all!(
186///     fetch_value(1).await,
187///     fetch_value(2).await,
188///     fetch_value(3).await
189/// );
190/// for result in results {
191///     println!("{}", result);
192/// }
193/// ```
194#[proc_macro]
195pub fn join_all(input: TokenStream) -> TokenStream {
196    join::join_all_impl(input)
197}
198
199/// Races multiple futures, returning the first to complete.
200///
201/// The `race!` macro expands to the inline [`Cx::race*`](asupersync::Cx::race)
202/// family. The losing futures are cancelled by drop, but they are not drained.
203///
204/// If you need the stronger "losers are drained" invariant, race spawned tasks
205/// with [`Scope::race`](asupersync::Scope::race) instead.
206///
207/// # Syntax
208///
209/// ```ignore
210/// race!(cx, { future1, future2, ... })
211/// race!(cx, { "name" => future1, "other" => future2, ... })
212/// race!(cx, timeout: Duration::from_secs(5), { future1, future2, ... })
213/// ```
214///
215/// # Returns
216///
217/// The result of the winning future.
218///
219/// # Loser Cleanup
220///
221/// All non-winning futures are dropped, which requests cancellation for inline
222/// futures but does not await their cleanup path.
223///
224/// # Example
225///
226/// ```ignore
227/// let result = race!(cx, {
228///     primary_service.fetch().await,
229///     backup_service.fetch().await,
230/// });
231/// // One completed; the loser was cancelled by drop but not drained.
232/// ```
233#[proc_macro]
234pub fn race(input: TokenStream) -> TokenStream {
235    race::race_impl(input)
236}
237
238/// Instruments a function or impl method with a tracing span.
239///
240/// The generated wrapper uses `asupersync::tracing_compat`, so it creates real
241/// spans when `tracing-integration` is enabled and becomes a no-op when tracing
242/// is disabled.
243///
244/// Supported arguments:
245///
246/// - `name = "custom_name"` overrides the span name
247/// - `level = "trace" | "debug" | "info" | "warn" | "error"` sets span level
248/// - `skip(arg1, arg2, ...)` excludes arguments from captured fields
249///
250/// # Examples
251///
252/// ```ignore
253/// use asupersync::tracing_compat::instrument;
254///
255/// #[instrument]
256/// async fn load_user(user_id: u64) -> Result<(), Error> {
257///     Ok(())
258/// }
259///
260/// #[instrument(name = "cache_refresh", level = "debug", skip(secret))]
261/// fn refresh(secret: &Secret, key: &str) {}
262/// ```
263#[proc_macro_attribute]
264pub fn instrument(attr: TokenStream, item: TokenStream) -> TokenStream {
265    instrument::instrument_impl(attr, item)
266}
267
268/// Marks a test with the specification section and requirement it validates.
269///
270/// # Syntax
271///
272/// ```ignore
273/// #[conformance(spec = "3.2.1", requirement = "Region close waits for all children")]
274/// #[test]
275/// fn test_region_close_waits() { /* ... */ }
276/// ```
277///
278/// The macro is validation-only: it checks that `spec` and `requirement` are
279/// present and string literals, then leaves the item unchanged.
280#[proc_macro_attribute]
281pub fn conformance(attr: TokenStream, item: TokenStream) -> TokenStream {
282    match parse_conformance_args(&attr) {
283        Ok(_) => item,
284        Err(message) => util::compile_error(&message).into(),
285    }
286}
287
288/// Generates typestate-encoded session types from a protocol DSL.
289///
290/// The macro takes a protocol specification and generates a module containing
291/// message structs, paired session type aliases (initiator + responder), and
292/// constructor functions. The responder type is the dual of the initiator:
293/// `Send`↔`Recv`, `Select`↔`Offer`.
294///
295/// # Syntax
296///
297/// ```ignore
298/// session_protocol! {
299///     module_name<T> for ObligationVariant {
300///         msg MessageName;
301///         msg MessageWithFields { field: Type };
302///
303///         send MessageName => select {
304///             send T => end,
305///             send OtherMsg => end,
306///         }
307///     }
308/// }
309/// ```
310///
311/// # Body Actions
312///
313/// - `send Type => body` — send a value, then continue
314/// - `recv Type => body` — receive a value, then continue
315/// - `select { a, b }` — local choice (becomes `Offer` for responder)
316/// - `offer { a, b }` — remote choice (becomes `Select` for responder)
317/// - `loop { body }` — recursion point (generates `renew_loop` constructor)
318/// - `continue` — jump back to enclosing `loop`
319/// - `end` — protocol termination
320///
321/// # Generated Items
322///
323/// - `pub mod <name>` containing:
324///   - Message structs with `Debug, Clone` (+ `Copy` for unit structs)
325///   - `InitiatorSession` type alias
326///   - `ResponderSession` type alias
327///   - `new_session(channel_id) -> (Chan<Initiator, ...>, Chan<Responder, ...>)`
328///   - (if `loop` used) `InitiatorLoop`, `ResponderLoop` type aliases
329///   - (if `loop` used) `renew_loop(channel_id)` constructor
330///
331/// # Example
332///
333/// ```ignore
334/// session_protocol! {
335///     lease for Lease {
336///         msg AcquireMsg;
337///         msg RenewMsg;
338///         msg ReleaseMsg;
339///
340///         send AcquireMsg => loop {
341///             select {
342///                 send RenewMsg => continue,
343///                 send ReleaseMsg => end,
344///             }
345///         }
346///     }
347/// }
348/// ```
349#[proc_macro]
350pub fn session_protocol(input: TokenStream) -> TokenStream {
351    session::session_protocol_impl(input)
352}
353
354#[derive(Debug, Clone, PartialEq, Eq)]
355struct ConformanceArgs {
356    spec: String,
357    requirement: String,
358}
359
360fn parse_conformance_args(attr: &TokenStream) -> Result<ConformanceArgs, String> {
361    parse_conformance_args_str(&attr.to_string())
362}
363
364fn parse_conformance_args_str(input: &str) -> Result<ConformanceArgs, String> {
365    let raw = input.trim();
366    if raw.is_empty() {
367        return Err("conformance attribute requires arguments".to_string());
368    }
369
370    let mut spec = None;
371    let mut requirement = None;
372
373    for part in split_args(raw) {
374        let part = part.trim();
375        if part.is_empty() {
376            continue;
377        }
378        let (key, value) = split_key_value(part)?;
379        let value = parse_string_literal(value)?;
380        match key {
381            "spec" => spec = Some(value),
382            "requirement" => requirement = Some(value),
383            other => {
384                return Err(format!(
385                    "conformance attribute has unknown key '{other}', expected 'spec' or 'requirement'"
386                ));
387            }
388        }
389    }
390
391    let spec = spec.ok_or_else(|| "conformance attribute missing 'spec'".to_string())?;
392    let requirement =
393        requirement.ok_or_else(|| "conformance attribute missing 'requirement'".to_string())?;
394
395    Ok(ConformanceArgs { spec, requirement })
396}
397
398fn split_args(input: &str) -> Vec<String> {
399    let mut parts = Vec::new();
400    let mut current = String::new();
401    let mut in_string = false;
402    let mut escape = false;
403
404    for ch in input.chars() {
405        if in_string {
406            current.push(ch);
407            if escape {
408                escape = false;
409                continue;
410            }
411            if ch == '\\' {
412                escape = true;
413            } else if ch == '"' {
414                in_string = false;
415            }
416            continue;
417        }
418
419        match ch {
420            '"' => {
421                in_string = true;
422                current.push(ch);
423            }
424            ',' => {
425                parts.push(current);
426                current = String::new();
427            }
428            _ => current.push(ch),
429        }
430    }
431
432    if !current.trim().is_empty() {
433        parts.push(current);
434    }
435
436    parts
437}
438
439fn split_key_value(input: &str) -> Result<(&str, &str), String> {
440    let mut iter = input.splitn(2, '=');
441    let key = iter
442        .next()
443        .map(str::trim)
444        .filter(|s| !s.is_empty())
445        .ok_or_else(|| "conformance attribute expects key = \"value\" pairs".to_string())?;
446    let value = iter
447        .next()
448        .map(str::trim)
449        .filter(|s| !s.is_empty())
450        .ok_or_else(|| format!("conformance attribute missing value for '{key}'"))?;
451    Ok((key, value))
452}
453
454fn parse_string_literal(input: &str) -> Result<String, String> {
455    let trimmed = input.trim();
456    if !trimmed.starts_with('"') || !trimmed.ends_with('"') {
457        return Err(format!(
458            "conformance attribute values must be string literals, got: {trimmed}"
459        ));
460    }
461    let inner = &trimmed[1..trimmed.len() - 1];
462    let mut out = String::new();
463    let mut chars = inner.chars();
464    while let Some(ch) = chars.next() {
465        if ch == '\\' {
466            let next = chars.next().ok_or_else(|| {
467                "conformance attribute contains dangling escape sequence".to_string()
468            })?;
469            match next {
470                '\\' => out.push('\\'),
471                '"' => out.push('"'),
472                'n' => out.push('\n'),
473                'r' => out.push('\r'),
474                't' => out.push('\t'),
475                other => {
476                    return Err(format!(
477                        "conformance attribute contains unsupported escape: \\{other}"
478                    ));
479                }
480            }
481        } else {
482            out.push(ch);
483        }
484    }
485    Ok(out)
486}
487
488#[cfg(test)]
489mod tests {
490    use super::parse_conformance_args_str;
491
492    #[test]
493    fn parse_conformance_args_ok() {
494        let args =
495            parse_conformance_args_str(r#"spec = "3.2.1", requirement = "Region close waits""#)
496                .unwrap();
497        assert_eq!(args.spec, "3.2.1");
498        assert_eq!(args.requirement, "Region close waits");
499    }
500
501    #[test]
502    fn parse_conformance_args_missing_spec() {
503        let err = parse_conformance_args_str(r#"requirement = "Region close waits""#).unwrap_err();
504        assert!(err.contains("missing 'spec'"));
505    }
506
507    #[test]
508    fn parse_conformance_args_missing_requirement() {
509        let err = parse_conformance_args_str(r#"spec = "3.2.1""#).unwrap_err();
510        assert!(err.contains("missing 'requirement'"));
511    }
512}