1use std::pin::Pin;
5use std::sync::Arc;
6
7use crate::grpc::service::kserve::inference::DataType;
8use crate::grpc::service::kserve::inference::ModelInput;
9use crate::grpc::service::kserve::inference::ModelOutput;
10use crate::http::service::Metrics;
11use crate::http::service::metrics;
12
13use crate::discovery::ModelManager;
14use crate::protocols::tensor::{NvCreateTensorRequest, NvCreateTensorResponse};
15use crate::request_template::RequestTemplate;
16use anyhow::Result;
17use derive_builder::Builder;
18use dynamo_runtime::transports::etcd;
19use futures::pin_mut;
20use tokio::task::JoinHandle;
21use tokio_stream::{Stream, StreamExt};
22use tokio_util::sync::CancellationToken;
23
24use crate::grpc::service::openai::completion_response_stream;
25use crate::grpc::service::tensor::tensor_response_stream;
26use std::convert::{TryFrom, TryInto};
27use tonic::{Request, Response, Status, transport::Server};
28
29use crate::protocols::openai::completions::{
30 NvCreateCompletionRequest, NvCreateCompletionResponse,
31};
32
33pub mod inference {
34 tonic::include_proto!("inference");
35}
36use inference::grpc_inference_service_server::{GrpcInferenceService, GrpcInferenceServiceServer};
37use inference::{
38 ModelConfig, ModelConfigRequest, ModelConfigResponse, ModelInferRequest, ModelInferResponse,
39 ModelMetadataRequest, ModelMetadataResponse, ModelStreamInferResponse,
40};
41
42pub struct State {
46 metrics: Arc<Metrics>,
47 manager: Arc<ModelManager>,
48 etcd_client: Option<etcd::Client>,
49}
50
51impl State {
52 pub fn new(manager: Arc<ModelManager>) -> Self {
53 Self {
54 manager,
55 metrics: Arc::new(Metrics::default()),
56 etcd_client: None,
57 }
58 }
59
60 pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: Option<etcd::Client>) -> Self {
61 Self {
62 manager,
63 metrics: Arc::new(Metrics::default()),
64 etcd_client,
65 }
66 }
67
68 pub fn metrics_clone(&self) -> Arc<Metrics> {
70 self.metrics.clone()
71 }
72
73 pub fn manager(&self) -> &ModelManager {
74 Arc::as_ref(&self.manager)
75 }
76
77 pub fn manager_clone(&self) -> Arc<ModelManager> {
78 self.manager.clone()
79 }
80
81 pub fn etcd_client(&self) -> Option<&etcd::Client> {
82 self.etcd_client.as_ref()
83 }
84
85 fn is_tensor_model(&self, model: &String) -> bool {
86 self.manager.list_tensor_models().contains(model)
87 }
88}
89
90#[derive(Clone)]
91pub struct KserveService {
92 state: Arc<State>,
94
95 port: u16,
96 host: String,
97 request_template: Option<RequestTemplate>,
98}
99
100#[derive(Clone, Builder)]
101#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
102pub struct KserveServiceConfig {
103 #[builder(default = "8787")]
104 port: u16,
105
106 #[builder(setter(into), default = "String::from(\"0.0.0.0\")")]
107 host: String,
108
109 #[builder(default = "None")]
110 request_template: Option<RequestTemplate>,
111
112 #[builder(default = "None")]
113 etcd_client: Option<etcd::Client>,
114}
115
116impl KserveService {
117 pub fn builder() -> KserveServiceConfigBuilder {
118 KserveServiceConfigBuilder::default()
119 }
120
121 pub fn state_clone(&self) -> Arc<State> {
122 self.state.clone()
123 }
124
125 pub fn state(&self) -> &State {
126 Arc::as_ref(&self.state)
127 }
128
129 pub fn model_manager(&self) -> &ModelManager {
130 self.state().manager()
131 }
132
133 pub async fn spawn(&self, cancel_token: CancellationToken) -> JoinHandle<Result<()>> {
134 let this = self.clone();
135 tokio::spawn(async move { this.run(cancel_token).await })
136 }
137
138 pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> {
139 let address = format!("{}:{}", self.host, self.port);
140 tracing::info!(address, "Starting KServe gRPC service on: {address}");
141
142 let observer = cancel_token.child_token();
143 Server::builder()
144 .add_service(GrpcInferenceServiceServer::new(self.clone()))
145 .serve_with_shutdown(address.parse()?, observer.cancelled_owned())
146 .await
147 .inspect_err(|_| cancel_token.cancel())?;
148
149 Ok(())
150 }
151}
152
153impl KserveServiceConfigBuilder {
154 pub fn build(self) -> Result<KserveService, anyhow::Error> {
155 let config: KserveServiceConfig = self.build_internal()?;
156
157 let model_manager = Arc::new(ModelManager::new());
158 let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client));
159
160 let registry = metrics::Registry::new();
162 state.metrics_clone().register(®istry)?;
163
164 Ok(KserveService {
165 state,
166 port: config.port,
167 host: config.host,
168 request_template: config.request_template,
169 })
170 }
171
172 pub fn with_request_template(mut self, request_template: Option<RequestTemplate>) -> Self {
173 self.request_template = Some(request_template);
174 self
175 }
176
177 pub fn with_etcd_client(mut self, etcd_client: Option<etcd::Client>) -> Self {
178 self.etcd_client = Some(etcd_client);
179 self
180 }
181}
182
183#[tonic::async_trait]
184impl GrpcInferenceService for KserveService {
185 async fn model_infer(
186 &self,
187 request: Request<ModelInferRequest>,
188 ) -> Result<Response<ModelInferResponse>, Status> {
189 let model = request.get_ref().model_name.clone();
190 let request = request.into_inner();
191 let request_id = request.id.clone();
192
193 if self.state().is_tensor_model(&model) {
195 let tensor_request: NvCreateTensorRequest = NvCreateTensorRequest::try_from(request)
197 .map_err(|e| Status::invalid_argument(format!("Failed to parse request: {}", e)))?;
198
199 let stream = tensor_response_stream(self.state_clone(), tensor_request, false).await?;
200
201 let tensor_response = NvCreateTensorResponse::from_annotated_stream(stream)
202 .await
203 .map_err(|e| {
204 tracing::error!("Failed to fold completions stream: {:?}", e);
205 Status::internal(format!("Failed to fold completions stream: {}", e))
206 })?;
207
208 let mut reply: ModelInferResponse = tensor_response.try_into().map_err(|e| {
209 Status::invalid_argument(format!("Failed to parse response: {}", e))
210 })?;
211 reply.id = request_id;
212
213 return Ok(Response::new(reply));
214 }
215
216 let mut completion_request: NvCreateCompletionRequest = request
218 .try_into()
219 .map_err(|e| Status::invalid_argument(format!("Failed to parse request: {}", e)))?;
220
221 if completion_request.inner.stream.unwrap_or(false) {
222 return Err(Status::invalid_argument(
224 "Streaming is not supported for this endpoint",
225 ));
226 }
227
228 if let Some(template) = self.request_template.as_ref() {
230 if completion_request.inner.model.is_empty() {
231 completion_request.inner.model = template.model.clone();
232 }
233 if completion_request.inner.temperature.unwrap_or(0.0) == 0.0 {
234 completion_request.inner.temperature = Some(template.temperature);
235 }
236 if completion_request.inner.max_tokens.unwrap_or(0) == 0 {
237 completion_request.inner.max_tokens = Some(template.max_completion_tokens);
238 }
239 }
240
241 let model = completion_request.inner.model.clone();
242 let parsing_options = self.state.manager.get_parsing_options(&model);
243
244 let stream = completion_response_stream(self.state_clone(), completion_request).await?;
245
246 let completion_response =
247 NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
248 .await
249 .map_err(|e| {
250 tracing::error!("Failed to fold completions stream: {:?}", e);
251 Status::internal(format!("Failed to fold completions stream: {}", e))
252 })?;
253
254 let mut reply: ModelInferResponse = completion_response
255 .try_into()
256 .map_err(|e| Status::invalid_argument(format!("Failed to parse response: {}", e)))?;
257 reply.id = request_id;
258
259 Ok(Response::new(reply))
260 }
261
262 type ModelStreamInferStream =
263 Pin<Box<dyn Stream<Item = Result<ModelStreamInferResponse, Status>> + Send + 'static>>;
264
265 async fn model_stream_infer(
266 &self,
267 request: Request<tonic::Streaming<ModelInferRequest>>,
268 ) -> Result<Response<Self::ModelStreamInferStream>, Status> {
269 let mut request_stream = request.into_inner();
270 let state = self.state_clone();
271 let template = self.request_template.clone();
272 let output = async_stream::try_stream! {
273 while let Some(request) = request_stream.next().await {
278 let request = match request {
279 Err(e) => {
280 tracing::error!("Unexpected gRPC failed to read request: {}", e);
281 yield ModelStreamInferResponse {
282 error_message: e.to_string(),
283 infer_response: None
284 };
285 continue;
286 }
287 Ok(request) => {
288 request
289 }
290 };
291
292 let model = request.model_name.clone();
293
294 if state.is_tensor_model(&model) {
296 let request_id = request.id.clone();
298 let tensor_request: NvCreateTensorRequest = request.try_into().map_err(|e| {
299 Status::invalid_argument(format!("Failed to parse request: {}", e))
300 })?;
301
302 let stream = tensor_response_stream(state.clone(), tensor_request, true).await?;
303
304 pin_mut!(stream);
305 while let Some(response) = stream.next().await {
306 match response.data {
307 Some(data) => {
308 let mut reply = ModelStreamInferResponse::try_from(data).map_err(|e| {
309 Status::invalid_argument(format!("Failed to parse response: {}", e))
310 })?;
311 if reply.infer_response.is_some() {
312 reply.infer_response.as_mut().unwrap().id = request_id.clone();
313 }
314 yield reply;
315 },
316 None => {
317 },
319 }
320 }
321 continue;
322 }
323
324 let request_id = request.id.clone();
327 let mut completion_request: NvCreateCompletionRequest = request.try_into().map_err(|e| {
328 Status::invalid_argument(format!("Failed to parse request: {}", e))
329 })?;
330
331 if let Some(template) = &template {
333 if completion_request.inner.model.is_empty() {
334 completion_request.inner.model = template.model.clone();
335 }
336 if completion_request.inner.temperature.unwrap_or(0.0) == 0.0 {
337 completion_request.inner.temperature = Some(template.temperature);
338 }
339 if completion_request.inner.max_tokens.unwrap_or(0) == 0 {
340 completion_request.inner.max_tokens = Some(template.max_completion_tokens);
341 }
342 }
343
344 let model = completion_request.inner.model.clone();
345 let parsing_options = state.manager.get_parsing_options(&model);
346
347 let streaming = completion_request.inner.stream.unwrap_or(false);
348
349 let stream = completion_response_stream(state.clone(), completion_request).await?;
350
351 if streaming {
352 pin_mut!(stream);
353 while let Some(response) = stream.next().await {
354 match response.data {
355 Some(data) => {
356 let mut reply = ModelStreamInferResponse::try_from(data).map_err(|e| {
357 Status::invalid_argument(format!("Failed to parse response: {}", e))
358 })?;
359 if reply.infer_response.is_some() {
360 reply.infer_response.as_mut().unwrap().id = request_id.clone();
361 }
362 yield reply;
363 },
364 None => {
365 },
367 }
368 }
369 } else {
370 let completion_response = NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
371 .await
372 .map_err(|e| {
373 tracing::error!(
374 "Failed to fold completions stream: {:?}",
375 e
376 );
377 Status::internal(format!("Failed to fold completions stream: {}", e))
378 })?;
379
380 let mut response: ModelStreamInferResponse = completion_response.try_into().map_err(|e| {
381 Status::invalid_argument(format!("Failed to parse response: {}", e))
382 })?;
383 if response.infer_response.is_some() {
384 response.infer_response.as_mut().unwrap().id = request_id.clone();
385 }
386 yield response;
387 }
388 }
389 };
390
391 Ok(Response::new(
392 Box::pin(output) as Self::ModelStreamInferStream
393 ))
394 }
395
396 async fn model_metadata(
397 &self,
398 request: Request<ModelMetadataRequest>,
399 ) -> Result<Response<ModelMetadataResponse>, Status> {
400 let entries = self.state.manager().get_model_entries();
401 let request_model_name = &request.into_inner().name;
402 if let Some(entry) = entries
403 .into_iter()
404 .find(|entry| request_model_name == &entry.name)
405 {
406 if entry.model_type.supports_tensor() {
407 if let Some(config) = entry.runtime_config.as_ref()
408 && let Some(tensor_model_config) = config.tensor_model_config.as_ref()
409 {
410 return Ok(Response::new(ModelMetadataResponse {
411 name: tensor_model_config.name.clone(),
412 versions: vec!["1".to_string()],
413 platform: "dynamo".to_string(),
414 inputs: tensor_model_config
415 .inputs
416 .iter()
417 .map(|input| inference::model_metadata_response::TensorMetadata {
418 name: input.name.clone(),
419 datatype: input.data_type.to_string(),
420 shape: input.shape.clone(),
421 })
422 .collect(),
423 outputs: tensor_model_config
424 .outputs
425 .iter()
426 .map(
427 |output| inference::model_metadata_response::TensorMetadata {
428 name: output.name.clone(),
429 datatype: output.data_type.to_string(),
430 shape: output.shape.clone(),
431 },
432 )
433 .collect(),
434 }));
435 }
436 Err(Status::invalid_argument(format!(
437 "Model '{}' has type Tensor but no model config is provided",
438 request_model_name
439 )))?
440 } else if entry.model_type.supports_completions() {
441 return Ok(Response::new(ModelMetadataResponse {
442 name: entry.name,
443 versions: vec!["1".to_string()],
444 platform: "dynamo".to_string(),
445 inputs: vec![
446 inference::model_metadata_response::TensorMetadata {
447 name: "text_input".to_string(),
448 datatype: "BYTES".to_string(),
449 shape: vec![1],
450 },
451 inference::model_metadata_response::TensorMetadata {
452 name: "streaming".to_string(),
453 datatype: "BOOL".to_string(),
454 shape: vec![1],
455 },
456 ],
457 outputs: vec![
458 inference::model_metadata_response::TensorMetadata {
459 name: "text_output".to_string(),
460 datatype: "BYTES".to_string(),
461 shape: vec![-1],
462 },
463 inference::model_metadata_response::TensorMetadata {
464 name: "finish_reason".to_string(),
465 datatype: "BYTES".to_string(),
466 shape: vec![-1],
467 },
468 ],
469 }));
470 }
471 }
472 Err(Status::not_found(format!(
473 "Model '{}' not found",
474 request_model_name
475 )))
476 }
477
478 async fn model_config(
479 &self,
480 request: Request<ModelConfigRequest>,
481 ) -> Result<Response<ModelConfigResponse>, Status> {
482 let entries = self.state.manager().get_model_entries();
483 let request_model_name = &request.into_inner().name;
484 if let Some(entry) = entries
485 .into_iter()
486 .find(|entry| request_model_name == &entry.name)
487 {
488 if entry.model_type.supports_tensor() {
489 if let Some(config) = entry.runtime_config.as_ref()
490 && let Some(tensor_model_config) = config.tensor_model_config.as_ref()
491 {
492 let model_config = ModelConfig {
493 name: tensor_model_config.name.clone(),
494 platform: "dynamo".to_string(),
495 backend: "dynamo".to_string(),
496 input: tensor_model_config
497 .inputs
498 .iter()
499 .map(|input| ModelInput {
500 name: input.name.clone(),
501 data_type: input.data_type.to_kserve(),
502 dims: input.shape.clone(),
503 ..Default::default()
504 })
505 .collect(),
506 output: tensor_model_config
507 .outputs
508 .iter()
509 .map(|output| ModelOutput {
510 name: output.name.clone(),
511 data_type: output.data_type.to_kserve(),
512 dims: output.shape.clone(),
513 ..Default::default()
514 })
515 .collect(),
516 ..Default::default()
517 };
518 return Ok(Response::new(ModelConfigResponse {
519 config: Some(model_config.clone()),
520 }));
521 }
522 Err(Status::invalid_argument(format!(
523 "Model '{}' has type Tensor but no model config is provided",
524 request_model_name
525 )))?
526 } else if entry.model_type.supports_completions() {
527 let config = ModelConfig {
528 name: entry.name,
529 platform: "dynamo".to_string(),
530 backend: "dynamo".to_string(),
531 input: vec![
532 ModelInput {
533 name: "text_input".to_string(),
534 data_type: DataType::TypeString as i32,
535 dims: vec![1],
536 ..Default::default()
537 },
538 ModelInput {
539 name: "streaming".to_string(),
540 data_type: DataType::TypeBool as i32,
541 dims: vec![1],
542 optional: true,
543 ..Default::default()
544 },
545 ],
546 output: vec![
547 ModelOutput {
548 name: "text_output".to_string(),
549 data_type: DataType::TypeString as i32,
550 dims: vec![-1],
551 ..Default::default()
552 },
553 ModelOutput {
554 name: "finish_reason".to_string(),
555 data_type: DataType::TypeString as i32,
556 dims: vec![-1],
557 ..Default::default()
558 },
559 ],
560 ..Default::default()
561 };
562 return Ok(Response::new(ModelConfigResponse {
563 config: Some(config),
564 }));
565 }
566 }
567 Err(Status::not_found(format!(
568 "Model '{}' not found",
569 request_model_name
570 )))
571 }
572}