use super::openai::*;
use super::{Metadata, RequestContext};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct GatewayRequest {
pub metadata: Metadata,
pub context: RequestContext,
pub request_type: RequestType,
pub data: RequestData,
pub provider_params: HashMap<String, serde_json::Value>,
pub routing: RoutingPreferences,
pub caching: CachingPreferences,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RequestType {
ChatCompletion,
Completion,
Embedding,
ImageGeneration,
ImageEdit,
ImageVariation,
AudioTranscription,
AudioTranslation,
AudioSpeech,
Moderation,
FineTuning,
Files,
Assistants,
Threads,
Batches,
VectorStores,
Rerank,
Realtime,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum RequestData {
#[serde(rename = "chat_completion")]
ChatCompletion(Box<ChatCompletionRequest>),
#[serde(rename = "completion")]
Completion(CompletionRequest),
#[serde(rename = "embedding")]
Embedding(EmbeddingRequest),
#[serde(rename = "image_generation")]
ImageGeneration(ImageGenerationRequest),
#[serde(rename = "audio_transcription")]
AudioTranscription(AudioTranscriptionRequest),
#[serde(rename = "moderation")]
Moderation(ModerationRequest),
#[serde(rename = "rerank")]
Rerank(RerankRequest),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionRequest {
pub model: String,
pub prompt: Option<String>,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub n: Option<u32>,
pub stream: Option<bool>,
pub stop: Option<Vec<String>>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub logit_bias: Option<HashMap<String, f32>>,
pub user: Option<String>,
pub suffix: Option<String>,
pub echo: Option<bool>,
pub best_of: Option<u32>,
pub logprobs: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub model: String,
pub input: EmbeddingInput,
pub encoding_format: Option<String>,
pub dimensions: Option<u32>,
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum EmbeddingInput {
String(String),
Array(Vec<String>),
Tokens(Vec<u32>),
TokenArrays(Vec<Vec<u32>>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageGenerationRequest {
pub model: Option<String>,
pub prompt: String,
pub n: Option<u32>,
pub size: Option<String>,
pub response_format: Option<String>,
pub quality: Option<String>,
pub style: Option<String>,
pub user: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioTranscriptionRequest {
pub model: String,
pub file: String,
pub language: Option<String>,
pub prompt: Option<String>,
pub response_format: Option<String>,
pub temperature: Option<f32>,
pub timestamp_granularities: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModerationRequest {
pub model: Option<String>,
pub input: ModerationInput,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ModerationInput {
String(String),
Array(Vec<String>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankRequest {
pub model: String,
pub query: String,
pub documents: Vec<RerankDocument>,
pub top_k: Option<u32>,
pub return_documents: Option<bool>,
pub max_chunks_per_doc: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum RerankDocument {
String(String),
Object {
text: String,
},
}
#[derive(Debug, Clone, Default)]
pub struct RoutingPreferences {
pub preferred_providers: Vec<String>,
pub excluded_providers: Vec<String>,
pub strategy_override: Option<String>,
pub tags: Vec<String>,
pub region: Option<String>,
pub optimize_cost: bool,
pub optimize_latency: bool,
}
#[derive(Debug, Clone, Default)]
pub struct CachingPreferences {
pub enabled: bool,
pub ttl_seconds: Option<u64>,
pub key_prefix: Option<String>,
pub semantic_cache: bool,
pub similarity_threshold: Option<f32>,
pub tags: Vec<String>,
}
impl GatewayRequest {
pub fn new(request_type: RequestType, data: RequestData, context: RequestContext) -> Self {
Self {
metadata: Metadata::new(),
context,
request_type,
data,
provider_params: HashMap::new(),
routing: RoutingPreferences::default(),
caching: CachingPreferences::default(),
}
}
pub fn model(&self) -> Option<&str> {
match &self.data {
RequestData::ChatCompletion(req) => Some(&req.model),
RequestData::Completion(req) => Some(&req.model),
RequestData::Embedding(req) => Some(&req.model),
RequestData::ImageGeneration(req) => req.model.as_deref(),
RequestData::AudioTranscription(req) => Some(&req.model),
RequestData::Moderation(req) => req.model.as_deref(),
RequestData::Rerank(req) => Some(&req.model),
}
}
pub fn is_streaming(&self) -> bool {
match &self.data {
RequestData::ChatCompletion(req) => req.stream.unwrap_or(false),
RequestData::Completion(req) => req.stream.unwrap_or(false),
_ => false,
}
}
pub fn estimated_tokens(&self) -> Option<u32> {
None
}
pub fn set_provider_param<K: Into<String>, V: Into<serde_json::Value>>(
&mut self,
key: K,
value: V,
) {
self.provider_params.insert(key.into(), value.into());
}
pub fn get_provider_param(&self, key: &str) -> Option<&serde_json::Value> {
self.provider_params.get(key)
}
pub fn with_routing(mut self, routing: RoutingPreferences) -> Self {
self.routing = routing;
self
}
pub fn with_caching(mut self, caching: CachingPreferences) -> Self {
self.caching = caching;
self
}
pub fn add_preferred_provider<S: Into<String>>(mut self, provider: S) -> Self {
self.routing.preferred_providers.push(provider.into());
self
}
pub fn exclude_provider<S: Into<String>>(mut self, provider: S) -> Self {
self.routing.excluded_providers.push(provider.into());
self
}
pub fn enable_caching(mut self, ttl_seconds: Option<u64>) -> Self {
self.caching.enabled = true;
self.caching.ttl_seconds = ttl_seconds;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gateway_request_creation() {
let context = RequestContext::new();
let chat_request = ChatCompletionRequest::default();
let data = RequestData::ChatCompletion(Box::new(chat_request));
let gateway_request = GatewayRequest::new(RequestType::ChatCompletion, data, context);
assert!(matches!(
gateway_request.request_type,
RequestType::ChatCompletion
));
assert!(matches!(
gateway_request.data,
RequestData::ChatCompletion(_)
));
}
#[test]
fn test_model_extraction() {
let context = RequestContext::new();
let chat_request = ChatCompletionRequest {
model: "gpt-4".to_string(),
..Default::default()
};
let data = RequestData::ChatCompletion(Box::new(chat_request));
let gateway_request = GatewayRequest::new(RequestType::ChatCompletion, data, context);
assert_eq!(gateway_request.model(), Some("gpt-4"));
}
#[test]
fn test_streaming_detection() {
let context = RequestContext::new();
let chat_request = ChatCompletionRequest {
stream: Some(true),
..Default::default()
};
let data = RequestData::ChatCompletion(Box::new(chat_request));
let gateway_request = GatewayRequest::new(RequestType::ChatCompletion, data, context);
assert!(gateway_request.is_streaming());
}
}