1use dynamo_runtime::{
5 engine::AsyncEngineContext,
6 pipeline::{AsyncEngineContextProvider, Context},
7 protocols::annotated::AnnotationsProvider,
8};
9use futures::{Stream, StreamExt, stream};
10use std::sync::Arc;
11
12use crate::protocols::openai::completions::{
13 NvCreateCompletionRequest, NvCreateCompletionResponse,
14};
15use crate::types::Annotated;
16
17use super::kserve;
18use super::kserve::inference;
19
20use crate::http::service::{
22 disconnect::{ConnectionHandle, create_connection_monitor},
23 metrics::{Endpoint, InflightGuard, process_response_and_observe_metrics},
24};
25use dynamo_async_openai::types::{CompletionFinishReason, CreateCompletionRequest, Prompt};
26
27use tonic::Status;
28
29pub const ANNOTATION_REQUEST_ID: &str = "request_id";
31
32pub async fn completion_response_stream(
44 state: Arc<kserve::State>,
45 request: NvCreateCompletionRequest,
46) -> Result<impl Stream<Item = Annotated<NvCreateCompletionResponse>>, Status> {
47 let request_id = get_or_create_request_id(request.inner.user.as_deref());
50 let request = Context::with_id(request, request_id.clone());
51 let context = request.context();
52
53 let (mut connection_handle, stream_handle) =
55 create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
56
57 let streaming = request.inner.stream.unwrap_or(false);
58 let request = request.map(|mut req| {
60 req.inner.stream = Some(true);
61 req
62 });
63
64 let model = &request.inner.model;
67
68 let engine = state
70 .manager()
71 .get_completions_engine(model)
72 .map_err(|_| Status::not_found("model not found"))?;
73
74 let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
75
76 let inflight_guard =
77 state
78 .metrics_clone()
79 .create_inflight_guard(model, Endpoint::Completions, streaming);
80
81 let mut response_collector = state.metrics_clone().create_response_collector(model);
82
83 let annotations = request.annotations();
85
86 let stream = engine
88 .generate(request)
89 .await
90 .map_err(|e| Status::internal(format!("Failed to generate completions: {}", e)))?;
91
92 let ctx = stream.context();
94
95 let annotations = annotations.map_or(Vec::new(), |annotations| {
97 annotations
98 .iter()
99 .filter_map(|annotation| {
100 if annotation == ANNOTATION_REQUEST_ID {
101 Annotated::<NvCreateCompletionResponse>::from_annotation(
102 ANNOTATION_REQUEST_ID,
103 &request_id,
104 )
105 .ok()
106 } else {
107 None
108 }
109 })
110 .collect::<Vec<_>>()
111 });
112
113 let stream = stream::iter(annotations).chain(stream);
115
116 let mut http_queue_guard = Some(http_queue_guard);
118 let stream = stream.inspect(move |response| {
119 process_response_and_observe_metrics(
121 response,
122 &mut response_collector,
123 &mut http_queue_guard,
124 );
125 });
126
127 let stream = grpc_monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
128
129 connection_handle.disarm();
132
133 Ok(stream)
134}
135
136pub fn grpc_monitor_for_disconnects<T>(
144 stream: impl Stream<Item = Annotated<T>>,
145 context: Arc<dyn AsyncEngineContext>,
146 mut inflight_guard: InflightGuard,
147 mut stream_handle: ConnectionHandle,
148) -> impl Stream<Item = Annotated<T>> {
149 stream_handle.arm();
150 async_stream::stream! {
151 tokio::pin!(stream);
152 loop {
153 tokio::select! {
154 event = stream.next() => {
155 match event {
156 Some(response) => {
157 yield response;
158 }
159 None => {
160 inflight_guard.mark_ok();
162 stream_handle.disarm();
163 break;
164 }
165 }
166 }
167 _ = context.stopped() => {
168 tracing::trace!("Context stopped; breaking stream");
169 break;
170 }
171 }
172 }
173 }
174}
175
176fn get_or_create_request_id(primary: Option<&str>) -> String {
179 if let Some(primary) = primary
181 && let Ok(uuid) = uuid::Uuid::parse_str(primary)
182 {
183 return uuid.to_string();
184 }
185
186 let uuid = uuid::Uuid::new_v4();
188 uuid.to_string()
189}
190
191impl TryFrom<inference::ModelInferRequest> for NvCreateCompletionRequest {
192 type Error = Status;
193
194 fn try_from(request: inference::ModelInferRequest) -> Result<Self, Self::Error> {
195 if !request.raw_input_contents.is_empty()
198 && request.inputs.len() != request.raw_input_contents.len()
199 {
200 return Err(Status::invalid_argument(
201 "`raw_input_contents` must be used for all inputs",
202 ));
203 }
204
205 let mut text_input = None;
207 let mut stream = false;
208 for (idx, input) in request.inputs.iter().enumerate() {
209 match input.name.as_str() {
210 "text_input" => {
211 if input.datatype != "BYTES" {
212 return Err(Status::invalid_argument(format!(
213 "Expected 'text_input' to be of type BYTES for string input, got {:?}",
214 input.datatype
215 )));
216 }
217 if input.shape != vec![1] && input.shape != vec![1, 1] {
218 return Err(Status::invalid_argument(format!(
219 "Expected 'text_input' to have shape [1], got {:?}",
220 input.shape
221 )));
222 }
223 match &input.contents {
224 Some(content) => {
225 let bytes = content.bytes_contents.first().ok_or_else(|| {
226 Status::invalid_argument(
227 "'text_input' must contain exactly one element",
228 )
229 })?;
230 text_input = Some(String::from_utf8_lossy(bytes).to_string());
231 }
232 None => {
233 let raw_input =
234 request.raw_input_contents.get(idx).ok_or_else(|| {
235 Status::invalid_argument("Missing raw input for 'text_input'")
236 })?;
237 if raw_input.len() < 4 {
238 return Err(Status::invalid_argument(
239 "'text_input' raw input must be length-prefixed (>= 4 bytes)",
240 ));
241 }
242 text_input = Some(String::from_utf8_lossy(&raw_input[4..]).to_string());
246 }
247 }
248 }
249 "streaming" | "stream" => {
250 if input.datatype != "BOOL" {
251 return Err(Status::invalid_argument(format!(
252 "Expected '{}' to be of type BOOL, got {:?}",
253 input.name, input.datatype
254 )));
255 }
256 if input.shape != vec![1] {
257 return Err(Status::invalid_argument(format!(
258 "Expected 'stream' to have shape [1], got {:?}",
259 input.shape
260 )));
261 }
262 match &input.contents {
263 Some(content) => {
264 stream = *content.bool_contents.first().ok_or_else(|| {
265 Status::invalid_argument(
266 "'stream' must contain exactly one element",
267 )
268 })?;
269 }
270 None => {
271 let raw_input =
272 request.raw_input_contents.get(idx).ok_or_else(|| {
273 Status::invalid_argument("Missing raw input for 'stream'")
274 })?;
275 if raw_input.is_empty() {
276 return Err(Status::invalid_argument(
277 "'stream' raw input must contain at least one byte",
278 ));
279 }
280 stream = raw_input[0] != 0;
281 }
282 }
283 }
284 _ => {
285 return Err(Status::invalid_argument(format!(
286 "Invalid input name: {}, supported inputs are 'text_input', 'stream'",
287 input.name
288 )));
289 }
290 }
291 }
292
293 let text_input = match text_input {
295 Some(input) => input,
296 None => {
297 return Err(Status::invalid_argument(
298 "Missing required input: 'text_input'",
299 ));
300 }
301 };
302
303 Ok(NvCreateCompletionRequest {
304 inner: CreateCompletionRequest {
305 model: request.model_name,
306 prompt: Prompt::String(text_input),
307 stream: Some(stream),
308 user: if request.id.is_empty() {
309 None
310 } else {
311 Some(request.id.clone())
312 },
313 ..Default::default()
314 },
315 common: Default::default(),
316 nvext: None,
317 })
318 }
319}
320
321impl TryFrom<NvCreateCompletionResponse> for inference::ModelInferResponse {
322 type Error = anyhow::Error;
323
324 fn try_from(response: NvCreateCompletionResponse) -> Result<Self, Self::Error> {
325 let mut outputs = vec![];
326 let mut text_output = vec![];
327 let mut finish_reason = vec![];
328 for choice in &response.inner.choices {
329 text_output.push(choice.text.clone());
330 let reason_str = match choice.finish_reason.as_ref() {
331 Some(CompletionFinishReason::Stop) => "stop",
332 Some(CompletionFinishReason::Length) => "length",
333 Some(CompletionFinishReason::ContentFilter) => "content_filter",
334 None => "",
335 };
336 finish_reason.push(reason_str.to_string());
337 }
338 outputs.push(inference::model_infer_response::InferOutputTensor {
339 name: "text_output".to_string(),
340 datatype: "BYTES".to_string(),
341 shape: vec![text_output.len() as i64],
342 contents: Some(inference::InferTensorContents {
343 bytes_contents: text_output
344 .into_iter()
345 .map(|text| text.as_bytes().to_vec())
346 .collect(),
347 ..Default::default()
348 }),
349 ..Default::default()
350 });
351 outputs.push(inference::model_infer_response::InferOutputTensor {
352 name: "finish_reason".to_string(),
353 datatype: "BYTES".to_string(),
354 shape: vec![finish_reason.len() as i64],
355 contents: Some(inference::InferTensorContents {
356 bytes_contents: finish_reason
357 .into_iter()
358 .map(|text| text.as_bytes().to_vec())
359 .collect(),
360 ..Default::default()
361 }),
362 ..Default::default()
363 });
364
365 Ok(inference::ModelInferResponse {
366 model_name: response.inner.model,
367 model_version: "1".to_string(),
368 id: response.inner.id,
369 outputs,
370 parameters: ::std::collections::HashMap::<String, inference::InferParameter>::new(),
371 raw_output_contents: vec![],
372 })
373 }
374}
375
376impl TryFrom<NvCreateCompletionResponse> for inference::ModelStreamInferResponse {
377 type Error = anyhow::Error;
378
379 fn try_from(response: NvCreateCompletionResponse) -> Result<Self, Self::Error> {
380 match inference::ModelInferResponse::try_from(response) {
381 Ok(response) => Ok(inference::ModelStreamInferResponse {
382 infer_response: Some(response),
383 ..Default::default()
384 }),
385 Err(e) => Ok(inference::ModelStreamInferResponse {
386 infer_response: None,
387 error_message: format!("Failed to convert response: {}", e),
388 }),
389 }
390 }
391}