1use std::{
2 any::{self, TypeId},
3 collections::HashMap,
4 future::Future,
5 ops::{Deref, DerefMut},
6 sync::Arc,
7};
8
9use tokio::sync;
10
11pub use pipeit_derive::node;
12
13pub extern crate pipeit_derive;
14
15pub mod concurrency;
16pub mod ext;
17pub mod handler;
18pub mod tag;
19pub mod sink;
20pub mod dep;
21#[cfg(feature = "tower")]
22pub mod service;
23#[derive(Clone)]
27pub(crate) struct DendencyMap(
28 pub(crate) Arc<HashMap<TypeId, Arc<sync::RwLock<Box<dyn any::Any + Send + Sync>>>>>,
29);
30
31#[derive(Default, Clone)]
35pub struct Shared {
36 pub(crate) inner: HashMap<TypeId, Arc<sync::RwLock<Box<dyn any::Any + Send + Sync>>>>,
37}
38
39impl Shared {
40 pub fn new() -> Self {
42 Self::default()
43 }
44
45 pub fn insert<T: Send + Sync + 'static>(mut self, resource: T) -> Self {
48 self.inner.insert(
49 TypeId::of::<T>(),
50 Arc::new(sync::RwLock::new(
51 Box::new(resource) as Box<dyn any::Any + Send + Sync>
52 )),
53 );
54 self
55 }
56}
57
58pub struct Res<T>(
61 sync::OwnedRwLockReadGuard<Box<dyn any::Any + Send + Sync>>,
62 std::marker::PhantomData<T>,
63);
64
65pub struct ResMut<T>(
68 sync::OwnedRwLockWriteGuard<Box<dyn any::Any + Send + Sync>>,
69 std::marker::PhantomData<T>,
70);
71
72impl<T: 'static> Deref for Res<T> {
73 type Target = T;
74 fn deref(&self) -> &Self::Target {
75 (**self.0)
78 .downcast_ref::<T>()
79 .expect("Resource type mismatch during Res deref")
80 }
81}
82
83impl<T: 'static> Deref for ResMut<T> {
84 type Target = T;
85 fn deref(&self) -> &Self::Target {
86 (**self.0)
87 .downcast_ref::<T>()
88 .expect("Resource type mismatch during ResMut deref")
89 }
90}
91
92impl<T: 'static> DerefMut for ResMut<T> {
93 fn deref_mut(&mut self) -> &mut Self::Target {
94 (**self.0)
96 .downcast_mut::<T>()
97 .expect("Resource type mismatch during ResMut deref_mut")
98 }
99}
100
101pub struct Context<Input>
110where
111 Input: ?Sized,
112{
113 shared: DendencyMap,
114 input: Arc<Input>,
115}
116
117impl<I> Context<I> {
118 pub fn new(input: I, shared: Shared) -> Self {
120 Self {
121 shared: DendencyMap(Arc::new(shared.inner)),
122 input: Arc::new(input),
123 }
124 }
125 pub fn empty(input: I) -> Self {
126 Self {
127 shared: DendencyMap(Arc::new(Shared::new().inner)),
128 input: Arc::new(input),
129 }
130 }
131 pub(crate) fn replace<NewInput>(self, input: NewInput) -> Context<NewInput> {
134 Context {
135 shared: self.shared,
136 input: Arc::new(input),
137 }
138 }
139
140 pub(crate) fn input(&self) -> Arc<I> {
142 self.input.clone()
143 }
144
145 pub(crate) fn into_parts(self) -> (Arc<I>, DendencyMap) {
148 (self.input, self.shared)
149 }
150
151 pub(crate) fn from_parts(input: Arc<I>, shared: DendencyMap) -> Self {
153 Self { shared, input }
154 }
155 pub async fn invoke<H, Args, O>(self, handler: H) -> Context<O>
180 where
181 H: handler::Handler<I, O, Args>,
182 I: Clone + Send + Sync + 'static,
183 O: Send + Sync + 'static,
184 Args: Send + Sync + 'static,
185 {
186 let (input, shared) = self.into_parts();
187 let shared_cloned = shared.clone();
188 let new_ctx = Context::from_parts(input, shared_cloned);
190 let output = handler.call(new_ctx).await;
191 Context::from_parts(Arc::new(output), shared)
192 }
193}
194impl<Input> Clone for Context<Input>
195where
196 Input: ?Sized,
197{
198 fn clone(&self) -> Self {
199 Self {
200 shared: self.shared.clone(),
201 input: self.input.clone(),
202 }
203 }
204}
205
206pub trait FromContext<Input>: Send {
209 fn from(ctx: Context<Input>) -> impl Future<Output = Self> + Send;
210}
211
212#[derive(Clone)]
214pub struct Input<T>(pub Arc<T>);
215
216impl<T> Input<T> {
217 pub fn try_unwrap(self) -> Result<T, Arc<T>> {
221 Arc::try_unwrap(self.0)
222 }
223}
224impl<T> Deref for Input<T> {
225 type Target = T;
226 fn deref(&self) -> &Self::Target {
227 &self.0
228 }
229}
230
231impl<I: Send + Sync + 'static> FromContext<I> for Input<I> {
232 fn from(ctx: Context<I>) -> impl Future<Output = Self> + Send {
233 let input = ctx.input.clone();
234 async move { Input(input) }
235 }
236}
237
238impl<I, T> FromContext<I> for Res<T>
239where
240 I: Send + Sync + 'static,
241 T: Send + Sync + 'static,
242{
243 fn from(ctx: Context<I>) -> impl Future<Output = Self> + Send {
244 async move {
245 let shared = ctx.shared.0.clone();
246 let dep = shared
247 .get(&TypeId::of::<T>())
248 .expect("Dependency not found")
249 .clone();
250 let guard = dep.read_owned().await;
251 Res(guard, std::marker::PhantomData)
252 }
253 }
254}
255
256impl<I, T> FromContext<I> for ResMut<T>
257where
258 I: Send + Sync + 'static,
259 T: Send + Sync + 'static,
260{
261 fn from(ctx: Context<I>) -> impl Future<Output = Self> + Send {
262 async move {
263 let shared = ctx.shared.0.clone();
264 let dep = shared
265 .get(&TypeId::of::<T>())
266 .expect("Dependency not found")
267 .clone();
268 let guard = dep.write_owned().await;
269 ResMut(guard, std::marker::PhantomData)
270 }
271 }
272}
273
274pub trait Pipeline<I, O>: Send + Sync + 'static {
277 fn apply(&self, ctx: Context<I>) -> impl Future<Output = O> + Send;
278}
279
280#[derive(Debug, Clone)]
282pub enum PipelineError {
283 Failure { msg: String, expected: String },
284 Fatal { msg: String },
285}
286
287pub type PResult<O, E = PipelineError> = Result<O, E>;
289
290pub struct Cond<F, P, I, O> {
292 predicate: F,
293 next: P,
294 _marker: std::marker::PhantomData<fn(I, O)>,
295}
296
297impl<F, P, I, O> Pipeline<I, PResult<O>> for Cond<F, P, I, O>
298where
299 F: Pipeline<I, bool>,
300 P: Pipeline<I, O>,
301 I: Clone + Send + Sync + 'static,
302 O: Send + 'static,
303 F: Send + Sync + 'static,
304 P: Send + Sync + 'static,
305{
306 fn apply(&self, ctx: Context<I>) -> impl Future<Output = PResult<O>> + Send {
307 async move {
308 let matched = self.predicate.apply(ctx.clone()).await;
309 if matched {
310 Ok(self.next.apply(ctx).await)
311 } else {
312 Err(PipelineError::Failure {
313 msg: "Condition not met".to_string(),
314 expected: "true".to_string(),
315 })
316 }
317 }
318 }
319}
320
321pub fn cond<I, O, F, P, ArgsF, ArgsP>(
347 predicate: F,
348 next: P,
349) -> Cond<crate::ext::Pipe<F, ArgsF>, crate::ext::Pipe<P, ArgsP>, I, O>
350where
351 F: crate::handler::Handler<I, bool, ArgsF>,
352 P: crate::handler::Handler<I, O, ArgsP>,
353 I: Clone + Send + Sync + 'static,
354 O: Send + 'static,
355 ArgsF: Send + Sync + 'static,
356 ArgsP: Send + Sync + 'static,
357{
358 use crate::ext::HandlerExt;
359 Cond {
360 predicate: predicate.pipe(),
361 next: next.pipe(),
362 _marker: std::marker::PhantomData,
363 }
364}
365
366pub struct Alt<P> {
368 pub pipelines: P,
369}
370
371pub fn alt<P>(pipelines: P) -> Alt<P> {
375 Alt { pipelines }
376}
377
378macro_rules! impl_pipeline_for_alt {
379 ($($P:ident),+) => {
380 impl<I, O, $($P),+ > Pipeline<I, PResult<O>> for Alt<($($P,)+)>
381 where
382 I: Clone + Send + Sync + 'static,
383 O: Send + 'static,
384 $( $P: Pipeline<I, PResult<O>> ),+
385 {
386 fn apply(&self, ctx: Context<I>) -> impl Future<Output = PResult<O>> + Send {
387 #[allow(non_snake_case)]
388 let ($($P,)+) = &self.pipelines;
389 async move {
390 $(
391 match $P.apply(ctx.clone()).await {
392 Ok(res) => return Ok(res),
393 Err(PipelineError::Fatal { msg }) => return Err(PipelineError::Fatal { msg }),
394 Err(PipelineError::Failure { .. }) => {}
395 }
396 )*
397 Err(PipelineError::Failure {
398 msg: "All pipeline branches failed".to_string(),
399 expected: "at least one branch success".to_string()
400 })
401 }
402 }
403 }
404 };
405}
406
407impl_pipeline_for_alt!(P1);
408impl_pipeline_for_alt!(P1, P2);
409impl_pipeline_for_alt!(P1, P2, P3);
410impl_pipeline_for_alt!(P1, P2, P3, P4);
411impl_pipeline_for_alt!(P1, P2, P3, P4, P5);
412impl_pipeline_for_alt!(P1, P2, P3, P4, P5, P6);
413impl_pipeline_for_alt!(P1, P2, P3, P4, P5, P6, P7);
414impl_pipeline_for_alt!(P1, P2, P3, P4, P5, P6, P7, P8);
415impl_pipeline_for_alt!(P1, P2, P3, P4, P5, P6, P7, P8, P9);
416
417pub async fn identity<I: Clone>(input: Input<I>) -> I {
420 (*input).clone()
421}
422
423pub async fn identity_res<I: Clone>(input: Input<I>) -> PResult<I> {
426 Ok((*input).clone())
427}
428
429#[cfg(test)]
430mod tests {
431 use crate::{Context, Input, Res, ResMut, Shared};
432 #[derive(Debug, Clone)]
433 struct Counter {
434 c: i32,
435 }
436 #[tokio::test]
437 async fn test_chain_invoke() {
439 let ctx = Context::new(3, Shared::new().insert(Counter { c: 1 }));
440 ctx.invoke(async |x: Input<i32>, mut counter: ResMut<Counter>| {
441 counter.c += 1;
442 *x + 1
443 })
444 .await
445 .invoke(async |x: Input<i32>, counter: Res<Counter>| *x + counter.c)
446 .await
447 .invoke(async |x: Input<i32>| assert_eq!(*x, 6))
448 .await;
449 }
450}