use super::{AzureProcessedResponse, AzureResponseMetadata, ResponseMetrics};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct ResponseProcessingConfig {
pub process_content_filters: bool,
pub calculate_metrics: bool,
pub validate_structure: bool,
pub max_response_size: usize,
}
impl Default for ResponseProcessingConfig {
fn default() -> Self {
Self {
process_content_filters: true,
calculate_metrics: true,
validate_structure: true,
max_response_size: 10 * 1024 * 1024, }
}
}
pub struct AzureResponseProcessor {
config: ResponseProcessingConfig,
}
impl AzureResponseProcessor {
pub fn new() -> Self {
Self {
config: ResponseProcessingConfig::default(),
}
}
pub fn with_config(config: ResponseProcessingConfig) -> Self {
Self { config }
}
pub fn process_response<T: Serialize + for<'de> Deserialize<'de>>(
&self,
response: T,
) -> Result<AzureProcessedResponse<T>, String> {
let start_time = std::time::Instant::now();
let json_response = serde_json::to_value(&response)
.map_err(|e| format!("Failed to serialize response: {}", e))?;
let response_size = serde_json::to_vec(&response).map_or(0, |v| v.len());
if response_size > self.config.max_response_size {
return Err(format!(
"Response size {} exceeds limit of {}",
response_size, self.config.max_response_size
));
}
if self.config.validate_structure {
self.validate_response_structure(&json_response)?;
}
let metadata = self.extract_metadata(&json_response);
let content_filtered = if self.config.process_content_filters {
self.check_content_filtering(&json_response)
} else {
false
};
let warnings = self.collect_warnings(&json_response);
let metrics = if self.config.calculate_metrics {
self.calculate_metrics(&json_response, start_time, response_size)
} else {
ResponseMetrics::default()
};
Ok(AzureProcessedResponse {
data: response,
metadata,
content_filtered,
warnings,
metrics,
})
}
pub fn process_streaming_chunk<T: Serialize>(
&self,
chunk: T,
is_final: bool,
) -> Result<StreamingChunk, String> {
let json_chunk = serde_json::to_value(&chunk)
.map_err(|e| format!("Failed to serialize chunk: {}", e))?;
let content_filtered = self.check_content_filtering_chunk(&json_chunk);
let warnings = self.collect_chunk_warnings(&json_chunk);
Ok(StreamingChunk {
data: json_chunk,
is_final,
content_filtered,
warnings,
})
}
fn validate_response_structure(&self, response: &serde_json::Value) -> Result<(), String> {
if response.get("choices").is_some() {
self.validate_chat_completion_structure(response)?;
}
else if response.get("data").is_some() {
self.validate_embedding_structure(response)?;
}
else if response.get("created").is_some() && response.get("data").is_some() {
self.validate_image_generation_structure(response)?;
}
Ok(())
}
fn validate_chat_completion_structure(
&self,
response: &serde_json::Value,
) -> Result<(), String> {
let choices = response
.get("choices")
.and_then(|c| c.as_array())
.ok_or("Invalid choices array")?;
if choices.is_empty() {
return Err("Empty choices array".to_string());
}
let first_choice = &choices[0];
if first_choice.get("message").is_none() && first_choice.get("text").is_none() {
return Err("Choice missing message or text content".to_string());
}
if first_choice.get("finish_reason").is_none() {
return Err("Choice missing finish_reason".to_string());
}
Ok(())
}
fn validate_embedding_structure(&self, response: &serde_json::Value) -> Result<(), String> {
let data = response
.get("data")
.and_then(|d| d.as_array())
.ok_or("Invalid embedding data array")?;
if data.is_empty() {
return Err("Empty embedding data array".to_string());
}
let first_embedding = &data[0];
if first_embedding.get("embedding").is_none() {
return Err("Embedding entry missing embedding field".to_string());
}
Ok(())
}
fn validate_image_generation_structure(
&self,
response: &serde_json::Value,
) -> Result<(), String> {
let data = response
.get("data")
.and_then(|d| d.as_array())
.ok_or("Invalid image data array")?;
if data.is_empty() {
return Err("Empty image data array".to_string());
}
Ok(())
}
fn extract_metadata(&self, response: &serde_json::Value) -> AzureResponseMetadata {
let mut metadata = AzureResponseMetadata::default();
if let Some(model) = response.get("model").and_then(|m| m.as_str()) {
metadata.deployment_id = Some(model.to_string());
}
if let Some(id) = response.get("id").and_then(|i| i.as_str()) {
metadata.request_id = Some(id.to_string());
}
metadata.content_filter_results = self.extract_content_filters(response);
metadata.prompt_filter_results = self.extract_prompt_filters(response);
metadata
}
fn extract_content_filters(
&self,
response: &serde_json::Value,
) -> Option<super::ContentFilterResults> {
if let Some(choices) = response.get("choices").and_then(|c| c.as_array())
&& let Some(first_choice) = choices.first()
&& let Some(filters) = first_choice.get("content_filter_results")
&& let Ok(filter_results) = serde_json::from_value(filters.clone())
{
return Some(filter_results);
}
if let Some(filters) = response.get("content_filter_results")
&& let Ok(filter_results) = serde_json::from_value(filters.clone())
{
return Some(filter_results);
}
None
}
fn extract_prompt_filters(
&self,
response: &serde_json::Value,
) -> Option<Vec<super::PromptFilterResult>> {
if let Some(filters) = response.get("prompt_filter_results")
&& let Ok(filter_results) = serde_json::from_value(filters.clone())
{
return Some(filter_results);
}
None
}
fn check_content_filtering(&self, response: &serde_json::Value) -> bool {
if let Some(choices) = response.get("choices").and_then(|c| c.as_array()) {
for choice in choices {
if let Some(finish_reason) = choice.get("finish_reason").and_then(|r| r.as_str())
&& finish_reason == "content_filter"
{
return true;
}
}
}
if let Some(filters) = self.extract_content_filters(response) {
return self.is_any_content_filtered(&filters);
}
false
}
fn check_content_filtering_chunk(&self, chunk: &serde_json::Value) -> bool {
self.check_content_filtering(chunk)
}
fn is_any_content_filtered(&self, filters: &super::ContentFilterResults) -> bool {
filters.hate.as_ref().is_some_and(|f| f.filtered)
|| filters.self_harm.as_ref().is_some_and(|f| f.filtered)
|| filters.sexual.as_ref().is_some_and(|f| f.filtered)
|| filters.violence.as_ref().is_some_and(|f| f.filtered)
}
fn collect_warnings(&self, response: &serde_json::Value) -> Vec<String> {
let mut warnings = Vec::new();
if response
.get("choices")
.and_then(|c| c.as_array())
.is_some_and(|arr| arr.is_empty())
{
warnings.push("Response contains empty choices array".to_string());
}
if response.get("choices").is_some() && response.get("usage").is_none() {
warnings.push("Response missing usage information".to_string());
}
if self.check_content_filtering(response) {
warnings.push("Content was filtered by Azure content filters".to_string());
}
warnings
}
fn collect_chunk_warnings(&self, chunk: &serde_json::Value) -> Vec<String> {
let mut warnings = Vec::new();
if self.check_content_filtering_chunk(chunk) {
warnings.push("Streaming chunk was filtered".to_string());
}
warnings
}
fn calculate_metrics(
&self,
_response: &serde_json::Value,
start_time: std::time::Instant,
response_size: usize,
) -> ResponseMetrics {
let total_time = start_time.elapsed().as_millis() as u64;
ResponseMetrics {
total_time_ms: total_time,
transformation_time_ms: total_time / 4, filtering_time_ms: total_time / 8, response_size_bytes: response_size,
}
}
}
impl Default for AzureResponseProcessor {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct StreamingChunk {
pub data: serde_json::Value,
pub is_final: bool,
pub content_filtered: bool,
pub warnings: Vec<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_response_processing_config_default() {
let config = ResponseProcessingConfig::default();
assert!(config.process_content_filters);
assert!(config.calculate_metrics);
assert!(config.validate_structure);
assert_eq!(config.max_response_size, 10 * 1024 * 1024); }
#[test]
fn test_response_processing_config_custom() {
let config = ResponseProcessingConfig {
process_content_filters: false,
calculate_metrics: false,
validate_structure: false,
max_response_size: 1024,
};
assert!(!config.process_content_filters);
assert!(!config.calculate_metrics);
assert!(!config.validate_structure);
assert_eq!(config.max_response_size, 1024);
}
#[test]
fn test_response_processing_config_clone() {
let config = ResponseProcessingConfig::default();
let cloned = config.clone();
assert_eq!(
config.process_content_filters,
cloned.process_content_filters
);
assert_eq!(config.max_response_size, cloned.max_response_size);
}
#[test]
fn test_response_processing_config_debug() {
let config = ResponseProcessingConfig::default();
let debug = format!("{:?}", config);
assert!(debug.contains("ResponseProcessingConfig"));
}
#[test]
fn test_azure_response_processor_new() {
let processor = AzureResponseProcessor::new();
assert!(processor.config.process_content_filters);
assert!(processor.config.calculate_metrics);
}
#[test]
fn test_azure_response_processor_default() {
let processor = AzureResponseProcessor::default();
assert!(processor.config.validate_structure);
}
#[test]
fn test_azure_response_processor_with_config() {
let config = ResponseProcessingConfig {
process_content_filters: false,
calculate_metrics: true,
validate_structure: false,
max_response_size: 5000,
};
let processor = AzureResponseProcessor::with_config(config);
assert!(!processor.config.process_content_filters);
assert!(processor.config.calculate_metrics);
assert!(!processor.config.validate_structure);
assert_eq!(processor.config.max_response_size, 5000);
}
#[test]
fn test_process_response() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": [{"message": {"content": "test"}, "finish_reason": "stop"}],
"usage": {"total_tokens": 10}
});
let result = processor.process_response(response).unwrap();
assert!(!result.content_filtered);
}
#[test]
fn test_process_response_with_id() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"id": "chatcmpl-123456",
"model": "gpt-4",
"choices": [{"message": {"content": "Hello"}, "finish_reason": "stop"}],
"usage": {"total_tokens": 15}
});
let result = processor.process_response(response).unwrap();
assert!(result.metadata.request_id.is_some());
assert_eq!(result.metadata.request_id.unwrap(), "chatcmpl-123456");
assert!(result.metadata.deployment_id.is_some());
assert_eq!(result.metadata.deployment_id.unwrap(), "gpt-4");
}
#[test]
fn test_process_response_without_usage_warning() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": [{"message": {"content": "test"}, "finish_reason": "stop"}]
});
let result = processor.process_response(response).unwrap();
assert!(result.warnings.iter().any(|w| w.contains("missing usage")));
}
#[test]
fn test_process_response_content_filtered() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": [{"message": {"content": ""}, "finish_reason": "content_filter"}],
"usage": {"total_tokens": 5}
});
let result = processor.process_response(response).unwrap();
assert!(result.content_filtered);
assert!(result.warnings.iter().any(|w| w.contains("filtered")));
}
#[test]
fn test_process_response_exceeds_size_limit() {
let config = ResponseProcessingConfig {
max_response_size: 10, ..Default::default()
};
let processor = AzureResponseProcessor::with_config(config);
let response = serde_json::json!({
"choices": [{"message": {"content": "This is a long response"}, "finish_reason": "stop"}]
});
let result = processor.process_response(response);
assert!(result.is_err());
assert!(result.unwrap_err().contains("exceeds limit"));
}
#[test]
fn test_process_response_with_metrics_disabled() {
let config = ResponseProcessingConfig {
calculate_metrics: false,
..Default::default()
};
let processor = AzureResponseProcessor::with_config(config);
let response = serde_json::json!({
"choices": [{"message": {"content": "test"}, "finish_reason": "stop"}],
"usage": {"total_tokens": 10}
});
let result = processor.process_response(response).unwrap();
assert_eq!(result.metrics.total_time_ms, 0);
}
#[test]
fn test_process_response_with_validation_disabled() {
let config = ResponseProcessingConfig {
validate_structure: false,
..Default::default()
};
let processor = AzureResponseProcessor::with_config(config);
let response = serde_json::json!({
"choices": []
});
let result = processor.process_response(response);
assert!(result.is_ok());
}
#[test]
fn test_validate_chat_structure() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": [{"message": {"content": "test"}, "finish_reason": "stop"}]
});
assert!(processor.validate_response_structure(&response).is_ok());
}
#[test]
fn test_validate_chat_structure_with_text() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": [{"text": "completion text", "finish_reason": "stop"}]
});
assert!(processor.validate_response_structure(&response).is_ok());
}
#[test]
fn test_validate_chat_structure_empty_choices() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": []
});
let result = processor.validate_response_structure(&response);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Empty choices"));
}
#[test]
fn test_validate_chat_structure_missing_content() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": [{"finish_reason": "stop"}]
});
let result = processor.validate_response_structure(&response);
assert!(result.is_err());
assert!(result.unwrap_err().contains("missing message or text"));
}
#[test]
fn test_validate_chat_structure_missing_finish_reason() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": [{"message": {"content": "test"}}]
});
let result = processor.validate_response_structure(&response);
assert!(result.is_err());
assert!(result.unwrap_err().contains("missing finish_reason"));
}
#[test]
fn test_validate_embedding_structure() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"data": [{"embedding": [0.1, 0.2, 0.3], "index": 0}],
"model": "text-embedding-ada-002"
});
assert!(processor.validate_response_structure(&response).is_ok());
}
#[test]
fn test_validate_embedding_structure_empty_data() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"data": []
});
let result = processor.validate_response_structure(&response);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Empty embedding"));
}
#[test]
fn test_validate_embedding_structure_missing_embedding() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"data": [{"index": 0}]
});
let result = processor.validate_response_structure(&response);
assert!(result.is_err());
assert!(result.unwrap_err().contains("missing embedding field"));
}
#[test]
fn test_validate_image_generation_structure() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"created": 1700000000,
"data": [{"embedding": [0.1], "url": "https://example.com/image.png"}]
});
assert!(processor.validate_response_structure(&response).is_ok());
}
#[test]
fn test_validate_image_generation_structure_empty_data() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"created": 1700000000,
"data": []
});
let result = processor.validate_response_structure(&response);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Empty embedding"));
}
#[test]
fn test_validate_unknown_structure() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"unknown": "data"
});
assert!(processor.validate_response_structure(&response).is_ok());
}
#[test]
fn test_check_content_filtering_none() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": [{"message": {"content": "Hello"}, "finish_reason": "stop"}]
});
assert!(!processor.check_content_filtering(&response));
}
#[test]
fn test_check_content_filtering_by_finish_reason() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": [{"message": {"content": ""}, "finish_reason": "content_filter"}]
});
assert!(processor.check_content_filtering(&response));
}
#[test]
fn test_check_content_filtering_multiple_choices() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": [
{"message": {"content": "Hello"}, "finish_reason": "stop"},
{"message": {"content": ""}, "finish_reason": "content_filter"}
]
});
assert!(processor.check_content_filtering(&response));
}
#[test]
fn test_process_streaming_chunk_normal() {
let processor = AzureResponseProcessor::new();
let chunk = serde_json::json!({
"choices": [{"delta": {"content": "Hello"}, "finish_reason": null}]
});
let result = processor.process_streaming_chunk(chunk, false).unwrap();
assert!(!result.is_final);
assert!(!result.content_filtered);
assert!(result.warnings.is_empty());
}
#[test]
fn test_process_streaming_chunk_final() {
let processor = AzureResponseProcessor::new();
let chunk = serde_json::json!({
"choices": [{"delta": {}, "finish_reason": "stop"}]
});
let result = processor.process_streaming_chunk(chunk, true).unwrap();
assert!(result.is_final);
}
#[test]
fn test_process_streaming_chunk_filtered() {
let processor = AzureResponseProcessor::new();
let chunk = serde_json::json!({
"choices": [{"delta": {}, "finish_reason": "content_filter"}]
});
let result = processor.process_streaming_chunk(chunk, true).unwrap();
assert!(result.content_filtered);
assert!(result.warnings.iter().any(|w| w.contains("filtered")));
}
#[test]
fn test_streaming_chunk_creation() {
let chunk = StreamingChunk {
data: serde_json::json!({"test": "data"}),
is_final: false,
content_filtered: false,
warnings: vec![],
};
assert!(!chunk.is_final);
assert!(!chunk.content_filtered);
assert!(chunk.warnings.is_empty());
}
#[test]
fn test_streaming_chunk_with_warnings() {
let chunk = StreamingChunk {
data: serde_json::json!({}),
is_final: true,
content_filtered: true,
warnings: vec!["Warning 1".to_string(), "Warning 2".to_string()],
};
assert!(chunk.is_final);
assert!(chunk.content_filtered);
assert_eq!(chunk.warnings.len(), 2);
}
#[test]
fn test_streaming_chunk_clone() {
let chunk = StreamingChunk {
data: serde_json::json!({"content": "test"}),
is_final: false,
content_filtered: false,
warnings: vec!["test warning".to_string()],
};
let cloned = chunk.clone();
assert_eq!(chunk.is_final, cloned.is_final);
assert_eq!(chunk.warnings, cloned.warnings);
}
#[test]
fn test_streaming_chunk_debug() {
let chunk = StreamingChunk {
data: serde_json::json!({}),
is_final: false,
content_filtered: false,
warnings: vec![],
};
let debug = format!("{:?}", chunk);
assert!(debug.contains("StreamingChunk"));
}
#[test]
fn test_calculate_metrics() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": [{"message": {"content": "test"}, "finish_reason": "stop"}],
"usage": {"total_tokens": 10}
});
let result = processor.process_response(response).unwrap();
assert!(result.metrics.response_size_bytes > 0);
}
#[test]
fn test_collect_warnings_empty_choices() {
let config = ResponseProcessingConfig {
validate_structure: false,
..Default::default()
};
let processor = AzureResponseProcessor::with_config(config);
let response = serde_json::json!({
"choices": []
});
let warnings = processor.collect_warnings(&response);
assert!(warnings.iter().any(|w| w.contains("empty choices")));
}
#[test]
fn test_collect_warnings_no_issues() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"choices": [{"message": {"content": "test"}, "finish_reason": "stop"}],
"usage": {"total_tokens": 10}
});
let warnings = processor.collect_warnings(&response);
assert!(warnings.is_empty());
}
#[test]
fn test_extract_metadata_with_model() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({
"model": "gpt-4-turbo",
"id": "resp-123",
"choices": [{"message": {"content": "test"}, "finish_reason": "stop"}]
});
let metadata = processor.extract_metadata(&response);
assert_eq!(metadata.deployment_id, Some("gpt-4-turbo".to_string()));
assert_eq!(metadata.request_id, Some("resp-123".to_string()));
}
#[test]
fn test_extract_metadata_empty() {
let processor = AzureResponseProcessor::new();
let response = serde_json::json!({});
let metadata = processor.extract_metadata(&response);
assert!(metadata.deployment_id.is_none());
assert!(metadata.request_id.is_none());
}
}