use std::collections::HashSet;
use std::sync::Arc;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use super::timing::RequestTracker;
use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride;
use crate::kv_router::protocols::{BlockExtraInfo, WorkerId};
use crate::preprocessor::media::RdmaMediaDataDescriptor;
use crate::protocols::TokenIdType;
#[derive(Serialize, Deserialize, Debug, Clone, Default, Builder)]
#[builder(default)]
pub struct RoutingHints {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub backend_instance_id: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_worker_id: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub decode_worker_id: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dp_rank: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub expected_output_tokens: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub lora_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub priority_jump: Option<f64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub priority: Option<i32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub allowed_worker_ids: Option<HashSet<WorkerId>>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default)]
pub struct BootstrapInfo {
pub bootstrap_host: String,
pub bootstrap_port: u16,
pub bootstrap_room: u64,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PrefillResult {
pub disaggregated_params: serde_json::Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<dynamo_async_openai::types::PromptTokensDetails>,
}
#[derive(Serialize, Deserialize, Debug, Clone, Default, Builder)]
#[builder(default)]
pub struct MmRoutingInfo {
pub routing_token_ids: Vec<TokenIdType>,
pub block_mm_infos: Vec<Option<BlockExtraInfo>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum MultimodalData {
Url(url::Url),
#[serde(rename(serialize = "Url"))]
RawUrl(String),
Decoded(RdmaMediaDataDescriptor),
}
pub type MultimodalDataMap = std::collections::HashMap<String, Vec<MultimodalData>>;
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
pub struct PreprocessedRequest {
pub model: String,
pub token_ids: Vec<TokenIdType>,
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_embeds: Option<String>,
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub multi_modal_data: Option<MultimodalDataMap>,
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub mm_routing_info: Option<MmRoutingInfo>,
pub stop_conditions: StopConditions,
pub sampling_options: SamplingOptions,
pub output_options: OutputOptions,
#[builder(default)]
pub eos_token_ids: Vec<TokenIdType>,
#[builder(default)]
pub mdc_sum: Option<String>,
#[builder(default)]
pub annotations: Vec<String>,
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub routing: Option<RoutingHints>,
#[builder(default)]
pub router_config_override: Option<RouterConfigOverride>,
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prefill_result: Option<PrefillResult>,
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub bootstrap_info: Option<BootstrapInfo>,
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_args: Option<serde_json::Value>,
#[builder(default)]
#[serde(skip)]
pub tracker: Option<Arc<RequestTracker>>,
}
impl PreprocessedRequest {
pub fn has_annotation(&self, annotation: &str) -> bool {
self.annotations.contains(&annotation.to_string())
}
pub fn get_annotation_value(&self, key: &str) -> Option<String> {
let prefix = format!("{}:", key);
self.annotations
.iter()
.find(|a| a.starts_with(&prefix))
.map(|a| a[prefix.len()..].to_string())
}
pub fn builder() -> PreprocessedRequestBuilder {
PreprocessedRequestBuilder::default()
}
pub fn routing_mut(&mut self) -> &mut RoutingHints {
self.routing.get_or_insert_with(RoutingHints::default)
}
pub fn block_mm_routing_info(&self) -> (&[TokenIdType], Option<&[Option<BlockExtraInfo>]>) {
let Some(mm) = self.mm_routing_info.as_ref() else {
return (&self.token_ids, None);
};
let tokens = mm.routing_token_ids.as_slice();
if tokens.is_empty() {
return (&self.token_ids, None);
}
(tokens, Some(mm.block_mm_infos.as_slice()))
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Builder)]
pub struct PreprocessedEmbeddingRequest {
pub token_ids: Vec<Vec<TokenIdType>>,
pub model: String,
pub encoding_format: Option<String>,
pub dimensions: Option<u32>,
#[builder(default)]
pub mdc_sum: Option<String>,
#[builder(default)]
pub annotations: Vec<String>,
}
impl PreprocessedEmbeddingRequest {
pub fn has_annotation(&self, annotation: &str) -> bool {
self.annotations.contains(&annotation.to_string())
}
}
impl PreprocessedEmbeddingRequest {
pub fn builder() -> PreprocessedEmbeddingRequestBuilder {
PreprocessedEmbeddingRequestBuilder::default()
}
}