Skip to main content

solti_api/
grpc.rs

1//! # gRPC transport.
2//!
3//! [`SoltiApiService`] implements the generated `SoltiApi` trait from `proto/solti/v1/api.proto`,
4//! delegating to an [`ApiHandler`](crate::ApiHandler).
5
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Instant;
9
10use tokio_stream::StreamExt;
11use tonic::{Request, Response, Status};
12use tracing::debug;
13
14use solti_model::TaskQuery;
15
16use crate::convert::{output_event_to_proto, proto_to_domain_status, tasks_page_to_proto};
17use crate::error::ApiError;
18use crate::handler::ApiHandler;
19use crate::metrics::{ApiMetricsHandle, Transport, noop_api_metrics};
20use crate::proto_api::{self, solti_api_server::SoltiApi, solti_api_server::SoltiApiServer};
21use crate::validate::{clamp_list_limit, non_empty_id};
22
23/// gRPC service wrapping an [`ApiHandler`](crate::ApiHandler).
24///
25/// ## Also
26///
27/// - `SoltiApiServer` generated tonic server wrapper.
28/// - [`ApiError`](crate::ApiError) mapped to `tonic::Status`.
29pub struct SoltiApiService<H> {
30    handler: Arc<H>,
31    metrics: ApiMetricsHandle,
32}
33
34impl<H> SoltiApiService<H>
35where
36    H: ApiHandler,
37{
38    /// Create a new gRPC service with the given handler and no-op metrics.
39    pub fn new(handler: Arc<H>) -> Self {
40        Self::new_with_metrics(handler, noop_api_metrics())
41    }
42
43    /// Create a new gRPC service with an explicit metrics backend.
44    pub fn new_with_metrics(handler: Arc<H>, metrics: ApiMetricsHandle) -> Self {
45        Self { handler, metrics }
46    }
47
48    async fn instrument<F, T>(&self, method: &'static str, fut: F) -> Result<Response<T>, Status>
49    where
50        F: Future<Output = Result<Response<T>, Status>>,
51    {
52        self.metrics.record_in_flight_delta(Transport::Grpc, 1);
53        let start = Instant::now();
54        let result = fut.await;
55        let duration_ms = start.elapsed().as_millis() as u64;
56        let status = match &result {
57            Ok(_) => 0u16,
58            Err(s) => s.code() as u16,
59        };
60        let path = format!("/solti.v1.SoltiApi/{}", method);
61        self.metrics
62            .record_request(Transport::Grpc, method, &path, status, duration_ms);
63        self.metrics.record_in_flight_delta(Transport::Grpc, -1);
64        result
65    }
66}
67
68/// Build a configured `SoltiApiServer` with no-op metrics.
69///
70/// ## Example
71///
72/// ```rust,no_run
73/// # use std::sync::Arc;
74/// # use solti_api::{build_grpc_server, SupervisorApiAdapter};
75/// # async fn example(adapter: Arc<SupervisorApiAdapter>) -> Result<(), Box<dyn std::error::Error>> {
76/// let svc = build_grpc_server(adapter);
77/// tonic::transport::Server::builder()
78///     .add_service(svc)
79///     .serve("0.0.0.0:50052".parse()?)
80///     .await?;
81/// # Ok(()) }
82/// ```
83pub fn build_grpc_server<H>(handler: Arc<H>) -> SoltiApiServer<SoltiApiService<H>>
84where
85    H: ApiHandler,
86{
87    build_grpc_server_with_metrics(handler, noop_api_metrics())
88}
89
90/// Build a configured `SoltiApiServer` with an explicit metrics backend.
91pub fn build_grpc_server_with_metrics<H>(
92    handler: Arc<H>,
93    metrics: ApiMetricsHandle,
94) -> SoltiApiServer<SoltiApiService<H>>
95where
96    H: ApiHandler,
97{
98    SoltiApiServer::new(SoltiApiService::new_with_metrics(handler, metrics))
99        .max_decoding_message_size(crate::MAX_REQUEST_BYTES)
100        .max_encoding_message_size(crate::MAX_REQUEST_BYTES)
101}
102
103#[tonic::async_trait]
104impl<H> SoltiApi for SoltiApiService<H>
105where
106    H: ApiHandler,
107{
108    async fn submit_task(
109        &self,
110        request: Request<proto_api::SubmitTaskRequest>,
111    ) -> Result<Response<proto_api::SubmitTaskResponse>, Status> {
112        self.instrument("SubmitTask", async move {
113            let req = request.into_inner();
114
115            let spec = req
116                .spec
117                .ok_or_else(|| Status::invalid_argument("missing spec"))?;
118
119            let spec =
120                crate::convert::convert_create_spec(spec).map_err(|e: ApiError| Status::from(e))?;
121
122            debug!(slot = %spec.slot(), kind = ?spec.kind(), "grpc: submitting task");
123            let task_id = self.handler.submit_task(spec).await.map_err(Status::from)?;
124
125            Ok(Response::new(proto_api::SubmitTaskResponse {
126                task_id: task_id.to_string(),
127            }))
128        })
129        .await
130    }
131
132    async fn get_task_status(
133        &self,
134        request: Request<proto_api::GetTaskStatusRequest>,
135    ) -> Result<Response<proto_api::GetTaskStatusResponse>, Status> {
136        self.instrument("GetTaskStatus", async move {
137            let req = request.into_inner();
138
139            non_empty_id("task_id", &req.task_id).map_err(Status::from)?;
140
141            let task_id = solti_model::TaskId::from(req.task_id);
142            debug!(%task_id, "grpc: getting task status");
143
144            let info = self
145                .handler
146                .get_task_status(&task_id)
147                .await
148                .map_err(Status::from)?;
149
150            let task = info
151                .map(proto_api::TaskData::try_from)
152                .transpose()
153                .map_err(Status::from)?;
154
155            Ok(Response::new(proto_api::GetTaskStatusResponse { task }))
156        })
157        .await
158    }
159
160    async fn list_tasks(
161        &self,
162        request: Request<proto_api::ListTasksRequest>,
163    ) -> Result<Response<proto_api::ListTasksResponse>, Status> {
164        self.instrument("ListTasks", async move {
165            let req = request.into_inner();
166
167            let mut query = TaskQuery::new();
168
169            if let Some(slot) = req.slot {
170                non_empty_id("slot", &slot).map_err(Status::from)?;
171                query = query.with_slot(slot);
172            }
173
174            if let Some(status_raw) = req.status {
175                let status = proto_to_domain_status(status_raw).map_err(Status::from)?;
176                query = query.with_status(status);
177            }
178
179            query = query.with_limit(clamp_list_limit(req.limit));
180            if req.offset > 0 {
181                query = query.with_offset(req.offset as usize);
182            }
183
184            let page = self
185                .handler
186                .query_tasks(query)
187                .await
188                .map_err(Status::from)?;
189
190            debug!(
191                count = page.items.len(),
192                total = page.total,
193                "grpc: tasks listed"
194            );
195
196            let response = tasks_page_to_proto(page).map_err(Status::from)?;
197            Ok(Response::new(response))
198        })
199        .await
200    }
201
202    async fn list_task_runs(
203        &self,
204        request: Request<proto_api::ListTaskRunsRequest>,
205    ) -> Result<Response<proto_api::ListTaskRunsResponse>, Status> {
206        self.instrument("ListTaskRuns", async move {
207            let req = request.into_inner();
208
209            non_empty_id("task_id", &req.task_id).map_err(Status::from)?;
210
211            let task_id = solti_model::TaskId::from(req.task_id);
212            debug!(%task_id, "grpc: listing task runs");
213
214            let runs = self
215                .handler
216                .list_task_runs(&task_id)
217                .await
218                .map_err(Status::from)?;
219
220            let runs = runs.into_iter().map(proto_api::TaskRunInfo::from).collect();
221
222            Ok(Response::new(proto_api::ListTaskRunsResponse { runs }))
223        })
224        .await
225    }
226
227    async fn delete_task(
228        &self,
229        request: Request<proto_api::DeleteTaskRequest>,
230    ) -> Result<Response<proto_api::DeleteTaskResponse>, Status> {
231        self.instrument("DeleteTask", async move {
232            let req = request.into_inner();
233
234            non_empty_id("task_id", &req.task_id).map_err(Status::from)?;
235
236            let task_id = solti_model::TaskId::from(req.task_id);
237            debug!(%task_id, "grpc: deleting task");
238
239            self.handler
240                .delete_task(&task_id)
241                .await
242                .map_err(Status::from)?;
243
244            debug!(%task_id, "grpc: task deleted");
245            Ok(Response::new(proto_api::DeleteTaskResponse {}))
246        })
247        .await
248    }
249
250    /// Server-streaming RPC.
251    type StreamTaskLogsStream = Pin<
252        Box<
253            dyn tokio_stream::Stream<Item = Result<proto_api::OutputEventProto, Status>>
254                + Send
255                + 'static,
256        >,
257    >;
258
259    async fn stream_task_logs(
260        &self,
261        request: Request<proto_api::StreamTaskLogsRequest>,
262    ) -> Result<Response<Self::StreamTaskLogsStream>, Status> {
263        let req = request.into_inner();
264        non_empty_id("task_id", &req.task_id).map_err(Status::from)?;
265
266        let task_id = solti_model::TaskId::from(req.task_id);
267        debug!(%task_id, "grpc: subscribing to task log stream");
268
269        let domain_stream = self
270            .handler
271            .stream_task_logs(&task_id)
272            .await
273            .map_err(Status::from)?;
274
275        let proto_stream = domain_stream.map(|ev| Ok(output_event_to_proto(ev)));
276        Ok(Response::new(Box::pin(proto_stream)))
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    use std::time::{Duration, UNIX_EPOCH};
285
286    use async_trait::async_trait;
287    use bytes::Bytes;
288    use solti_model::{
289        OutputChunk, OutputEvent, StreamKind as ModelStreamKind, Task, TaskId, TaskPage, TaskQuery,
290        TaskRun, TaskSpec,
291    };
292
293    use crate::error::ApiError;
294    use crate::handler::{ApiHandler, OutputEventStream};
295
296    struct StreamMock;
297
298    #[async_trait]
299    impl ApiHandler for StreamMock {
300        async fn submit_task(&self, _spec: TaskSpec) -> Result<TaskId, ApiError> {
301            unreachable!()
302        }
303        async fn get_task_status(&self, _id: &TaskId) -> Result<Option<Task>, ApiError> {
304            unreachable!()
305        }
306        async fn query_tasks(&self, _q: TaskQuery) -> Result<TaskPage<Task>, ApiError> {
307            unreachable!()
308        }
309        async fn list_task_runs(&self, _id: &TaskId) -> Result<Vec<TaskRun>, ApiError> {
310            unreachable!()
311        }
312        async fn delete_task(&self, _id: &TaskId) -> Result<(), ApiError> {
313            unreachable!()
314        }
315        async fn stream_task_logs(&self, id: &TaskId) -> Result<OutputEventStream, ApiError> {
316            if id.as_str() == "missing" {
317                return Err(ApiError::TaskNotFound(id.to_string()));
318            }
319            let events = vec![
320                OutputEvent::RunStarted {
321                    attempt: 1,
322                    started_at: UNIX_EPOCH + Duration::from_millis(1000),
323                },
324                OutputEvent::Chunk(OutputChunk {
325                    attempt: 1,
326                    stream: ModelStreamKind::Stdout,
327                    seq: 0,
328                    ts: UNIX_EPOCH + Duration::from_millis(1100),
329                    line: Bytes::from_static(b"hello-grpc"),
330                }),
331                OutputEvent::RunFinished {
332                    attempt: 1,
333                    exit_code: Some(0),
334                    finished_at: UNIX_EPOCH + Duration::from_millis(1500),
335                },
336            ];
337            Ok(Box::pin(tokio_stream::iter(events)))
338        }
339    }
340
341    fn service() -> SoltiApiService<StreamMock> {
342        SoltiApiService::new(Arc::new(StreamMock))
343    }
344
345    #[tokio::test]
346    async fn stream_task_logs_returns_three_proto_events_in_order() {
347        let svc = service();
348        let req = Request::new(proto_api::StreamTaskLogsRequest {
349            task_id: "tsk_1".into(),
350        });
351
352        let response = svc.stream_task_logs(req).await.expect("stream Ok");
353        let mut stream = response.into_inner();
354
355        match stream.next().await.unwrap().unwrap().kind.unwrap() {
356            proto_api::output_event_proto::Kind::RunStarted(r) => {
357                assert_eq!(r.attempt, 1);
358                assert_eq!(r.started_at, 1000);
359            }
360            other => panic!("expected RunStarted, got {other:?}"),
361        }
362
363        match stream.next().await.unwrap().unwrap().kind.unwrap() {
364            proto_api::output_event_proto::Kind::Chunk(c) => {
365                assert_eq!(c.attempt, 1);
366                assert_eq!(c.stream, proto_api::OutputStreamKind::Stdout as i32);
367                assert_eq!(c.seq, 0);
368                assert_eq!(&c.line[..], b"hello-grpc");
369            }
370            other => panic!("expected Chunk, got {other:?}"),
371        }
372
373        match stream.next().await.unwrap().unwrap().kind.unwrap() {
374            proto_api::output_event_proto::Kind::RunFinished(r) => {
375                assert_eq!(r.attempt, 1);
376                assert_eq!(r.exit_code, Some(0));
377                assert_eq!(r.finished_at, 1500);
378            }
379            other => panic!("expected RunFinished, got {other:?}"),
380        }
381        assert!(stream.next().await.is_none(), "stream must terminate");
382    }
383
384    #[tokio::test]
385    async fn stream_task_logs_rejects_empty_task_id() {
386        let svc = service();
387        let req = Request::new(proto_api::StreamTaskLogsRequest {
388            task_id: "  ".into(),
389        });
390        let status = match svc.stream_task_logs(req).await {
391            Err(s) => s,
392            Ok(_) => panic!("expected error status"),
393        };
394        assert_eq!(status.code(), tonic::Code::InvalidArgument);
395    }
396
397    #[tokio::test]
398    async fn stream_task_logs_maps_task_not_found_to_not_found_status() {
399        let svc = service();
400        let req = Request::new(proto_api::StreamTaskLogsRequest {
401            task_id: "missing".into(),
402        });
403        let status = match svc.stream_task_logs(req).await {
404            Err(s) => s,
405            Ok(_) => panic!("expected error status"),
406        };
407        assert_eq!(status.code(), tonic::Code::NotFound);
408    }
409}