1use std::{
2 any::{self, TypeId},
3 collections::HashMap,
4 future::Future,
5 ops::{Deref, DerefMut},
6 sync::Arc,
7};
8
9use tokio::sync;
10
11pub use pipeit_derive::node;
12
13pub extern crate pipeit_derive;
14
15pub mod cocurrency;
16pub mod ext;
17pub mod handler;
18pub mod tag;
19#[derive(Clone)]
23pub(crate) struct DendencyMap(
24 Arc<HashMap<TypeId, Arc<sync::RwLock<Box<dyn any::Any + Send + Sync>>>>>,
25);
26
27#[derive(Default, Clone)]
31pub struct Shared {
32 inner: HashMap<TypeId, Arc<sync::RwLock<Box<dyn any::Any + Send + Sync>>>>,
33}
34
35impl Shared {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn insert<T: Send + Sync + 'static>(mut self, resource: T) -> Self {
44 self.inner.insert(
45 TypeId::of::<T>(),
46 Arc::new(sync::RwLock::new(
47 Box::new(resource) as Box<dyn any::Any + Send + Sync>
48 )),
49 );
50 self
51 }
52}
53
54pub struct Res<T>(
57 sync::OwnedRwLockReadGuard<Box<dyn any::Any + Send + Sync>>,
58 std::marker::PhantomData<T>,
59);
60
61pub struct ResMut<T>(
64 sync::OwnedRwLockWriteGuard<Box<dyn any::Any + Send + Sync>>,
65 std::marker::PhantomData<T>,
66);
67
68impl<T: 'static> Deref for Res<T> {
69 type Target = T;
70 fn deref(&self) -> &Self::Target {
71 (**self.0)
74 .downcast_ref::<T>()
75 .expect("Resource type mismatch during Res deref")
76 }
77}
78
79impl<T: 'static> Deref for ResMut<T> {
80 type Target = T;
81 fn deref(&self) -> &Self::Target {
82 (**self.0)
83 .downcast_ref::<T>()
84 .expect("Resource type mismatch during ResMut deref")
85 }
86}
87
88impl<T: 'static> DerefMut for ResMut<T> {
89 fn deref_mut(&mut self) -> &mut Self::Target {
90 (**self.0)
92 .downcast_mut::<T>()
93 .expect("Resource type mismatch during ResMut deref_mut")
94 }
95}
96
97pub struct Context<Input>
99where
100 Input: ?Sized,
101{
102 shared: DendencyMap,
103 input: Arc<Input>,
104}
105
106impl<I> Context<I> {
107 pub fn new(input: I, shared: Shared) -> Self {
109 Self {
110 shared: DendencyMap(Arc::new(shared.inner)),
111 input: Arc::new(input),
112 }
113 }
114 pub fn empty(input: I) -> Self {
115 Self {
116 shared: DendencyMap(Arc::new(Shared::new().inner)),
117 input: Arc::new(input),
118 }
119 }
120 pub(crate) fn replace<NewInput>(self, input: NewInput) -> Context<NewInput> {
123 Context {
124 shared: self.shared,
125 input: Arc::new(input),
126 }
127 }
128
129 pub(crate) fn input(&self) -> Arc<I> {
131 self.input.clone()
132 }
133
134 pub(crate) fn into_parts(self) -> (Arc<I>, DendencyMap) {
137 (self.input, self.shared)
138 }
139
140 pub(crate) fn from_parts(input: Arc<I>, shared: DendencyMap) -> Self {
142 Self { shared, input }
143 }
144 pub async fn invoke<H, Args, O>(self, handler: H) -> Context<O>
169 where
170 H: handler::Handler<I, O, Args>,
171 I: Clone + Send + Sync + 'static,
172 O: Send + Sync + 'static,
173 Args: Send + Sync + 'static,
174 {
175 let (input, shared) = self.into_parts();
176 let shared_cloned = shared.clone();
177 let new_ctx = Context {
179 input,
180 shared: shared_cloned,
181 };
182 let output = handler.call(new_ctx).await;
183 Context {
184 shared: shared,
185 input: Arc::new(output),
186 }
187 }
188}
189impl<Input> Clone for Context<Input>
190where
191 Input: ?Sized,
192{
193 fn clone(&self) -> Self {
194 Self {
195 shared: self.shared.clone(),
196 input: self.input.clone(),
197 }
198 }
199}
200
201pub trait FromContext<Input>: Send {
204 fn from(ctx: Context<Input>) -> impl Future<Output = Self> + Send;
205}
206
207#[derive(Clone)]
209pub struct Input<T>(pub Arc<T>);
210
211impl<T> Input<T> {
212 pub fn try_unwrap(self) -> Result<T, Arc<T>> {
216 Arc::try_unwrap(self.0)
217 }
218}
219impl<T> Deref for Input<T> {
220 type Target = T;
221 fn deref(&self) -> &Self::Target {
222 &self.0
223 }
224}
225
226impl<I: Send + Sync + 'static> FromContext<I> for Input<I> {
227 fn from(ctx: Context<I>) -> impl Future<Output = Self> + Send {
228 let input = ctx.input.clone();
229 async move { Input(input) }
230 }
231}
232
233impl<I, T> FromContext<I> for Res<T>
234where
235 I: Send + Sync + 'static,
236 T: Send + Sync + 'static,
237{
238 fn from(ctx: Context<I>) -> impl Future<Output = Self> + Send {
239 async move {
240 let shared = ctx.shared.0.clone();
241 let dep = shared
242 .get(&TypeId::of::<T>())
243 .expect("Dependency not found")
244 .clone();
245 let guard = dep.read_owned().await;
246 Res(guard, std::marker::PhantomData)
247 }
248 }
249}
250
251impl<I, T> FromContext<I> for ResMut<T>
252where
253 I: Send + Sync + 'static,
254 T: Send + Sync + 'static,
255{
256 fn from(ctx: Context<I>) -> impl Future<Output = Self> + Send {
257 async move {
258 let shared = ctx.shared.0.clone();
259 let dep = shared
260 .get(&TypeId::of::<T>())
261 .expect("Dependency not found")
262 .clone();
263 let guard = dep.write_owned().await;
264 ResMut(guard, std::marker::PhantomData)
265 }
266 }
267}
268
269pub trait Pipeline<I, O>: Send + Sync + 'static {
272 fn apply(&self, ctx: Context<I>) -> impl Future<Output = O> + Send;
273}
274
275#[derive(Debug, Clone)]
277pub enum PipelineError {
278 Failure { msg: String, expected: String },
279 Fatal { msg: String },
280}
281
282pub type PResult<O, E = PipelineError> = Result<O, E>;
284
285pub struct Cond<F, P, I, O> {
287 predicate: F,
288 next: P,
289 _marker: std::marker::PhantomData<fn(I, O)>,
290}
291
292impl<F, P, I, O> Pipeline<I, PResult<O>> for Cond<F, P, I, O>
293where
294 F: Pipeline<I, bool>,
295 P: Pipeline<I, O>,
296 I: Clone + Send + Sync + 'static,
297 O: Send + 'static,
298 F: Send + Sync + 'static,
299 P: Send + Sync + 'static,
300{
301 fn apply(&self, ctx: Context<I>) -> impl Future<Output = PResult<O>> + Send {
302 async move {
303 let matched = self.predicate.apply(ctx.clone()).await;
304 if matched {
305 Ok(self.next.apply(ctx).await)
306 } else {
307 Err(PipelineError::Failure {
308 msg: "Condition not met".to_string(),
309 expected: "true".to_string(),
310 })
311 }
312 }
313 }
314}
315
316pub fn cond<I, O, F, P, ArgsF, ArgsP>(
342 predicate: F,
343 next: P,
344) -> Cond<crate::ext::Pipe<F, ArgsF>, crate::ext::Pipe<P, ArgsP>, I, O>
345where
346 F: crate::handler::Handler<I, bool, ArgsF>,
347 P: crate::handler::Handler<I, O, ArgsP>,
348 I: Clone + Send + Sync + 'static,
349 O: Send + 'static,
350 ArgsF: Send + Sync + 'static,
351 ArgsP: Send + Sync + 'static,
352{
353 use crate::ext::HandlerExt;
354 Cond {
355 predicate: predicate.pipe(),
356 next: next.pipe(),
357 _marker: std::marker::PhantomData,
358 }
359}
360
361pub struct Choice<T>(pub T);
363
364macro_rules! impl_pipeline_for_tuple {
365 ($($P:ident),+) => {
366 impl<I, O, $($P),+ > Pipeline<I, PResult<O>> for ($($P,)+)
367 where
368 I: Clone + Send + Sync + 'static,
369 O: Send + 'static,
370 $( $P: Pipeline<I, PResult<O>> ),+
371 {
372 fn apply(&self, ctx: Context<I>) -> impl Future<Output = PResult<O>> + Send {
373 #[allow(non_snake_case)]
374 let ($($P,)+) = self;
375 async move {
376 $(
377 match $P.apply(ctx.clone()).await {
378 Ok(res) => return Ok(res),
379 Err(PipelineError::Fatal { msg }) => return Err(PipelineError::Fatal { msg }),
380 Err(PipelineError::Failure { .. }) => {}
381 }
382 )*
383 Err(PipelineError::Fatal { msg: "All pipeline branches failed".to_string() })
384 }
385 }
386 }
387 };
388}
389
390impl_pipeline_for_tuple!(P1);
391impl_pipeline_for_tuple!(P1, P2);
392impl_pipeline_for_tuple!(P1, P2, P3);
393impl_pipeline_for_tuple!(P1, P2, P3, P4);
394impl_pipeline_for_tuple!(P1, P2, P3, P4, P5);
395impl_pipeline_for_tuple!(P1, P2, P3, P4, P5, P6);
396impl_pipeline_for_tuple!(P1, P2, P3, P4, P5, P6, P7);
397impl_pipeline_for_tuple!(P1, P2, P3, P4, P5, P6, P7, P8);
398
399pub async fn identity<I: Clone>(input: Input<I>) -> I {
402 (*input).clone()
403}
404
405pub async fn identity_res<I: Clone>(input: Input<I>) -> PResult<I> {
408 Ok((*input).clone())
409}
410
411#[cfg(test)]
412mod tests {
413 use crate::{Context, Input, Res, ResMut, Shared};
414 #[derive(Debug, Clone)]
415 struct Counter {
416 c: i32,
417 }
418 #[tokio::test]
419 async fn test_chain_invoke() {
421 let ctx = Context::new(3, Shared::new().insert(Counter { c: 1 }));
422 ctx.invoke(async |x: Input<i32>, mut counter: ResMut<Counter>| {
423 counter.c += 1;
424 *x + 1
425 })
426 .await
427 .invoke(async |x: Input<i32>, counter: Res<Counter>| *x + counter.c)
428 .await
429 .invoke(async |x: Input<i32>| assert_eq!(*x, 6))
430 .await;
431 }
432}