use anyhow::Result;
use derive_builder::Builder;
use serde::ser::SerializeStruct;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::HashMap;
use std::time::SystemTime;
use super::TokenIdType;
pub mod llm_backend;
pub mod postprocessor;
pub mod preprocessor;
pub trait SamplingOptionsProvider {
fn extract_sampling_options(&self) -> Result<SamplingOptions>;
}
pub trait StopConditionsProvider {
fn extract_stop_conditions(&self) -> Result<StopConditions>;
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub enum FinishReason {
#[serde(rename = "eos")]
EoS,
#[serde(rename = "length")]
Length,
#[serde(rename = "stop")]
Stop,
#[serde(rename = "error")]
Error(String),
#[serde(rename = "cancelled")]
Cancelled,
}
impl std::fmt::Display for FinishReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FinishReason::EoS => write!(f, "eos"),
FinishReason::Length => write!(f, "length"),
FinishReason::Stop => write!(f, "stop"),
FinishReason::Error(msg) => write!(f, "error: {}", msg),
FinishReason::Cancelled => write!(f, "cancelled"),
}
}
}
impl std::str::FromStr for FinishReason {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"eos" => Ok(FinishReason::EoS),
"length" => Ok(FinishReason::Length),
"stop" => Ok(FinishReason::Stop),
"cancelled" => Ok(FinishReason::Cancelled),
s if s.starts_with("error: ") => Ok(FinishReason::Error(s[7..].to_string())),
_ => Err(anyhow::anyhow!("Invalid FinishReason variant: '{}'", s)),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub enum PromptType {
#[serde(rename = "token_ids")]
TokenIds(Vec<TokenIdType>),
#[serde(rename = "raw")]
Raw(String),
#[serde(rename = "completion")]
Completion(CompletionContext),
#[serde(rename = "chat_completion")]
ChatCompletion(ChatContext),
#[serde(rename = "custom_json")]
CustomJson(serde_json::Value),
}
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
pub struct CompletionRequest {
pub prompt: PromptType,
pub stop_conditions: StopConditions,
pub sampling_options: SamplingOptions,
#[builder(default)]
pub mdc_sum: Option<String>,
#[builder(default)]
pub annotations: Option<Vec<String>>,
}
impl CompletionRequest {
pub fn builder() -> CompletionRequestBuilder {
CompletionRequestBuilder::default()
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct CompletionContext {
pub prompt: String,
pub system_prompt: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct ChatTurn {
pub user: String,
pub assistant: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct ChatContext {
#[serde(flatten)]
pub completion: CompletionContext,
pub context: Vec<ChatTurn>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct StopConditions {
pub max_tokens: Option<u32>,
pub stop: Option<Vec<String>>,
pub stop_token_ids_hidden: Option<Vec<TokenIdType>>,
pub min_tokens: Option<u32>,
pub ignore_eos: Option<bool>,
}
impl StopConditions {
pub fn apply_ignore_eos(&mut self) {
if self.ignore_eos.unwrap_or(false) {
self.min_tokens = self.max_tokens;
self.stop = None;
self.stop_token_ids_hidden = None;
}
}
}
pub const TEMPERATURE_RANGE: (f32, f32) = (0.0, 1.0);
pub const TOP_P_RANGE: (f32, f32) = (0.0, 1.0);
pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (-1.0, 1.0);
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct SamplingOptions {
pub n: Option<i32>,
pub best_of: Option<i32>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub repetition_penalty: Option<f32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<i32>,
pub min_p: Option<f32>,
pub use_beam_search: Option<bool>,
pub length_penalty: Option<f32>,
pub seed: Option<i64>,
}
impl SamplingOptions {
pub fn force_greedy(&mut self) {
self.presence_penalty = None;
self.frequency_penalty = None;
self.repetition_penalty = None;
self.temperature = None;
self.top_p = None;
self.top_k = None;
self.min_p = None;
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct OutputOptions {
pub logprobs: Option<u32>,
pub prompt_logprobs: Option<u32>,
pub skip_special_tokens: Option<bool>,
pub formatted_prompt: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatCompletionLogprobs {
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Vec<ChatCompletionTokenLogprob>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<Vec<ChatCompletionTokenLogprob>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatCompletionTokenLogprob {
pub token: String,
pub logprob: f64,
pub bytes: Option<Vec<u8>>,
pub top_logprobs: Vec<TopLogprob>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TopLogprob {
pub token: String,
pub logprob: f64,
pub bytes: Option<Vec<u8>>,
}
#[derive(Serialize, Deserialize, Debug)]
pub enum StreamingResponse {
Initialize(Option<Prologue>),
Step(Box<StreamingCompletionResponse>),
Finalize(Option<Epilogue>),
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Prologue {
pub formatted_prompt: Option<String>,
pub input_token_ids: Option<Vec<TokenIdType>>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Epilogue {}
#[derive(Debug)]
pub struct StreamingCompletionResponse {
pub delta: Delta,
pub logprobs: Option<ChatCompletionLogprobs>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum StreamState {
Active,
Finished(FinishReason),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum Logits {
All(Vec<f32>),
Sparse(Vec<(u32, f32)>),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum LogProbs {
Normalized(Logits),
Raw(Logits),
}
pub struct SequencePositionData {
pub token_id: TokenIdType,
pub logprobs: Option<LogProbs>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Delta {
pub is_complete: bool,
pub finish_reason: Option<FinishReason>,
pub token_ids: Option<Vec<u32>>,
pub tokens: Option<Vec<String>>,
pub text: Option<String>,
pub sequence_length: Option<usize>,
pub index: Option<usize>,
pub cum_log_probs: Option<f64>,
pub err_msg: Option<String>,
pub usage: Option<Usage>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Usage {
pub input_tokens_count: usize,
pub output_tokens_count: usize,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Stats {
pub time_since_last_forward_pass_us: Option<u64>,
pub request_active_count: u32,
pub request_context_count: u32,
pub request_generation_count: u32,
pub request_scheduled_count: u32,
pub request_max_count: u32,
pub kv_free_cache_blocks: u64,
pub kv_max_cache_blocks: u64,
pub kv_used_cache_blocks: u64,
pub kv_tokens_per_cache_block: u64,
pub runtime_cpu_memory_usage: u64,
pub runtime_gpu_memory_usage: u64,
pub runtime_pinned_memory_usage: u64,
pub iteration_counter: u64,
pub microbatch_id: u64,
pub total_context_tokens: u32,
pub timestamp: String,
}
impl Serialize for StreamingCompletionResponse {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("StreamingCompletionResponse", 2)?;
state.serialize_field("delta", &self.delta)?;
state.end()
}
}
impl<'de> Deserialize<'de> for StreamingCompletionResponse {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct TempResponse {
delta: Delta,
logprobs: Option<ChatCompletionLogprobs>,
}
let TempResponse { delta, logprobs } = TempResponse::deserialize(deserializer)?;
Ok(StreamingCompletionResponse { delta, logprobs })
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ScatterData<T> {
pub x: Vec<T>,
pub y: Vec<T>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Trace {
pub time_to_first_token: u64,
pub token_to_token: Vec<u64>,
pub start: SystemTime,
pub complete: SystemTime,
pub initial_tokens: u32,
pub max_tokens: u32,
pub t2ft_iteration_count: u64,
pub t2t_iteration_count: Vec<u64>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct PerformanceModel {
pub t2ft_intercept: f64,
pub t2ft_slope: f64,
pub t2tl_intercept: f64,
pub t2tl_slope: f64,
pub t2ft_fit_r2: f64,
pub t2tl_fit_r2: f64,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct CalibrationResults {
pub effective_flops: f64,
pub effective_memory_bandwidth: f64,
pub max_q: u32,
pub performance_model: PerformanceModel,
pub traces: Vec<Trace>,
pub t2ft_scatter_data: ScatterData<f64>,
pub t2tl_scatter_data: ScatterData<f64>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct LoadgenResults {
pub stats_by_iteration: HashMap<u64, Stats>,
pub traces: Vec<Trace>,
}
impl CompletionContext {
pub fn new(prompt: String, system_prompt: Option<String>) -> Self {
Self {
prompt,
system_prompt,
}
}
pub fn from_prompt(prompt: String) -> Self {
Self {
prompt,
system_prompt: None,
}
}
pub fn with_system_prompt(prompt: String, system_prompt: String) -> Self {
Self {
prompt,
system_prompt: Some(system_prompt),
}
}
}
impl From<CompletionContext> for PromptType {
fn from(context: CompletionContext) -> Self {
PromptType::Completion(context)
}
}
#[cfg(test)]
mod tests {
use serde_json;
use super::*;
#[test]
fn test_completion_context_new() {
let prompt = "Hello, world!".to_string();
let system_prompt = Some("This is a system prompt.".to_string());
let context = CompletionContext::new(prompt.clone(), system_prompt.clone());
assert_eq!(context.prompt, prompt);
assert_eq!(context.system_prompt, system_prompt);
}
#[test]
fn test_completion_context_from_prompt() {
let prompt = "Hello, world!".to_string();
let context = CompletionContext::from_prompt(prompt.clone());
assert_eq!(context.prompt, prompt);
assert_eq!(context.system_prompt, None);
}
#[test]
fn test_completion_context_with_system_prompt() {
let prompt = "Hello, world!".to_string();
let system_prompt = "This is a system prompt.".to_string();
let context = CompletionContext::with_system_prompt(prompt.clone(), system_prompt.clone());
assert_eq!(context.prompt, prompt);
assert_eq!(context.system_prompt, Some(system_prompt));
}
#[test]
fn test_completion_context_into_prompt_type() {
let prompt = "Hello, world!".to_string();
let system_prompt = "This is a system prompt.".to_string();
let context = CompletionContext::with_system_prompt(prompt.clone(), system_prompt.clone());
let prompt_type: PromptType = context.into();
if let PromptType::Completion(completion_context) = prompt_type {
assert_eq!(completion_context.prompt, prompt);
assert_eq!(completion_context.system_prompt, Some(system_prompt));
} else {
panic!("Expected a Completion variant");
}
}
#[test]
fn test_serialize_without_stats() {
let response = StreamingCompletionResponse {
delta: Delta {
is_complete: false,
finish_reason: None,
token_ids: None,
tokens: None,
text: None,
sequence_length: None,
index: None,
cum_log_probs: None,
err_msg: None,
usage: None,
},
logprobs: None,
};
let serialized = serde_json::to_string(&response).expect("Failed to serialize");
let expected = r#"{
"delta": {
"is_complete": false,
"finish_reason": null,
"token_ids": null,
"tokens": null,
"text": null,
"sequence_length": null,
"index": null,
"cum_log_probs": null,
"err_msg": null,
"usage": null
}
}"#;
assert_eq!(
serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
serde_json::from_str::<serde_json::Value>(expected).unwrap()
);
}
#[test]
fn test_deserialize_without_stats() {
let json_data = r#"{
"delta": {
"is_complete": false,
"finish_reason": null,
"token_ids": null,
"tokens": null,
"text": null,
"sequence_length": null,
"index": null,
"cum_log_probs": null,
"err_msg": null,
"usage": null
}
}"#;
let deserialized: StreamingCompletionResponse =
serde_json::from_str(json_data).expect("Failed to deserialize");
let expected = StreamingCompletionResponse {
delta: Delta {
is_complete: false,
finish_reason: None,
token_ids: None,
tokens: None,
text: None,
sequence_length: None,
index: None,
cum_log_probs: None,
err_msg: None,
usage: None,
},
logprobs: None,
};
assert_eq!(deserialized.delta.is_complete, expected.delta.is_complete);
assert_eq!(
deserialized.delta.finish_reason,
expected.delta.finish_reason
);
assert_eq!(deserialized.delta.token_ids, expected.delta.token_ids);
assert_eq!(deserialized.delta.tokens, expected.delta.tokens);
assert_eq!(deserialized.delta.text, expected.delta.text);
assert_eq!(
deserialized.delta.sequence_length,
expected.delta.sequence_length
);
assert_eq!(deserialized.delta.index, expected.delta.index);
assert_eq!(
deserialized.delta.cum_log_probs,
expected.delta.cum_log_probs
);
assert_eq!(deserialized.delta.err_msg, expected.delta.err_msg);
assert_eq!(deserialized.delta.usage, expected.delta.usage);
}
#[test]
fn test_serialize_delta_and_none_stats() {
let response = StreamingCompletionResponse {
delta: Delta {
is_complete: true,
finish_reason: Some(FinishReason::Length),
token_ids: Some(vec![101, 102, 103]),
tokens: Some(vec!["token1".to_string(), "token2".to_string()]),
text: Some("example text".to_string()),
sequence_length: Some(3),
index: Some(0),
cum_log_probs: Some(-0.5),
err_msg: None,
usage: None,
},
logprobs: None,
};
let serialized = serde_json::to_string(&response).expect("Failed to serialize");
let expected_json = r#"{
"delta": {
"is_complete": true,
"finish_reason": "length",
"token_ids": [101, 102, 103],
"tokens": ["token1", "token2"],
"text": "example text",
"sequence_length": 3,
"index": 0,
"cum_log_probs": -0.5,
"err_msg": null,
"usage": null
}
}"#;
assert_eq!(
serde_json::from_str::<serde_json::Value>(&serialized).unwrap(),
serde_json::from_str::<serde_json::Value>(expected_json).unwrap()
);
}
}