Skip to main content

pipe_it/
lib.rs

1use std::{
2    any::{self, TypeId},
3    collections::HashMap,
4    future::Future,
5    ops::{Deref, DerefMut},
6    sync::Arc,
7};
8
9use tokio::sync;
10
11pub use pipeit_derive::node;
12
13pub extern crate pipeit_derive;
14
15pub mod cocurrency;
16pub mod ext;
17pub mod handler;
18pub mod tag;
19// store dynamic data
20/// A map storing shared dependencies, keyed by their TypeId.
21/// Wrapped in an Arc for thread-safe sharing and using Arc<RwLock> to support owned guards.
22#[derive(Clone)]
23pub(crate) struct DendencyMap(
24    Arc<HashMap<TypeId, Arc<sync::RwLock<Box<dyn any::Any + Send + Sync>>>>>,
25);
26
27/// A collection of shared resources that can be injected into pipeline functions.
28///
29/// This acts as a builder to conveniently register resources before creating a Context.
30#[derive(Default, Clone)]
31pub struct Shared {
32    inner: HashMap<TypeId, Arc<sync::RwLock<Box<dyn any::Any + Send + Sync>>>>,
33}
34
35impl Shared {
36    /// Creates a new, empty shared resource collection.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Adds a new resource to the collection.
42    /// The resource will be wrapped in an Arc<RwLock> to allow concurrent access.
43    pub fn insert<T: Send + Sync + 'static>(mut self, resource: T) -> Self {
44        self.inner.insert(
45            TypeId::of::<T>(),
46            Arc::new(sync::RwLock::new(
47                Box::new(resource) as Box<dyn any::Any + Send + Sync>
48            )),
49        );
50        self
51    }
52}
53
54/// A read-only resource container that mirrors Bevy's Res.
55/// It holds an owned read lock on a shared dependency and provides strong typing via Deref.
56pub struct Res<T>(
57    sync::OwnedRwLockReadGuard<Box<dyn any::Any + Send + Sync>>,
58    std::marker::PhantomData<T>,
59);
60
61/// A mutable resource container that mirrors Bevy's ResMut.
62/// It holds an owned write lock on a shared dependency and provides strong typing via Deref/DerefMut.
63pub struct ResMut<T>(
64    sync::OwnedRwLockWriteGuard<Box<dyn any::Any + Send + Sync>>,
65    std::marker::PhantomData<T>,
66);
67
68impl<T: 'static> Deref for Res<T> {
69    type Target = T;
70    fn deref(&self) -> &Self::Target {
71        // Here we perform the downcast from dyn Any to the strong type T.
72        // This is exactly how Bevy provides strong typing from a generic container.
73        (**self.0)
74            .downcast_ref::<T>()
75            .expect("Resource type mismatch during Res deref")
76    }
77}
78
79impl<T: 'static> Deref for ResMut<T> {
80    type Target = T;
81    fn deref(&self) -> &Self::Target {
82        (**self.0)
83            .downcast_ref::<T>()
84            .expect("Resource type mismatch during ResMut deref")
85    }
86}
87
88impl<T: 'static> DerefMut for ResMut<T> {
89    fn deref_mut(&mut self) -> &mut Self::Target {
90        // Performs mutable downcast to provide &mut T.
91        (**self.0)
92            .downcast_mut::<T>()
93            .expect("Resource type mismatch during ResMut deref_mut")
94    }
95}
96
97/// The execution context for a pipeline, carrying both shared dependencies and the specific input.
98pub struct Context<Input>
99where
100    Input: ?Sized,
101{
102    shared: DendencyMap,
103    input: Arc<Input>,
104}
105
106impl<I> Context<I> {
107    /// Creates a new context with the provided input and shared dependency collection.
108    pub fn new(input: I, shared: Shared) -> Self {
109        Self {
110            shared: DendencyMap(Arc::new(shared.inner)),
111            input: Arc::new(input),
112        }
113    }
114    pub fn empty(input: I) -> Self {
115        Self {
116            shared: DendencyMap(Arc::new(Shared::new().inner)),
117            input: Arc::new(input),
118        }
119    }
120    /// Replaces the input of the current context while keeping the same shared dependencies.
121    /// Returns a new Context with the updated input type.
122    pub(crate) fn replace<NewInput>(self, input: NewInput) -> Context<NewInput> {
123        Context {
124            shared: self.shared,
125            input: Arc::new(input),
126        }
127    }
128
129    /// Returns a reference to the current input wrapped in an Arc.
130    pub(crate) fn input(&self) -> Arc<I> {
131        self.input.clone()
132    }
133
134    /// Consumes the Context and returns the input Arc and the shared dependencies.
135    /// This allows avoiding the clone of the input if the Context is no longer needed.
136    pub(crate) fn into_parts(self) -> (Arc<I>, DendencyMap) {
137        (self.input, self.shared)
138    }
139
140    /// Reconstructs a Context from its parts.
141    pub(crate) fn from_parts(input: Arc<I>, shared: DendencyMap) -> Self {
142        Self { shared, input }
143    }
144    /// # Example
145    ///
146    /// ```rust
147    /// use pipe_it::{Context, Input, Res, ResMut, Shared};
148    ///
149    /// #[derive(Debug, Clone)]
150    /// struct Counter {
151    ///     c: i32,
152    /// }
153    ///
154    /// #[tokio::main]
155    /// async fn main() {
156    ///     let ctx = Context::new(3, Shared::new().insert(Counter { c: 1 }));
157    ///     ctx.invoke(async |x: Input<i32>, mut counter: ResMut<Counter>| {
158    ///         counter.c += 1;
159    ///         *x + 1
160    ///     })
161    ///     .await
162    ///     .invoke(async |x: Input<i32>, counter: Res<Counter>| *x + counter.c)
163    ///     .await
164    ///     .invoke(async |x: Input<i32>| assert_eq!(*x, 6))
165    ///     .await;
166    /// }
167    /// ```
168    pub async fn invoke<H, Args, O>(self, handler: H) -> Context<O>
169    where
170        H: handler::Handler<I, O, Args>,
171        I: Clone + Send + Sync + 'static,
172        O: Send + Sync + 'static,
173        Args: Send + Sync + 'static,
174    {
175        let (input, shared) = self.into_parts();
176        let shared_cloned = shared.clone();
177        // move input to avoid increasing ref counter
178        let new_ctx = Context {
179            input,
180            shared: shared_cloned,
181        };
182        let output = handler.call(new_ctx).await;
183        Context {
184            shared: shared,
185            input: Arc::new(output),
186        }
187    }
188}
189impl<Input> Clone for Context<Input>
190where
191    Input: ?Sized,
192{
193    fn clone(&self) -> Self {
194        Self {
195            shared: self.shared.clone(),
196            input: self.input.clone(),
197        }
198    }
199}
200
201/// A trait for types that can be extracted from a Context.
202/// Similar to Axum's FromRequest or Bevy's SystemParam.
203pub trait FromContext<Input>: Send {
204    fn from(ctx: Context<Input>) -> impl Future<Output = Self> + Send;
205}
206
207/// A wrapper for the primary input data of the pipeline.
208#[derive(Clone)]
209pub struct Input<T>(pub Arc<T>);
210
211impl<T> Input<T> {
212    /// Returns the inner value, if the Arc has exactly one strong reference.
213    /// Otherwise, an Err is returned with the same Arc that was passed in.
214    /// // Optimized: consume context to steal Arc without cloning if possible
215    pub fn try_unwrap(self) -> Result<T, Arc<T>> {
216        Arc::try_unwrap(self.0)
217    }
218}
219impl<T> Deref for Input<T> {
220    type Target = T;
221    fn deref(&self) -> &Self::Target {
222        &self.0
223    }
224}
225
226impl<I: Send + Sync + 'static> FromContext<I> for Input<I> {
227    fn from(ctx: Context<I>) -> impl Future<Output = Self> + Send {
228        let input = ctx.input.clone();
229        async move { Input(input) }
230    }
231}
232
233impl<I, T> FromContext<I> for Res<T>
234where
235    I: Send + Sync + 'static,
236    T: Send + Sync + 'static,
237{
238    fn from(ctx: Context<I>) -> impl Future<Output = Self> + Send {
239        async move {
240            let shared = ctx.shared.0.clone();
241            let dep = shared
242                .get(&TypeId::of::<T>())
243                .expect("Dependency not found")
244                .clone();
245            let guard = dep.read_owned().await;
246            Res(guard, std::marker::PhantomData)
247        }
248    }
249}
250
251impl<I, T> FromContext<I> for ResMut<T>
252where
253    I: Send + Sync + 'static,
254    T: Send + Sync + 'static,
255{
256    fn from(ctx: Context<I>) -> impl Future<Output = Self> + Send {
257        async move {
258            let shared = ctx.shared.0.clone();
259            let dep = shared
260                .get(&TypeId::of::<T>())
261                .expect("Dependency not found")
262                .clone();
263            let guard = dep.write_owned().await;
264            ResMut(guard, std::marker::PhantomData)
265        }
266    }
267}
268
269/// Represents a pipeline unit that can be applied to a Context.
270/// Implementations are automatically provided for functions that match the signature.
271pub trait Pipeline<I, O>: Send + Sync + 'static {
272    fn apply(&self, ctx: Context<I>) -> impl Future<Output = O> + Send;
273}
274
275/// Errors occurring during pipeline execution.
276#[derive(Debug, Clone)]
277pub enum PipelineError {
278    Failure { msg: String, expected: String },
279    Fatal { msg: String },
280}
281
282/// Standard Result type for pipeline operations.
283pub type PResult<O, E = PipelineError> = Result<O, E>;
284
285/// A pipeline that executes only if a predicate on the input is met.
286pub struct Cond<F, P, I, O> {
287    predicate: F,
288    next: P,
289    _marker: std::marker::PhantomData<fn(I, O)>,
290}
291
292impl<F, P, I, O> Pipeline<I, PResult<O>> for Cond<F, P, I, O>
293where
294    F: Pipeline<I, bool>,
295    P: Pipeline<I, O>,
296    I: Clone + Send + Sync + 'static,
297    O: Send + 'static,
298    F: Send + Sync + 'static,
299    P: Send + Sync + 'static,
300{
301    fn apply(&self, ctx: Context<I>) -> impl Future<Output = PResult<O>> + Send {
302        async move {
303            let matched = self.predicate.apply(ctx.clone()).await;
304            if matched {
305                Ok(self.next.apply(ctx).await)
306            } else {
307                Err(PipelineError::Failure {
308                    msg: "Condition not met".to_string(),
309                    expected: "true".to_string(),
310                })
311            }
312        }
313    }
314}
315
316/// Creates a conditional pipeline.
317/// If the predicate returns true, the next pipeline is executed.
318/// If false, it returns a PipelineError::Failure.
319///
320/// # Example
321///
322/// ```rust
323/// use pipe_it::{cond, Context, Pipeline, Input, ext::HandlerExt};
324///
325/// async fn is_even(n: Input<i32>) -> bool { *n % 2 == 0 }
326/// async fn process(n: Input<i32>) -> String { "Even".to_string() }
327///
328/// # #[tokio::main]
329/// # async fn main() {
330/// let pipe = cond(is_even, process);
331///
332/// // Success case
333/// let result = pipe.apply(Context::empty(2)).await;
334/// assert_eq!(result.unwrap(), "Even");
335///
336/// // Failure case
337/// let result = pipe.apply(Context::empty(1)).await;
338/// assert!(result.is_err());
339/// # }
340/// ```
341pub fn cond<I, O, F, P, ArgsF, ArgsP>(
342    predicate: F,
343    next: P,
344) -> Cond<crate::ext::Pipe<F, ArgsF>, crate::ext::Pipe<P, ArgsP>, I, O>
345where
346    F: crate::handler::Handler<I, bool, ArgsF>,
347    P: crate::handler::Handler<I, O, ArgsP>,
348    I: Clone + Send + Sync + 'static,
349    O: Send + 'static,
350    ArgsF: Send + Sync + 'static,
351    ArgsP: Send + Sync + 'static,
352{
353    use crate::ext::HandlerExt;
354    Cond {
355        predicate: predicate.pipe(),
356        next: next.pipe(),
357        _marker: std::marker::PhantomData,
358    }
359}
360
361/// A wrapper used as a marker for choice-based pipeline execution.
362pub struct Choice<T>(pub T);
363
364macro_rules! impl_pipeline_for_tuple {
365    ($($P:ident),+) => {
366        impl<I, O, $($P),+ > Pipeline<I, PResult<O>> for ($($P,)+)
367        where
368            I: Clone + Send + Sync + 'static,
369            O: Send + 'static,
370            $( $P: Pipeline<I, PResult<O>> ),+
371        {
372            fn apply(&self, ctx: Context<I>) -> impl Future<Output = PResult<O>> + Send {
373                #[allow(non_snake_case)]
374                let ($($P,)+) = self;
375                async move {
376                    $(
377                        match $P.apply(ctx.clone()).await {
378                            Ok(res) => return Ok(res),
379                            Err(PipelineError::Fatal { msg }) => return Err(PipelineError::Fatal { msg }),
380                            Err(PipelineError::Failure { .. }) => {}
381                        }
382                    )*
383                    Err(PipelineError::Fatal { msg: "All pipeline branches failed".to_string() })
384                }
385            }
386        }
387    };
388}
389
390impl_pipeline_for_tuple!(P1);
391impl_pipeline_for_tuple!(P1, P2);
392impl_pipeline_for_tuple!(P1, P2, P3);
393impl_pipeline_for_tuple!(P1, P2, P3, P4);
394impl_pipeline_for_tuple!(P1, P2, P3, P4, P5);
395impl_pipeline_for_tuple!(P1, P2, P3, P4, P5, P6);
396impl_pipeline_for_tuple!(P1, P2, P3, P4, P5, P6, P7);
397impl_pipeline_for_tuple!(P1, P2, P3, P4, P5, P6, P7, P8);
398
399/// An identity pipeline that simply returns the current input.
400/// This acts as a neutral element in pipeline composition.
401pub async fn identity<I: Clone>(input: Input<I>) -> I {
402    (*input).clone()
403}
404
405/// An identity pipeline that returns the current input wrapped in a successful Result.
406/// Useful as a terminal fallback in a choice-based pipeline (tuple).
407pub async fn identity_res<I: Clone>(input: Input<I>) -> PResult<I> {
408    Ok((*input).clone())
409}
410
411#[cfg(test)]
412mod tests {
413    use crate::{Context, Input, Res, ResMut, Shared};
414    #[derive(Debug, Clone)]
415    struct Counter {
416        c: i32,
417    }
418    #[tokio::test]
419    // test context invoke api
420    async fn test_chain_invoke() {
421        let ctx = Context::new(3, Shared::new().insert(Counter { c: 1 }));
422        ctx.invoke(async |x: Input<i32>, mut counter: ResMut<Counter>| {
423            counter.c += 1;
424            *x + 1
425        })
426        .await
427        .invoke(async |x: Input<i32>, counter: Res<Counter>| *x + counter.c)
428        .await
429        .invoke(async |x: Input<i32>| assert_eq!(*x, 6))
430        .await;
431    }
432}