use dynamo_runtime::{
engine::AsyncEngineContext,
pipeline::{AsyncEngineContextProvider, Context},
protocols::annotated::AnnotationsProvider,
};
use futures::{Stream, StreamExt, stream};
use std::sync::Arc;
use crate::protocols::openai::ParsingOptions;
use crate::protocols::openai::completions::{
NvCreateCompletionRequest, NvCreateCompletionResponse,
};
use crate::types::Annotated;
use super::kserve;
use super::kserve::inference;
use crate::http::service::{
disconnect::{ConnectionHandle, create_connection_monitor},
metrics::{Endpoint, InflightGuard, process_response_and_observe_metrics},
};
use dynamo_async_openai::types::{CompletionFinishReason, CreateCompletionRequest, Prompt};
use tonic::Status;
pub const ANNOTATION_REQUEST_ID: &str = "request_id";
pub async fn completion_response_stream(
state: Arc<kserve::State>,
request: NvCreateCompletionRequest,
) -> Result<
(
impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
ParsingOptions,
),
Status,
> {
let request_id = get_or_create_request_id(request.inner.user.as_deref());
let request = Context::with_id(request, request_id.clone());
let context = request.context();
let (mut connection_handle, stream_handle) =
create_connection_monitor(context.clone(), Some(state.metrics_clone())).await;
let streaming = request.inner.stream.unwrap_or(false);
let request = request.map(|mut req| {
req.inner.stream = Some(true);
req
});
let model = &request.inner.model;
let (engine, parsing_options) = state
.manager()
.get_completions_engine_with_parsing(model)
.map_err(|_| Status::not_found("model not found"))?;
let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
let inflight_guard =
state
.metrics_clone()
.create_inflight_guard(model, Endpoint::Completions, streaming);
let mut response_collector = state.metrics_clone().create_response_collector(model);
let annotations = request.annotations();
let stream = engine
.generate(request)
.await
.map_err(|e| Status::internal(format!("Failed to generate completions: {}", e)))?;
let ctx = stream.context();
let annotations = annotations.map_or(Vec::new(), |annotations| {
annotations
.iter()
.filter_map(|annotation| {
if annotation == ANNOTATION_REQUEST_ID {
Annotated::<NvCreateCompletionResponse>::from_annotation(
ANNOTATION_REQUEST_ID,
&request_id,
)
.ok()
} else {
None
}
})
.collect::<Vec<_>>()
});
let stream = stream::iter(annotations).chain(stream);
let mut http_queue_guard = Some(http_queue_guard);
let stream = stream.inspect(move |response| {
process_response_and_observe_metrics(
response,
&mut response_collector,
&mut http_queue_guard,
);
});
let stream = grpc_monitor_for_disconnects(stream, ctx, inflight_guard, stream_handle);
connection_handle.disarm();
Ok((stream, parsing_options))
}
pub fn grpc_monitor_for_disconnects<T>(
stream: impl Stream<Item = Annotated<T>>,
context: Arc<dyn AsyncEngineContext>,
mut inflight_guard: InflightGuard,
mut stream_handle: ConnectionHandle,
) -> impl Stream<Item = Annotated<T>> {
stream_handle.arm();
async_stream::stream! {
tokio::pin!(stream);
loop {
tokio::select! {
event = stream.next() => {
match event {
Some(response) => {
yield response;
}
None => {
inflight_guard.mark_ok();
stream_handle.disarm();
break;
}
}
}
_ = context.stopped() => {
tracing::trace!("Context stopped; breaking stream");
break;
}
}
}
}
}
fn get_or_create_request_id(primary: Option<&str>) -> String {
if let Some(primary) = primary
&& let Ok(uuid) = uuid::Uuid::parse_str(primary)
{
return uuid.to_string();
}
let uuid = uuid::Uuid::new_v4();
uuid.to_string()
}
impl TryFrom<inference::ModelInferRequest> for NvCreateCompletionRequest {
type Error = Status;
fn try_from(request: inference::ModelInferRequest) -> Result<Self, Self::Error> {
if !request.raw_input_contents.is_empty()
&& request.inputs.len() != request.raw_input_contents.len()
{
return Err(Status::invalid_argument(
"`raw_input_contents` must be used for all inputs",
));
}
let mut text_input = None;
let mut stream = false;
for (idx, input) in request.inputs.iter().enumerate() {
match input.name.as_str() {
"text_input" => {
if input.datatype != "BYTES" {
return Err(Status::invalid_argument(format!(
"Expected 'text_input' to be of type BYTES for string input, got {:?}",
input.datatype
)));
}
if input.shape != vec![1] && input.shape != vec![1, 1] {
return Err(Status::invalid_argument(format!(
"Expected 'text_input' to have shape [1], got {:?}",
input.shape
)));
}
match &input.contents {
Some(content) => {
let bytes = content.bytes_contents.first().ok_or_else(|| {
Status::invalid_argument(
"'text_input' must contain exactly one element",
)
})?;
text_input = Some(String::from_utf8_lossy(bytes).to_string());
}
None => {
let raw_input =
request.raw_input_contents.get(idx).ok_or_else(|| {
Status::invalid_argument("Missing raw input for 'text_input'")
})?;
if raw_input.len() < 4 {
return Err(Status::invalid_argument(
"'text_input' raw input must be length-prefixed (>= 4 bytes)",
));
}
text_input = Some(String::from_utf8_lossy(&raw_input[4..]).to_string());
}
}
}
"streaming" | "stream" => {
if input.datatype != "BOOL" {
return Err(Status::invalid_argument(format!(
"Expected '{}' to be of type BOOL, got {:?}",
input.name, input.datatype
)));
}
if input.shape != vec![1] {
return Err(Status::invalid_argument(format!(
"Expected 'stream' to have shape [1], got {:?}",
input.shape
)));
}
match &input.contents {
Some(content) => {
stream = *content.bool_contents.first().ok_or_else(|| {
Status::invalid_argument(
"'stream' must contain exactly one element",
)
})?;
}
None => {
let raw_input =
request.raw_input_contents.get(idx).ok_or_else(|| {
Status::invalid_argument("Missing raw input for 'stream'")
})?;
if raw_input.is_empty() {
return Err(Status::invalid_argument(
"'stream' raw input must contain at least one byte",
));
}
stream = raw_input[0] != 0;
}
}
}
_ => {
return Err(Status::invalid_argument(format!(
"Invalid input name: {}, supported inputs are 'text_input', 'stream'",
input.name
)));
}
}
}
let text_input = match text_input {
Some(input) => input,
None => {
return Err(Status::invalid_argument(
"Missing required input: 'text_input'",
));
}
};
Ok(NvCreateCompletionRequest {
inner: CreateCompletionRequest {
model: request.model_name,
prompt: Prompt::String(text_input),
stream: Some(stream),
user: if request.id.is_empty() {
None
} else {
Some(request.id.clone())
},
..Default::default()
},
common: Default::default(),
nvext: None,
metadata: None,
unsupported_fields: Default::default(),
})
}
}
impl TryFrom<NvCreateCompletionResponse> for inference::ModelInferResponse {
type Error = anyhow::Error;
fn try_from(response: NvCreateCompletionResponse) -> Result<Self, Self::Error> {
let mut outputs = vec![];
let mut text_output = vec![];
let mut finish_reason = vec![];
for choice in &response.inner.choices {
text_output.push(choice.text.clone());
let reason_str = match choice.finish_reason.as_ref() {
Some(CompletionFinishReason::Stop) => "stop",
Some(CompletionFinishReason::Length) => "length",
Some(CompletionFinishReason::ContentFilter) => "content_filter",
None => "",
};
finish_reason.push(reason_str.to_string());
}
outputs.push(inference::model_infer_response::InferOutputTensor {
name: "text_output".to_string(),
datatype: "BYTES".to_string(),
shape: vec![text_output.len() as i64],
contents: Some(inference::InferTensorContents {
bytes_contents: text_output
.into_iter()
.map(|text| text.as_bytes().to_vec())
.collect(),
..Default::default()
}),
..Default::default()
});
outputs.push(inference::model_infer_response::InferOutputTensor {
name: "finish_reason".to_string(),
datatype: "BYTES".to_string(),
shape: vec![finish_reason.len() as i64],
contents: Some(inference::InferTensorContents {
bytes_contents: finish_reason
.into_iter()
.map(|text| text.as_bytes().to_vec())
.collect(),
..Default::default()
}),
..Default::default()
});
Ok(inference::ModelInferResponse {
model_name: response.inner.model,
model_version: "1".to_string(),
id: response.inner.id,
outputs,
parameters: ::std::collections::HashMap::<String, inference::InferParameter>::new(),
raw_output_contents: vec![],
})
}
}
impl TryFrom<NvCreateCompletionResponse> for inference::ModelStreamInferResponse {
type Error = anyhow::Error;
fn try_from(response: NvCreateCompletionResponse) -> Result<Self, Self::Error> {
match inference::ModelInferResponse::try_from(response) {
Ok(response) => Ok(inference::ModelStreamInferResponse {
infer_response: Some(response),
..Default::default()
}),
Err(e) => Ok(inference::ModelStreamInferResponse {
infer_response: None,
error_message: format!("Failed to convert response: {}", e),
}),
}
}
}