use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use validator::Validate;
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ScanPromptRequest {
#[validate(length(min = 1, max = 100000, message = "Prompt must be between 1 and 100000 characters"))]
pub prompt: String,
#[validate(length(max = 20, message = "Maximum 20 scanners allowed"))]
#[serde(default)]
pub scanners: Vec<String>,
#[serde(default = "default_cache_enabled")]
pub cache_enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ScanOutputRequest {
#[validate(length(min = 1, max = 100000))]
pub prompt: String,
#[validate(length(min = 1, max = 100000))]
pub output: String,
#[validate(length(max = 20))]
#[serde(default)]
pub scanners: Vec<String>,
#[serde(default = "default_cache_enabled")]
pub cache_enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(rename_all = "camelCase")]
pub struct BatchScanRequest {
#[validate(length(min = 1, max = 100, message = "Batch size must be between 1 and 100"))]
pub items: Vec<ScanPromptRequest>,
#[serde(default = "default_max_concurrent")]
#[validate(range(min = 1, max = 10, message = "Max concurrent must be between 1 and 10"))]
pub max_concurrent: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(rename_all = "camelCase")]
pub struct AnonymizeRequest {
#[validate(length(min = 1, max = 100000))]
pub text: String,
#[serde(default)]
pub entity_types: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(rename_all = "camelCase")]
pub struct DeanonymizeRequest {
#[validate(length(min = 1, max = 100000))]
pub text: String,
#[validate(length(min = 1, max = 100))]
pub session_id: String,
}
fn default_cache_enabled() -> bool {
true
}
fn default_max_concurrent() -> usize {
5
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scan_prompt_request_valid() {
let req = ScanPromptRequest {
prompt: "Test prompt".to_string(),
scanners: vec!["toxicity".to_string()],
cache_enabled: true,
};
assert!(req.validate().is_ok());
}
#[test]
fn test_scan_prompt_request_empty_prompt() {
let req = ScanPromptRequest {
prompt: "".to_string(),
scanners: vec![],
cache_enabled: true,
};
assert!(req.validate().is_err());
}
#[test]
fn test_scan_prompt_request_too_long() {
let req = ScanPromptRequest {
prompt: "a".repeat(100001),
scanners: vec![],
cache_enabled: true,
};
assert!(req.validate().is_err());
}
#[test]
fn test_scan_prompt_request_too_many_scanners() {
let req = ScanPromptRequest {
prompt: "Test".to_string(),
scanners: vec!["scanner".to_string(); 21],
cache_enabled: true,
};
assert!(req.validate().is_err());
}
#[test]
fn test_scan_output_request_valid() {
let req = ScanOutputRequest {
prompt: "What is AI?".to_string(),
output: "AI is artificial intelligence".to_string(),
scanners: vec![],
cache_enabled: true,
};
assert!(req.validate().is_ok());
}
#[test]
fn test_batch_scan_request_valid() {
let req = BatchScanRequest {
items: vec![
ScanPromptRequest {
prompt: "Test 1".to_string(),
scanners: vec![],
cache_enabled: true,
},
ScanPromptRequest {
prompt: "Test 2".to_string(),
scanners: vec![],
cache_enabled: true,
},
],
max_concurrent: 5,
};
assert!(req.validate().is_ok());
}
#[test]
fn test_batch_scan_request_empty() {
let req = BatchScanRequest {
items: vec![],
max_concurrent: 5,
};
assert!(req.validate().is_err());
}
#[test]
fn test_batch_scan_request_too_large() {
let items = (0..101)
.map(|i| ScanPromptRequest {
prompt: format!("Test {}", i),
scanners: vec![],
cache_enabled: true,
})
.collect();
let req = BatchScanRequest {
items,
max_concurrent: 5,
};
assert!(req.validate().is_err());
}
#[test]
fn test_batch_scan_request_invalid_concurrent() {
let req = BatchScanRequest {
items: vec![ScanPromptRequest {
prompt: "Test".to_string(),
scanners: vec![],
cache_enabled: true,
}],
max_concurrent: 0,
};
assert!(req.validate().is_err());
}
#[test]
fn test_anonymize_request_valid() {
let req = AnonymizeRequest {
text: "My email is john@example.com".to_string(),
entity_types: vec!["EMAIL".to_string()],
};
assert!(req.validate().is_ok());
}
#[test]
fn test_deanonymize_request_valid() {
let req = DeanonymizeRequest {
text: "My email is [EMAIL_1]".to_string(),
session_id: "session123".to_string(),
};
assert!(req.validate().is_ok());
}
#[test]
fn test_serialization() {
let req = ScanPromptRequest {
prompt: "Test".to_string(),
scanners: vec!["toxicity".to_string()],
cache_enabled: true,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"prompt\""));
assert!(json.contains("\"scanners\""));
assert!(json.contains("\"cacheEnabled\""));
}
#[test]
fn test_deserialization() {
let json = r#"{"prompt":"Test","scanners":["toxicity"],"cacheEnabled":true}"#;
let req: ScanPromptRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.prompt, "Test");
assert_eq!(req.scanners.len(), 1);
assert!(req.cache_enabled);
}
#[test]
fn test_default_values() {
let json = r#"{"prompt":"Test"}"#;
let req: ScanPromptRequest = serde_json::from_str(json).unwrap();
assert!(req.scanners.is_empty());
assert!(req.cache_enabled);
}
}