Skip to main content

fission_shell/async_host/
native.rs

1use fission_core::{
2    ActionEnvelope, BoxFuture, CapabilityCtx, CapabilityType, JobCtx, JobRef, JobSpec,
3    OperationCapability, ResourceExecutionContext, ServiceCtx, ServiceRunner, ServiceSpec,
4    ServiceType,
5};
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::{mpsc, Arc};
9use tokio::runtime::{Builder as TokioRuntimeBuilder, Runtime as TokioRuntime};
10
11pub type WakeFn = Arc<dyn Fn() + Send + Sync>;
12
13#[derive(Clone, Debug)]
14pub enum AsyncMessage {
15    JobOk {
16        job_name: String,
17        req_id: u64,
18        payload: Vec<u8>,
19        on_ok: Option<ActionEnvelope>,
20        resource: Option<ResourceExecutionContext>,
21    },
22    JobErr {
23        job_name: String,
24        req_id: u64,
25        payload: Option<Vec<u8>>,
26        on_err: Option<ActionEnvelope>,
27        message: Option<String>,
28        resource: Option<ResourceExecutionContext>,
29    },
30    ServiceStarted {
31        service_name: String,
32        slot_key: String,
33        instance_id: u64,
34        resource: Option<ResourceExecutionContext>,
35    },
36    ServiceStartFailed {
37        service_name: String,
38        slot_key: String,
39        instance_id: u64,
40        payload: Option<Vec<u8>>,
41        message: Option<String>,
42        resource: Option<ResourceExecutionContext>,
43    },
44    ServiceEvent {
45        service_name: String,
46        slot_key: String,
47        instance_id: u64,
48        payload: Vec<u8>,
49        resource: Option<ResourceExecutionContext>,
50    },
51    ServiceStopped {
52        service_name: String,
53        slot_key: String,
54        instance_id: u64,
55        resource: Option<ResourceExecutionContext>,
56    },
57    ServiceCommandOk {
58        service_name: String,
59        slot_key: String,
60        instance_id: u64,
61        req_id: u64,
62        payload: Option<Vec<u8>>,
63        on_ok: Option<ActionEnvelope>,
64        resource: Option<ResourceExecutionContext>,
65    },
66    ServiceCommandErr {
67        service_name: String,
68        slot_key: String,
69        instance_id: u64,
70        req_id: u64,
71        payload: Option<Vec<u8>>,
72        on_err: Option<ActionEnvelope>,
73        message: Option<String>,
74        resource: Option<ResourceExecutionContext>,
75    },
76    CapabilityOk {
77        capability_name: String,
78        req_id: u64,
79        payload: Vec<u8>,
80        on_ok: Option<ActionEnvelope>,
81        resource: Option<ResourceExecutionContext>,
82    },
83    CapabilityErr {
84        capability_name: String,
85        req_id: u64,
86        payload: Option<Vec<u8>>,
87        on_err: Option<ActionEnvelope>,
88        message: Option<String>,
89        resource: Option<ResourceExecutionContext>,
90    },
91}
92
93#[derive(Clone)]
94pub enum ServiceControlMessage {
95    Command {
96        req_id: u64,
97        payload: Vec<u8>,
98        on_ok: Option<ActionEnvelope>,
99        on_err: Option<ActionEnvelope>,
100    },
101    Stop,
102}
103
104#[derive(Clone)]
105pub struct RunningServiceHandle {
106    pub instance_id: u64,
107    pub control_tx: mpsc::Sender<ServiceControlMessage>,
108}
109
110#[derive(Clone)]
111struct JobLaunch {
112    req_id: u64,
113    payload: Vec<u8>,
114    on_ok: Option<ActionEnvelope>,
115    on_err: Option<ActionEnvelope>,
116    resource: Option<ResourceExecutionContext>,
117    tx: mpsc::Sender<AsyncMessage>,
118    wake: WakeFn,
119}
120
121#[derive(Clone)]
122struct CapabilityLaunch {
123    req_id: u64,
124    payload: Vec<u8>,
125    on_ok: Option<ActionEnvelope>,
126    on_err: Option<ActionEnvelope>,
127    resource: Option<ResourceExecutionContext>,
128    tx: mpsc::Sender<AsyncMessage>,
129    wake: WakeFn,
130}
131
132#[derive(Clone)]
133struct ServiceLaunch {
134    service_name: String,
135    slot_key: String,
136    instance_id: u64,
137    config: Vec<u8>,
138    resource: Option<ResourceExecutionContext>,
139    tx: mpsc::Sender<AsyncMessage>,
140    wake: WakeFn,
141}
142
143type JobHandler = dyn Fn(JobLaunch) + Send + Sync;
144type ServiceSpawner = dyn Fn(ServiceLaunch) -> RunningServiceHandle + Send + Sync;
145type CapabilitySpawner = dyn Fn(CapabilityLaunch) + Send + Sync;
146
147pub struct AsyncRegistry {
148    jobs: HashMap<String, Arc<JobHandler>>,
149    services: HashMap<String, Arc<ServiceSpawner>>,
150    operations: HashMap<String, Arc<CapabilitySpawner>>,
151}
152
153impl Default for AsyncRegistry {
154    fn default() -> Self {
155        Self {
156            jobs: HashMap::new(),
157            services: HashMap::new(),
158            operations: HashMap::new(),
159        }
160    }
161}
162
163impl AsyncRegistry {
164    pub fn new() -> Self {
165        Self::default()
166    }
167
168    pub fn register_operation_capability<C, F, Fut>(
169        &mut self,
170        capability: CapabilityType<C>,
171        handler: F,
172    ) where
173        C: OperationCapability,
174        F: Fn(C::Request, CapabilityCtx) -> Fut + Send + Sync + 'static,
175        Fut: Future<Output = Result<C::Ok, C::Err>> + Send + 'static,
176    {
177        let handler = Arc::new(handler);
178        self.operations.insert(
179            capability.name.to_string(),
180            Arc::new(move |launch: CapabilityLaunch| {
181                let handler = handler.clone();
182                let name = capability.name.to_string();
183                std::thread::spawn(move || {
184                    let runtime = match new_job_runtime() {
185                        Ok(runtime) => runtime,
186                        Err(err) => {
187                            let _ = launch.tx.send(AsyncMessage::CapabilityErr {
188                                capability_name: name,
189                                req_id: launch.req_id,
190                                payload: None,
191                                on_err: launch.on_err,
192                                message: Some(err),
193                                resource: launch.resource,
194                            });
195                            (launch.wake)();
196                            return;
197                        }
198                    };
199
200                    let request = match serde_json::from_slice::<C::Request>(&launch.payload) {
201                        Ok(request) => request,
202                        Err(err) => {
203                            let _ = launch.tx.send(AsyncMessage::CapabilityErr {
204                                capability_name: name,
205                                req_id: launch.req_id,
206                                payload: None,
207                                on_err: launch.on_err,
208                                message: Some(err.to_string()),
209                                resource: launch.resource,
210                            });
211                            (launch.wake)();
212                            return;
213                        }
214                    };
215
216                    match runtime.block_on(handler(
217                        request,
218                        CapabilityCtx {
219                            req_id: launch.req_id,
220                        },
221                    )) {
222                        Ok(ok) => match serde_json::to_vec(&ok) {
223                            Ok(payload) => {
224                                let _ = launch.tx.send(AsyncMessage::CapabilityOk {
225                                    capability_name: name,
226                                    req_id: launch.req_id,
227                                    payload,
228                                    on_ok: launch.on_ok,
229                                    resource: launch.resource,
230                                });
231                            }
232                            Err(err) => {
233                                let _ = launch.tx.send(AsyncMessage::CapabilityErr {
234                                    capability_name: name,
235                                    req_id: launch.req_id,
236                                    payload: None,
237                                    on_err: launch.on_err,
238                                    message: Some(err.to_string()),
239                                    resource: launch.resource,
240                                });
241                            }
242                        },
243                        Err(err) => {
244                            let (payload, message) = serde_json::to_vec(&err)
245                                .ok()
246                                .map(|payload| (Some(payload), None))
247                                .unwrap_or_else(|| {
248                                    (None, Some("capability error serialization failed".into()))
249                                });
250                            let _ = launch.tx.send(AsyncMessage::CapabilityErr {
251                                capability_name: name,
252                                req_id: launch.req_id,
253                                payload,
254                                on_err: launch.on_err,
255                                message,
256                                resource: launch.resource,
257                            });
258                        }
259                    }
260
261                    (launch.wake)();
262                });
263            }),
264        );
265    }
266
267    pub fn register_job<J, F, Fut>(&mut self, job: JobRef<J>, handler: F)
268    where
269        J: JobSpec,
270        F: Fn(J::Request, JobCtx) -> Fut + Send + Sync + 'static,
271        Fut: Future<Output = Result<J::Ok, J::Err>> + Send + 'static,
272    {
273        let handler = Arc::new(handler);
274        self.jobs.insert(
275            job.name.to_string(),
276            Arc::new(move |launch: JobLaunch| {
277                let handler = handler.clone();
278                std::thread::spawn(move || {
279                    let runtime = match new_job_runtime() {
280                        Ok(runtime) => runtime,
281                        Err(err) => {
282                            let _ = launch.tx.send(AsyncMessage::JobErr {
283                                job_name: J::NAME.to_string(),
284                                req_id: launch.req_id,
285                                payload: None,
286                                on_err: launch.on_err,
287                                message: Some(err),
288                                resource: launch.resource,
289                            });
290                            (launch.wake)();
291                            return;
292                        }
293                    };
294                    let request = match serde_json::from_slice::<J::Request>(&launch.payload) {
295                        Ok(request) => request,
296                        Err(err) => {
297                            let _ = launch.tx.send(AsyncMessage::JobErr {
298                                job_name: J::NAME.to_string(),
299                                req_id: launch.req_id,
300                                payload: None,
301                                on_err: launch.on_err,
302                                message: Some(err.to_string()),
303                                resource: launch.resource,
304                            });
305                            (launch.wake)();
306                            return;
307                        }
308                    };
309
310                    match runtime.block_on(handler(
311                        request,
312                        JobCtx {
313                            req_id: launch.req_id,
314                        },
315                    )) {
316                        Ok(ok) => match serde_json::to_vec(&ok) {
317                            Ok(payload) => {
318                                let _ = launch.tx.send(AsyncMessage::JobOk {
319                                    job_name: J::NAME.to_string(),
320                                    req_id: launch.req_id,
321                                    payload,
322                                    on_ok: launch.on_ok,
323                                    resource: launch.resource,
324                                });
325                            }
326                            Err(err) => {
327                                let _ = launch.tx.send(AsyncMessage::JobErr {
328                                    job_name: J::NAME.to_string(),
329                                    req_id: launch.req_id,
330                                    payload: None,
331                                    on_err: launch.on_err,
332                                    message: Some(err.to_string()),
333                                    resource: launch.resource,
334                                });
335                            }
336                        },
337                        Err(err) => {
338                            let (payload, message) = serde_json::to_vec(&err)
339                                .ok()
340                                .map(|payload| (Some(payload), None))
341                                .unwrap_or_else(|| {
342                                    (None, Some("job error serialization failed".into()))
343                                });
344                            let _ = launch.tx.send(AsyncMessage::JobErr {
345                                job_name: J::NAME.to_string(),
346                                req_id: launch.req_id,
347                                payload,
348                                on_err: launch.on_err,
349                                message,
350                                resource: launch.resource,
351                            });
352                        }
353                    }
354
355                    (launch.wake)();
356                });
357            }),
358        );
359    }
360
361    pub fn register_service<S, F, Fut>(&mut self, service: ServiceType<S>, starter: F)
362    where
363        S: ServiceSpec + 'static,
364        F: Fn(S::Config, ServiceCtx<S>) -> Fut + Send + Sync + 'static,
365        Fut: Future<Output = Result<Box<dyn ServiceRunner<S>>, S::StartErr>> + Send + 'static,
366    {
367        let starter = Arc::new(starter);
368        self.services.insert(
369            service.name.to_string(),
370            Arc::new(move |launch: ServiceLaunch| {
371                let (control_tx, control_rx) = mpsc::channel();
372                let starter = starter.clone();
373                let tx = launch.tx.clone();
374                let wake = launch.wake.clone();
375                let service_name = launch.service_name.clone();
376                let slot_key = launch.slot_key.clone();
377                let resource = launch.resource.clone();
378                let instance_id = launch.instance_id;
379                let config_bytes = launch.config.clone();
380
381                std::thread::spawn(move || {
382                    let runtime = match new_service_runtime() {
383                        Ok(runtime) => runtime,
384                        Err(err) => {
385                            let _ = tx.send(AsyncMessage::ServiceStartFailed {
386                                service_name,
387                                slot_key,
388                                instance_id,
389                                payload: None,
390                                message: Some(err),
391                                resource,
392                            });
393                            wake();
394                            return;
395                        }
396                    };
397                    let tx_for_emit = tx.clone();
398                    let wake_for_emit = wake.clone();
399                    let service_name_for_emit = service_name.clone();
400                    let slot_key_for_emit = slot_key.clone();
401                    let resource_for_emit = resource.clone();
402                    let emit = Arc::new(move |payload: Vec<u8>| -> BoxFuture<Result<(), String>> {
403                        let tx = tx_for_emit.clone();
404                        let wake = wake_for_emit.clone();
405                        let service_name = service_name_for_emit.clone();
406                        let slot_key = slot_key_for_emit.clone();
407                        let resource = resource_for_emit.clone();
408                        Box::pin(async move {
409                            tx.send(AsyncMessage::ServiceEvent {
410                                service_name,
411                                slot_key,
412                                instance_id,
413                                payload,
414                                resource,
415                            })
416                            .map_err(|err| err.to_string())?;
417                            wake();
418                            Ok(())
419                        })
420                    });
421
422                    let ctx = ServiceCtx::<S>::new_runtime(
423                        service_name.clone(),
424                        slot_key.clone(),
425                        instance_id,
426                        emit,
427                    );
428
429                    let config = match serde_json::from_slice::<S::Config>(&config_bytes) {
430                        Ok(config) => config,
431                        Err(err) => {
432                            let _ = tx.send(AsyncMessage::ServiceStartFailed {
433                                service_name,
434                                slot_key,
435                                instance_id,
436                                payload: None,
437                                message: Some(err.to_string()),
438                                resource,
439                            });
440                            wake();
441                            return;
442                        }
443                    };
444
445                    let mut runner = match runtime.block_on(starter(config, ctx.clone())) {
446                        Ok(runner) => {
447                            let _ = tx.send(AsyncMessage::ServiceStarted {
448                                service_name: service_name.clone(),
449                                slot_key: slot_key.clone(),
450                                instance_id,
451                                resource: resource.clone(),
452                            });
453                            wake();
454                            runner
455                        }
456                        Err(err) => {
457                            let payload = serde_json::to_vec(&err).ok();
458                            let _ = tx.send(AsyncMessage::ServiceStartFailed {
459                                service_name,
460                                slot_key,
461                                instance_id,
462                                payload,
463                                message: None,
464                                resource,
465                            });
466                            wake();
467                            return;
468                        }
469                    };
470
471                    while let Ok(message) = control_rx.recv() {
472                        match message {
473                            ServiceControlMessage::Command {
474                                req_id,
475                                payload,
476                                on_ok,
477                                on_err,
478                            } => {
479                                let command = match serde_json::from_slice::<S::Command>(&payload) {
480                                    Ok(command) => command,
481                                    Err(err) => {
482                                        let _ = tx.send(AsyncMessage::ServiceCommandErr {
483                                            service_name: service_name.clone(),
484                                            slot_key: slot_key.clone(),
485                                            instance_id,
486                                            req_id,
487                                            payload: None,
488                                            on_err,
489                                            message: Some(err.to_string()),
490                                            resource: resource.clone(),
491                                        });
492                                        wake();
493                                        continue;
494                                    }
495                                };
496
497                                match runtime.block_on(runner.on_command(command, ctx.clone())) {
498                                    Ok(ok) => {
499                                        let payload = serde_json::to_vec(&ok).ok();
500                                        let _ = tx.send(AsyncMessage::ServiceCommandOk {
501                                            service_name: service_name.clone(),
502                                            slot_key: slot_key.clone(),
503                                            instance_id,
504                                            req_id,
505                                            payload,
506                                            on_ok,
507                                            resource: resource.clone(),
508                                        });
509                                    }
510                                    Err(err) => {
511                                        let (payload, message) = serde_json::to_vec(&err)
512                                            .ok()
513                                            .map(|payload| (Some(payload), None))
514                                            .unwrap_or_else(|| {
515                                                (
516                                                    None,
517                                                    Some(
518                                                        "service command error serialization failed"
519                                                            .into(),
520                                                    ),
521                                                )
522                                            });
523                                        let _ = tx.send(AsyncMessage::ServiceCommandErr {
524                                            service_name: service_name.clone(),
525                                            slot_key: slot_key.clone(),
526                                            instance_id,
527                                            req_id,
528                                            payload,
529                                            on_err,
530                                            message,
531                                            resource: resource.clone(),
532                                        });
533                                    }
534                                }
535                                wake();
536                            }
537                            ServiceControlMessage::Stop => {
538                                runtime.block_on(runner.on_stop(ctx.clone()));
539                                let _ = tx.send(AsyncMessage::ServiceStopped {
540                                    service_name: service_name.clone(),
541                                    slot_key: slot_key.clone(),
542                                    instance_id,
543                                    resource: resource.clone(),
544                                });
545                                wake();
546                                return;
547                            }
548                        }
549                    }
550
551                    runtime.block_on(runner.on_stop(ctx));
552                    let _ = tx.send(AsyncMessage::ServiceStopped {
553                        service_name,
554                        slot_key,
555                        instance_id,
556                        resource,
557                    });
558                    wake();
559                });
560
561                RunningServiceHandle {
562                    instance_id: launch.instance_id,
563                    control_tx,
564                }
565            }),
566        );
567    }
568
569    pub fn spawn_job(
570        &self,
571        job_name: &str,
572        req_id: u64,
573        payload: Vec<u8>,
574        on_ok: Option<ActionEnvelope>,
575        on_err: Option<ActionEnvelope>,
576        resource: Option<ResourceExecutionContext>,
577        tx: &mpsc::Sender<AsyncMessage>,
578        wake: WakeFn,
579    ) -> bool {
580        let Some(handler) = self.jobs.get(job_name) else {
581            return false;
582        };
583        handler(JobLaunch {
584            req_id,
585            payload,
586            on_ok,
587            on_err,
588            resource,
589            tx: tx.clone(),
590            wake,
591        });
592        true
593    }
594
595    pub fn spawn_capability(
596        &self,
597        capability_name: &str,
598        req_id: u64,
599        payload: Vec<u8>,
600        on_ok: Option<ActionEnvelope>,
601        on_err: Option<ActionEnvelope>,
602        resource: Option<ResourceExecutionContext>,
603        tx: &mpsc::Sender<AsyncMessage>,
604        wake: WakeFn,
605    ) -> bool {
606        let Some(handler) = self.operations.get(capability_name) else {
607            return false;
608        };
609        handler(CapabilityLaunch {
610            req_id,
611            payload,
612            on_ok,
613            on_err,
614            resource,
615            tx: tx.clone(),
616            wake,
617        });
618        true
619    }
620
621    pub fn spawn_service(
622        &self,
623        service_name: &str,
624        slot_key: &str,
625        instance_id: u64,
626        config: Vec<u8>,
627        resource: Option<ResourceExecutionContext>,
628        tx: &mpsc::Sender<AsyncMessage>,
629        wake: WakeFn,
630    ) -> Option<RunningServiceHandle> {
631        let spawner = self.services.get(service_name)?;
632        Some(spawner(ServiceLaunch {
633            service_name: service_name.to_string(),
634            slot_key: slot_key.to_string(),
635            instance_id,
636            config,
637            resource,
638            tx: tx.clone(),
639            wake,
640        }))
641    }
642}
643
644fn new_job_runtime() -> Result<TokioRuntime, String> {
645    TokioRuntimeBuilder::new_current_thread()
646        .enable_all()
647        .build()
648        .map_err(|err| err.to_string())
649}
650
651fn new_service_runtime() -> Result<TokioRuntime, String> {
652    TokioRuntimeBuilder::new_multi_thread()
653        .worker_threads(1)
654        .enable_all()
655        .build()
656        .map_err(|err| err.to_string())
657}