use super::Metadata;
use super::openai::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct GatewayResponse {
pub metadata: Metadata,
pub response_type: ResponseType,
pub data: ResponseData,
pub provider_info: ProviderInfo,
pub metrics: ResponseMetrics,
pub cache_info: CacheInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseType {
ChatCompletion,
Completion,
Embedding,
ImageGeneration,
AudioTranscription,
Moderation,
Rerank,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ResponseData {
#[serde(rename = "chat_completion")]
ChatCompletion(ChatCompletionResponse),
#[serde(rename = "completion")]
Completion(CompletionResponse),
#[serde(rename = "embedding")]
Embedding(EmbeddingResponse),
#[serde(rename = "image_generation")]
ImageGeneration(ImageGenerationResponse),
#[serde(rename = "audio_transcription")]
AudioTranscription(AudioTranscriptionResponse),
#[serde(rename = "moderation")]
Moderation(ModerationResponse),
#[serde(rename = "rerank")]
Rerank(RerankResponse),
#[serde(rename = "error")]
Error(ErrorResponse),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionChoice {
pub index: u32,
pub text: String,
pub logprobs: Option<CompletionLogprobs>,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionLogprobs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<f64>,
pub top_logprobs: Vec<HashMap<String, f64>>,
pub text_offset: Vec<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: Option<EmbeddingUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingData {
pub object: String,
pub index: u32,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: u32,
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageGenerationResponse {
pub created: u64,
pub data: Vec<ImageData>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageData {
pub url: Option<String>,
pub b64_json: Option<String>,
pub revised_prompt: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioTranscriptionResponse {
pub text: String,
pub language: Option<String>,
pub duration: Option<f64>,
pub segments: Option<Vec<TranscriptionSegment>>,
pub words: Option<Vec<TranscriptionWord>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptionSegment {
pub id: u32,
pub seek: u32,
pub start: f64,
pub end: f64,
pub text: String,
pub tokens: Vec<u32>,
pub temperature: f64,
pub avg_logprob: f64,
pub compression_ratio: f64,
pub no_speech_prob: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptionWord {
pub word: String,
pub start: f64,
pub end: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModerationResponse {
pub id: String,
pub model: String,
pub results: Vec<ModerationResult>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModerationResult {
pub flagged: bool,
pub categories: ModerationCategories,
pub category_scores: ModerationCategoryScores,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModerationCategories {
pub sexual: bool,
pub hate: bool,
pub harassment: bool,
#[serde(rename = "self-harm")]
pub self_harm: bool,
#[serde(rename = "sexual/minors")]
pub sexual_minors: bool,
#[serde(rename = "hate/threatening")]
pub hate_threatening: bool,
#[serde(rename = "harassment/threatening")]
pub harassment_threatening: bool,
#[serde(rename = "self-harm/instructions")]
pub self_harm_instructions: bool,
#[serde(rename = "self-harm/intent")]
pub self_harm_intent: bool,
pub violence: bool,
#[serde(rename = "violence/graphic")]
pub violence_graphic: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModerationCategoryScores {
pub sexual: f64,
pub hate: f64,
pub harassment: f64,
#[serde(rename = "self-harm")]
pub self_harm: f64,
#[serde(rename = "sexual/minors")]
pub sexual_minors: f64,
#[serde(rename = "hate/threatening")]
pub hate_threatening: f64,
#[serde(rename = "harassment/threatening")]
pub harassment_threatening: f64,
#[serde(rename = "self-harm/instructions")]
pub self_harm_instructions: f64,
#[serde(rename = "self-harm/intent")]
pub self_harm_intent: f64,
pub violence: f64,
#[serde(rename = "violence/graphic")]
pub violence_graphic: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankResponse {
pub id: String,
pub model: String,
pub results: Vec<RerankResult>,
pub usage: Option<RerankUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankResult {
pub index: u32,
pub relevance_score: f64,
pub document: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankUsage {
pub total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
pub code: Option<String>,
pub param: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct ProviderInfo {
pub name: String,
pub provider_type: String,
pub model: String,
pub api_version: Option<String>,
pub region: Option<String>,
pub deployment_id: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct ResponseMetrics {
pub total_time_ms: u64,
pub provider_time_ms: u64,
pub queue_time_ms: u64,
pub processing_time_ms: u64,
pub retry_count: u32,
pub from_cache: bool,
pub cache_type: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct CacheInfo {
pub cached: bool,
pub cache_key: Option<String>,
pub ttl_seconds: Option<u64>,
pub hit: bool,
pub cache_type: Option<String>,
pub similarity_score: Option<f32>,
}
impl GatewayResponse {
pub fn new(response_type: ResponseType, data: ResponseData) -> Self {
Self {
metadata: Metadata::new(),
response_type,
data,
provider_info: ProviderInfo::default(),
metrics: ResponseMetrics::default(),
cache_info: CacheInfo::default(),
}
}
pub fn with_provider_info(mut self, provider_info: ProviderInfo) -> Self {
self.provider_info = provider_info;
self
}
pub fn with_metrics(mut self, metrics: ResponseMetrics) -> Self {
self.metrics = metrics;
self
}
pub fn with_cache_info(mut self, cache_info: CacheInfo) -> Self {
self.cache_info = cache_info;
self
}
pub fn is_error(&self) -> bool {
matches!(self.response_type, ResponseType::Error)
}
pub fn usage(&self) -> Option<&Usage> {
match &self.data {
ResponseData::ChatCompletion(resp) => resp.usage.as_ref(),
ResponseData::Completion(resp) => resp.usage.as_ref(),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gateway_response_creation() {
let chat_response = ChatCompletionResponse {
id: "test-id".to_string(),
object: "chat.completion".to_string(),
created: 1234567890,
model: "gpt-4".to_string(),
system_fingerprint: None,
choices: vec![],
usage: None,
};
let data = ResponseData::ChatCompletion(chat_response);
let response = GatewayResponse::new(ResponseType::ChatCompletion, data);
assert!(matches!(
response.response_type,
ResponseType::ChatCompletion
));
assert!(!response.is_error());
}
#[test]
fn test_error_response() {
let error_response = ErrorResponse {
error: ErrorDetail {
message: "Test error".to_string(),
error_type: "invalid_request".to_string(),
code: Some("400".to_string()),
param: None,
},
};
let data = ResponseData::Error(error_response);
let response = GatewayResponse::new(ResponseType::Error, data);
assert!(response.is_error());
}
}