use serde::{Deserialize, Serialize};
pub use super::FinishReason;
pub use super::preprocessor::PreprocessedRequest;
use crate::protocols::TokenIdType;
use dynamo_async_openai::types::CompletionUsage;
use dynamo_async_openai::types::StopReason;
use dynamo_runtime::error::DynamoError;
use dynamo_runtime::protocols::maybe_error::MaybeError;
pub type TokenType = Option<String>;
pub type LogProbs = Vec<f64>;
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)]
#[serde(rename_all = "lowercase")]
pub enum OutputType {
#[default]
Text,
Image,
Video,
Audio,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct ImageUrlData {
pub url: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct VideoUrlData {
pub url: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct AudioUrlData {
pub url: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentPart {
Text { text: String },
ImageUrl { image_url: ImageUrlData },
VideoUrl { video_url: VideoUrlData },
AudioUrl { audio_url: AudioUrlData },
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct TopLogprob {
pub rank: u32,
pub token_id: TokenIdType,
pub token: TokenType,
pub logprob: f64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bytes: Option<Vec<u8>>,
}
pub type TopLogprobs = Vec<Vec<TopLogprob>>;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct BackendOutput {
pub token_ids: Vec<TokenIdType>,
pub tokens: Vec<TokenType>,
pub text: Option<String>,
pub cum_log_probs: Option<f64>,
pub log_probs: Option<LogProbs>,
pub top_logprobs: Option<TopLogprobs>,
pub finish_reason: Option<FinishReason>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
pub index: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub completion_usage: Option<CompletionUsage>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_params: Option<serde_json::Value>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct LLMEngineOutput {
pub token_ids: Vec<TokenIdType>,
pub tokens: Option<Vec<TokenType>>,
pub text: Option<String>,
#[serde(default)]
pub output_type: OutputType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub content_parts: Option<Vec<ContentPart>>,
pub cum_log_probs: Option<f64>,
pub log_probs: Option<LogProbs>,
pub top_logprobs: Option<TopLogprobs>,
pub finish_reason: Option<FinishReason>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub stop_reason: Option<StopReason>,
pub index: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disaggregated_params: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_args: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub completion_usage: Option<CompletionUsage>,
}
impl LLMEngineOutput {
pub fn cancelled() -> Self {
LLMEngineOutput {
token_ids: vec![],
tokens: None,
text: None,
output_type: OutputType::default(),
content_parts: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Cancelled),
stop_reason: None,
index: None,
disaggregated_params: None,
extra_args: None,
completion_usage: None,
}
}
pub fn stop() -> Self {
LLMEngineOutput {
token_ids: vec![],
tokens: None,
text: None,
output_type: OutputType::default(),
content_parts: None,
cum_log_probs: None,
log_probs: None,
finish_reason: Some(FinishReason::Stop),
stop_reason: None,
top_logprobs: None,
index: None,
disaggregated_params: None,
extra_args: None,
completion_usage: None,
}
}
pub fn length() -> Self {
LLMEngineOutput {
token_ids: vec![],
tokens: None,
text: None,
output_type: OutputType::default(),
content_parts: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Length),
stop_reason: None,
index: None,
disaggregated_params: None,
extra_args: None,
completion_usage: None,
}
}
pub fn error(err_msg: String) -> Self {
LLMEngineOutput {
token_ids: vec![],
tokens: None,
text: None,
output_type: OutputType::default(),
content_parts: None,
cum_log_probs: None,
log_probs: None,
top_logprobs: None,
finish_reason: Some(FinishReason::Error(err_msg)),
stop_reason: None,
index: None,
disaggregated_params: None,
extra_args: None,
completion_usage: None,
}
}
}
impl MaybeError for LLMEngineOutput {
fn from_err(err: impl std::error::Error + 'static) -> Self {
LLMEngineOutput::error(err.to_string())
}
fn err(&self) -> Option<DynamoError> {
if let Some(FinishReason::Error(err_msg)) = &self.finish_reason {
Some(DynamoError::msg(err_msg.clone()))
} else {
None
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct EmbeddingsEngineOutput {
pub embeddings: Vec<Vec<f64>>,
pub prompt_tokens: u32,
pub total_tokens: u32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_maybe_error() {
let output = LLMEngineOutput::stop();
assert!(output.err().is_none());
assert!(output.is_ok());
assert!(!output.is_err());
let output = LLMEngineOutput::error("Test error".to_string());
assert!(format!("{}", output.err().unwrap()).contains("Test error"));
assert!(!output.is_ok());
assert!(output.is_err());
}
}