1use anyhow::Result;
17use derive_builder::Builder;
18use serde::{Deserialize, Serialize};
19
20use super::TokenIdType;
21
22pub mod llm_backend;
23pub mod postprocessor;
24pub mod preprocessor;
25
26pub trait SamplingOptionsProvider {
29 fn extract_sampling_options(&self) -> Result<SamplingOptions>;
30}
31
32pub trait StopConditionsProvider {
33 fn extract_stop_conditions(&self) -> Result<StopConditions>;
34}
35
36pub trait OutputOptionsProvider {
37 fn extract_output_options(&self) -> Result<OutputOptions>;
38}
39
40#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
41pub enum FinishReason {
42 #[serde(rename = "eos")]
43 EoS,
44
45 #[serde(rename = "length")]
46 Length,
47
48 #[serde(rename = "stop")]
49 Stop,
50
51 #[serde(rename = "error")]
52 Error(String),
53
54 #[serde(rename = "cancelled")]
55 Cancelled,
56
57 #[serde(rename = "content_filter")]
58 ContentFilter,
59}
60
61impl std::fmt::Display for FinishReason {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 match self {
64 FinishReason::EoS => write!(f, "eos"),
65 FinishReason::Length => write!(f, "length"),
66 FinishReason::Stop => write!(f, "stop"),
67 FinishReason::Error(msg) => write!(f, "error: {}", msg),
68 FinishReason::Cancelled => write!(f, "cancelled"),
69 FinishReason::ContentFilter => write!(f, "content_filter"),
70 }
71 }
72}
73
74impl std::str::FromStr for FinishReason {
75 type Err = anyhow::Error;
76
77 fn from_str(s: &str) -> Result<Self, Self::Err> {
78 match s {
79 "eos" => Ok(FinishReason::EoS),
80 "length" => Ok(FinishReason::Length),
81 "stop" => Ok(FinishReason::Stop),
82 "cancelled" => Ok(FinishReason::Cancelled),
83 s if s.starts_with("error: ") => Ok(FinishReason::Error(s[7..].to_string())),
84 _ => Err(anyhow::anyhow!("Invalid FinishReason variant: '{}'", s)),
85 }
86 }
87}
88
89impl From<FinishReason> for dynamo_async_openai::types::CompletionFinishReason {
90 fn from(reason: FinishReason) -> Self {
91 match reason {
92 FinishReason::EoS | FinishReason::Stop | FinishReason::Cancelled => {
93 dynamo_async_openai::types::CompletionFinishReason::Stop
94 }
95 FinishReason::ContentFilter => {
96 dynamo_async_openai::types::CompletionFinishReason::ContentFilter
97 }
98 FinishReason::Length => dynamo_async_openai::types::CompletionFinishReason::Length,
99 FinishReason::Error(_) => dynamo_async_openai::types::CompletionFinishReason::Stop,
100 }
101 }
102}
103
104impl From<dynamo_async_openai::types::CompletionFinishReason> for FinishReason {
105 fn from(reason: dynamo_async_openai::types::CompletionFinishReason) -> Self {
106 match reason {
107 dynamo_async_openai::types::CompletionFinishReason::Stop => FinishReason::Stop,
108 dynamo_async_openai::types::CompletionFinishReason::Length => FinishReason::Length,
109 dynamo_async_openai::types::CompletionFinishReason::ContentFilter => {
110 FinishReason::ContentFilter
111 }
112 }
113 }
114}
115
116#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
121pub enum PromptType {
122 #[serde(rename = "token_ids")]
126 TokenIds(Vec<TokenIdType>),
127
128 #[serde(rename = "raw")]
132 Raw(String),
133
134 #[serde(rename = "completion")]
139 Completion(CompletionContext),
140
141 #[serde(rename = "chat_completion")]
146 ChatCompletion(ChatContext),
147
148 #[serde(rename = "custom_json")]
152 CustomJson(serde_json::Value),
153}
154
155#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
162pub struct CompletionRequest {
163 pub prompt: PromptType,
165
166 pub stop_conditions: StopConditions,
168
169 pub sampling_options: SamplingOptions,
173
174 #[builder(default)]
175 pub output_options: OutputOptions,
176
177 #[builder(default)]
179 pub mdc_sum: Option<String>,
180
181 #[builder(default)]
183 pub annotations: Option<Vec<String>>,
184}
185
186impl CompletionRequest {
187 pub fn builder() -> CompletionRequestBuilder {
188 CompletionRequestBuilder::default()
189 }
190}
191
192#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
193pub struct CompletionContext {
196 pub prompt: String,
198
199 pub system_prompt: Option<String>,
201}
202
203#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
205pub struct ChatTurn {
206 pub user: String,
208
209 pub assistant: String,
211}
212
213#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
216pub struct ChatContext {
217 #[serde(flatten)]
219 pub completion: CompletionContext,
220
221 pub context: Vec<ChatTurn>,
223}
224
225#[derive(Serialize, Deserialize, Debug, Clone, Default)]
228pub struct StopConditions {
229 pub max_tokens: Option<u32>,
231
232 pub stop: Option<Vec<String>>,
235
236 pub stop_token_ids_hidden: Option<Vec<TokenIdType>>,
239
240 pub min_tokens: Option<u32>,
243
244 pub ignore_eos: Option<bool>,
248
249 pub max_thinking_tokens: Option<u32>,
252}
253
254impl StopConditions {
255 pub fn apply_ignore_eos(&mut self) {
256 if self.ignore_eos.unwrap_or(false) {
257 self.min_tokens = self.max_tokens;
258 self.stop = None;
259 self.stop_token_ids_hidden = None;
260 }
261 }
262}
263
264pub const TEMPERATURE_RANGE: (f32, f32) = (0.0, 1.0);
266
267pub const TOP_P_RANGE: (f32, f32) = (0.0, 1.0);
269
270pub const FREQUENCY_PENALTY_RANGE: (f32, f32) = (-1.0, 1.0);
272
273#[derive(Serialize, Deserialize, Debug, Clone, Default)]
275pub struct SamplingOptions {
276 pub n: Option<u8>,
278
279 pub best_of: Option<u8>,
285
286 pub presence_penalty: Option<f32>,
291
292 pub frequency_penalty: Option<f32>,
297
298 pub repetition_penalty: Option<f32>,
303
304 pub temperature: Option<f32>,
308
309 pub top_p: Option<f32>,
312
313 pub top_k: Option<i32>,
316
317 pub min_p: Option<f32>,
321
322 pub use_beam_search: Option<bool>,
324
325 pub length_penalty: Option<f32>,
328
329 pub seed: Option<i64>,
331
332 pub include_stop_str_in_output: Option<bool>,
334
335 pub guided_decoding: Option<GuidedDecodingOptions>,
337}
338
339#[derive(Serialize, Deserialize, Debug, Clone, Default)]
343pub struct GuidedDecodingOptions {
344 #[serde(skip_serializing_if = "Option::is_none")]
346 pub json: Option<serde_json::Value>,
347
348 #[serde(skip_serializing_if = "Option::is_none")]
350 pub regex: Option<String>,
351
352 #[serde(skip_serializing_if = "Option::is_none")]
354 pub choice: Option<Vec<String>>,
355
356 #[serde(skip_serializing_if = "Option::is_none")]
358 pub grammar: Option<String>,
359
360 #[serde(skip_serializing_if = "Option::is_none")]
362 pub backend: Option<String>,
363}
364
365impl GuidedDecodingOptions {
366 pub fn new(
368 json: Option<serde_json::Value>,
369 regex: Option<String>,
370 choice: Option<Vec<String>>,
371 grammar: Option<String>,
372 backend: Option<String>,
373 ) -> Self {
374 Self {
375 json,
376 regex,
377 choice,
378 grammar,
379 backend,
380 }
381 }
382
383 pub fn validated(
385 json: Option<serde_json::Value>,
386 regex: Option<String>,
387 choice: Option<Vec<String>>,
388 grammar: Option<String>,
389 backend: Option<String>,
390 ) -> Result<Self> {
391 let instance = Self::new(json, regex, choice, grammar, backend);
392 instance.validate()?;
393 Ok(instance)
394 }
395
396 pub fn from_optional(
398 json: Option<serde_json::Value>,
399 regex: Option<String>,
400 choice: Option<Vec<String>>,
401 grammar: Option<String>,
402 backend: Option<String>,
403 ) -> Result<Option<Self>> {
404 let is_empty_choice = choice.as_ref().is_none_or(|v| v.is_empty());
405 if json.is_none() && regex.is_none() && is_empty_choice && grammar.is_none() {
406 return Ok(None);
407 }
408 let instance = Self::validated(json, regex, choice, grammar, backend)?;
409 Ok(Some(instance))
410 }
411
412 pub fn validate(&self) -> Result<()> {
414 let count = [
415 self.json.is_some(),
416 self.regex.is_some(),
417 self.choice.as_ref().is_some_and(|v| !v.is_empty()),
418 self.grammar.is_some(),
419 ]
420 .iter()
421 .filter(|&&v| v)
422 .count();
423
424 if count > 1 {
425 Err(anyhow::anyhow!(
426 "Only one of json, regex, choice, or grammar can be set, but multiple are specified: {:?}",
427 self
428 ))
429 } else {
430 Ok(())
431 }
432 }
433}
434
435impl SamplingOptions {
436 pub fn force_greedy(&mut self) {
437 self.presence_penalty = None;
438 self.frequency_penalty = None;
439 self.repetition_penalty = None;
440 self.temperature = None;
441 self.top_p = None;
442 self.top_k = None;
443 self.min_p = None;
444 }
445}
446
447#[derive(Serialize, Deserialize, Debug, Clone, Default)]
449pub struct OutputOptions {
450 pub logprobs: Option<u32>,
457
458 pub prompt_logprobs: Option<u32>,
460
461 pub skip_special_tokens: Option<bool>,
465
466 pub formatted_prompt: Option<bool>,
470}
471
472#[derive(Debug, Serialize, Deserialize, Clone)]
474pub struct ChatCompletionLogprobs {
475 #[serde(skip_serializing_if = "Option::is_none")]
477 pub content: Option<Vec<ChatCompletionTokenLogprob>>,
478
479 #[serde(skip_serializing_if = "Option::is_none")]
481 pub refusal: Option<Vec<ChatCompletionTokenLogprob>>,
482}
483
484#[derive(Debug, Serialize, Deserialize, Clone)]
485pub struct ChatCompletionTokenLogprob {
486 pub token: String,
488
489 pub logprob: f64,
492
493 pub bytes: Option<Vec<u8>>,
498
499 pub top_logprobs: Vec<TopLogprob>,
502}
503
504#[derive(Debug, Serialize, Deserialize, Clone)]
505pub struct TopLogprob {
506 pub token: String,
508
509 pub logprob: f64,
511
512 pub bytes: Option<Vec<u8>>,
515}
516
517#[derive(Serialize, Deserialize, Debug, Clone)]
518pub enum StreamState {
519 Active,
520 Finished(FinishReason),
521}
522
523#[derive(Serialize, Deserialize, Debug, Clone)]
524#[serde(rename_all = "snake_case")]
525pub enum Logits {
526 All(Vec<f32>),
527 Sparse(Vec<(u32, f32)>),
528}
529
530#[derive(Serialize, Deserialize, Debug, Clone)]
531#[serde(rename_all = "snake_case")]
532pub enum LogProbs {
533 Normalized(Logits),
534 Raw(Logits),
535}
536
537pub struct SequencePositionData {
539 pub token_id: TokenIdType,
540
541 pub logprobs: Option<LogProbs>,
543}
544
545#[derive(Debug)]
546pub struct StreamingCompletionResponse {
547 pub delta: Delta,
548 pub logprobs: Option<ChatCompletionLogprobs>,
549}
550
551#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
555pub struct Delta {
556 pub is_complete: bool,
557
558 pub finish_reason: Option<FinishReason>,
559
560 pub token_ids: Option<Vec<u32>>,
562
563 pub tokens: Option<Vec<String>>,
565
566 pub text: Option<String>,
568
569 pub sequence_length: Option<usize>,
572
573 pub index: Option<usize>,
576
577 pub cum_log_probs: Option<f64>,
579
580 pub err_msg: Option<String>,
583
584 pub usage: Option<Usage>,
586}
587
588#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
589pub struct Usage {
590 pub input_tokens_count: usize,
591 pub output_tokens_count: usize,
592}
593
594impl CompletionContext {
595 pub fn new(prompt: String, system_prompt: Option<String>) -> Self {
597 Self {
598 prompt,
599 system_prompt,
600 }
601 }
602
603 pub fn from_prompt(prompt: String) -> Self {
605 Self {
606 prompt,
607 system_prompt: None,
608 }
609 }
610
611 pub fn with_system_prompt(prompt: String, system_prompt: String) -> Self {
613 Self {
614 prompt,
615 system_prompt: Some(system_prompt),
616 }
617 }
618}
619
620impl From<CompletionContext> for PromptType {
622 fn from(context: CompletionContext) -> Self {
623 PromptType::Completion(context)
624 }
625}
626
627#[cfg(test)]
628mod tests {
629
630 use super::*;
631
632 #[test]
633 fn test_completion_context_new() {
634 let prompt = "Hello, world!".to_string();
635 let system_prompt = Some("This is a system prompt.".to_string());
636 let context = CompletionContext::new(prompt.clone(), system_prompt.clone());
637
638 assert_eq!(context.prompt, prompt);
639 assert_eq!(context.system_prompt, system_prompt);
640 }
641
642 #[test]
643 fn test_completion_context_from_prompt() {
644 let prompt = "Hello, world!".to_string();
645 let context = CompletionContext::from_prompt(prompt.clone());
646
647 assert_eq!(context.prompt, prompt);
648 assert_eq!(context.system_prompt, None);
649 }
650
651 #[test]
652 fn test_completion_context_with_system_prompt() {
653 let prompt = "Hello, world!".to_string();
654 let system_prompt = "This is a system prompt.".to_string();
655 let context = CompletionContext::with_system_prompt(prompt.clone(), system_prompt.clone());
656
657 assert_eq!(context.prompt, prompt);
658 assert_eq!(context.system_prompt, Some(system_prompt));
659 }
660
661 #[test]
662 fn test_completion_context_into_prompt_type() {
663 let prompt = "Hello, world!".to_string();
664 let system_prompt = "This is a system prompt.".to_string();
665 let context = CompletionContext::with_system_prompt(prompt.clone(), system_prompt.clone());
666 let prompt_type: PromptType = context.into();
667
668 if let PromptType::Completion(completion_context) = prompt_type {
669 assert_eq!(completion_context.prompt, prompt);
670 assert_eq!(completion_context.system_prompt, Some(system_prompt));
671 } else {
672 panic!("Expected a Completion variant");
673 }
674 }
675
676 #[test]
677
678 fn test_guided_decoding_options_new_and_exclusive() {
679 let json_val = serde_json::json!({"type": "object"});
681 let backend = Some("xgrammar".to_string());
682 let opts = GuidedDecodingOptions::validated(
683 Some(json_val.clone()),
684 None,
685 None,
686 None,
687 backend.clone(),
688 );
689 assert!(opts.is_ok());
690 let opts = opts.unwrap();
691 assert_eq!(opts.json, Some(json_val));
692 assert!(opts.regex.is_none());
693 assert!(opts.choice.is_none());
694 assert!(opts.grammar.is_none());
695 assert_eq!(opts.backend, backend);
696
697 let regex = Some(r"\d+".to_string());
699 let opts = GuidedDecodingOptions::validated(None, regex.clone(), None, None, None);
700 assert!(opts.is_ok());
701 let opts = opts.unwrap();
702 assert_eq!(opts.regex, regex);
703 assert!(opts.json.is_none());
704 assert!(opts.choice.is_none());
705 assert!(opts.grammar.is_none());
706
707 let choice = Some(vec!["A".to_string(), "B".to_string()]);
709 let opts = GuidedDecodingOptions::validated(None, None, choice.clone(), None, None);
710 assert!(opts.is_ok());
711 let opts = opts.unwrap();
712 assert_eq!(opts.choice, choice);
713 assert!(opts.json.is_none());
714 assert!(opts.regex.is_none());
715 assert!(opts.grammar.is_none());
716
717 let grammar = Some("root ::= 'yes' | 'no'".to_string());
719 let opts = GuidedDecodingOptions::validated(None, None, None, grammar.clone(), None);
720 assert!(opts.is_ok());
721 let opts = opts.unwrap();
722 assert_eq!(opts.grammar, grammar);
723 assert!(opts.json.is_none());
724 assert!(opts.regex.is_none());
725 assert!(opts.choice.is_none());
726
727 let opts = GuidedDecodingOptions::validated(
729 Some(serde_json::json!({})),
730 Some(r"\d+".to_string()),
731 None,
732 None,
733 None,
734 );
735 assert!(opts.is_err());
736
737 let opts = GuidedDecodingOptions::validated(
738 None,
739 Some(r"\d+".to_string()),
740 Some(vec!["A".to_string()]),
741 None,
742 None,
743 );
744 assert!(opts.is_err());
745
746 let opts = GuidedDecodingOptions::validated(
747 Some(serde_json::json!({})),
748 None,
749 Some(vec!["A".to_string()]),
750 Some("root ::= 'yes'".to_string()),
751 None,
752 );
753 assert!(opts.is_err());
754
755 let opts = GuidedDecodingOptions::validated(None, None, None, None, None);
757 assert!(opts.is_ok());
758 }
759
760 #[test]
761 fn test_guided_decoding_options_from_optional() {
762 let opts = GuidedDecodingOptions::from_optional(None, None, None, None, None);
764 assert!(opts.is_ok());
765 assert!(opts.unwrap().is_none());
766
767 let regex = Some(r"\w+".to_string());
769 let opts = GuidedDecodingOptions::from_optional(None, regex.clone(), None, None, None);
770 assert!(opts.is_ok());
771 let val = opts.unwrap();
772 assert!(val.is_some());
773 let val = val.unwrap();
774 assert_eq!(val.regex, regex);
775
776 let opts = GuidedDecodingOptions::from_optional(
778 Some(serde_json::json!({})),
779 Some(r"\d+".to_string()),
780 None,
781 None,
782 None,
783 );
784 assert!(opts.is_err());
785
786 let opts = GuidedDecodingOptions::from_optional(None, None, Some(vec![]), None, None);
788 assert!(opts.is_ok());
789 let val = opts.unwrap();
790 assert!(val.is_none());
791
792 let opts = GuidedDecodingOptions::from_optional(
794 None,
795 None,
796 Some(vec!["A".to_string()]),
797 None,
798 None,
799 );
800 assert!(opts.is_ok());
801 let val = opts.unwrap();
802 assert!(val.is_some());
803 let val = val.unwrap();
804 assert_eq!(val.choice, Some(vec!["A".to_string()]));
805 }
806}