use either::Either;
use indexmap::IndexMap;
use mistralrs_audio::AudioInput;
use mistralrs_quant::IsqType;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::VideoInput;
use crate::{
response::Response, sampler::SamplingParams, tools::ToolChoice, CustomLogitsProcessor,
DiffusionGenerationParams, Tool,
};
use std::{fmt::Debug, path::PathBuf, sync::Arc};
use tokio::sync::mpsc::Sender;
pub type LlguidanceGrammar = llguidance::api::TopLevelGrammar;
#[derive(Clone, Serialize, Deserialize)]
pub enum Constraint {
Regex(String),
Lark(String),
JsonSchema(serde_json::Value),
Llguidance(LlguidanceGrammar),
None,
}
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq)]
#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
pub enum ImageGenerationResponseFormat {
Url,
B64Json,
}
pub type MessageContent = Either<String, Vec<IndexMap<String, Value>>>;
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, Default)]
#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[serde(rename_all = "lowercase")]
pub enum ReasoningEffort {
Low,
#[default]
Medium,
High,
}
impl ReasoningEffort {
pub fn as_str(&self) -> &'static str {
match self {
Self::Low => "low",
Self::Medium => "medium",
Self::High => "high",
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum RequestMessage {
Chat {
messages: Vec<IndexMap<String, MessageContent>>,
enable_thinking: Option<bool>,
reasoning_effort: Option<ReasoningEffort>,
},
Completion {
text: String,
echo_prompt: bool,
best_of: Option<usize>,
},
CompletionTokens(Vec<u32>),
MultimodalChat {
#[serde(skip)] images: Vec<image::DynamicImage>,
#[serde(skip)] audios: Vec<AudioInput>,
#[serde(skip)]
videos: Vec<VideoInput>,
messages: Vec<IndexMap<String, MessageContent>>,
enable_thinking: Option<bool>,
reasoning_effort: Option<ReasoningEffort>,
},
ImageGeneration {
prompt: String,
format: ImageGenerationResponseFormat,
generation_params: DiffusionGenerationParams,
save_file: Option<PathBuf>,
},
SpeechGeneration {
prompt: String,
},
Embedding {
prompt: String,
},
EmbeddingTokens {
prompt: Vec<u32>,
},
}
fn default_responder<T>() -> Sender<T> {
let (sender, _) = tokio::sync::mpsc::channel(1);
sender
}
#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq, eq_int))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Default)]
pub enum SearchContextSize {
#[serde(rename = "low")]
Low,
#[default]
#[serde(rename = "medium")]
Medium,
#[serde(rename = "high")]
High,
}
#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct ApproximateUserLocation {
pub city: String,
pub country: String,
pub region: String,
pub timezone: String,
}
#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")]
pub enum WebSearchUserLocation {
#[serde(rename = "approximate")]
Approximate {
approximate: ApproximateUserLocation,
},
}
#[cfg_attr(feature = "pyo3_macros", pyo3::pyclass(eq))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Default)]
pub struct WebSearchOptions {
pub search_context_size: Option<SearchContextSize>,
pub user_location: Option<WebSearchUserLocation>,
pub search_description: Option<String>,
pub extract_description: Option<String>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct NormalRequest {
pub messages: RequestMessage,
pub sampling_params: SamplingParams,
#[serde(default = "default_responder")]
#[serde(skip)]
pub response: Sender<Response>,
pub return_logprobs: bool,
pub is_streaming: bool,
pub id: usize,
pub constraint: Constraint,
pub suffix: Option<String>,
pub tools: Option<Vec<Tool>>,
pub tool_choice: Option<ToolChoice>,
#[serde(skip)]
pub logits_processors: Option<Vec<Arc<dyn CustomLogitsProcessor>>>,
pub return_raw_logits: bool,
pub web_search_options: Option<WebSearchOptions>,
pub model_id: Option<String>,
#[serde(default)]
pub truncate_sequence: bool,
}
impl NormalRequest {
pub fn new_simple(
messages: RequestMessage,
sampling_params: SamplingParams,
response: Sender<Response>,
id: usize,
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolChoice>,
) -> Self {
Self {
messages,
sampling_params,
response,
id,
tools,
tool_choice,
return_logprobs: false,
is_streaming: false,
constraint: Constraint::None,
suffix: None,
logits_processors: None,
return_raw_logits: false,
web_search_options: None,
model_id: None,
truncate_sequence: false,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct TokenizationRequest {
pub text: Either<Vec<IndexMap<String, MessageContent>>, String>,
pub tools: Option<Vec<Tool>>,
pub add_generation_prompt: bool,
pub add_special_tokens: bool,
pub enable_thinking: Option<bool>,
pub reasoning_effort: Option<ReasoningEffort>,
#[serde(default = "default_responder")]
#[serde(skip)]
pub response: Sender<anyhow::Result<Vec<u32>>>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct DetokenizationRequest {
pub tokens: Vec<u32>,
pub skip_special_tokens: bool,
#[serde(default = "default_responder")]
#[serde(skip)]
pub response: Sender<anyhow::Result<String>>,
}
#[derive(Clone, Serialize, Deserialize)]
pub enum Request {
Normal(Box<NormalRequest>),
ReIsq(IsqType),
Tokenize(TokenizationRequest),
Detokenize(DetokenizationRequest),
Terminate,
TerminateAllSeqsNextStep,
}
impl Debug for Request {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Request::Normal(boxed_req) => {
let NormalRequest {
messages,
sampling_params,
is_streaming,
id,
..
} = &**boxed_req;
write!(
f,
"Request {id} {{ messages: `{messages:?}`, sampling_params: {sampling_params:?}, is_streaming: {is_streaming}}}",
)
}
Request::ReIsq(tp) => {
write!(f, "Re ISQ Request {tp:?}",)
}
Request::Tokenize(req) => {
write!(f, "Tokenization Request {:?}", req.text)
}
Request::Detokenize(req) => {
write!(f, "Tokenization Request {:?}", req.tokens)
}
Request::Terminate => write!(f, "Termination Request"),
Request::TerminateAllSeqsNextStep => write!(f, "Terminate All Seqs Next Step"),
}
}
}