1use std::ops::{Deref, DerefMut};
25use std::sync::Arc;
26
27use super::{AsyncEngineContext, AsyncEngineContextProvider, Data};
28use crate::engine::AsyncEngineController;
29use async_trait::async_trait;
30
31use super::registry::Registry;
32
33pub struct Context<T: Data> {
34 current: T,
35 controller: Arc<Controller>, registry: Registry,
37 stages: Vec<String>,
38}
39
40impl<T: Send + Sync + 'static> Context<T> {
41 pub fn new(current: T) -> Self {
43 Context {
44 current,
45 controller: Arc::new(Controller::default()),
46 registry: Registry::new(),
47 stages: Vec::new(),
48 }
49 }
50
51 pub fn with_controller(current: T, controller: Controller) -> Self {
52 Context {
53 current,
54 controller: Arc::new(controller),
55 registry: Registry::new(),
56 stages: Vec::new(),
57 }
58 }
59
60 pub fn with_id(current: T, id: String) -> Self {
61 Context {
62 current,
63 controller: Arc::new(Controller::new(id)),
64 registry: Registry::new(),
65 stages: Vec::new(),
66 }
67 }
68
69 pub fn id(&self) -> &str {
70 self.controller.id()
71 }
72
73 pub fn controller(&self) -> &Controller {
74 &self.controller
75 }
76
77 pub fn insert<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
79 self.registry.insert_shared(key, value);
80 }
81
82 pub fn insert_unique<K: ToString, U: Send + Sync + 'static>(&mut self, key: K, value: U) {
84 self.registry.insert_unique(key, value);
85 }
86
87 pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
89 self.registry.get_shared(key)
90 }
91
92 pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
94 self.registry.clone_unique(key)
95 }
96
97 pub fn take_unique<V: Send + Sync + 'static>(&mut self, key: &str) -> Result<V, String> {
99 self.registry.take_unique(key)
100 }
101
102 pub fn transfer<U: Send + Sync + 'static>(self, new_current: U) -> (T, Context<U>) {
105 (
106 self.current,
107 Context {
108 current: new_current,
109 controller: self.controller,
110 registry: self.registry,
111 stages: self.stages,
112 },
113 )
114 }
115
116 pub fn into_parts(self) -> (T, Context<()>) {
118 self.transfer(())
119 }
120
121 pub fn stages(&self) -> &Vec<String> {
122 &self.stages
123 }
124
125 pub fn add_stage(&mut self, stage: &str) {
126 self.stages.push(stage.to_string());
127 }
128
129 pub fn map<U: Send + Sync + 'static, F>(self, f: F) -> Context<U>
131 where
132 F: FnOnce(T) -> U,
133 {
134 let (current, temp_context) = self.transfer(());
136
137 let new_current = f(current);
139
140 temp_context.transfer(new_current).1
142 }
143
144 pub fn try_map<U, F, E>(self, f: F) -> Result<Context<U>, E>
145 where
146 F: FnOnce(T) -> Result<U, E>,
147 U: Send + Sync + 'static,
148 {
149 let (current, temp_context) = self.transfer(());
151
152 let new_current = f(current)?;
154
155 Ok(temp_context.transfer(new_current).1)
157 }
158}
159
160impl<T: Data> std::fmt::Debug for Context<T> {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 f.debug_struct("Context")
163 .field("id", &self.controller.id())
164 .finish()
165 }
166}
167
168impl<T: Data> Deref for Context<T> {
170 type Target = T;
171
172 fn deref(&self) -> &Self::Target {
173 &self.current
174 }
175}
176
177impl<T: Data> DerefMut for Context<T> {
179 fn deref_mut(&mut self) -> &mut Self::Target {
180 &mut self.current
181 }
182}
183
184impl<T> From<T> for Context<T>
186where
187 T: Send + Sync + 'static,
188{
189 fn from(current: T) -> Self {
190 Context::new(current)
191 }
192}
193
194pub trait IntoContext<U: Data> {
196 fn into_context(self) -> Context<U>;
197}
198
199impl<T, U> IntoContext<U> for Context<T>
201where
202 T: Send + Sync + 'static + Into<U>,
203 U: Send + Sync + 'static,
204{
205 fn into_context(self) -> Context<U> {
206 self.map(|current| current.into())
207 }
208}
209
210impl<T: Data> AsyncEngineContextProvider for Context<T> {
211 fn context(&self) -> Arc<dyn AsyncEngineContext> {
212 self.controller.clone()
213 }
214}
215
216#[derive(Debug, Clone)]
217pub struct StreamContext {
218 controller: Arc<Controller>,
219 registry: Arc<Registry>,
220 stages: Vec<String>,
221}
222
223impl StreamContext {
224 fn new(controller: Arc<Controller>, registry: Registry) -> Self {
225 StreamContext {
226 controller,
227 registry: Arc::new(registry),
228 stages: Vec::new(),
229 }
230 }
231
232 pub fn get<V: Send + Sync + 'static>(&self, key: &str) -> Result<Arc<V>, String> {
234 self.registry.get_shared(key)
235 }
236
237 pub fn clone_unique<V: Clone + Send + Sync + 'static>(&self, key: &str) -> Result<V, String> {
239 self.registry.clone_unique(key)
240 }
241
242 pub fn registry(&self) -> Arc<Registry> {
243 self.registry.clone()
244 }
245
246 pub fn stages(&self) -> &Vec<String> {
247 &self.stages
248 }
249
250 pub fn add_stage(&mut self, stage: &str) {
251 self.stages.push(stage.to_string());
252 }
253}
254
255#[async_trait]
256impl AsyncEngineContext for StreamContext {
257 fn id(&self) -> &str {
258 self.controller.id()
259 }
260
261 fn stop(&self) {
262 self.controller.stop();
263 }
264
265 fn kill(&self) {
266 self.controller.kill();
267 }
268
269 fn stop_generating(&self) {
270 self.controller.stop_generating();
271 }
272
273 fn is_stopped(&self) -> bool {
274 self.controller.is_stopped()
275 }
276
277 fn is_killed(&self) -> bool {
278 self.controller.is_killed()
279 }
280
281 async fn stopped(&self) {
282 self.controller.stopped().await
283 }
284
285 async fn killed(&self) {
286 self.controller.killed().await
287 }
288}
289
290impl AsyncEngineContextProvider for StreamContext {
291 fn context(&self) -> Arc<dyn AsyncEngineContext> {
292 self.controller.clone()
293 }
294}
295
296impl<T: Send + Sync + 'static> From<Context<T>> for StreamContext {
297 fn from(value: Context<T>) -> Self {
298 StreamContext::new(value.controller, value.registry)
299 }
300}
301
302use tokio::sync::watch::{channel, Receiver, Sender};
305
306#[derive(Debug, Eq, PartialEq)]
307enum State {
308 Live,
309 Stopped,
310 Killed,
311}
312
313#[derive(Debug)]
315pub struct Controller {
316 id: String,
317 tx: Sender<State>,
318 rx: Receiver<State>,
319}
320
321impl Controller {
322 pub fn new(id: String) -> Self {
323 let (tx, rx) = channel(State::Live);
324 Self { id, tx, rx }
325 }
326
327 pub fn id(&self) -> &str {
328 &self.id
329 }
330}
331
332impl Default for Controller {
333 fn default() -> Self {
334 Self::new(uuid::Uuid::new_v4().to_string())
335 }
336}
337
338impl AsyncEngineController for Controller {}
339
340#[async_trait]
341impl AsyncEngineContext for Controller {
342 fn id(&self) -> &str {
343 &self.id
344 }
345
346 fn is_stopped(&self) -> bool {
347 *self.rx.borrow() != State::Live
348 }
349
350 fn is_killed(&self) -> bool {
351 *self.rx.borrow() == State::Killed
352 }
353
354 async fn stopped(&self) {
355 let mut rx = self.rx.clone();
356 if *rx.borrow_and_update() != State::Live {
357 return;
358 }
359 let _ = rx.changed().await;
360 }
361
362 async fn killed(&self) {
363 let mut rx = self.rx.clone();
364 if *rx.borrow_and_update() == State::Killed {
365 return;
366 }
367 let _ = rx.changed().await;
368 }
369
370 fn stop_generating(&self) {
371 self.stop();
372 }
373
374 fn stop(&self) {
375 let _ = self.tx.send(State::Stopped);
376 }
377
378 fn kill(&self) {
379 let _ = self.tx.send(State::Killed);
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::*;
386
387 #[derive(Debug, Clone)]
388 struct Input {
389 value: String,
390 }
391
392 #[derive(Debug, Clone)]
393 struct Processed {
394 length: usize,
395 }
396
397 #[derive(Debug, Clone)]
398 struct Final {
399 message: String,
400 }
401
402 impl From<Input> for Processed {
403 fn from(input: Input) -> Self {
404 Processed {
405 length: input.value.len(),
406 }
407 }
408 }
409
410 impl From<Processed> for Final {
411 fn from(processed: Processed) -> Self {
412 Final {
413 message: format!("Processed length: {}", processed.length),
414 }
415 }
416 }
417
418 #[test]
419 fn test_insert_and_get() {
420 let mut ctx = Context::new(Input {
421 value: "Hello".to_string(),
422 });
423
424 ctx.insert("key1", 42);
425 ctx.insert("key2", "some data".to_string());
426
427 assert_eq!(*ctx.get::<i32>("key1").unwrap(), 42);
428 assert_eq!(*ctx.get::<String>("key2").unwrap(), "some data");
429 assert!(ctx.get::<f64>("key1").is_err()); }
431
432 #[test]
433 fn test_transfer() {
434 let ctx = Context::new(Input {
435 value: "Hello".to_string(),
436 });
437
438 let (input, ctx) = ctx.transfer(Processed { length: 5 });
439
440 assert_eq!(input.value, "Hello");
441 assert_eq!(ctx.length, 5);
442 }
443
444 #[test]
445 fn test_map() {
446 let ctx = Context::new(Input {
447 value: "Hello".to_string(),
448 });
449
450 let ctx: Context<Processed> = ctx.map(|input| input.into());
451 let ctx: Context<Final> = ctx.map(|processed| processed.into());
452
453 assert_eq!(ctx.current.message, "Processed length: 5");
454 }
455
456 #[test]
457 fn test_into_context() {
458 let ctx = Context::new(Input {
459 value: "Hello".to_string(),
460 });
461
462 let ctx: Context<Processed> = ctx.into_context();
463 let ctx: Context<Final> = ctx.into_context();
464
465 assert_eq!(ctx.current.message, "Processed length: 5");
466 }
467}