use crate::core::TokenUsage;
use crate::engine::{
EmbedOptions, EmbedRequest, EmbeddingResult, GenerationResult, QueryOptions, QueryRequest,
RequestStats, SamplingRuntimeOverride, DEFAULT_CONTEXT_KEY, DEFAULT_MAX_TOKENS,
};
use crate::client::{
EndpointRef, LocalEmbedOptions, LocalTextOptions, SippEmbeddingResponse, SippError,
SippQueryRequest, SippResponseMetadata, SippTextOptions, SippTextResponse,
};
#[cfg(test)]
#[path = "../tests/client/map_tests.rs"]
mod map_tests;
pub(crate) fn local_query_request(request: SippQueryRequest) -> Result<QueryRequest, SippError> {
let options = local_query_options(request.options, request.local)?;
Ok(QueryRequest::new(request.prompt)
.options(options)
.emit_tokens(request.emit_tokens))
}
pub(crate) fn local_chat_options(
options: SippTextOptions,
local: LocalTextOptions,
) -> Result<QueryOptions, SippError> {
local_query_options(options, local)
}
pub(crate) fn local_embed_request(input: String, local: LocalEmbedOptions) -> EmbedRequest {
EmbedRequest {
input,
options: EmbedOptions {
normalize: local.normalize.unwrap_or(true),
context_key: local.context_key,
},
}
}
pub(crate) fn text_response(
endpoint: EndpointRef,
request_id: Option<String>,
result: GenerationResult,
) -> SippTextResponse {
SippTextResponse {
endpoint,
text: result.text,
finish_reason: result.finish_reason,
usage: Some(usage_from_stats(result.stats)),
local_stats: Some(result.stats),
metadata: local_metadata(request_id),
}
}
pub(crate) fn embedding_response(
endpoint: EndpointRef,
request_id: Option<String>,
result: EmbeddingResult,
) -> SippEmbeddingResponse {
SippEmbeddingResponse {
endpoint,
values: result.values,
usage: Some(usage_from_stats(result.stats)),
local_stats: Some(result.stats),
pooling: Some(result.pooling),
normalized: Some(result.normalized),
metadata: local_metadata(request_id),
}
}
#[cfg(feature = "providers")]
pub(crate) fn provider_text_response(
endpoint: EndpointRef,
request_id: Option<String>,
response: crate::providers::ProviderGenerateResponse,
) -> SippTextResponse {
provider_text_output(endpoint, request_id, response)
}
#[cfg(feature = "providers")]
pub(crate) fn provider_chat_response(
endpoint: EndpointRef,
request_id: Option<String>,
response: crate::providers::ProviderChatResponse,
) -> SippTextResponse {
provider_text_output(endpoint, request_id, response)
}
#[cfg(feature = "providers")]
pub(crate) fn provider_embedding_response(
endpoint: EndpointRef,
request_id: Option<String>,
response: crate::providers::ProviderEmbeddingResponse,
) -> SippEmbeddingResponse {
let metadata = response.metadata;
SippEmbeddingResponse {
endpoint,
values: response.result.values,
usage: response.usage,
local_stats: None,
pooling: None,
normalized: None,
metadata: SippResponseMetadata {
request_id,
upstream_request_id: metadata.request_id,
upstream_response_id: metadata.response_id,
},
}
}
#[cfg(feature = "providers")]
pub(crate) fn provider_generation_options(
options: crate::client::SippTextOptions,
) -> crate::providers::ProviderGenerationOptions {
crate::providers::ProviderGenerationOptions {
max_tokens: options.max_tokens,
temperature: options.temperature,
top_p: options.top_p,
stop: options.stop,
}
}
#[cfg(feature = "providers")]
fn provider_text_output(
endpoint: EndpointRef,
request_id: Option<String>,
response: crate::providers::ProviderResponse<crate::providers::ProviderTextOutput>,
) -> SippTextResponse {
let metadata = response.metadata;
SippTextResponse {
endpoint,
text: response.result.text,
finish_reason: response.result.finish_reason,
usage: response.usage,
local_stats: None,
metadata: SippResponseMetadata {
request_id,
upstream_request_id: metadata.request_id,
upstream_response_id: metadata.response_id,
},
}
}
fn local_metadata(request_id: Option<String>) -> SippResponseMetadata {
SippResponseMetadata {
request_id,
upstream_request_id: None,
upstream_response_id: None,
}
}
pub(crate) fn usage_from_stats(stats: RequestStats) -> TokenUsage {
let input_tokens = nonnegative_i32_to_u32(stats.input_tokens);
let output_tokens = nonnegative_i32_to_u32(stats.output_tokens);
let total_tokens = match (input_tokens, output_tokens) {
(Some(input), Some(output)) => input.checked_add(output),
_ => None,
};
TokenUsage {
input_tokens,
output_tokens,
total_tokens,
}
}
fn local_query_options(
options: SippTextOptions,
local: LocalTextOptions,
) -> Result<QueryOptions, SippError> {
let max_tokens = match options.max_tokens {
Some(max_tokens) => i32::try_from(max_tokens).map_err(|_| {
SippError::InvalidRequest("local max_tokens exceeds i32::MAX".to_string())
})?,
None => DEFAULT_MAX_TOKENS,
};
let sampling = local_sampling(options.temperature, options.top_p, local.sampling)?;
Ok(QueryOptions {
context_key: local
.context_key
.unwrap_or_else(|| DEFAULT_CONTEXT_KEY.to_string()),
max_tokens,
grammar: local.grammar.unwrap_or_default(),
json_schema: local.json_schema.unwrap_or_default(),
stop: options.stop,
sampling,
media: local.media,
})
}
fn local_sampling(
temperature: Option<f32>,
top_p: Option<f32>,
sampling: Option<SamplingRuntimeOverride>,
) -> Result<Option<SamplingRuntimeOverride>, SippError> {
let mut override_config = sampling.unwrap_or_default();
merge_sampling_field("temperature", &mut override_config.temperature, temperature)?;
merge_sampling_field("top_p", &mut override_config.top_p, top_p)?;
if override_config.is_empty() {
Ok(None)
} else {
Ok(Some(override_config))
}
}
fn merge_sampling_field(
name: &'static str,
target: &mut Option<f32>,
value: Option<f32>,
) -> Result<(), SippError> {
let Some(value) = value else {
return Ok(());
};
match target {
Some(existing) if *existing != value => Err(SippError::InvalidRequest(format!(
"common {name} conflicts with local sampling.{name}"
))),
Some(_) => Ok(()),
None => {
*target = Some(value);
Ok(())
}
}
}
fn nonnegative_i32_to_u32(value: i32) -> Option<u32> {
u32::try_from(value).ok()
}