use std::{future::Future, pin::Pin, sync::Arc};
use openai_protocol::{
chat::ChatCompletionRequest,
common::{ResponseFormat, StringOrArray},
completion::CompletionRequest,
generate::GenerateRequest,
messages::CreateMessageRequest,
responses::ResponsesRequest,
sampling_params::SamplingParams as GenerateSamplingParams,
};
use tonic::{transport::Channel, Request};
use tracing::{debug, warn};
use crate::{AbortOnDropClient, BoxedTraceInjector, NoopTraceInjector};
#[expect(clippy::allow_attributes)]
pub mod tokenspeed_proto {
#![allow(clippy::all, clippy::absolute_paths, unused_qualifications)]
tonic::include_proto!("tokenspeed.grpc.scheduler");
}
pub type AbortOnDropStream =
crate::AbortOnDropStream<tokenspeed_proto::GenerateResponse, TokenSpeedSchedulerClient>;
#[derive(Clone)]
pub struct TokenSpeedSchedulerClient {
client: tokenspeed_proto::token_speed_scheduler_client::TokenSpeedSchedulerClient<Channel>,
trace_injector: BoxedTraceInjector,
}
impl AbortOnDropClient for TokenSpeedSchedulerClient {
fn abort_for_drop(
self,
request_id: String,
) -> Pin<Box<dyn Future<Output = Result<(), tonic::Status>> + Send>> {
Box::pin(async move {
self.abort_request(request_id, "Stream dropped".to_string())
.await
})
}
}
impl TokenSpeedSchedulerClient {
pub async fn connect(endpoint: &str) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
Self::connect_with_trace_injector(endpoint, Arc::new(NoopTraceInjector)).await
}
pub async fn connect_with_trace_injector(
endpoint: &str,
trace_injector: BoxedTraceInjector,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
debug!("Connecting to TokenSpeed scheduler at {}", endpoint);
let channel = crate::channel::connect_channel(endpoint).await?;
let client =
tokenspeed_proto::token_speed_scheduler_client::TokenSpeedSchedulerClient::new(channel);
Ok(Self {
client,
trace_injector,
})
}
#[must_use]
pub fn with_trace_injector(mut self, trace_injector: BoxedTraceInjector) -> Self {
self.trace_injector = trace_injector;
self
}
pub async fn generate(
&self,
req: tokenspeed_proto::GenerateRequest,
) -> Result<AbortOnDropStream, tonic::Status> {
let request_id = req.request_id.clone();
let mut client = self.client.clone();
let mut request = Request::new(req);
if let Err(e) = self.trace_injector.inject(request.metadata_mut()) {
warn!("Failed to inject trace context: {}", e);
}
let response = client.generate(request).await?;
Ok(AbortOnDropStream::new(
response.into_inner(),
request_id,
self.clone(),
))
}
pub async fn health_check(
&self,
) -> Result<tokenspeed_proto::HealthCheckResponse, tonic::Status> {
debug!("Sending TokenSpeed health check request");
let request = Request::new(tokenspeed_proto::HealthCheckRequest {});
let mut client = self.client.clone();
let response = client.health_check(request).await?;
Ok(response.into_inner())
}
pub async fn abort_request(
&self,
request_id: String,
reason: String,
) -> Result<(), tonic::Status> {
debug!(
"Sending TokenSpeed abort for {} (reason: {})",
request_id, reason
);
let request = Request::new(tokenspeed_proto::AbortRequest {
request_id: request_id.clone(),
reason,
});
let mut client = self.client.clone();
let response = client.abort(request).await?;
debug!(
"TokenSpeed abort response for {}: success={}, message={}",
request_id,
response.get_ref().success,
response.get_ref().message
);
Ok(())
}
pub async fn get_model_info(
&self,
) -> Result<tokenspeed_proto::GetModelInfoResponse, tonic::Status> {
let request = Request::new(tokenspeed_proto::GetModelInfoRequest {});
let mut client = self.client.clone();
let response = client.get_model_info(request).await?;
Ok(response.into_inner())
}
pub async fn get_server_info(
&self,
) -> Result<tokenspeed_proto::GetServerInfoResponse, tonic::Status> {
let request = Request::new(tokenspeed_proto::GetServerInfoRequest {});
let mut client = self.client.clone();
let response = client.get_server_info(request).await?;
Ok(response.into_inner())
}
pub async fn get_loads(
&self,
include: Vec<String>,
) -> Result<tokenspeed_proto::GetLoadsResponse, tonic::Status> {
let request = Request::new(tokenspeed_proto::GetLoadsRequest {
dp_rank: None,
include,
});
let mut client = self.client.clone();
let response = client.get_loads(request).await?;
Ok(response.into_inner())
}
crate::impl_admin_ops!();
#[expect(
clippy::unused_self,
reason = "receiver kept for API parity with the other engine clients"
)]
pub fn build_generate_request_from_chat(
&self,
request_id: String,
body: &ChatCompletionRequest,
processed_text: String,
token_ids: Vec<u32>,
multimodal_inputs: Option<tokenspeed_proto::MultimodalInputs>,
tool_call_constraint: Option<(String, String)>,
) -> Result<tokenspeed_proto::GenerateRequest, String> {
let sampling_params = Self::build_sampling_params_from_chat(body, tool_call_constraint)?;
Ok(tokenspeed_proto::GenerateRequest {
request_id,
tokenized: Some(tokenspeed_proto::TokenizedInput {
original_text: processed_text,
input_ids: token_ids,
}),
sampling_params: Some(sampling_params),
return_logprob: body.logprobs,
logprob_start_len: Some(-1),
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
stream: body.stream,
mm_inputs: multimodal_inputs,
..Default::default()
})
}
#[expect(
clippy::unused_self,
reason = "receiver kept for API parity with the other engine clients"
)]
pub fn build_plain_generate_request(
&self,
request_id: String,
body: &GenerateRequest,
original_text: Option<String>,
token_ids: Vec<u32>,
) -> Result<tokenspeed_proto::GenerateRequest, String> {
let sampling_params =
Self::build_sampling_params_from_plain(body.sampling_params.as_ref())?;
Ok(tokenspeed_proto::GenerateRequest {
request_id,
tokenized: Some(tokenspeed_proto::TokenizedInput {
original_text: original_text.unwrap_or_default(),
input_ids: token_ids,
}),
sampling_params: Some(sampling_params),
return_logprob: body.return_logprob.unwrap_or(false),
logprob_start_len: Some(body.logprob_start_len.unwrap_or(-1)),
top_logprobs_num: body.top_logprobs_num.unwrap_or(0),
token_ids_logprob: body.token_ids_logprob.clone().unwrap_or_default(),
stream: body.stream,
mm_inputs: None,
})
}
#[expect(
clippy::unused_self,
reason = "receiver kept for API parity with the other engine clients"
)]
pub fn build_generate_request_from_responses(
&self,
request_id: String,
body: &ResponsesRequest,
processed_text: String,
token_ids: Vec<u32>,
constraint: Option<(String, String)>,
) -> Result<tokenspeed_proto::GenerateRequest, String> {
let sampling_params = Self::build_sampling_params_from_responses(body, constraint)?;
Ok(tokenspeed_proto::GenerateRequest {
request_id,
tokenized: Some(tokenspeed_proto::TokenizedInput {
original_text: processed_text,
input_ids: token_ids,
}),
sampling_params: Some(sampling_params),
stream: body.stream.unwrap_or(false),
..Default::default()
})
}
#[expect(
clippy::unused_self,
reason = "receiver kept for API parity with the other engine clients"
)]
pub fn build_generate_request_from_messages(
&self,
request_id: String,
body: &CreateMessageRequest,
processed_text: String,
token_ids: Vec<u32>,
multimodal_inputs: Option<tokenspeed_proto::MultimodalInputs>,
tool_call_constraint: Option<(String, String)>,
) -> Result<tokenspeed_proto::GenerateRequest, String> {
let sampling_params =
Self::build_sampling_params_from_messages(body, tool_call_constraint)?;
Ok(tokenspeed_proto::GenerateRequest {
request_id,
tokenized: Some(tokenspeed_proto::TokenizedInput {
original_text: processed_text,
input_ids: token_ids,
}),
sampling_params: Some(sampling_params),
stream: body.stream.unwrap_or(false),
mm_inputs: multimodal_inputs,
..Default::default()
})
}
#[expect(
clippy::unused_self,
reason = "receiver kept for API parity with the other engine clients"
)]
pub fn build_generate_request_from_completion(
&self,
request_id: String,
body: &CompletionRequest,
original_text: String,
token_ids: Vec<u32>,
) -> Result<tokenspeed_proto::GenerateRequest, String> {
let sampling_params = Self::build_sampling_params_from_completion(body)?;
Ok(tokenspeed_proto::GenerateRequest {
request_id,
tokenized: Some(tokenspeed_proto::TokenizedInput {
original_text,
input_ids: token_ids,
}),
sampling_params: Some(sampling_params),
return_logprob: body.logprobs.is_some(),
logprob_start_len: Some(-1),
top_logprobs_num: body.logprobs.unwrap_or(0) as i32,
stream: body.stream,
..Default::default()
})
}
fn build_sampling_params_from_chat(
request: &ChatCompletionRequest,
tool_call_constraint: Option<(String, String)>,
) -> Result<tokenspeed_proto::SamplingParams, String> {
let stop_sequences = Self::extract_stop_strings(request.stop.as_ref());
Ok(tokenspeed_proto::SamplingParams {
temperature: Some(request.temperature.unwrap_or(1.0)),
top_p: Some(request.top_p.unwrap_or(1.0)),
top_k: Some(request.top_k.unwrap_or(-1)),
min_p: Some(request.min_p.unwrap_or(0.0)),
frequency_penalty: Some(request.frequency_penalty.unwrap_or(0.0)),
presence_penalty: Some(request.presence_penalty.unwrap_or(0.0)),
repetition_penalty: Some(request.repetition_penalty.unwrap_or(1.0)),
max_new_tokens: request.max_completion_tokens,
stop: stop_sequences,
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
skip_special_tokens: true,
spaces_between_special_tokens: true,
ignore_eos: request.ignore_eos,
no_stop_trim: request.no_stop_trim,
n: request.n.unwrap_or(1),
constraint: Self::build_constraint_for_chat(request, tool_call_constraint)?,
..Default::default()
})
}
fn build_sampling_params_from_responses(
request: &ResponsesRequest,
constraint: Option<(String, String)>,
) -> Result<tokenspeed_proto::SamplingParams, String> {
Ok(tokenspeed_proto::SamplingParams {
temperature: Some(request.temperature.unwrap_or(1.0)),
top_p: Some(request.top_p.unwrap_or(1.0)),
top_k: Some(request.top_k),
min_p: Some(request.min_p),
frequency_penalty: Some(request.frequency_penalty.unwrap_or(0.0)),
presence_penalty: Some(request.presence_penalty.unwrap_or(0.0)),
repetition_penalty: Some(request.repetition_penalty),
max_new_tokens: request.max_output_tokens,
stop: vec![], stop_token_ids: vec![], skip_special_tokens: false, spaces_between_special_tokens: true,
ignore_eos: false,
no_stop_trim: false,
n: 1, constraint: Self::build_constraint_from_pair(constraint)?,
..Default::default()
})
}
fn build_sampling_params_from_messages(
request: &CreateMessageRequest,
tool_call_constraint: Option<(String, String)>,
) -> Result<tokenspeed_proto::SamplingParams, String> {
let stop_sequences = request.stop_sequences.clone().unwrap_or_default();
Ok(tokenspeed_proto::SamplingParams {
temperature: Some(request.temperature.unwrap_or(1.0) as f32),
top_p: Some(request.top_p.unwrap_or(1.0) as f32),
top_k: Some(request.top_k.map(|v| v as i32).unwrap_or(-1)),
min_p: Some(0.0),
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
repetition_penalty: Some(1.0),
max_new_tokens: Some(request.max_tokens),
stop: stop_sequences,
stop_token_ids: vec![],
skip_special_tokens: true,
spaces_between_special_tokens: true,
ignore_eos: false,
no_stop_trim: false,
n: 1,
constraint: Self::build_constraint_from_pair(tool_call_constraint)?,
..Default::default()
})
}
fn build_sampling_params_from_completion(
request: &CompletionRequest,
) -> Result<tokenspeed_proto::SamplingParams, String> {
let stop_sequences = match &request.stop {
Some(StringOrArray::String(s)) => vec![s.clone()],
Some(StringOrArray::Array(arr)) => arr.clone(),
None => vec![],
};
Ok(tokenspeed_proto::SamplingParams {
temperature: Some(request.temperature.unwrap_or(1.0)),
top_p: Some(request.top_p.unwrap_or(1.0)),
top_k: Some(request.top_k.unwrap_or(-1)),
min_p: Some(request.min_p.unwrap_or(0.0)),
frequency_penalty: Some(request.frequency_penalty.unwrap_or(0.0)),
presence_penalty: Some(request.presence_penalty.unwrap_or(0.0)),
repetition_penalty: Some(request.repetition_penalty.unwrap_or(1.0)),
max_new_tokens: request.max_tokens,
min_new_tokens: request.min_tokens.unwrap_or(0),
stop: stop_sequences,
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
skip_special_tokens: request.skip_special_tokens,
spaces_between_special_tokens: true,
ignore_eos: request.ignore_eos,
no_stop_trim: request.no_stop_trim,
n: request.n.unwrap_or(1),
constraint: Self::build_constraint_from_completion(request)?,
..Default::default()
})
}
fn build_sampling_params_from_plain(
params: Option<&GenerateSamplingParams>,
) -> Result<tokenspeed_proto::SamplingParams, String> {
let mut sampling = tokenspeed_proto::SamplingParams {
temperature: Some(1.0),
top_p: Some(1.0),
top_k: Some(-1),
repetition_penalty: Some(1.0),
n: 1,
skip_special_tokens: true,
spaces_between_special_tokens: true,
..Default::default()
};
let Some(p) = params else {
return Ok(sampling);
};
if let Some(v) = p.temperature {
sampling.temperature = Some(v);
}
if let Some(v) = p.top_p {
sampling.top_p = Some(v);
}
if let Some(v) = p.top_k {
sampling.top_k = Some(v);
}
if let Some(v) = p.frequency_penalty {
sampling.frequency_penalty = Some(v);
}
if let Some(v) = p.presence_penalty {
sampling.presence_penalty = Some(v);
}
if let Some(v) = p.repetition_penalty {
sampling.repetition_penalty = Some(v);
}
if let Some(v) = p.min_p {
sampling.min_p = Some(v);
}
if let Some(v) = p.ignore_eos {
sampling.ignore_eos = v;
}
if let Some(v) = p.skip_special_tokens {
sampling.skip_special_tokens = v;
}
if let Some(v) = p.no_stop_trim {
sampling.no_stop_trim = v;
}
if let Some(stop) = &p.stop {
match stop {
StringOrArray::String(s) => sampling.stop.push(s.clone()),
StringOrArray::Array(arr) => sampling.stop.extend(arr.clone()),
}
}
if let Some(stop_token_ids) = &p.stop_token_ids {
sampling.stop_token_ids.clone_from(stop_token_ids);
}
sampling.max_new_tokens = p.max_new_tokens;
if let Some(v) = p.min_new_tokens {
sampling.min_new_tokens = v;
}
if let Some(v) = p.n {
sampling.n = v;
}
sampling.constraint = Self::build_constraint_from_plain(p)?;
Ok(sampling)
}
fn extract_stop_strings(stop: Option<&StringOrArray>) -> Vec<String> {
match stop {
Some(StringOrArray::String(s)) => vec![s.clone()],
Some(StringOrArray::Array(arr)) => arr.clone(),
None => vec![],
}
}
fn build_constraint_for_chat(
request: &ChatCompletionRequest,
tool_call_constraint: Option<(String, String)>,
) -> Result<Option<tokenspeed_proto::sampling_params::Constraint>, String> {
let mut constraints = Vec::new();
match &request.response_format {
Some(ResponseFormat::JsonObject) => {
let schema = serde_json::json!({"type": "object"});
let schema_str = serde_json::to_string(&schema)
.map_err(|e| format!("Failed to serialize JSON schema: {e}"))?;
constraints.push(tokenspeed_proto::sampling_params::Constraint::JsonSchema(
schema_str,
));
}
Some(ResponseFormat::JsonSchema { json_schema }) => {
let schema_str = serde_json::to_string(&json_schema.schema)
.map_err(|e| format!("Failed to serialize JSON schema: {e}"))?;
constraints.push(tokenspeed_proto::sampling_params::Constraint::JsonSchema(
schema_str,
));
}
Some(ResponseFormat::Text) | None => {}
}
if let Some(ebnf) = &request.ebnf {
constraints.push(tokenspeed_proto::sampling_params::Constraint::EbnfGrammar(
ebnf.clone(),
));
}
if let Some(regex) = &request.regex {
constraints.push(tokenspeed_proto::sampling_params::Constraint::Regex(
regex.clone(),
));
}
if let Some((constraint_type, constraint_value)) = tool_call_constraint {
if constraints.is_empty() {
let tool_constraint =
Self::constraint_from_pair(constraint_type, constraint_value)?;
constraints.push(tool_constraint);
} else {
warn!(
"Constrained decoding is not compatible with tool calls, dropping tool constraint"
);
}
}
match constraints.len() {
0 => Ok(None),
1 => Ok(constraints.pop()),
_ => Err("Multiple constraints are not allowed.".to_string()),
}
}
fn build_constraint_from_pair(
constraint: Option<(String, String)>,
) -> Result<Option<tokenspeed_proto::sampling_params::Constraint>, String> {
if let Some((constraint_type, constraint_value)) = constraint {
Ok(Some(Self::constraint_from_pair(
constraint_type,
constraint_value,
)?))
} else {
Ok(None)
}
}
fn constraint_from_pair(
constraint_type: String,
constraint_value: String,
) -> Result<tokenspeed_proto::sampling_params::Constraint, String> {
match constraint_type.as_str() {
"structural_tag" => {
Ok(tokenspeed_proto::sampling_params::Constraint::StructuralTag(constraint_value))
}
"json_schema" => Ok(tokenspeed_proto::sampling_params::Constraint::JsonSchema(
constraint_value,
)),
"ebnf" => Ok(tokenspeed_proto::sampling_params::Constraint::EbnfGrammar(
constraint_value,
)),
"regex" => Ok(tokenspeed_proto::sampling_params::Constraint::Regex(
constraint_value,
)),
_ => Err(format!("Unknown constraint type: {constraint_type}")),
}
}
fn build_constraint_from_completion(
request: &CompletionRequest,
) -> Result<Option<tokenspeed_proto::sampling_params::Constraint>, String> {
let mut constraints = Vec::new();
if let Some(json_schema) = &request.json_schema {
constraints.push(tokenspeed_proto::sampling_params::Constraint::JsonSchema(
json_schema.clone(),
));
}
if let Some(regex) = &request.regex {
constraints.push(tokenspeed_proto::sampling_params::Constraint::Regex(
regex.clone(),
));
}
if let Some(ebnf) = &request.ebnf {
constraints.push(tokenspeed_proto::sampling_params::Constraint::EbnfGrammar(
ebnf.clone(),
));
}
match constraints.len() {
0 => Ok(None),
1 => Ok(constraints.pop()),
_ => Err("Multiple structured constraints are not allowed".to_string()),
}
}
fn build_constraint_from_plain(
params: &GenerateSamplingParams,
) -> Result<Option<tokenspeed_proto::sampling_params::Constraint>, String> {
let mut constraints = Vec::new();
if let Some(json_schema) = ¶ms.json_schema {
constraints.push(tokenspeed_proto::sampling_params::Constraint::JsonSchema(
json_schema.clone(),
));
}
if let Some(regex) = ¶ms.regex {
constraints.push(tokenspeed_proto::sampling_params::Constraint::Regex(
regex.clone(),
));
}
if let Some(ebnf) = ¶ms.ebnf {
constraints.push(tokenspeed_proto::sampling_params::Constraint::EbnfGrammar(
ebnf.clone(),
));
}
match constraints.len() {
0 => Ok(None),
1 => Ok(constraints.pop()),
_ => Err("Multiple structured constraints are not allowed".to_string()),
}
}
}
impl From<tokenspeed_proto::SchedulerLoad> for openai_protocol::worker::SchedulerLoadSnapshot {
fn from(load: tokenspeed_proto::SchedulerLoad) -> Self {
Self {
dp_rank: load.dp_rank,
num_running_reqs: load.num_running_reqs,
num_waiting_reqs: load.num_waiting_reqs,
num_waiting_uncached_tokens: 0,
num_total_reqs: load.num_total_reqs,
num_used_tokens: load.num_used_tokens,
max_total_num_tokens: load.max_total_num_tokens,
token_usage: load.token_usage,
gen_throughput: load.gen_throughput,
cache_hit_rate: load.cache_hit_rate,
utilization: load.utilization,
max_running_requests: load.max_running_requests,
}
}
}
impl From<tokenspeed_proto::GetLoadsResponse> for openai_protocol::worker::WorkerLoadResponse {
fn from(resp: tokenspeed_proto::GetLoadsResponse) -> Self {
Self {
timestamp: resp.timestamp,
dp_rank_count: resp.dp_rank_count,
loads: resp.loads.into_iter().map(Into::into).collect(),
}
}
}