flame_rs/service/
mod.rs

1/*
2Copyright 2025 The Flame Authors.
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6    http://www.apache.org/licenses/LICENSE-2.0
7Unless required by applicable law or agreed to in writing, software
8distributed under the License is distributed on an "AS IS" BASIS,
9WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10See the License for the specific language governing permissions and
11limitations under the License.
12*/
13
14use std::sync::Arc;
15
16use tokio::net::UnixListener;
17use tokio_stream::wrappers::UnixListenerStream;
18use tonic::{transport::Server, Request, Response, Status};
19
20use self::rpc::instance_server::{Instance, InstanceServer};
21use crate::apis::flame as rpc;
22
23use crate::apis::{CommonData, FlameError, TaskInput, TaskOutput};
24
25const FLAME_EXECUTOR_ID: &str = "FLAME_EXECUTOR_ID";
26
27pub struct ApplicationContext {
28    pub name: String,
29    pub image: Option<String>,
30    pub command: Option<String>,
31}
32
33pub struct SessionContext {
34    pub session_id: String,
35    pub application: ApplicationContext,
36    pub common_data: Option<CommonData>,
37}
38
39pub struct TaskContext {
40    pub task_id: String,
41    pub session_id: String,
42    pub input: Option<TaskInput>,
43}
44
45#[tonic::async_trait]
46pub trait FlameService: Send + Sync + 'static {
47    async fn on_session_enter(&self, _: SessionContext) -> Result<(), FlameError>;
48    async fn on_task_invoke(&self, _: TaskContext) -> Result<Option<TaskOutput>, FlameError>;
49    async fn on_session_leave(&self) -> Result<(), FlameError>;
50}
51
52pub type FlameServicePtr = Arc<dyn FlameService>;
53
54struct ShimService {
55    service: FlameServicePtr,
56}
57
58#[tonic::async_trait]
59impl Instance for ShimService {
60    async fn on_session_enter(
61        &self,
62        req: Request<rpc::SessionContext>,
63    ) -> Result<Response<rpc::Result>, Status> {
64        tracing::debug!("ShimService::on_session_enter");
65
66        let req = req.into_inner();
67        let resp = self
68            .service
69            .on_session_enter(SessionContext::from(req))
70            .await;
71
72        match resp {
73            Ok(_) => Ok(Response::new(rpc::Result {
74                return_code: 0,
75                message: None,
76            })),
77            Err(e) => Ok(Response::new(rpc::Result {
78                return_code: -1,
79                message: Some(e.to_string()),
80            })),
81        }
82    }
83
84    async fn on_task_invoke(
85        &self,
86        req: Request<rpc::TaskContext>,
87    ) -> Result<Response<rpc::TaskResult>, Status> {
88        tracing::debug!("ShimService::on_task_invoke");
89        let req = req.into_inner();
90        let resp = self.service.on_task_invoke(TaskContext::from(req)).await;
91
92        match resp {
93            Ok(data) => Ok(Response::new(rpc::TaskResult {
94                return_code: 0,
95                output: data.map(|d| d.into()),
96                message: None,
97            })),
98            Err(e) => Ok(Response::new(rpc::TaskResult {
99                return_code: -1,
100                output: None,
101                message: Some(e.to_string()),
102            })),
103        }
104    }
105
106    async fn on_session_leave(
107        &self,
108        _: Request<rpc::EmptyRequest>,
109    ) -> Result<Response<rpc::Result>, Status> {
110        tracing::debug!("ShimService::on_session_leave");
111        let resp = self.service.on_session_leave().await;
112
113        match resp {
114            Ok(_) => Ok(Response::new(rpc::Result {
115                return_code: 0,
116                message: None,
117            })),
118            Err(e) => Ok(Response::new(rpc::Result {
119                return_code: -1,
120                message: Some(e.to_string()),
121            })),
122        }
123    }
124}
125
126pub async fn run(service: impl FlameService) -> Result<(), Box<dyn std::error::Error>> {
127    let shim_service = ShimService {
128        service: Arc::new(service),
129    };
130
131    let uds = match std::env::var(FLAME_EXECUTOR_ID) {
132        Ok(executor_id) => UnixListener::bind(format!("/tmp/flame/shim/{executor_id}/fsi.sock"))?,
133        Err(_) => UnixListener::bind("/tmp/flame/shim/fsi.sock")?,
134    };
135
136    let uds_stream = UnixListenerStream::new(uds);
137
138    Server::builder()
139        .add_service(InstanceServer::new(shim_service))
140        .serve_with_incoming(uds_stream)
141        .await?;
142
143    Ok(())
144}
145
146impl From<rpc::ApplicationContext> for ApplicationContext {
147    fn from(ctx: rpc::ApplicationContext) -> Self {
148        Self {
149            name: ctx.name.clone(),
150            image: ctx.image.clone(),
151            command: ctx.command.clone(),
152        }
153    }
154}
155
156impl From<rpc::SessionContext> for SessionContext {
157    fn from(ctx: rpc::SessionContext) -> Self {
158        SessionContext {
159            session_id: ctx.session_id.clone(),
160            application: ctx.application.map(ApplicationContext::from).unwrap(),
161            common_data: ctx.common_data.map(|data| data.into()),
162        }
163    }
164}
165
166impl From<rpc::TaskContext> for TaskContext {
167    fn from(ctx: rpc::TaskContext) -> Self {
168        TaskContext {
169            task_id: ctx.task_id.clone(),
170            session_id: ctx.session_id.clone(),
171            input: ctx.input.map(|data| data.into()),
172        }
173    }
174}