1use std::ops::{Deref, DerefMut};
25use std::sync::Arc;
26
27use super::{AsyncEngineContext, AsyncEngineContextProvider, Data};
28use crate::engine::AsyncEngineController;
29use async_trait::async_trait;
30
31use super::registry::Registry;
32
33pub struct Context<T: Data> {
34 current: T,
35 controller: Arc<Controller>, registry: Registry,
37 stages: Vec<String>,
38}
39
40impl<T: Send + Sync + 'static> Context<T> {
41 pub fn new(current: T) -> Self {
43 Context {
44 current,
45 controller: Arc::new(Controller::default()),
46 registry: Registry::new(),
47 stages: Vec::new(),
48 }
49 }
50
51 pub fn rejoin<U: Send + Sync + 'static>(current: T, context: Context<U>) -> Self {
52 Context {
53 current,
54 controller: context.controller,
55 registry: context.registry,
56 stages: context.stages,
57 }
58 }
59
60 pub fn with_controller(current: T, controller: Controller) -> Self {
61 Context {
62 current,
63 controller: Arc::new(controller),
64 registry: Registry::new(),
65 stages: Vec::new(),
66 }
67 }
68
69 pub fn with_id(current: T, id: String) -> Self {
70 Context {
71 current,
72 controller: Arc::new(Controller::new(id)),
73 registry: Registry::new(),
74 stages: Vec::new(),
75 }
76 }
77
78 pub fn id(&self) -> &str {
80 self.controller.id()
81 }
82
83 pub fn content(&self) -> &T {
85 &self.current
86 }
87
88 pub fn controller(&self) -> &Controller {
89 &self.controller
90 }
91
92 pub fn insert<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
94 self.registry.insert_shared(key, value);
95 }
96
97 pub fn insert_unique<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
99 self.registry.insert_unique(key, value);
100 }
101
102 pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
104 self.registry.get_shared(key)
105 }
106
107 pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
109 self.registry.clone_unique(key)
110 }
111
112 pub fn take_unique<V: Send + Sync + 'static>(&mut self, key: &str) -> Result<V, String> {
114 self.registry.take_unique(key)
115 }
116
117 pub fn transfer<U: Send + Sync + 'static>(self, new_current: U) -> (T, Context<U>) {
120 (
121 self.current,
122 Context {
123 current: new_current,
124 controller: self.controller,
125 registry: self.registry,
126 stages: self.stages,
127 },
128 )
129 }
130
131 pub fn into_parts(self) -> (T, Context<()>) {
133 self.transfer(())
134 }
135
136 pub fn stages(&self) -> &Vec<String> {
137 &self.stages
138 }
139
140 pub fn add_stage(&mut self, stage: &str) {
141 self.stages.push(stage.to_string());
142 }
143
144 pub fn map<U: Send + Sync + 'static, F>(self, f: F) -> Context<U>
146 where
147 F: FnOnce(T) -> U,
148 {
149 let (current, temp_context) = self.transfer(());
151
152 let new_current = f(current);
154
155 temp_context.transfer(new_current).1
157 }
158
159 pub fn try_map<U, F, E>(self, f: F) -> Result<Context<U>, E>
160 where
161 F: FnOnce(T) -> Result<U, E>,
162 U: Send + Sync + 'static,
163 {
164 let (current, temp_context) = self.transfer(());
166
167 let new_current = f(current)?;
169
170 Ok(temp_context.transfer(new_current).1)
172 }
173}
174
175impl<T: Data> std::fmt::Debug for Context<T> {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 f.debug_struct("Context")
178 .field("id", &self.controller.id())
179 .finish()
180 }
181}
182
183impl<T: Data> Deref for Context<T> {
185 type Target = T;
186
187 fn deref(&self) -> &Self::Target {
188 &self.current
189 }
190}
191
192impl<T: Data> DerefMut for Context<T> {
194 fn deref_mut(&mut self) -> &mut Self::Target {
195 &mut self.current
196 }
197}
198
199impl<T> From<T> for Context<T>
201where
202 T: Send + Sync + 'static,
203{
204 fn from(current: T) -> Self {
205 Context::new(current)
206 }
207}
208
209pub trait IntoContext<U: Data> {
211 fn into_context(self) -> Context<U>;
212}
213
214impl<T, U> IntoContext<U> for Context<T>
216where
217 T: Send + Sync + 'static + Into<U>,
218 U: Send + Sync + 'static,
219{
220 fn into_context(self) -> Context<U> {
221 self.map(|current| current.into())
222 }
223}
224
225impl<T: Data> AsyncEngineContextProvider for Context<T> {
226 fn context(&self) -> Arc<dyn AsyncEngineContext> {
227 self.controller.clone()
228 }
229}
230
231#[derive(Debug, Clone)]
232pub struct StreamContext {
233 controller: Arc<Controller>,
234 registry: Arc<Registry>,
235 stages: Vec<String>,
236}
237
238impl StreamContext {
239 fn new(controller: Arc<Controller>, registry: Registry) -> Self {
240 StreamContext {
241 controller,
242 registry: Arc::new(registry),
243 stages: Vec::new(),
244 }
245 }
246
247 pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
249 self.registry.get_shared(key)
250 }
251
252 pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
254 self.registry.clone_unique(key)
255 }
256
257 pub fn registry(&self) -> Arc<Registry> {
258 self.registry.clone()
259 }
260
261 pub fn stages(&self) -> &Vec<String> {
262 &self.stages
263 }
264
265 pub fn add_stage(&mut self, stage: &str) {
266 self.stages.push(stage.to_string());
267 }
268}
269
270#[async_trait]
271impl AsyncEngineContext for StreamContext {
272 fn id(&self) -> &str {
273 self.controller.id()
274 }
275
276 fn stop(&self) {
277 self.controller.stop();
278 }
279
280 fn kill(&self) {
281 self.controller.kill();
282 }
283
284 fn stop_generating(&self) {
285 self.controller.stop_generating();
286 }
287
288 fn is_stopped(&self) -> bool {
289 self.controller.is_stopped()
290 }
291
292 fn is_killed(&self) -> bool {
293 self.controller.is_killed()
294 }
295
296 async fn stopped(&self) {
297 self.controller.stopped().await
298 }
299
300 async fn killed(&self) {
301 self.controller.killed().await
302 }
303}
304
305impl AsyncEngineContextProvider for StreamContext {
306 fn context(&self) -> Arc<dyn AsyncEngineContext> {
307 self.controller.clone()
308 }
309}
310
311impl<T: Send + Sync + 'static> From<Context<T>> for StreamContext {
312 fn from(value: Context<T>) -> Self {
313 StreamContext::new(value.controller, value.registry)
314 }
315}
316
317use tokio::sync::watch::{channel, Receiver, Sender};
320
321#[derive(Debug, Eq, PartialEq)]
322enum State {
323 Live,
324 Stopped,
325 Killed,
326}
327
328#[derive(Debug)]
330pub struct Controller {
331 id: String,
332 tx: Sender<State>,
333 rx: Receiver<State>,
334}
335
336impl Controller {
337 pub fn new(id: String) -> Self {
338 let (tx, rx) = channel(State::Live);
339 Self { id, tx, rx }
340 }
341
342 pub fn id(&self) -> &str {
343 &self.id
344 }
345}
346
347impl Default for Controller {
348 fn default() -> Self {
349 Self::new(uuid::Uuid::new_v4().to_string())
350 }
351}
352
353impl AsyncEngineController for Controller {}
354
355#[async_trait]
356impl AsyncEngineContext for Controller {
357 fn id(&self) -> &str {
358 &self.id
359 }
360
361 fn is_stopped(&self) -> bool {
362 *self.rx.borrow() != State::Live
363 }
364
365 fn is_killed(&self) -> bool {
366 *self.rx.borrow() == State::Killed
367 }
368
369 async fn stopped(&self) {
370 let mut rx = self.rx.clone();
371 if *rx.borrow_and_update() != State::Live {
372 return;
373 }
374 let _ = rx.changed().await;
375 }
376
377 async fn killed(&self) {
378 let mut rx = self.rx.clone();
379 if *rx.borrow_and_update() == State::Killed {
380 return;
381 }
382 let _ = rx.changed().await;
383 }
384
385 fn stop_generating(&self) {
386 self.stop();
387 }
388
389 fn stop(&self) {
390 let _ = self.tx.send(State::Stopped);
391 }
392
393 fn kill(&self) {
394 let _ = self.tx.send(State::Killed);
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 #[derive(Debug, Clone)]
403 struct Input {
404 value: String,
405 }
406
407 #[derive(Debug, Clone)]
408 struct Processed {
409 length: usize,
410 }
411
412 #[derive(Debug, Clone)]
413 struct Final {
414 message: String,
415 }
416
417 impl From<Input> for Processed {
418 fn from(input: Input) -> Self {
419 Processed {
420 length: input.value.len(),
421 }
422 }
423 }
424
425 impl From<Processed> for Final {
426 fn from(processed: Processed) -> Self {
427 Final {
428 message: format!("Processed length: {}", processed.length),
429 }
430 }
431 }
432
433 #[test]
434 fn test_insert_and_get() {
435 let mut ctx = Context::new(Input {
436 value: "Hello".to_string(),
437 });
438
439 ctx.insert("key1", 42);
440 ctx.insert("key2", "some data".to_string());
441
442 assert_eq!(*ctx.get::<i32>("key1").unwrap(), 42);
443 assert_eq!(*ctx.get::<String>("key2").unwrap(), "some data");
444 assert!(ctx.get::<f64>("key1").is_err()); }
446
447 #[test]
448 fn test_transfer() {
449 let ctx = Context::new(Input {
450 value: "Hello".to_string(),
451 });
452
453 let (input, ctx) = ctx.transfer(Processed { length: 5 });
454
455 assert_eq!(input.value, "Hello");
456 assert_eq!(ctx.length, 5);
457 }
458
459 #[test]
460 fn test_map() {
461 let ctx = Context::new(Input {
462 value: "Hello".to_string(),
463 });
464
465 let ctx: Context<Processed> = ctx.map(|input| input.into());
466 let ctx: Context<Final> = ctx.map(|processed| processed.into());
467
468 assert_eq!(ctx.current.message, "Processed length: 5");
469 }
470
471 #[test]
472 fn test_into_context() {
473 let ctx = Context::new(Input {
474 value: "Hello".to_string(),
475 });
476
477 let ctx: Context<Processed> = ctx.into_context();
478 let ctx: Context<Final> = ctx.into_context();
479
480 assert_eq!(ctx.current.message, "Processed length: 5");
481 }
482}