1use std::ops::{Deref, DerefMut};
5use std::sync::{Arc, Mutex};
6
7use super::{AsyncEngineContext, AsyncEngineContextProvider, Data};
8use crate::engine::AsyncEngineController;
9use async_trait::async_trait;
10
11use super::registry::Registry;
12
13pub struct Context<T: Data> {
14 current: T,
15 controller: Arc<Controller>, registry: Registry,
17 stages: Vec<String>,
18}
19
20impl<T: Send + Sync + 'static> Context<T> {
21 pub fn new(current: T) -> Self {
23 Context {
24 current,
25 controller: Arc::new(Controller::default()),
26 registry: Registry::new(),
27 stages: Vec::new(),
28 }
29 }
30
31 pub fn rejoin<U: Send + Sync + 'static>(current: T, context: Context<U>) -> Self {
32 Context {
33 current,
34 controller: context.controller,
35 registry: context.registry,
36 stages: context.stages,
37 }
38 }
39
40 pub fn with_controller(current: T, controller: Controller) -> Self {
41 Context {
42 current,
43 controller: Arc::new(controller),
44 registry: Registry::new(),
45 stages: Vec::new(),
46 }
47 }
48
49 pub fn with_id(current: T, id: String) -> Self {
50 Context {
51 current,
52 controller: Arc::new(Controller::new(id)),
53 registry: Registry::new(),
54 stages: Vec::new(),
55 }
56 }
57
58 pub fn id(&self) -> &str {
60 self.controller.id()
61 }
62
63 pub fn content(&self) -> &T {
65 &self.current
66 }
67
68 pub fn controller(&self) -> &Controller {
69 &self.controller
70 }
71
72 pub fn insert<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
74 self.registry.insert_shared(key, value);
75 }
76
77 pub fn insert_unique<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
79 self.registry.insert_unique(key, value);
80 }
81
82 pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
84 self.registry.get_shared(key)
85 }
86
87 pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
89 self.registry.clone_unique(key)
90 }
91
92 pub fn take_unique<V: Send + Sync + 'static>(&mut self, key: &str) -> Result<V, String> {
94 self.registry.take_unique(key)
95 }
96
97 pub fn transfer<U: Send + Sync + 'static>(self, new_current: U) -> (T, Context<U>) {
100 (
101 self.current,
102 Context {
103 current: new_current,
104 controller: self.controller,
105 registry: self.registry,
106 stages: self.stages,
107 },
108 )
109 }
110
111 pub fn into_parts(self) -> (T, Context<()>) {
113 self.transfer(())
114 }
115
116 pub fn stages(&self) -> &Vec<String> {
117 &self.stages
118 }
119
120 pub fn add_stage(&mut self, stage: &str) {
121 self.stages.push(stage.to_string());
122 }
123
124 pub fn map<U: Send + Sync + 'static, F>(self, f: F) -> Context<U>
126 where
127 F: FnOnce(T) -> U,
128 {
129 let (current, temp_context) = self.transfer(());
131
132 let new_current = f(current);
134
135 temp_context.transfer(new_current).1
137 }
138
139 pub fn try_map<U, F, E>(self, f: F) -> Result<Context<U>, E>
140 where
141 F: FnOnce(T) -> Result<U, E>,
142 U: Send + Sync + 'static,
143 {
144 let (current, temp_context) = self.transfer(());
146
147 let new_current = f(current)?;
149
150 Ok(temp_context.transfer(new_current).1)
152 }
153}
154
155impl<T: Data> std::fmt::Debug for Context<T> {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 f.debug_struct("Context")
158 .field("id", &self.controller.id())
159 .finish()
160 }
161}
162
163impl<T: Data> Deref for Context<T> {
165 type Target = T;
166
167 fn deref(&self) -> &Self::Target {
168 &self.current
169 }
170}
171
172impl<T: Data> DerefMut for Context<T> {
174 fn deref_mut(&mut self) -> &mut Self::Target {
175 &mut self.current
176 }
177}
178
179impl<T> From<T> for Context<T>
181where
182 T: Send + Sync + 'static,
183{
184 fn from(current: T) -> Self {
185 Context::new(current)
186 }
187}
188
189pub trait IntoContext<U: Data> {
191 fn into_context(self) -> Context<U>;
192}
193
194impl<T, U> IntoContext<U> for Context<T>
196where
197 T: Send + Sync + 'static + Into<U>,
198 U: Send + Sync + 'static,
199{
200 fn into_context(self) -> Context<U> {
201 self.map(|current| current.into())
202 }
203}
204
205impl<T: Data> AsyncEngineContextProvider for Context<T> {
206 fn context(&self) -> Arc<dyn AsyncEngineContext> {
207 self.controller.clone()
208 }
209}
210
211#[derive(Debug, Clone)]
212pub struct StreamContext {
213 controller: Arc<Controller>,
214 registry: Arc<Registry>,
215 stages: Vec<String>,
216}
217
218impl StreamContext {
219 fn new(controller: Arc<Controller>, registry: Registry) -> Self {
220 StreamContext {
221 controller,
222 registry: Arc::new(registry),
223 stages: Vec::new(),
224 }
225 }
226
227 pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
229 self.registry.get_shared(key)
230 }
231
232 pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
234 self.registry.clone_unique(key)
235 }
236
237 pub fn registry(&self) -> Arc<Registry> {
238 self.registry.clone()
239 }
240
241 pub fn stages(&self) -> &Vec<String> {
242 &self.stages
243 }
244
245 pub fn add_stage(&mut self, stage: &str) {
246 self.stages.push(stage.to_string());
247 }
248}
249
250#[async_trait]
251impl AsyncEngineContext for StreamContext {
252 fn id(&self) -> &str {
253 self.controller.id()
254 }
255
256 fn stop(&self) {
257 self.controller.stop();
258 }
259
260 fn kill(&self) {
261 self.controller.kill();
262 }
263
264 fn stop_generating(&self) {
265 self.controller.stop_generating();
266 }
267
268 fn is_stopped(&self) -> bool {
269 self.controller.is_stopped()
270 }
271
272 fn is_killed(&self) -> bool {
273 self.controller.is_killed()
274 }
275
276 async fn stopped(&self) {
277 self.controller.stopped().await
278 }
279
280 async fn killed(&self) {
281 self.controller.killed().await
282 }
283
284 fn link_child(&self, child: Arc<dyn AsyncEngineContext>) {
285 self.controller.link_child(child);
286 }
287}
288
289impl AsyncEngineContextProvider for StreamContext {
290 fn context(&self) -> Arc<dyn AsyncEngineContext> {
291 self.controller.clone()
292 }
293}
294
295impl<T: Send + Sync + 'static> From<Context<T>> for StreamContext {
296 fn from(value: Context<T>) -> Self {
297 StreamContext::new(value.controller, value.registry)
298 }
299}
300
301use tokio::sync::watch::{Receiver, Sender, channel};
304
305#[derive(Debug, Eq, PartialEq)]
306enum State {
307 Live,
308 Stopped,
309 Killed,
310}
311
312#[derive(Debug)]
314pub struct Controller {
315 id: String,
316 tx: Sender<State>,
317 rx: Receiver<State>,
318 child_context: Mutex<Vec<Arc<dyn AsyncEngineContext>>>,
319}
320
321impl Controller {
322 pub fn new(id: String) -> Self {
323 let (tx, rx) = channel(State::Live);
324 Self {
325 id,
326 tx,
327 rx,
328 child_context: Mutex::new(Vec::new()),
329 }
330 }
331
332 pub fn id(&self) -> &str {
333 &self.id
334 }
335}
336
337impl Default for Controller {
338 fn default() -> Self {
339 Self::new(uuid::Uuid::new_v4().to_string())
340 }
341}
342
343impl AsyncEngineController for Controller {}
344
345#[async_trait]
346impl AsyncEngineContext for Controller {
347 fn id(&self) -> &str {
348 &self.id
349 }
350
351 fn is_stopped(&self) -> bool {
352 *self.rx.borrow() != State::Live
353 }
354
355 fn is_killed(&self) -> bool {
356 *self.rx.borrow() == State::Killed
357 }
358
359 async fn stopped(&self) {
360 let mut rx = self.rx.clone();
361 loop {
362 if *rx.borrow_and_update() != State::Live || rx.changed().await.is_err() {
363 return;
364 }
365 }
366 }
367
368 async fn killed(&self) {
369 let mut rx = self.rx.clone();
370 loop {
371 if *rx.borrow_and_update() == State::Killed || rx.changed().await.is_err() {
372 return;
373 }
374 }
375 }
376
377 fn stop_generating(&self) {
378 let children = self
380 .child_context
381 .lock()
382 .expect("Failed to lock child context")
383 .iter()
384 .cloned()
385 .collect::<Vec<_>>();
386 for child in children {
387 child.stop_generating();
388 }
389
390 let _ = self.tx.send(State::Stopped);
391 }
392
393 fn stop(&self) {
394 let children = self
396 .child_context
397 .lock()
398 .expect("Failed to lock child context")
399 .iter()
400 .cloned()
401 .collect::<Vec<_>>();
402 for child in children {
403 child.stop();
404 }
405
406 let _ = self.tx.send(State::Stopped);
407 }
408
409 fn kill(&self) {
410 let children = self
412 .child_context
413 .lock()
414 .expect("Failed to lock child context")
415 .iter()
416 .cloned()
417 .collect::<Vec<_>>();
418 for child in children {
419 child.kill();
420 }
421
422 let _ = self.tx.send(State::Killed);
423 }
424
425 fn link_child(&self, child: Arc<dyn AsyncEngineContext>) {
426 self.child_context
427 .lock()
428 .expect("Failed to lock child context")
429 .push(child);
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[derive(Debug, Clone)]
438 struct Input {
439 value: String,
440 }
441
442 #[derive(Debug, Clone)]
443 struct Processed {
444 length: usize,
445 }
446
447 #[derive(Debug, Clone)]
448 struct Final {
449 message: String,
450 }
451
452 impl From<Input> for Processed {
453 fn from(input: Input) -> Self {
454 Processed {
455 length: input.value.len(),
456 }
457 }
458 }
459
460 impl From<Processed> for Final {
461 fn from(processed: Processed) -> Self {
462 Final {
463 message: format!("Processed length: {}", processed.length),
464 }
465 }
466 }
467
468 #[test]
469 fn test_insert_and_get() {
470 let mut ctx = Context::new(Input {
471 value: "Hello".to_string(),
472 });
473
474 ctx.insert("key1", 42);
475 ctx.insert("key2", "some data".to_string());
476
477 assert_eq!(*ctx.get::<i32>("key1").unwrap(), 42);
478 assert_eq!(*ctx.get::<String>("key2").unwrap(), "some data");
479 assert!(ctx.get::<f64>("key1").is_err()); }
481
482 #[test]
483 fn test_transfer() {
484 let ctx = Context::new(Input {
485 value: "Hello".to_string(),
486 });
487
488 let (input, ctx) = ctx.transfer(Processed { length: 5 });
489
490 assert_eq!(input.value, "Hello");
491 assert_eq!(ctx.length, 5);
492 }
493
494 #[test]
495 fn test_map() {
496 let ctx = Context::new(Input {
497 value: "Hello".to_string(),
498 });
499
500 let ctx: Context<Processed> = ctx.map(|input| input.into());
501 let ctx: Context<Final> = ctx.map(|processed| processed.into());
502
503 assert_eq!(ctx.current.message, "Processed length: 5");
504 }
505
506 #[test]
507 fn test_into_context() {
508 let ctx = Context::new(Input {
509 value: "Hello".to_string(),
510 });
511
512 let ctx: Context<Processed> = ctx.into_context();
513 let ctx: Context<Final> = ctx.into_context();
514
515 assert_eq!(ctx.current.message, "Processed length: 5");
516 }
517}