dynamo_llm/grpc/service/
openai.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use dynamo_runtime::{
5    engine::AsyncEngineContext,
6    pipeline::{AsyncEngineContextProvider, Context},
7    protocols::annotated::AnnotationsProvider,
8};
9use futures::{Stream, StreamExt, stream};
10use std::sync::Arc;
11
12use crate::protocols::openai::completions::{
13    NvCreateCompletionRequest, NvCreateCompletionResponse,
14};
15use crate::types::Annotated;
16
17use super::kserve;
18use super::kserve::inference;
19
20// [gluo NOTE] These are common utilities that should be shared between frontends
21use crate::http::service::{
22    disconnect::{ConnectionHandle, create_connection_monitor},
23    metrics::{Endpoint, InflightGuard, process_response_and_observe_metrics},
24};
25use dynamo_async_openai::types::{CompletionFinishReason, CreateCompletionRequest, Prompt};
26
27use tonic::Status;
28
29/// Dynamo Annotation for the request ID
30pub const ANNOTATION_REQUEST_ID: &str = "request_id";
31
32// [gluo NOTE] strip down version of lib/llm/src/http/service/openai.rs
33// dupliating it here as the original file has coupling with HTTP objects.
34
35/// OpenAI Completions Request Handler
36///
37/// This method will handle the incoming request for the `/v1/completions endpoint`. The endpoint is a "source"
38/// for an [`super::OpenAICompletionsStreamingEngine`] and will return a stream of
39/// responses which will be forward to the client.
40///
41/// Note: For all requests, streaming or non-streaming, we always call the engine with streaming enabled. For
42/// non-streaming requests, we will fold the stream into a single response as part of this handler.
43pub async fn completion_response_stream(
44    state: Arc<kserve::State>,
45    request: NvCreateCompletionRequest,
46) -> Result<impl Stream<Item = Annotated<NvCreateCompletionResponse>>, Status> {
47    // create the context for the request
48    // [WIP] from request id.
49    let request_id = get_or_create_request_id(request.inner.user.as_deref());
50    let request = Context::with_id(request, request_id.clone());
51    let context = request.context();
52
53    // create the connection handles
54    let (mut connection_handle, stream_handle) =
55        create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
56
57    let streaming = request.inner.stream.unwrap_or(false);
58    // update the request to always stream
59    let request = request.map(|mut req| {
60        req.inner.stream = Some(true);
61        req
62    });
63
64    // todo - make the protocols be optional for model name
65    // todo - when optional, if none, apply a default
66    let model = &request.inner.model;
67
68    // todo - error handling should be more robust
69    let engine = state
70        .manager()
71        .get_completions_engine(model)
72        .map_err(|_| Status::not_found("model not found"))?;
73
74    let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
75
76    let inflight_guard =
77        state
78            .metrics_clone()
79            .create_inflight_guard(model, Endpoint::Completions, streaming);
80
81    let mut response_collector = state.metrics_clone().create_response_collector(model);
82
83    // prepare to process any annotations
84    let annotations = request.annotations();
85
86    // issue the generate call on the engine
87    let stream = engine
88        .generate(request)
89        .await
90        .map_err(|e| Status::internal(format!("Failed to generate completions: {}", e)))?;
91
92    // capture the context to cancel the stream if the client disconnects
93    let ctx = stream.context();
94
95    // prepare any requested annotations
96    let annotations = annotations.map_or(Vec::new(), |annotations| {
97        annotations
98            .iter()
99            .filter_map(|annotation| {
100                if annotation == ANNOTATION_REQUEST_ID {
101                    Annotated::<NvCreateCompletionResponse>::from_annotation(
102                        ANNOTATION_REQUEST_ID,
103                        &request_id,
104                    )
105                    .ok()
106                } else {
107                    None
108                }
109            })
110            .collect::<Vec<_>>()
111    });
112
113    // apply any annotations to the front of the stream
114    let stream = stream::iter(annotations).chain(stream);
115
116    // Tap on the stream to collect response metrics and handle http_queue_guard
117    let mut http_queue_guard = Some(http_queue_guard);
118    let stream = stream.inspect(move |response| {
119        // Calls observe_response() on each token - drops http_queue_guard on first token
120        process_response_and_observe_metrics(
121            response,
122            &mut response_collector,
123            &mut http_queue_guard,
124        );
125    });
126
127    let stream = grpc_monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
128
129    // if we got here, then we will return a response and the potentially long running task has completed successfully
130    // without need to be cancelled.
131    connection_handle.disarm();
132
133    Ok(stream)
134}
135
136/// This method will consume an AsyncEngineStream and monitor for disconnects or context cancellation.
137/// This is gRPC variant of `monitor_for_disconnects` as that implementation has SSE specific handling.
138/// Should decouple and reuse `monitor_for_disconnects`
139///
140/// Uses `tokio::select!` to choose between receiving responses from the source stream or detecting when
141/// the context is stopped. If the context is stopped, we break the stream. If the source stream ends
142/// naturally, we mark the request as successful and send the final `[DONE]` event.
143pub fn grpc_monitor_for_disconnects<T>(
144    stream: impl Stream<Item = Annotated<T>>,
145    context: Arc<dyn AsyncEngineContext>,
146    mut inflight_guard: InflightGuard,
147    mut stream_handle: ConnectionHandle,
148) -> impl Stream<Item = Annotated<T>> {
149    stream_handle.arm();
150    async_stream::stream! {
151        tokio::pin!(stream);
152        loop {
153            tokio::select! {
154                event = stream.next() => {
155                    match event {
156                        Some(response) => {
157                            yield response;
158                        }
159                        None => {
160                            // Stream ended normally
161                            inflight_guard.mark_ok();
162                            stream_handle.disarm();
163                            break;
164                        }
165                    }
166                }
167                _ = context.stopped() => {
168                    tracing::trace!("Context stopped; breaking stream");
169                    break;
170                }
171            }
172        }
173    }
174}
175
176/// Get the request ID from a primary source, or lastly create a new one if not present
177// TODO: Similar function exists in lib/llm/src/http/service/openai.rs but with different signature and more complex logic (distributed tracing, headers)
178fn get_or_create_request_id(primary: Option<&str>) -> String {
179    // Try to get the request ID from the primary source
180    if let Some(primary) = primary
181        && let Ok(uuid) = uuid::Uuid::parse_str(primary)
182    {
183        return uuid.to_string();
184    }
185
186    // Try to parse the request ID as a UUID, or generate a new one if missing/invalid
187    let uuid = uuid::Uuid::new_v4();
188    uuid.to_string()
189}
190
191impl TryFrom<inference::ModelInferRequest> for NvCreateCompletionRequest {
192    type Error = Status;
193
194    fn try_from(request: inference::ModelInferRequest) -> Result<Self, Self::Error> {
195        // Protocol requires if `raw_input_contents` is used to hold input data,
196        // it must be used for all inputs.
197        if !request.raw_input_contents.is_empty()
198            && request.inputs.len() != request.raw_input_contents.len()
199        {
200            return Err(Status::invalid_argument(
201                "`raw_input_contents` must be used for all inputs",
202            ));
203        }
204
205        // iterate through inputs
206        let mut text_input = None;
207        let mut stream = false;
208        for (idx, input) in request.inputs.iter().enumerate() {
209            match input.name.as_str() {
210                "text_input" => {
211                    if input.datatype != "BYTES" {
212                        return Err(Status::invalid_argument(format!(
213                            "Expected 'text_input' to be of type BYTES for string input, got {:?}",
214                            input.datatype
215                        )));
216                    }
217                    if input.shape != vec![1] && input.shape != vec![1, 1] {
218                        return Err(Status::invalid_argument(format!(
219                            "Expected 'text_input' to have shape [1], got {:?}",
220                            input.shape
221                        )));
222                    }
223                    match &input.contents {
224                        Some(content) => {
225                            let bytes = content.bytes_contents.first().ok_or_else(|| {
226                                Status::invalid_argument(
227                                    "'text_input' must contain exactly one element",
228                                )
229                            })?;
230                            text_input = Some(String::from_utf8_lossy(bytes).to_string());
231                        }
232                        None => {
233                            let raw_input =
234                                request.raw_input_contents.get(idx).ok_or_else(|| {
235                                    Status::invalid_argument("Missing raw input for 'text_input'")
236                                })?;
237                            if raw_input.len() < 4 {
238                                return Err(Status::invalid_argument(
239                                    "'text_input' raw input must be length-prefixed (>= 4 bytes)",
240                                ));
241                            }
242                            // We restrict the 'text_input' only contain one element, only need to
243                            // parse the first element. Skip first four bytes that is used to store
244                            // the length of the input.
245                            text_input = Some(String::from_utf8_lossy(&raw_input[4..]).to_string());
246                        }
247                    }
248                }
249                "streaming" | "stream" => {
250                    if input.datatype != "BOOL" {
251                        return Err(Status::invalid_argument(format!(
252                            "Expected '{}' to be of type BOOL, got {:?}",
253                            input.name, input.datatype
254                        )));
255                    }
256                    if input.shape != vec![1] {
257                        return Err(Status::invalid_argument(format!(
258                            "Expected 'stream' to have shape [1], got {:?}",
259                            input.shape
260                        )));
261                    }
262                    match &input.contents {
263                        Some(content) => {
264                            stream = *content.bool_contents.first().ok_or_else(|| {
265                                Status::invalid_argument(
266                                    "'stream' must contain exactly one element",
267                                )
268                            })?;
269                        }
270                        None => {
271                            let raw_input =
272                                request.raw_input_contents.get(idx).ok_or_else(|| {
273                                    Status::invalid_argument("Missing raw input for 'stream'")
274                                })?;
275                            if raw_input.is_empty() {
276                                return Err(Status::invalid_argument(
277                                    "'stream' raw input must contain at least one byte",
278                                ));
279                            }
280                            stream = raw_input[0] != 0;
281                        }
282                    }
283                }
284                _ => {
285                    return Err(Status::invalid_argument(format!(
286                        "Invalid input name: {}, supported inputs are 'text_input', 'stream'",
287                        input.name
288                    )));
289                }
290            }
291        }
292
293        // return error if text_input is None
294        let text_input = match text_input {
295            Some(input) => input,
296            None => {
297                return Err(Status::invalid_argument(
298                    "Missing required input: 'text_input'",
299                ));
300            }
301        };
302
303        Ok(NvCreateCompletionRequest {
304            inner: CreateCompletionRequest {
305                model: request.model_name,
306                prompt: Prompt::String(text_input),
307                stream: Some(stream),
308                user: if request.id.is_empty() {
309                    None
310                } else {
311                    Some(request.id.clone())
312                },
313                ..Default::default()
314            },
315            common: Default::default(),
316            nvext: None,
317        })
318    }
319}
320
321impl TryFrom<NvCreateCompletionResponse> for inference::ModelInferResponse {
322    type Error = anyhow::Error;
323
324    fn try_from(response: NvCreateCompletionResponse) -> Result<Self, Self::Error> {
325        let mut outputs = vec![];
326        let mut text_output = vec![];
327        let mut finish_reason = vec![];
328        for choice in &response.inner.choices {
329            text_output.push(choice.text.clone());
330            let reason_str = match choice.finish_reason.as_ref() {
331                Some(CompletionFinishReason::Stop) => "stop",
332                Some(CompletionFinishReason::Length) => "length",
333                Some(CompletionFinishReason::ContentFilter) => "content_filter",
334                None => "",
335            };
336            finish_reason.push(reason_str.to_string());
337        }
338        outputs.push(inference::model_infer_response::InferOutputTensor {
339            name: "text_output".to_string(),
340            datatype: "BYTES".to_string(),
341            shape: vec![text_output.len() as i64],
342            contents: Some(inference::InferTensorContents {
343                bytes_contents: text_output
344                    .into_iter()
345                    .map(|text| text.as_bytes().to_vec())
346                    .collect(),
347                ..Default::default()
348            }),
349            ..Default::default()
350        });
351        outputs.push(inference::model_infer_response::InferOutputTensor {
352            name: "finish_reason".to_string(),
353            datatype: "BYTES".to_string(),
354            shape: vec![finish_reason.len() as i64],
355            contents: Some(inference::InferTensorContents {
356                bytes_contents: finish_reason
357                    .into_iter()
358                    .map(|text| text.as_bytes().to_vec())
359                    .collect(),
360                ..Default::default()
361            }),
362            ..Default::default()
363        });
364
365        Ok(inference::ModelInferResponse {
366            model_name: response.inner.model,
367            model_version: "1".to_string(),
368            id: response.inner.id,
369            outputs,
370            parameters: ::std::collections::HashMap::<String, inference::InferParameter>::new(),
371            raw_output_contents: vec![],
372        })
373    }
374}
375
376impl TryFrom<NvCreateCompletionResponse> for inference::ModelStreamInferResponse {
377    type Error = anyhow::Error;
378
379    fn try_from(response: NvCreateCompletionResponse) -> Result<Self, Self::Error> {
380        match inference::ModelInferResponse::try_from(response) {
381            Ok(response) => Ok(inference::ModelStreamInferResponse {
382                infer_response: Some(response),
383                ..Default::default()
384            }),
385            Err(e) => Ok(inference::ModelStreamInferResponse {
386                infer_response: None,
387                error_message: format!("Failed to convert response: {}", e),
388            }),
389        }
390    }
391}