use anyhow::Result;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use super::TokenIdType;
const MAX_GRAMMAR_NESTING_DEPTH: usize = 500;
const MAX_GRAMMAR_BYTE_LENGTH: usize = 64 * 1024;
const MAX_REGEX_BYTE_LENGTH: usize = 32 * 1024;
const MAX_WHITESPACE_PATTERN_BYTE_LENGTH: usize = 1024;
const MAX_JSON_SCHEMA_BYTE_LENGTH: usize = 256 * 1024;
const MAX_JSON_SCHEMA_NESTING_DEPTH: usize = 64;
struct CountingWriter {
count: usize,
limit: usize,
}
impl std::io::Write for CountingWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.count = self.count.saturating_add(buf.len());
if self.count > self.limit {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"guided_json byte length limit exceeded",
));
}
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
fn json_exceeds_nesting_depth(v: &serde_json::Value, max_depth: usize) -> bool {
let mut stack: Vec<(&serde_json::Value, usize)> = vec![(v, 1)];
while let Some((cur, d)) = stack.pop() {
match cur {
serde_json::Value::Object(m) => {
if d > max_depth {
return true;
}
for (_, child) in m {
stack.push((child, d + 1));
}
}
serde_json::Value::Array(a) => {
if d > max_depth {
return true;
}
for child in a {
stack.push((child, d + 1));
}
}
_ => {}
}
}
false
}
fn check_byte_len(field: &str, s: &str, max: usize) -> Result<()> {
if s.len() > max {
return Err(anyhow::anyhow!(
"{} exceeds maximum byte length of {} (got {})",
field,
max,
s.len()
));
}
Ok(())
}
pub mod llm_backend;
pub mod postprocessor;
pub mod preprocessor;
pub mod timing;
pub trait SamplingOptionsProvider {
fn extract_sampling_options(&self) -> Result<SamplingOptions>;
}
pub trait StopConditionsProvider {
fn extract_stop_conditions(&self) -> Result<StopConditions>;
}
pub trait OutputOptionsProvider {
fn extract_output_options(&self) -> Result<OutputOptions>;
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub enum FinishReason {
#[serde(rename = "eos")]
EoS,
#[serde(rename = "length")]
Length,
#[serde(rename = "stop")]
Stop,
#[serde(rename = "error")]
Error(String),
#[serde(rename = "cancelled")]
Cancelled,
#[serde(rename = "content_filter")]
ContentFilter,
}
impl std::fmt::Display for FinishReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FinishReason::EoS => write!(f, "eos"),
FinishReason::Length => write!(f, "length"),
FinishReason::Stop => write!(f, "stop"),
FinishReason::Error(msg) => write!(f, "error: {}", msg),
FinishReason::Cancelled => write!(f, "cancelled"),
FinishReason::ContentFilter => write!(f, "content_filter"),
}
}
}
impl std::str::FromStr for FinishReason {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"eos" => Ok(FinishReason::EoS),
"length" => Ok(FinishReason::Length),
"stop" => Ok(FinishReason::Stop),
"cancelled" => Ok(FinishReason::Cancelled),
s if s.starts_with("error: ") => Ok(FinishReason::Error(s[7..].to_string())),
_ => Err(anyhow::anyhow!("Invalid FinishReason variant: '{}'", s)),
}
}
}
impl From<FinishReason> for dynamo_async_openai::types::CompletionFinishReason {
fn from(reason: FinishReason) -> Self {
match reason {
FinishReason::EoS | FinishReason::Stop | FinishReason::Cancelled => {
dynamo_async_openai::types::CompletionFinishReason::Stop
}
FinishReason::ContentFilter => {
dynamo_async_openai::types::CompletionFinishReason::ContentFilter
}
FinishReason::Length => dynamo_async_openai::types::CompletionFinishReason::Length,
FinishReason::Error(_) => dynamo_async_openai::types::CompletionFinishReason::Stop,
}
}
}
impl From<dynamo_async_openai::types::CompletionFinishReason> for FinishReason {
fn from(reason: dynamo_async_openai::types::CompletionFinishReason) -> Self {
match reason {
dynamo_async_openai::types::CompletionFinishReason::Stop => FinishReason::Stop,
dynamo_async_openai::types::CompletionFinishReason::Length => FinishReason::Length,
dynamo_async_openai::types::CompletionFinishReason::ContentFilter => {
FinishReason::ContentFilter
}
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub enum PromptType {
#[serde(rename = "token_ids")]
TokenIds(Vec<TokenIdType>),
#[serde(rename = "raw")]
Raw(String),
#[serde(rename = "completion")]
Completion(CompletionContext),
#[serde(rename = "chat_completion")]
ChatCompletion(ChatContext),
#[serde(rename = "custom_json")]
CustomJson(serde_json::Value),
}
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
pub struct CompletionRequest {
pub prompt: PromptType,
pub stop_conditions: StopConditions,
pub sampling_options: SamplingOptions,
#[builder(default)]
pub output_options: OutputOptions,
#[builder(default)]
pub mdc_sum: Option<String>,
#[builder(default)]
pub annotations: Option<Vec<String>>,
}
impl CompletionRequest {
pub fn builder() -> CompletionRequestBuilder {
CompletionRequestBuilder::default()
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct CompletionContext {
pub prompt: String,
pub system_prompt: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct ChatTurn {
pub user: String,
pub assistant: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
pub struct ChatContext {
#[serde(flatten)]
pub completion: CompletionContext,
pub context: Vec<ChatTurn>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct StopConditions {
pub max_tokens: Option<u32>,
pub stop: Option<Vec<String>>,
pub stop_token_ids_hidden: Option<Vec<TokenIdType>>,
pub min_tokens: Option<u32>,
pub ignore_eos: Option<bool>,
pub max_thinking_tokens: Option<u32>,
}
impl StopConditions {
pub fn apply_ignore_eos(&mut self) {
if self.ignore_eos.unwrap_or(false) {
self.stop = None;
self.stop_token_ids_hidden = None;
}
}
}
pub const TEMPERATURE_RANGE: (f32, f32) = (0.0, 1.0);
pub const TOP_P_RANGE: (f32, f32) = (0.0, 1.0);
pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (-1.0, 1.0);
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct SamplingOptions {
pub n: Option<u8>,
pub best_of: Option<u8>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub repetition_penalty: Option<f32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<i32>,
pub min_p: Option<f32>,
pub use_beam_search: Option<bool>,
pub length_penalty: Option<f32>,
pub seed: Option<i64>,
pub include_stop_str_in_output: Option<bool>,
pub guided_decoding: Option<GuidedDecodingOptions>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct GuidedDecodingOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub json: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub regex: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub choice: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub grammar: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub backend: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub whitespace_pattern: Option<String>,
}
impl GuidedDecodingOptions {
pub fn new(
json: Option<serde_json::Value>,
regex: Option<String>,
choice: Option<Vec<String>>,
grammar: Option<String>,
backend: Option<String>,
whitespace_pattern: Option<String>,
) -> Self {
Self {
json,
regex,
choice,
grammar,
backend,
whitespace_pattern,
}
}
pub fn validated(
json: Option<serde_json::Value>,
regex: Option<String>,
choice: Option<Vec<String>>,
grammar: Option<String>,
backend: Option<String>,
whitespace_pattern: Option<String>,
) -> Result<Self> {
let instance = Self::new(json, regex, choice, grammar, backend, whitespace_pattern);
instance.validate()?;
Ok(instance)
}
pub fn from_optional(
json: Option<serde_json::Value>,
regex: Option<String>,
choice: Option<Vec<String>>,
grammar: Option<String>,
backend: Option<String>,
whitespace_pattern: Option<String>,
) -> Result<Option<Self>> {
let is_empty_choice = choice.as_ref().is_none_or(|v| v.is_empty());
if json.is_none()
&& regex.is_none()
&& is_empty_choice
&& grammar.is_none()
&& whitespace_pattern.is_none()
{
return Ok(None);
}
let instance = Self::validated(json, regex, choice, grammar, backend, whitespace_pattern)?;
Ok(Some(instance))
}
pub fn validate(&self) -> Result<()> {
let count = [
self.json.is_some(),
self.regex.is_some(),
self.choice.as_ref().is_some_and(|v| !v.is_empty()),
self.grammar.is_some(),
self.whitespace_pattern.is_some(),
]
.iter()
.filter(|&&v| v)
.count();
if count > 1 {
return Err(anyhow::anyhow!(
"Only one of json, regex, choice, or grammar can be set, but multiple are specified"
));
}
if let Some(ref grammar) = self.grammar {
check_byte_len("guided_grammar", grammar, MAX_GRAMMAR_BYTE_LENGTH)?;
let mut depth: usize = 0;
let mut max: usize = 0;
for ch in grammar.bytes() {
match ch {
b'(' | b'[' | b'{' => {
depth += 1;
if depth > max {
max = depth;
}
}
b')' | b']' | b'}' => {
depth = depth.saturating_sub(1);
}
_ => {}
}
}
if max > MAX_GRAMMAR_NESTING_DEPTH {
return Err(anyhow::anyhow!(
"guided_grammar exceeds maximum nesting depth of {} (got {})",
MAX_GRAMMAR_NESTING_DEPTH,
max
));
}
}
if let Some(ref regex) = self.regex {
check_byte_len("guided_regex", regex, MAX_REGEX_BYTE_LENGTH)?;
}
if let Some(ref ws) = self.whitespace_pattern {
check_byte_len(
"guided_whitespace_pattern",
ws,
MAX_WHITESPACE_PATTERN_BYTE_LENGTH,
)?;
}
if let Some(ref json) = self.json {
let mut counter = CountingWriter {
count: 0,
limit: MAX_JSON_SCHEMA_BYTE_LENGTH,
};
if serde_json::to_writer(&mut counter, json).is_err() {
return Err(anyhow::anyhow!(
"guided_json schema exceeds maximum byte length of {} (got at least {})",
MAX_JSON_SCHEMA_BYTE_LENGTH,
counter.count
));
}
if json_exceeds_nesting_depth(json, MAX_JSON_SCHEMA_NESTING_DEPTH) {
return Err(anyhow::anyhow!(
"guided_json schema exceeds maximum nesting depth of {}",
MAX_JSON_SCHEMA_NESTING_DEPTH
));
}
}
Ok(())
}
}
impl SamplingOptions {
pub fn force_greedy(&mut self) {
self.presence_penalty = None;
self.frequency_penalty = None;
self.repetition_penalty = None;
self.temperature = None;
self.top_p = None;
self.top_k = None;
self.min_p = None;
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct OutputOptions {
pub logprobs: Option<u32>,
pub prompt_logprobs: Option<u32>,
pub skip_special_tokens: Option<bool>,
pub formatted_prompt: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatCompletionLogprobs {
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<Vec<ChatCompletionTokenLogprob>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refusal: Option<Vec<ChatCompletionTokenLogprob>>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatCompletionTokenLogprob {
pub token: String,
pub logprob: f64,
pub bytes: Option<Vec<u8>>,
pub top_logprobs: Vec<TopLogprob>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TopLogprob {
pub token: String,
pub logprob: f64,
pub bytes: Option<Vec<u8>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum StreamState {
Active,
Finished(FinishReason),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum Logits {
All(Vec<f32>),
Sparse(Vec<(u32, f32)>),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum LogProbs {
Normalized(Logits),
Raw(Logits),
}
pub struct SequencePositionData {
pub token_id: TokenIdType,
pub logprobs: Option<LogProbs>,
}
#[derive(Debug)]
pub struct StreamingCompletionResponse {
pub delta: Delta,
pub logprobs: Option<ChatCompletionLogprobs>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Delta {
pub is_complete: bool,
pub finish_reason: Option<FinishReason>,
pub token_ids: Option<Vec<u32>>,
pub tokens: Option<Vec<String>>,
pub text: Option<String>,
pub sequence_length: Option<usize>,
pub index: Option<usize>,
pub cum_log_probs: Option<f64>,
pub err_msg: Option<String>,
pub usage: Option<Usage>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Usage {
pub input_tokens_count: usize,
pub output_tokens_count: usize,
}
impl CompletionContext {
pub fn new(prompt: String, system_prompt: Option<String>) -> Self {
Self {
prompt,
system_prompt,
}
}
pub fn from_prompt(prompt: String) -> Self {
Self {
prompt,
system_prompt: None,
}
}
pub fn with_system_prompt(prompt: String, system_prompt: String) -> Self {
Self {
prompt,
system_prompt: Some(system_prompt),
}
}
}
impl From<CompletionContext> for PromptType {
fn from(context: CompletionContext) -> Self {
PromptType::Completion(context)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_completion_context_new() {
let prompt = "Hello, world!".to_string();
let system_prompt = Some("This is a system prompt.".to_string());
let context = CompletionContext::new(prompt.clone(), system_prompt.clone());
assert_eq!(context.prompt, prompt);
assert_eq!(context.system_prompt, system_prompt);
}
#[test]
fn test_completion_context_from_prompt() {
let prompt = "Hello, world!".to_string();
let context = CompletionContext::from_prompt(prompt.clone());
assert_eq!(context.prompt, prompt);
assert_eq!(context.system_prompt, None);
}
#[test]
fn test_completion_context_with_system_prompt() {
let prompt = "Hello, world!".to_string();
let system_prompt = "This is a system prompt.".to_string();
let context = CompletionContext::with_system_prompt(prompt.clone(), system_prompt.clone());
assert_eq!(context.prompt, prompt);
assert_eq!(context.system_prompt, Some(system_prompt));
}
#[test]
fn test_completion_context_into_prompt_type() {
let prompt = "Hello, world!".to_string();
let system_prompt = "This is a system prompt.".to_string();
let context = CompletionContext::with_system_prompt(prompt.clone(), system_prompt.clone());
let prompt_type: PromptType = context.into();
if let PromptType::Completion(completion_context) = prompt_type {
assert_eq!(completion_context.prompt, prompt);
assert_eq!(completion_context.system_prompt, Some(system_prompt));
} else {
panic!("Expected a Completion variant");
}
}
#[test]
fn test_guided_decoding_options_new_and_exclusive() {
let json_val = serde_json::json!({"type": "object"});
let backend = Some("xgrammar".to_string());
let opts = GuidedDecodingOptions::validated(
Some(json_val.clone()),
None,
None,
None,
backend.clone(),
None,
);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.json, Some(json_val));
assert!(opts.regex.is_none());
assert!(opts.choice.is_none());
assert!(opts.grammar.is_none());
assert_eq!(opts.backend, backend);
assert!(opts.whitespace_pattern.is_none());
let regex = Some(r"\d+".to_string());
let opts = GuidedDecodingOptions::validated(None, regex.clone(), None, None, None, None);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.regex, regex);
assert!(opts.json.is_none());
assert!(opts.choice.is_none());
assert!(opts.grammar.is_none());
assert!(opts.whitespace_pattern.is_none());
let choice = Some(vec!["A".to_string(), "B".to_string()]);
let opts = GuidedDecodingOptions::validated(None, None, choice.clone(), None, None, None);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.choice, choice);
assert!(opts.json.is_none());
assert!(opts.regex.is_none());
assert!(opts.grammar.is_none());
assert!(opts.whitespace_pattern.is_none());
let grammar = Some("root ::= 'yes' | 'no'".to_string());
let opts = GuidedDecodingOptions::validated(None, None, None, grammar.clone(), None, None);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.grammar, grammar);
assert!(opts.json.is_none());
assert!(opts.regex.is_none());
assert!(opts.choice.is_none());
assert!(opts.whitespace_pattern.is_none());
let whitespace_pattern = Some(r"\s+".to_string());
let opts = GuidedDecodingOptions::validated(
None,
None,
None,
None,
None,
whitespace_pattern.clone(),
);
assert!(opts.is_ok());
let opts = opts.unwrap();
assert_eq!(opts.whitespace_pattern, whitespace_pattern);
assert!(opts.json.is_none());
assert!(opts.regex.is_none());
assert!(opts.choice.is_none());
assert!(opts.grammar.is_none());
let opts = GuidedDecodingOptions::validated(
Some(serde_json::json!({})),
Some(r"\d+".to_string()),
None,
None,
None,
None,
);
assert!(opts.is_err());
let opts = GuidedDecodingOptions::validated(
None,
Some(r"\d+".to_string()),
Some(vec!["A".to_string()]),
None,
None,
None,
);
assert!(opts.is_err());
let opts = GuidedDecodingOptions::validated(
Some(serde_json::json!({})),
None,
Some(vec!["A".to_string()]),
Some("root ::= 'yes'".to_string()),
None,
None,
);
assert!(opts.is_err());
let opts = GuidedDecodingOptions::validated(None, None, None, None, None, None);
assert!(opts.is_ok());
}
#[test]
fn test_guided_decoding_options_from_optional() {
let opts = GuidedDecodingOptions::from_optional(None, None, None, None, None, None);
assert!(opts.is_ok());
assert!(opts.unwrap().is_none());
let regex = Some(r"\w+".to_string());
let opts =
GuidedDecodingOptions::from_optional(None, regex.clone(), None, None, None, None);
assert!(opts.is_ok());
let val = opts.unwrap();
assert!(val.is_some());
let val = val.unwrap();
assert_eq!(val.regex, regex);
let opts = GuidedDecodingOptions::from_optional(
Some(serde_json::json!({})),
Some(r"\d+".to_string()),
None,
None,
None,
None,
);
assert!(opts.is_err());
let opts = GuidedDecodingOptions::from_optional(None, None, Some(vec![]), None, None, None);
assert!(opts.is_ok());
let val = opts.unwrap();
assert!(val.is_none());
let opts = GuidedDecodingOptions::from_optional(
None,
None,
Some(vec!["A".to_string()]),
None,
None,
None,
);
assert!(opts.is_ok());
let val = opts.unwrap();
assert!(val.is_some());
let val = val.unwrap();
assert_eq!(val.choice, Some(vec!["A".to_string()]));
}
#[test]
fn test_guided_grammar_deep_nesting_rejected() {
let grammar = "(".repeat(501) + "a" + &")".repeat(501);
let result = GuidedDecodingOptions::validated(None, None, None, Some(grammar), None, None);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("nesting depth"));
}
#[test]
fn test_guided_grammar_acceptable_nesting_ok() {
let grammar = "(".repeat(500) + "a" + &")".repeat(500);
let result = GuidedDecodingOptions::validated(None, None, None, Some(grammar), None, None);
assert!(result.is_ok());
}
#[test]
fn test_string_field_byte_length_bounds() {
type Build = fn(String) -> Result<GuidedDecodingOptions>;
let cases: &[(&str, usize, Build)] = &[
("guided_grammar", MAX_GRAMMAR_BYTE_LENGTH, |s| {
GuidedDecodingOptions::validated(None, None, None, Some(s), None, None)
}),
("guided_regex", MAX_REGEX_BYTE_LENGTH, |s| {
GuidedDecodingOptions::validated(None, Some(s), None, None, None, None)
}),
(
"guided_whitespace_pattern",
MAX_WHITESPACE_PATTERN_BYTE_LENGTH,
|s| GuidedDecodingOptions::validated(None, None, None, None, None, Some(s)),
),
];
for (name, limit, build) in cases {
assert!(
build("a".repeat(*limit)).is_ok(),
"{name} at limit should accept"
);
let err = build("a".repeat(limit + 1))
.err()
.unwrap_or_else(|| panic!("{name} over limit should reject"))
.to_string();
assert!(err.contains("byte length"), "{name}: {err}");
}
}
#[test]
fn test_guided_json_byte_length_rejected() {
let big_string = "a".repeat(MAX_JSON_SCHEMA_BYTE_LENGTH + 1);
let json = serde_json::Value::String(big_string);
let result = GuidedDecodingOptions::validated(Some(json), None, None, None, None, None);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("byte length"));
}
#[test]
fn test_guided_json_byte_length_at_limit_accepted() {
let s = "a".repeat(MAX_JSON_SCHEMA_BYTE_LENGTH - 2);
let json = serde_json::Value::String(s);
let result = GuidedDecodingOptions::validated(Some(json), None, None, None, None, None);
assert!(result.is_ok());
}
#[test]
fn test_guided_json_nesting_depth_rejected() {
let mut value = serde_json::json!({});
for _ in 0..(MAX_JSON_SCHEMA_NESTING_DEPTH + 1) {
value = serde_json::json!({ "a": value });
}
let result = GuidedDecodingOptions::validated(Some(value), None, None, None, None, None);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("nesting depth"));
}
#[test]
fn test_guided_json_nesting_depth_at_limit_accepted() {
let mut value = serde_json::json!("leaf");
for _ in 0..MAX_JSON_SCHEMA_NESTING_DEPTH {
value = serde_json::json!({ "a": value });
}
let result = GuidedDecodingOptions::validated(Some(value), None, None, None, None, None);
assert!(result.is_ok());
}
}