dynamo_llm/grpc/service/
kserve.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::pin::Pin;
5use std::sync::Arc;
6
7use crate::grpc::service::kserve::inference::DataType;
8use crate::grpc::service::kserve::inference::ModelInput;
9use crate::grpc::service::kserve::inference::ModelOutput;
10use crate::http::service::Metrics;
11use crate::http::service::metrics;
12
13use crate::discovery::ModelManager;
14use crate::protocols::tensor::{NvCreateTensorRequest, NvCreateTensorResponse};
15use crate::request_template::RequestTemplate;
16use anyhow::Result;
17use derive_builder::Builder;
18use dynamo_runtime::transports::etcd;
19use futures::pin_mut;
20use tokio::task::JoinHandle;
21use tokio_stream::{Stream, StreamExt};
22use tokio_util::sync::CancellationToken;
23
24use crate::grpc::service::openai::completion_response_stream;
25use crate::grpc::service::tensor::tensor_response_stream;
26use std::convert::{TryFrom, TryInto};
27use tonic::{Request, Response, Status, transport::Server};
28
29use crate::protocols::openai::completions::{
30    NvCreateCompletionRequest, NvCreateCompletionResponse,
31};
32
33pub mod inference {
34    tonic::include_proto!("inference");
35}
36use inference::grpc_inference_service_server::{GrpcInferenceService, GrpcInferenceServiceServer};
37use inference::{
38    ModelConfig, ModelConfigRequest, ModelConfigResponse, ModelInferRequest, ModelInferResponse,
39    ModelMetadataRequest, ModelMetadataResponse, ModelStreamInferResponse,
40};
41
42/// [gluo TODO] 'metrics' are for HTTP service and there is HTTP endpoint
43/// for it as part of HTTP service. Should we always start HTTP service up
44/// for non-inference?
45pub struct State {
46    metrics: Arc<Metrics>,
47    manager: Arc<ModelManager>,
48    etcd_client: Option<etcd::Client>,
49}
50
51impl State {
52    pub fn new(manager: Arc<ModelManager>) -> Self {
53        Self {
54            manager,
55            metrics: Arc::new(Metrics::default()),
56            etcd_client: None,
57        }
58    }
59
60    pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: Option<etcd::Client>) -> Self {
61        Self {
62            manager,
63            metrics: Arc::new(Metrics::default()),
64            etcd_client,
65        }
66    }
67
68    /// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
69    pub fn metrics_clone(&self) -> Arc<Metrics> {
70        self.metrics.clone()
71    }
72
73    pub fn manager(&self) -> &ModelManager {
74        Arc::as_ref(&self.manager)
75    }
76
77    pub fn manager_clone(&self) -> Arc<ModelManager> {
78        self.manager.clone()
79    }
80
81    pub fn etcd_client(&self) -> Option<&etcd::Client> {
82        self.etcd_client.as_ref()
83    }
84
85    fn is_tensor_model(&self, model: &String) -> bool {
86        self.manager.list_tensor_models().contains(model)
87    }
88}
89
90#[derive(Clone)]
91pub struct KserveService {
92    // The state we share with every request handler
93    state: Arc<State>,
94
95    port: u16,
96    host: String,
97    request_template: Option<RequestTemplate>,
98}
99
100#[derive(Clone, Builder)]
101#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
102pub struct KserveServiceConfig {
103    #[builder(default = "8787")]
104    port: u16,
105
106    #[builder(setter(into), default = "String::from(\"0.0.0.0\")")]
107    host: String,
108
109    #[builder(default = "None")]
110    request_template: Option<RequestTemplate>,
111
112    #[builder(default = "None")]
113    etcd_client: Option<etcd::Client>,
114}
115
116impl KserveService {
117    pub fn builder() -> KserveServiceConfigBuilder {
118        KserveServiceConfigBuilder::default()
119    }
120
121    pub fn state_clone(&self) -> Arc<State> {
122        self.state.clone()
123    }
124
125    pub fn state(&self) -> &State {
126        Arc::as_ref(&self.state)
127    }
128
129    pub fn model_manager(&self) -> &ModelManager {
130        self.state().manager()
131    }
132
133    pub async fn spawn(&self, cancel_token: CancellationToken) -> JoinHandle<Result<()>> {
134        let this = self.clone();
135        tokio::spawn(async move { this.run(cancel_token).await })
136    }
137
138    pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> {
139        let address = format!("{}:{}", self.host, self.port);
140        tracing::info!(address, "Starting KServe gRPC service on: {address}");
141
142        let observer = cancel_token.child_token();
143        Server::builder()
144            .add_service(GrpcInferenceServiceServer::new(self.clone()))
145            .serve_with_shutdown(address.parse()?, observer.cancelled_owned())
146            .await
147            .inspect_err(|_| cancel_token.cancel())?;
148
149        Ok(())
150    }
151}
152
153impl KserveServiceConfigBuilder {
154    pub fn build(self) -> Result<KserveService, anyhow::Error> {
155        let config: KserveServiceConfig = self.build_internal()?;
156
157        let model_manager = Arc::new(ModelManager::new());
158        let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client));
159
160        // enable prometheus metrics
161        let registry = metrics::Registry::new();
162        state.metrics_clone().register(&registry)?;
163
164        Ok(KserveService {
165            state,
166            port: config.port,
167            host: config.host,
168            request_template: config.request_template,
169        })
170    }
171
172    pub fn with_request_template(mut self, request_template: Option<RequestTemplate>) -> Self {
173        self.request_template = Some(request_template);
174        self
175    }
176
177    pub fn with_etcd_client(mut self, etcd_client: Option<etcd::Client>) -> Self {
178        self.etcd_client = Some(etcd_client);
179        self
180    }
181}
182
183#[tonic::async_trait]
184impl GrpcInferenceService for KserveService {
185    async fn model_infer(
186        &self,
187        request: Request<ModelInferRequest>,
188    ) -> Result<Response<ModelInferResponse>, Status> {
189        let model = request.get_ref().model_name.clone();
190        let request = request.into_inner();
191        let request_id = request.id.clone();
192
193        // [gluo TODO] refactor to reuse code, inference logic is largely the same
194        if self.state().is_tensor_model(&model) {
195            // Fallback handling by assuming the model is OpenAI Completions model
196            let tensor_request: NvCreateTensorRequest = NvCreateTensorRequest::try_from(request)
197                .map_err(|e| Status::invalid_argument(format!("Failed to parse request: {}", e)))?;
198
199            let stream = tensor_response_stream(self.state_clone(), tensor_request, false).await?;
200
201            let tensor_response = NvCreateTensorResponse::from_annotated_stream(stream)
202                .await
203                .map_err(|e| {
204                    tracing::error!("Failed to fold completions stream: {:?}", e);
205                    Status::internal(format!("Failed to fold completions stream: {}", e))
206                })?;
207
208            let mut reply: ModelInferResponse = tensor_response.try_into().map_err(|e| {
209                Status::invalid_argument(format!("Failed to parse response: {}", e))
210            })?;
211            reply.id = request_id;
212
213            return Ok(Response::new(reply));
214        }
215
216        // Fallback handling by assuming the model is OpenAI Completions model
217        let mut completion_request: NvCreateCompletionRequest = request
218            .try_into()
219            .map_err(|e| Status::invalid_argument(format!("Failed to parse request: {}", e)))?;
220
221        if completion_request.inner.stream.unwrap_or(false) {
222            // return error that streaming is not supported
223            return Err(Status::invalid_argument(
224                "Streaming is not supported for this endpoint",
225            ));
226        }
227
228        // Apply template values if present
229        if let Some(template) = self.request_template.as_ref() {
230            if completion_request.inner.model.is_empty() {
231                completion_request.inner.model = template.model.clone();
232            }
233            if completion_request.inner.temperature.unwrap_or(0.0) == 0.0 {
234                completion_request.inner.temperature = Some(template.temperature);
235            }
236            if completion_request.inner.max_tokens.unwrap_or(0) == 0 {
237                completion_request.inner.max_tokens = Some(template.max_completion_tokens);
238            }
239        }
240
241        let model = completion_request.inner.model.clone();
242        let parsing_options = self.state.manager.get_parsing_options(&model);
243
244        let stream = completion_response_stream(self.state_clone(), completion_request).await?;
245
246        let completion_response =
247            NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
248                .await
249                .map_err(|e| {
250                    tracing::error!("Failed to fold completions stream: {:?}", e);
251                    Status::internal(format!("Failed to fold completions stream: {}", e))
252                })?;
253
254        let mut reply: ModelInferResponse = completion_response
255            .try_into()
256            .map_err(|e| Status::invalid_argument(format!("Failed to parse response: {}", e)))?;
257        reply.id = request_id;
258
259        Ok(Response::new(reply))
260    }
261
262    type ModelStreamInferStream =
263        Pin<Box<dyn Stream<Item = Result<ModelStreamInferResponse, Status>> + Send + 'static>>;
264
265    async fn model_stream_infer(
266        &self,
267        request: Request<tonic::Streaming<ModelInferRequest>>,
268    ) -> Result<Response<Self::ModelStreamInferStream>, Status> {
269        let mut request_stream = request.into_inner();
270        let state = self.state_clone();
271        let template = self.request_template.clone();
272        let output = async_stream::try_stream! {
273            // [gluo FIXME] should be able to demux request / response streaming
274            // await requests in a separate task until cancellation / completion,
275            // and passing AsyncEngineStream for each request to the response stream
276            // which will be collectively polling.
277            while let Some(request) = request_stream.next().await {
278                let request = match request {
279                    Err(e) => {
280                        tracing::error!("Unexpected gRPC failed to read request: {}", e);
281                        yield ModelStreamInferResponse {
282                            error_message: e.to_string(),
283                            infer_response: None
284                        };
285                        continue;
286                    }
287                    Ok(request) => {
288                        request
289                    }
290                };
291
292                let model = request.model_name.clone();
293
294                // [gluo TODO] refactor to reuse code, inference logic is largely the same
295                if state.is_tensor_model(&model) {
296                    // Must keep track of 'request_id' which will be returned in corresponding response
297                    let request_id = request.id.clone();
298                    let tensor_request: NvCreateTensorRequest = request.try_into().map_err(|e| {
299                        Status::invalid_argument(format!("Failed to parse request: {}", e))
300                    })?;
301
302                    let stream = tensor_response_stream(state.clone(), tensor_request, true).await?;
303
304                    pin_mut!(stream);
305                    while let Some(response) = stream.next().await {
306                        match response.data {
307                            Some(data) => {
308                                let mut reply = ModelStreamInferResponse::try_from(data).map_err(|e| {
309                                    Status::invalid_argument(format!("Failed to parse response: {}", e))
310                                })?;
311                                if reply.infer_response.is_some() {
312                                    reply.infer_response.as_mut().unwrap().id = request_id.clone();
313                                }
314                                yield reply;
315                            },
316                            None => {
317                                // Skip if no data is present, the response is for annotation
318                            },
319                        }
320                    }
321                    continue;
322                }
323
324                // Fallback handling by assuming the model is OpenAI Completions model
325                // Must keep track of 'request_id' which will be returned in corresponding response
326                let request_id = request.id.clone();
327                let mut completion_request: NvCreateCompletionRequest = request.try_into().map_err(|e| {
328                    Status::invalid_argument(format!("Failed to parse request: {}", e))
329                })?;
330
331                // Apply template values if present
332                if let Some(template) = &template {
333                    if completion_request.inner.model.is_empty() {
334                        completion_request.inner.model = template.model.clone();
335                    }
336                    if completion_request.inner.temperature.unwrap_or(0.0) == 0.0 {
337                        completion_request.inner.temperature = Some(template.temperature);
338                    }
339                    if completion_request.inner.max_tokens.unwrap_or(0) == 0 {
340                        completion_request.inner.max_tokens = Some(template.max_completion_tokens);
341                    }
342                }
343
344                let model = completion_request.inner.model.clone();
345                let parsing_options = state.manager.get_parsing_options(&model);
346
347                let streaming = completion_request.inner.stream.unwrap_or(false);
348
349                let stream = completion_response_stream(state.clone(), completion_request).await?;
350
351                if streaming {
352                    pin_mut!(stream);
353                    while let Some(response) = stream.next().await {
354                        match response.data {
355                            Some(data) => {
356                                let mut reply = ModelStreamInferResponse::try_from(data).map_err(|e| {
357                                    Status::invalid_argument(format!("Failed to parse response: {}", e))
358                                })?;
359                                if reply.infer_response.is_some() {
360                                    reply.infer_response.as_mut().unwrap().id = request_id.clone();
361                                }
362                                yield reply;
363                            },
364                            None => {
365                                // Skip if no data is present, the response is for annotation
366                            },
367                        }
368                    }
369                } else {
370                    let completion_response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
371                        .await
372                        .map_err(|e| {
373                            tracing::error!(
374                                "Failed to fold completions stream: {:?}",
375                                e
376                            );
377                            Status::internal(format!("Failed to fold completions stream: {}", e))
378                        })?;
379
380                    let mut response: ModelStreamInferResponse = completion_response.try_into().map_err(|e| {
381                        Status::invalid_argument(format!("Failed to parse response: {}", e))
382                    })?;
383                    if response.infer_response.is_some() {
384                        response.infer_response.as_mut().unwrap().id = request_id.clone();
385                    }
386                    yield response;
387                }
388            }
389        };
390
391        Ok(Response::new(
392            Box::pin(output) as Self::ModelStreamInferStream
393        ))
394    }
395
396    async fn model_metadata(
397        &self,
398        request: Request<ModelMetadataRequest>,
399    ) -> Result<Response<ModelMetadataResponse>, Status> {
400        let entries = self.state.manager().get_model_entries();
401        let request_model_name = &request.into_inner().name;
402        if let Some(entry) = entries
403            .into_iter()
404            .find(|entry| request_model_name == &entry.name)
405        {
406            if entry.model_type.supports_tensor() {
407                if let Some(config) = entry.runtime_config.as_ref()
408                    && let Some(tensor_model_config) = config.tensor_model_config.as_ref()
409                {
410                    return Ok(Response::new(ModelMetadataResponse {
411                        name: tensor_model_config.name.clone(),
412                        versions: vec!["1".to_string()],
413                        platform: "dynamo".to_string(),
414                        inputs: tensor_model_config
415                            .inputs
416                            .iter()
417                            .map(|input| inference::model_metadata_response::TensorMetadata {
418                                name: input.name.clone(),
419                                datatype: input.data_type.to_string(),
420                                shape: input.shape.clone(),
421                            })
422                            .collect(),
423                        outputs: tensor_model_config
424                            .outputs
425                            .iter()
426                            .map(
427                                |output| inference::model_metadata_response::TensorMetadata {
428                                    name: output.name.clone(),
429                                    datatype: output.data_type.to_string(),
430                                    shape: output.shape.clone(),
431                                },
432                            )
433                            .collect(),
434                    }));
435                }
436                Err(Status::invalid_argument(format!(
437                    "Model '{}' has type Tensor but no model config is provided",
438                    request_model_name
439                )))?
440            } else if entry.model_type.supports_completions() {
441                return Ok(Response::new(ModelMetadataResponse {
442                    name: entry.name,
443                    versions: vec!["1".to_string()],
444                    platform: "dynamo".to_string(),
445                    inputs: vec![
446                        inference::model_metadata_response::TensorMetadata {
447                            name: "text_input".to_string(),
448                            datatype: "BYTES".to_string(),
449                            shape: vec![1],
450                        },
451                        inference::model_metadata_response::TensorMetadata {
452                            name: "streaming".to_string(),
453                            datatype: "BOOL".to_string(),
454                            shape: vec![1],
455                        },
456                    ],
457                    outputs: vec![
458                        inference::model_metadata_response::TensorMetadata {
459                            name: "text_output".to_string(),
460                            datatype: "BYTES".to_string(),
461                            shape: vec![-1],
462                        },
463                        inference::model_metadata_response::TensorMetadata {
464                            name: "finish_reason".to_string(),
465                            datatype: "BYTES".to_string(),
466                            shape: vec![-1],
467                        },
468                    ],
469                }));
470            }
471        }
472        Err(Status::not_found(format!(
473            "Model '{}' not found",
474            request_model_name
475        )))
476    }
477
478    async fn model_config(
479        &self,
480        request: Request<ModelConfigRequest>,
481    ) -> Result<Response<ModelConfigResponse>, Status> {
482        let entries = self.state.manager().get_model_entries();
483        let request_model_name = &request.into_inner().name;
484        if let Some(entry) = entries
485            .into_iter()
486            .find(|entry| request_model_name == &entry.name)
487        {
488            if entry.model_type.supports_tensor() {
489                if let Some(config) = entry.runtime_config.as_ref()
490                    && let Some(tensor_model_config) = config.tensor_model_config.as_ref()
491                {
492                    let model_config = ModelConfig {
493                        name: tensor_model_config.name.clone(),
494                        platform: "dynamo".to_string(),
495                        backend: "dynamo".to_string(),
496                        input: tensor_model_config
497                            .inputs
498                            .iter()
499                            .map(|input| ModelInput {
500                                name: input.name.clone(),
501                                data_type: input.data_type.to_kserve(),
502                                dims: input.shape.clone(),
503                                ..Default::default()
504                            })
505                            .collect(),
506                        output: tensor_model_config
507                            .outputs
508                            .iter()
509                            .map(|output| ModelOutput {
510                                name: output.name.clone(),
511                                data_type: output.data_type.to_kserve(),
512                                dims: output.shape.clone(),
513                                ..Default::default()
514                            })
515                            .collect(),
516                        ..Default::default()
517                    };
518                    return Ok(Response::new(ModelConfigResponse {
519                        config: Some(model_config.clone()),
520                    }));
521                }
522                Err(Status::invalid_argument(format!(
523                    "Model '{}' has type Tensor but no model config is provided",
524                    request_model_name
525                )))?
526            } else if entry.model_type.supports_completions() {
527                let config = ModelConfig {
528                    name: entry.name,
529                    platform: "dynamo".to_string(),
530                    backend: "dynamo".to_string(),
531                    input: vec![
532                        ModelInput {
533                            name: "text_input".to_string(),
534                            data_type: DataType::TypeString as i32,
535                            dims: vec![1],
536                            ..Default::default()
537                        },
538                        ModelInput {
539                            name: "streaming".to_string(),
540                            data_type: DataType::TypeBool as i32,
541                            dims: vec![1],
542                            optional: true,
543                            ..Default::default()
544                        },
545                    ],
546                    output: vec![
547                        ModelOutput {
548                            name: "text_output".to_string(),
549                            data_type: DataType::TypeString as i32,
550                            dims: vec![-1],
551                            ..Default::default()
552                        },
553                        ModelOutput {
554                            name: "finish_reason".to_string(),
555                            data_type: DataType::TypeString as i32,
556                            dims: vec![-1],
557                            ..Default::default()
558                        },
559                    ],
560                    ..Default::default()
561                };
562                return Ok(Response::new(ModelConfigResponse {
563                    config: Some(config),
564                }));
565            }
566        }
567        Err(Status::not_found(format!(
568            "Model '{}' not found",
569            request_model_name
570        )))
571    }
572}