1use std::sync::Arc;
15
16use tokio::net::UnixListener;
17use tokio_stream::wrappers::UnixListenerStream;
18use tonic::{transport::Server, Request, Response, Status};
19
20use self::rpc::instance_server::{Instance, InstanceServer};
21use crate::apis::flame as rpc;
22
23use crate::apis::{CommonData, FlameError, TaskInput, TaskOutput};
24
25const FLAME_EXECUTOR_ID: &str = "FLAME_EXECUTOR_ID";
26
27pub struct ApplicationContext {
28 pub name: String,
29 pub image: Option<String>,
30 pub command: Option<String>,
31}
32
33pub struct SessionContext {
34 pub session_id: String,
35 pub application: ApplicationContext,
36 pub common_data: Option<CommonData>,
37}
38
39pub struct TaskContext {
40 pub task_id: String,
41 pub session_id: String,
42 pub input: Option<TaskInput>,
43}
44
45#[tonic::async_trait]
46pub trait FlameService: Send + Sync + 'static {
47 async fn on_session_enter(&self, _: SessionContext) -> Result<(), FlameError>;
48 async fn on_task_invoke(&self, _: TaskContext) -> Result<Option<TaskOutput>, FlameError>;
49 async fn on_session_leave(&self) -> Result<(), FlameError>;
50}
51
52pub type FlameServicePtr = Arc<dyn FlameService>;
53
54struct ShimService {
55 service: FlameServicePtr,
56}
57
58#[tonic::async_trait]
59impl Instance for ShimService {
60 async fn on_session_enter(
61 &self,
62 req: Request<rpc::SessionContext>,
63 ) -> Result<Response<rpc::Result>, Status> {
64 tracing::debug!("ShimService::on_session_enter");
65
66 let req = req.into_inner();
67 let resp = self
68 .service
69 .on_session_enter(SessionContext::from(req))
70 .await;
71
72 match resp {
73 Ok(_) => Ok(Response::new(rpc::Result {
74 return_code: 0,
75 message: None,
76 })),
77 Err(e) => Ok(Response::new(rpc::Result {
78 return_code: -1,
79 message: Some(e.to_string()),
80 })),
81 }
82 }
83
84 async fn on_task_invoke(
85 &self,
86 req: Request<rpc::TaskContext>,
87 ) -> Result<Response<rpc::TaskResult>, Status> {
88 tracing::debug!("ShimService::on_task_invoke");
89 let req = req.into_inner();
90 let resp = self.service.on_task_invoke(TaskContext::from(req)).await;
91
92 match resp {
93 Ok(data) => Ok(Response::new(rpc::TaskResult {
94 return_code: 0,
95 output: data.map(|d| d.into()),
96 message: None,
97 })),
98 Err(e) => Ok(Response::new(rpc::TaskResult {
99 return_code: -1,
100 output: None,
101 message: Some(e.to_string()),
102 })),
103 }
104 }
105
106 async fn on_session_leave(
107 &self,
108 _: Request<rpc::EmptyRequest>,
109 ) -> Result<Response<rpc::Result>, Status> {
110 tracing::debug!("ShimService::on_session_leave");
111 let resp = self.service.on_session_leave().await;
112
113 match resp {
114 Ok(_) => Ok(Response::new(rpc::Result {
115 return_code: 0,
116 message: None,
117 })),
118 Err(e) => Ok(Response::new(rpc::Result {
119 return_code: -1,
120 message: Some(e.to_string()),
121 })),
122 }
123 }
124}
125
126pub async fn run(service: impl FlameService) -> Result<(), Box<dyn std::error::Error>> {
127 let shim_service = ShimService {
128 service: Arc::new(service),
129 };
130
131 let uds = match std::env::var(FLAME_EXECUTOR_ID) {
132 Ok(executor_id) => UnixListener::bind(format!("/tmp/flame/shim/{executor_id}/fsi.sock"))?,
133 Err(_) => UnixListener::bind("/tmp/flame/shim/fsi.sock")?,
134 };
135
136 let uds_stream = UnixListenerStream::new(uds);
137
138 Server::builder()
139 .add_service(InstanceServer::new(shim_service))
140 .serve_with_incoming(uds_stream)
141 .await?;
142
143 Ok(())
144}
145
146impl From<rpc::ApplicationContext> for ApplicationContext {
147 fn from(ctx: rpc::ApplicationContext) -> Self {
148 Self {
149 name: ctx.name.clone(),
150 image: ctx.image.clone(),
151 command: ctx.command.clone(),
152 }
153 }
154}
155
156impl From<rpc::SessionContext> for SessionContext {
157 fn from(ctx: rpc::SessionContext) -> Self {
158 SessionContext {
159 session_id: ctx.session_id.clone(),
160 application: ctx.application.map(ApplicationContext::from).unwrap(),
161 common_data: ctx.common_data.map(|data| data.into()),
162 }
163 }
164}
165
166impl From<rpc::TaskContext> for TaskContext {
167 fn from(ctx: rpc::TaskContext) -> Self {
168 TaskContext {
169 task_id: ctx.task_id.clone(),
170 session_id: ctx.session_id.clone(),
171 input: ctx.input.map(|data| data.into()),
172 }
173 }
174}