use std::collections::HashMap;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use validator::Validate;
mod aggregator;
mod delta;
pub use aggregator::DeltaAggregator;
pub use delta::DeltaGenerator;
use super::{
common::{self, SamplingOptionsProvider, StopConditionsProvider},
nvext::{NvExt, NvExtProvider},
CompletionUsage, ContentProvider, OpenAISamplingOptionsProvider, OpenAIStopConditionsProvider,
};
use dynamo_runtime::protocols::annotated::AnnotationsProvider;
#[derive(Serialize, Deserialize, Validate, Debug, Clone)]
pub struct CompletionRequest {
#[serde(flatten)]
pub inner: async_openai::types::CreateCompletionRequest,
#[serde(skip_serializing_if = "Option::is_none")]
pub nvext: Option<NvExt>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct CompletionResponse {
pub id: String,
pub choices: Vec<CompletionChoice>,
pub created: u64,
pub model: String,
pub object: String,
pub usage: Option<CompletionUsage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize, Builder)]
pub struct CompletionChoice {
#[builder(setter(into))]
pub text: String,
#[builder(default = "0")]
pub index: u64,
#[builder(default, setter(into, strip_option))]
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub logprobs: Option<LogprobResult>,
}
impl ContentProvider for CompletionChoice {
fn content(&self) -> String {
self.text.clone()
}
}
impl CompletionChoice {
pub fn builder() -> CompletionChoiceBuilder {
CompletionChoiceBuilder::default()
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct LogprobResult {
pub tokens: Vec<String>,
pub token_logprobs: Vec<f32>,
pub top_logprobs: Vec<HashMap<String, f32>>,
pub text_offset: Vec<i32>,
}
pub fn prompt_to_string(prompt: &async_openai::types::Prompt) -> String {
match prompt {
async_openai::types::Prompt::String(s) => s.clone(),
async_openai::types::Prompt::StringArray(arr) => arr.join(" "), async_openai::types::Prompt::IntegerArray(arr) => arr
.iter()
.map(|&num| num.to_string())
.collect::<Vec<_>>()
.join(" "),
async_openai::types::Prompt::ArrayOfIntegerArray(arr) => arr
.iter()
.map(|inner| {
inner
.iter()
.map(|&num| num.to_string())
.collect::<Vec<_>>()
.join(" ")
})
.collect::<Vec<_>>()
.join(" | "), }
}
impl NvExtProvider for CompletionRequest {
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
fn raw_prompt(&self) -> Option<String> {
if let Some(nvext) = self.nvext.as_ref() {
if let Some(use_raw_prompt) = nvext.use_raw_prompt {
if use_raw_prompt {
return Some(prompt_to_string(&self.inner.prompt));
}
}
}
None
}
}
impl AnnotationsProvider for CompletionRequest {
fn annotations(&self) -> Option<Vec<String>> {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.clone())
}
fn has_annotation(&self, annotation: &str) -> bool {
self.nvext
.as_ref()
.and_then(|nvext| nvext.annotations.as_ref())
.map(|annotations| annotations.contains(&annotation.to_string()))
.unwrap_or(false)
}
}
impl OpenAISamplingOptionsProvider for CompletionRequest {
fn get_temperature(&self) -> Option<f32> {
self.inner.temperature
}
fn get_top_p(&self) -> Option<f32> {
self.inner.top_p
}
fn get_frequency_penalty(&self) -> Option<f32> {
self.inner.frequency_penalty
}
fn get_presence_penalty(&self) -> Option<f32> {
self.inner.presence_penalty
}
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
}
impl OpenAIStopConditionsProvider for CompletionRequest {
fn get_max_tokens(&self) -> Option<u32> {
self.inner.max_tokens
}
fn get_min_tokens(&self) -> Option<u32> {
None
}
fn get_stop(&self) -> Option<Vec<String>> {
None
}
fn nvext(&self) -> Option<&NvExt> {
self.nvext.as_ref()
}
}
#[derive(Builder)]
pub struct ResponseFactory {
#[builder(setter(into))]
pub model: String,
#[builder(default)]
pub system_fingerprint: Option<String>,
#[builder(default = "format!(\"cmpl-{}\", uuid::Uuid::new_v4())")]
pub id: String,
#[builder(default = "\"text_completion\".to_string()")]
pub object: String,
#[builder(default = "chrono::Utc::now().timestamp() as u64")]
pub created: u64,
}
impl ResponseFactory {
pub fn builder() -> ResponseFactoryBuilder {
ResponseFactoryBuilder::default()
}
pub fn make_response(
&self,
choice: CompletionChoice,
usage: Option<CompletionUsage>,
) -> CompletionResponse {
CompletionResponse {
id: self.id.clone(),
object: self.object.clone(),
created: self.created,
model: self.model.clone(),
choices: vec![choice],
system_fingerprint: self.system_fingerprint.clone(),
usage,
}
}
}
impl TryFrom<CompletionRequest> for common::CompletionRequest {
type Error = anyhow::Error;
fn try_from(request: CompletionRequest) -> Result<Self, Self::Error> {
if request.inner.suffix.is_some() {
return Err(anyhow::anyhow!("suffix is not supported"));
}
let stop_conditions = request
.extract_stop_conditions()
.map_err(|e| anyhow::anyhow!("Failed to extract stop conditions: {}", e))?;
let sampling_options = request
.extract_sampling_options()
.map_err(|e| anyhow::anyhow!("Failed to extract sampling options: {}", e))?;
let prompt = common::PromptType::Completion(common::CompletionContext {
prompt: prompt_to_string(&request.inner.prompt),
system_prompt: None,
});
Ok(common::CompletionRequest {
prompt,
stop_conditions,
sampling_options,
mdc_sum: None,
annotations: None,
})
}
}
impl TryFrom<common::StreamingCompletionResponse> for CompletionChoice {
type Error = anyhow::Error;
fn try_from(response: common::StreamingCompletionResponse) -> Result<Self, Self::Error> {
let choice = CompletionChoice {
text: response
.delta
.text
.ok_or(anyhow::anyhow!("No text in response"))?,
index: response.delta.index.unwrap_or(0) as u64,
logprobs: None,
finish_reason: match &response.delta.finish_reason {
Some(common::FinishReason::EoS) => Some("stop".to_string()),
Some(common::FinishReason::Stop) => Some("stop".to_string()),
Some(common::FinishReason::Length) => Some("length".to_string()),
Some(common::FinishReason::Error(err_msg)) => {
return Err(anyhow::anyhow!("finish_reason::error = {}", err_msg));
}
Some(common::FinishReason::Cancelled) => Some("cancelled".to_string()),
None => None,
},
};
Ok(choice)
}
}