1use std::collections::BTreeMap;
5use std::ops::{Deref, DerefMut};
6use std::sync::{Arc, Mutex};
7
8use super::{AsyncEngineContext, AsyncEngineContextProvider, Data};
9use crate::engine::AsyncEngineController;
10use async_trait::async_trait;
11
12use super::registry::Registry;
13
14pub struct Context<T: Data> {
15 current: T,
16 controller: Arc<Controller>, registry: Registry,
18 stages: Vec<String>,
19 metadata: BTreeMap<String, String>,
20}
21
22impl<T: Send + Sync + 'static> Context<T> {
23 pub fn new(current: T) -> Self {
25 Context {
26 current,
27 controller: Arc::new(Controller::default()),
28 registry: Registry::new(),
29 stages: Vec::new(),
30 metadata: BTreeMap::new(),
31 }
32 }
33
34 pub fn rejoin<U: Send + Sync + 'static>(current: T, context: Context<U>) -> Self {
35 Context {
36 current,
37 controller: context.controller,
38 registry: context.registry,
39 stages: context.stages,
40 metadata: context.metadata,
41 }
42 }
43
44 pub fn with_controller(current: T, controller: Controller) -> Self {
45 Context {
46 current,
47 controller: Arc::new(controller),
48 registry: Registry::new(),
49 stages: Vec::new(),
50 metadata: BTreeMap::new(),
51 }
52 }
53
54 #[deprecated(
55 since = "1.1.2",
56 note = "Use `Context::with_id_and_metadata` instead; pass `Default::default()` \
57 when you have no metadata to propagate. `with_id` will be removed once \
58 all call sites have been migrated."
59 )]
60 pub fn with_id(current: T, id: String) -> Self {
61 Context {
62 current,
63 controller: Arc::new(Controller::new(id)),
64 registry: Registry::new(),
65 stages: Vec::new(),
66 metadata: BTreeMap::new(),
67 }
68 }
69
70 pub fn with_id_and_metadata(
71 current: T,
72 id: String,
73 metadata: BTreeMap<String, String>,
74 ) -> Self {
75 Context {
76 current,
77 controller: Arc::new(Controller::new(id)),
78 registry: Registry::new(),
79 stages: Vec::new(),
80 metadata,
81 }
82 }
83
84 pub fn id(&self) -> &str {
86 self.controller.id()
87 }
88
89 pub fn content(&self) -> &T {
91 &self.current
92 }
93
94 pub fn controller(&self) -> &Controller {
95 &self.controller
96 }
97
98 pub fn metadata(&self) -> &BTreeMap<String, String> {
99 &self.metadata
100 }
101
102 pub fn metadata_mut(&mut self) -> &mut BTreeMap<String, String> {
103 &mut self.metadata
104 }
105
106 pub fn set_metadata(&mut self, metadata: BTreeMap<String, String>) {
107 self.metadata = metadata;
108 }
109
110 pub fn insert_metadata<K: Into<String>, V: Into<String>>(&mut self, key: K, value: V) {
111 self.metadata.insert(key.into(), value.into());
112 }
113
114 pub fn insert<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
116 self.registry.insert_shared(key, value);
117 }
118
119 pub fn insert_unique<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
121 self.registry.insert_unique(key, value);
122 }
123
124 pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
126 self.registry.get_shared(key)
127 }
128
129 pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
131 self.registry.clone_unique(key)
132 }
133
134 pub fn take_unique<V: Send + Sync + 'static>(&mut self, key: &str) -> Result<V, String> {
136 self.registry.take_unique(key)
137 }
138
139 pub fn transfer<U: Send + Sync + 'static>(self, new_current: U) -> (T, Context<U>) {
142 (
143 self.current,
144 Context {
145 current: new_current,
146 controller: self.controller,
147 registry: self.registry,
148 stages: self.stages,
149 metadata: self.metadata,
150 },
151 )
152 }
153
154 pub fn into_parts(self) -> (T, Context<()>) {
156 self.transfer(())
157 }
158
159 pub fn stages(&self) -> &Vec<String> {
160 &self.stages
161 }
162
163 pub fn add_stage(&mut self, stage: &str) {
164 self.stages.push(stage.to_string());
165 }
166
167 pub fn map<U: Send + Sync + 'static, F>(self, f: F) -> Context<U>
169 where
170 F: FnOnce(T) -> U,
171 {
172 let (current, temp_context) = self.transfer(());
174
175 let new_current = f(current);
177
178 temp_context.transfer(new_current).1
180 }
181
182 pub fn try_map<U, F, E>(self, f: F) -> Result<Context<U>, E>
183 where
184 F: FnOnce(T) -> Result<U, E>,
185 U: Send + Sync + 'static,
186 {
187 let (current, temp_context) = self.transfer(());
189
190 let new_current = f(current)?;
192
193 Ok(temp_context.transfer(new_current).1)
195 }
196}
197
198impl<T: Data> std::fmt::Debug for Context<T> {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 f.debug_struct("Context")
201 .field("id", &self.controller.id())
202 .finish()
203 }
204}
205
206impl<T: Data> Deref for Context<T> {
208 type Target = T;
209
210 fn deref(&self) -> &Self::Target {
211 &self.current
212 }
213}
214
215impl<T: Data> DerefMut for Context<T> {
217 fn deref_mut(&mut self) -> &mut Self::Target {
218 &mut self.current
219 }
220}
221
222impl<T> From<T> for Context<T>
224where
225 T: Send + Sync + 'static,
226{
227 fn from(current: T) -> Self {
228 Context::new(current)
229 }
230}
231
232pub trait IntoContext<U: Data> {
234 fn into_context(self) -> Context<U>;
235}
236
237impl<T, U> IntoContext<U> for Context<T>
239where
240 T: Send + Sync + 'static + Into<U>,
241 U: Send + Sync + 'static,
242{
243 fn into_context(self) -> Context<U> {
244 self.map(|current| current.into())
245 }
246}
247
248impl<T: Data> AsyncEngineContextProvider for Context<T> {
249 fn context(&self) -> Arc<dyn AsyncEngineContext> {
250 self.controller.clone()
251 }
252}
253
254#[derive(Debug, Clone)]
255pub struct StreamContext {
256 controller: Arc<Controller>,
257 registry: Arc<Registry>,
258 stages: Vec<String>,
259 metadata: BTreeMap<String, String>,
260}
261
262impl StreamContext {
263 fn new(
264 controller: Arc<Controller>,
265 registry: Registry,
266 metadata: BTreeMap<String, String>,
267 ) -> Self {
268 StreamContext {
269 controller,
270 registry: Arc::new(registry),
271 stages: Vec::new(),
272 metadata,
273 }
274 }
275
276 pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
278 self.registry.get_shared(key)
279 }
280
281 pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
283 self.registry.clone_unique(key)
284 }
285
286 pub fn registry(&self) -> Arc<Registry> {
287 self.registry.clone()
288 }
289
290 pub fn stages(&self) -> &Vec<String> {
291 &self.stages
292 }
293
294 pub fn add_stage(&mut self, stage: &str) {
295 self.stages.push(stage.to_string());
296 }
297
298 pub fn metadata(&self) -> &BTreeMap<String, String> {
299 &self.metadata
300 }
301}
302
303#[async_trait]
304impl AsyncEngineContext for StreamContext {
305 fn id(&self) -> &str {
306 self.controller.id()
307 }
308
309 fn stop(&self) {
310 self.controller.stop();
311 }
312
313 fn kill(&self) {
314 self.controller.kill();
315 }
316
317 fn stop_generating(&self) {
318 self.controller.stop_generating();
319 }
320
321 fn is_stopped(&self) -> bool {
322 self.controller.is_stopped()
323 }
324
325 fn is_killed(&self) -> bool {
326 self.controller.is_killed()
327 }
328
329 async fn stopped(&self) {
330 self.controller.stopped().await
331 }
332
333 async fn killed(&self) {
334 self.controller.killed().await
335 }
336
337 fn link_child(&self, child: Arc<dyn AsyncEngineContext>) {
338 self.controller.link_child(child);
339 }
340}
341
342impl AsyncEngineContextProvider for StreamContext {
343 fn context(&self) -> Arc<dyn AsyncEngineContext> {
344 self.controller.clone()
345 }
346}
347
348impl<T: Send + Sync + 'static> From<Context<T>> for StreamContext {
349 fn from(value: Context<T>) -> Self {
350 StreamContext::new(value.controller, value.registry, value.metadata)
351 }
352}
353
354use tokio::sync::watch::{Receiver, Sender, channel};
357
358#[derive(Debug, Eq, PartialEq)]
359enum State {
360 Live,
361 Stopped,
362 Killed,
363}
364
365#[derive(Debug)]
367pub struct Controller {
368 id: String,
369 tx: Sender<State>,
370 rx: Receiver<State>,
371 child_context: Mutex<Vec<Arc<dyn AsyncEngineContext>>>,
372}
373
374impl Controller {
375 pub fn new(id: String) -> Self {
376 let (tx, rx) = channel(State::Live);
377 Self {
378 id,
379 tx,
380 rx,
381 child_context: Mutex::new(Vec::new()),
382 }
383 }
384
385 pub fn id(&self) -> &str {
386 &self.id
387 }
388}
389
390impl Default for Controller {
391 fn default() -> Self {
392 Self::new(uuid::Uuid::new_v4().to_string())
393 }
394}
395
396impl AsyncEngineController for Controller {}
397
398#[async_trait]
399impl AsyncEngineContext for Controller {
400 fn id(&self) -> &str {
401 &self.id
402 }
403
404 fn is_stopped(&self) -> bool {
405 *self.rx.borrow() != State::Live
406 }
407
408 fn is_killed(&self) -> bool {
409 *self.rx.borrow() == State::Killed
410 }
411
412 async fn stopped(&self) {
413 let mut rx = self.rx.clone();
414 loop {
415 if *rx.borrow_and_update() != State::Live || rx.changed().await.is_err() {
416 return;
417 }
418 }
419 }
420
421 async fn killed(&self) {
422 let mut rx = self.rx.clone();
423 loop {
424 if *rx.borrow_and_update() == State::Killed || rx.changed().await.is_err() {
425 return;
426 }
427 }
428 }
429
430 fn stop_generating(&self) {
431 let children = self
433 .child_context
434 .lock()
435 .expect("Failed to lock child context")
436 .iter()
437 .cloned()
438 .collect::<Vec<_>>();
439 for child in children {
440 child.stop_generating();
441 }
442
443 let _ = self.tx.send(State::Stopped);
444 }
445
446 fn stop(&self) {
447 let children = self
449 .child_context
450 .lock()
451 .expect("Failed to lock child context")
452 .iter()
453 .cloned()
454 .collect::<Vec<_>>();
455 for child in children {
456 child.stop();
457 }
458
459 let _ = self.tx.send(State::Stopped);
460 }
461
462 fn kill(&self) {
463 let children = self
465 .child_context
466 .lock()
467 .expect("Failed to lock child context")
468 .iter()
469 .cloned()
470 .collect::<Vec<_>>();
471 for child in children {
472 child.kill();
473 }
474
475 let _ = self.tx.send(State::Killed);
476 }
477
478 fn link_child(&self, child: Arc<dyn AsyncEngineContext>) {
479 self.child_context
480 .lock()
481 .expect("Failed to lock child context")
482 .push(child);
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489
490 #[derive(Debug, Clone)]
491 struct Input {
492 value: String,
493 }
494
495 #[derive(Debug, Clone)]
496 struct Processed {
497 length: usize,
498 }
499
500 #[derive(Debug, Clone)]
501 struct Final {
502 message: String,
503 }
504
505 impl From<Input> for Processed {
506 fn from(input: Input) -> Self {
507 Processed {
508 length: input.value.len(),
509 }
510 }
511 }
512
513 impl From<Processed> for Final {
514 fn from(processed: Processed) -> Self {
515 Final {
516 message: format!("Processed length: {}", processed.length),
517 }
518 }
519 }
520
521 #[test]
522 fn test_insert_and_get() {
523 let mut ctx = Context::new(Input {
524 value: "Hello".to_string(),
525 });
526
527 ctx.insert("key1", 42);
528 ctx.insert("key2", "some data".to_string());
529
530 assert_eq!(*ctx.get::<i32>("key1").unwrap(), 42);
531 assert_eq!(*ctx.get::<String>("key2").unwrap(), "some data");
532 assert!(ctx.get::<f64>("key1").is_err()); }
534
535 #[test]
536 fn test_metadata_preserved_across_transfers() {
537 let mut ctx = Context::new(Input {
538 value: "Hello".to_string(),
539 });
540 ctx.insert_metadata("tenant", "alpha");
541
542 let (_, transferred) = ctx.transfer(Processed { length: 5 });
543 assert_eq!(
544 transferred.metadata().get("tenant").map(String::as_str),
545 Some("alpha")
546 );
547 }
548
549 #[test]
550 fn test_with_id_and_metadata_constructor() {
551 let metadata = BTreeMap::from([("tenant".to_string(), "alpha".to_string())]);
552 let ctx = Context::with_id_and_metadata(
553 Input {
554 value: "Hello".to_string(),
555 },
556 "request-123".to_string(),
557 metadata,
558 );
559
560 assert_eq!(ctx.id(), "request-123");
561 assert_eq!(
562 ctx.metadata().get("tenant").map(String::as_str),
563 Some("alpha")
564 );
565 }
566
567 #[test]
568 fn test_metadata_preserved_across_rejoin() {
569 let mut ctx = Context::new(Input {
570 value: "Hello".to_string(),
571 });
572 ctx.insert_metadata("tenant", "alpha");
573
574 let (input, empty_ctx) = ctx.into_parts();
575 let rejoined = Context::rejoin(input, empty_ctx);
576 assert_eq!(
577 rejoined.metadata().get("tenant").map(String::as_str),
578 Some("alpha")
579 );
580 }
581
582 #[test]
583 fn test_metadata_preserved_in_stream_context() {
584 let mut ctx = Context::new(Input {
585 value: "Hello".to_string(),
586 });
587 ctx.insert_metadata("tenant", "alpha");
588
589 let stream_ctx = StreamContext::from(ctx);
590 assert_eq!(
591 stream_ctx.metadata().get("tenant").map(String::as_str),
592 Some("alpha")
593 );
594 }
595
596 #[test]
597 fn test_transfer() {
598 let ctx = Context::new(Input {
599 value: "Hello".to_string(),
600 });
601
602 let (input, ctx) = ctx.transfer(Processed { length: 5 });
603
604 assert_eq!(input.value, "Hello");
605 assert_eq!(ctx.length, 5);
606 }
607
608 #[test]
609 fn test_map() {
610 let ctx = Context::new(Input {
611 value: "Hello".to_string(),
612 });
613
614 let ctx: Context<Processed> = ctx.map(|input| input.into());
615 let ctx: Context<Final> = ctx.map(|processed| processed.into());
616
617 assert_eq!(ctx.current.message, "Processed length: 5");
618 }
619
620 #[test]
621 fn test_into_context() {
622 let ctx = Context::new(Input {
623 value: "Hello".to_string(),
624 });
625
626 let ctx: Context<Processed> = ctx.into_context();
627 let ctx: Context<Final> = ctx.into_context();
628
629 assert_eq!(ctx.current.message, "Processed length: 5");
630 }
631}