use crate::config::Config;
use crate::error::{Error, Result};
use crate::http::{api_paths, HttpClient};
use crate::models::{
AnalysisResult, BatchOptions, DetectionModelResult, DetectionResult, DetectionResultList,
FloatOrObject, FormattedDetectionResultList, GetResultOptions, GetResultsOptions,
UploadOptions, UploadResult,
};
use futures::future;
use std::time::{Duration, Instant};
use tokio::time::sleep;
pub struct Client {
http_client: HttpClient,
}
impl Client {
pub fn new(config: Config) -> Result<Self> {
let http_client = HttpClient::new(config)?;
Ok(Self { http_client })
}
pub async fn upload(&self, options: UploadOptions) -> Result<UploadResult> {
self.http_client
.upload_file::<UploadResult>(&options.file_path)
.await
}
pub async fn upload_social_media(&self, social_media_link: &str) -> Result<UploadResult> {
self.http_client
.upload_social_media_link(social_media_link)
.await
}
pub async fn get_result(
&self,
request_id: &str,
options: Option<GetResultOptions>,
) -> Result<DetectionResult> {
let opts = options.unwrap_or_default();
let should_wait =
opts.max_attempts.unwrap_or(0) > 0 && opts.polling_interval.unwrap_or(0) > 0;
if should_wait {
self.wait_for_result(
request_id,
opts.max_attempts.unwrap(),
opts.polling_interval.unwrap(),
)
.await
} else {
self.fetch_result(request_id).await
}
}
async fn fetch_result(&self, request_id: &str) -> Result<DetectionResult> {
let endpoint = format!("{}/{}", api_paths::MEDIA_RESULT, request_id);
let result = self.http_client.get::<AnalysisResult>(&endpoint).await?;
Ok(self.normalize_scores(&result))
}
fn normalize_scores(&self, result: &AnalysisResult) -> DetectionResult {
let mut detection_result = DetectionResult {
status: if result.status == "FAKE" {
"MANIPULATED".to_string()
} else {
result.status.clone()
},
request_id: result.request_id.clone(),
score: result.final_score.map(|final_score| final_score / 100.0),
models: vec![],
};
if result.results_summary.is_some() {
if let Some(metadata) = &result.results_summary.as_ref().unwrap().metadata {
if let Some(final_score) = metadata.get("finalScore") {
if let Some(score_value) = final_score.as_f64() {
detection_result.score = Some(score_value / 100.0)
}
}
}
}
detection_result.models = result
.models
.iter()
.filter(|model| model.status != "NOT_APPLICABLE")
.map(|model| DetectionModelResult {
name: model.name.clone(),
status: if model.status == "FAKE" {
"MANIPULATED".to_string()
} else {
model.status.clone()
},
score: match model.prediction_number {
Some(FloatOrObject::Float(val)) => Some(val),
_ => None,
},
})
.collect();
detection_result
}
async fn wait_for_result(
&self,
request_id: &str,
max_attempts: u64,
polling_interval: u64,
) -> Result<DetectionResult> {
let start_time = Instant::now();
for _ in 0..max_attempts {
let result = self.fetch_result(request_id).await?;
match result.status.as_str() {
"ANALYZING" | "DOWNLOADING" => sleep(Duration::from_millis(polling_interval)).await,
_ => {
return Ok(result);
}
}
}
Err(Error::UnknownError(format!(
"Timed out waiting for result after {} seconds",
(Instant::now() - start_time).as_secs()
)))
}
pub async fn process_batch(
&self,
file_paths: Vec<&str>,
options: BatchOptions,
) -> Result<Vec<DetectionResult>> {
if file_paths.is_empty() {
return Ok(Vec::new());
}
let max_concurrency = options.max_concurrency.unwrap_or(5);
let should_wait =
options.max_attempts.unwrap_or(0) > 0 && options.polling_interval.unwrap_or(0) > 0;
let uploads = future::join_all(
file_paths
.chunks(max_concurrency)
.map(|chunk| {
let chunk_futures = chunk.iter().map(|&path| {
let upload_options = UploadOptions {
file_path: path.to_string(),
};
self.upload(upload_options)
});
future::join_all(chunk_futures)
})
.collect::<Vec<_>>(),
)
.await
.into_iter()
.flatten()
.collect::<Vec<Result<UploadResult>>>();
let request_ids: Vec<String> = uploads
.into_iter()
.filter_map(|upload_result| match upload_result {
Ok(result) => Some(result.request_id),
Err(_) => None,
})
.collect();
if should_wait {
let get_options = GetResultOptions {
max_attempts: options.max_attempts,
polling_interval: options.polling_interval,
};
let results = future::join_all(
request_ids
.chunks(max_concurrency)
.map(|chunk| {
let chunk_futures = chunk
.iter()
.map(|id| self.get_result(id, Some(get_options.clone())));
future::join_all(chunk_futures)
})
.collect::<Vec<_>>(),
)
.await
.into_iter()
.flatten()
.collect::<Vec<Result<DetectionResult>>>();
Ok(results.into_iter().filter_map(|r| r.ok()).collect())
} else {
Ok(request_ids
.into_iter()
.map(|id| DetectionResult {
request_id: id,
status: "PROCESSING".to_string(),
score: None,
models: Vec::new(),
})
.collect())
}
}
pub async fn get_results(
&self,
options: Option<GetResultsOptions>,
) -> Result<FormattedDetectionResultList> {
let opts = options.unwrap_or_default();
let should_wait =
opts.max_attempts.unwrap_or(0) > 0 && opts.polling_interval.unwrap_or(0) > 0;
if should_wait {
self.wait_for_results(opts).await
} else {
self.fetch_results(opts).await
}
}
async fn fetch_results(
&self,
options: GetResultsOptions,
) -> Result<FormattedDetectionResultList> {
let page_number = options.page_number.unwrap_or(0);
let endpoint = format!("{}/{}", api_paths::ALL_MEDIA_RESULTS, page_number);
let mut params = Vec::new();
if let Some(size) = options.size {
params.push(("size", size.to_string()));
}
if let Some(ref name) = options.name {
params.push(("name", name.to_string()));
}
if let Some(ref start_date) = options.start_date {
params.push(("startDate", start_date.to_string()));
}
if let Some(ref end_date) = options.end_date {
params.push(("endDate", end_date.to_string()));
}
let param_refs: Vec<(&str, &str)> = params.iter().map(|(k, v)| (*k, v.as_str())).collect();
let raw_result = self
.http_client
.get_with_params::<DetectionResultList>(&endpoint, ¶m_refs)
.await?;
Ok(self.format_results_list(&raw_result))
}
async fn wait_for_results(
&self,
options: GetResultsOptions,
) -> Result<FormattedDetectionResultList> {
let max_attempts = options.max_attempts.unwrap_or(5);
let polling_interval = options.polling_interval.unwrap_or(2000);
let start_time = Instant::now();
for _ in 0..max_attempts {
let result = self.fetch_results(options.clone()).await?;
let still_analyzing = result.items.iter().any(|item| item.status == "ANALYZING");
if !still_analyzing {
return Ok(result);
}
sleep(Duration::from_millis(polling_interval)).await;
}
Err(Error::UnknownError(format!(
"Timed out waiting for results after {} seconds",
(Instant::now() - start_time).as_secs()
)))
}
fn format_results_list(
&self,
raw_result: &DetectionResultList,
) -> FormattedDetectionResultList {
let formatted_items = raw_result
.items
.iter()
.map(|item| self.normalize_scores(item))
.collect();
FormattedDetectionResultList {
total_items: raw_result.total_items,
total_pages: raw_result.total_pages,
current_page: raw_result.current_page,
current_page_items_count: raw_result.current_page_items_count,
items: formatted_items,
}
}
pub async fn detect_file(&self, file_path: &str) -> Result<DetectionResult> {
let upload_result = self
.upload(UploadOptions {
file_path: file_path.to_string(),
})
.await?;
self.get_result(
&upload_result.request_id,
Some(GetResultOptions {
max_attempts: Some(150),
polling_interval: Some(2000),
}),
)
.await
}
}
#[cfg(test)]
mod tests {
use crate::{BatchOptions, Client, Config, Error, GetResultOptions, UploadOptions};
use serde_json::json;
use std::fs::File;
use std::io::Write;
use tempfile::tempdir;
#[tokio::test]
async fn test_client_new() {
let client = Client::new(Config {
api_key: "test_api_key".to_string(),
..Default::default()
});
assert!(client.is_ok());
}
#[tokio::test]
async fn test_client_new_empty_api_key() {
let client = Client::new(Config {
api_key: "".to_string(),
..Default::default()
});
assert!(client.is_err());
}
#[tokio::test]
async fn test_get_result() {
let mut server = mockito::Server::new_async().await;
let request_id = "test_request_id";
let mock = server
.mock("GET", "/api/media/users/test_request_id")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"requestId": request_id,
"overallStatus": "COMPLETED",
"models": [
{
"name": "TestModel",
"status": "COMPLETED",
"predictionNumber": 0.27,
"normalizedPredictionNumber": 27,
"finalScore": null
},
{
"name": "TestModel2",
"status": "COMPLETED",
"predictionNumber": {
"reason": "relevance: no faces detected/faces too small",
"decision": "NOT_EVALUATED"
},
"normalizedPredictionNumber": null,
"rollingAvgNumber": null,
"finalScore": null
},
{
"name": "TestModel3",
"status": "NOT_APPLICABLE",
"predictionNumber": {
"reason": "relevance: no faces detected/faces too small",
"decision": "NOT_EVALUATED"
},
"normalizedPredictionNumber": null,
"rollingAvgNumber": null,
"finalScore": null
},
],
"resultsSummary": {
"status": "COMPLETED",
"metadata": {
"finalScore": 85
}
}
})
.to_string(),
)
.create_async()
.await;
let client = Client::new(Config {
api_key: "test_api_key".to_string(),
base_url: Some(server.url()),
..Default::default()
})
.unwrap();
let result = client.get_result(request_id, None).await.unwrap();
assert_eq!(result.request_id, request_id);
assert_eq!(result.status, "COMPLETED");
assert_eq!(result.score, Some(0.85));
assert_eq!(result.models.len(), 2);
assert_eq!(result.models[0].name, "TestModel");
assert_eq!(result.models[0].score, Some(0.27));
assert_eq!(result.models[0].status, "COMPLETED");
assert_eq!(result.models[1].name, "TestModel2");
assert_eq!(result.models[1].score, None);
assert_eq!(result.models[1].status, "COMPLETED");
mock.assert_async().await;
}
#[tokio::test]
async fn test_score_normalization() {
let mut server = mockito::Server::new_async().await;
let request_id = "test_normalize";
let mock1 = server
.mock("GET", "/api/media/users/test_normalize")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"requestId": request_id,
"overallStatus": "COMPLETED",
"models": [],
"resultsSummary": {
"status": "COMPLETED",
"metadata": {
"finalScore": 85
}
}
})
.to_string(),
)
.create_async()
.await;
let client = Client::new(Config {
api_key: "test_api_key".to_string(),
base_url: Some(server.url()),
..Default::default()
})
.unwrap();
let result = client.get_result(request_id, None).await.unwrap();
assert_eq!(result.score, Some(0.85));
mock1.assert_async().await;
let mock2 = server
.mock("GET", "/api/media/users/test_normalize")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"requestId": request_id,
"overallStatus": "COMPLETED",
"models": [
{
"name": "Model1",
"status": "COMPLETED",
"predictionNumber": 0.92
}
]
})
.to_string(),
)
.create_async()
.await;
let result = client.get_result(request_id, None).await.unwrap();
assert_eq!(result.models[0].score, Some(0.92));
mock2.assert_async().await;
let mock3 = server
.mock("GET", "/api/media/users/test_normalize")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"requestId": request_id,
"overallStatus": "COMPLETED",
"models": [
{
"name": "Model2",
"status": "COMPLETED",
"normalizedPredictionNumber": 80.0
}
]
})
.to_string(),
)
.create_async()
.await;
let result = client.get_result(request_id, None).await.unwrap();
assert_eq!(result.models[0].score, None);
mock3.assert_async().await;
let mock4 = server
.mock("GET", "/api/media/users/test_normalize")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"requestId": request_id,
"overallStatus": "COMPLETED",
"models": [
{
"name": "Model3",
"status": "COMPLETED",
"finalScore": 70.0
}
]
})
.to_string(),
)
.create_async()
.await;
let result = client.get_result(request_id, None).await.unwrap();
assert_eq!(result.models[0].score, None);
mock4.assert_async().await;
}
#[tokio::test]
async fn test_upload_with_invalid_file() {
let client = Client::new(Config {
api_key: "test_api_key".to_string(),
..Default::default()
})
.unwrap();
let result = client
.upload(UploadOptions {
file_path: "non_existent_file.jpg".to_string(),
})
.await;
assert!(result.is_err());
match result.unwrap_err() {
Error::InvalidFile(_) => {} err => panic!("Unexpected error: {:?}", err),
}
}
#[tokio::test]
async fn test_wait_for_result() {
let mut server = mockito::Server::new_async().await;
let request_id = "test-wait-request";
let mock1 = server
.mock("GET", format!("/api/media/users/{}", request_id).as_str())
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"requestId": request_id,
"overallStatus": "ANALYZING",
"models": []
})
.to_string(),
)
.create_async()
.await;
let mock2 = server
.mock("GET", format!("/api/media/users/{}", request_id).as_str())
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"requestId": request_id,
"overallStatus": "COMPLETED",
"finalScore": 75,
"models": []
})
.to_string(),
)
.create_async()
.await;
let client = Client::new(Config {
api_key: "test_api_key".to_string(),
base_url: Some(server.url()),
..Default::default()
})
.unwrap();
let result = client
.get_result(
request_id,
Some(GetResultOptions {
max_attempts: Some(5),
polling_interval: Some(1000), }),
)
.await;
assert!(result.is_ok());
let analysis_result = result.unwrap();
assert_eq!(analysis_result.request_id, request_id);
assert_eq!(analysis_result.status, "COMPLETED");
assert_eq!(analysis_result.score, Some(0.75));
mock1.assert_async().await;
mock2.assert_async().await;
}
#[tokio::test]
async fn test_process_batch_empty() {
let client = Client::new(Config {
api_key: "test_api_key".to_string(),
..Default::default()
})
.unwrap();
let result = client
.process_batch(
vec![],
BatchOptions {
max_concurrency: Some(2),
max_attempts: Some(10),
polling_interval: Some(1000),
},
)
.await;
assert!(result.is_ok());
let batch_results = result.unwrap();
assert!(batch_results.is_empty());
}
#[tokio::test]
async fn test_process_batch_without_waiting() {
let mut server = mockito::Server::new_async().await;
let dir = tempdir().unwrap();
let file_path1 = dir.path().join("test1.jpg");
let mut file1 = File::create(&file_path1).unwrap();
file1.write_all(b"test image data 1").unwrap();
let file_path2 = dir.path().join("test2.jpg");
let mut file2 = File::create(&file_path2).unwrap();
file2.write_all(b"test image data 2").unwrap();
let mock1 = server
.mock("POST", "/api/files/aws-presigned")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"code": "success",
"errno": 0,
"requestId": "test-request-id-1",
"mediaId": "test-media-id-1",
"response": {
"signedUrl": format!("{}/upload1", server.url())
}
})
.to_string(),
)
.create_async()
.await;
let mock2 = server
.mock("POST", "/api/files/aws-presigned")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"code": "success",
"errno": 0,
"requestId": "test-request-id-2",
"mediaId": "test-media-id-2",
"response": {
"signedUrl": format!("{}/upload2", server.url())
}
})
.to_string(),
)
.create_async()
.await;
let mock_upload1 = server
.mock("PUT", "/upload1")
.with_status(200)
.create_async()
.await;
let mock_upload2 = server
.mock("PUT", "/upload2")
.with_status(200)
.create_async()
.await;
let client = Client::new(Config {
api_key: "test_api_key".to_string(),
base_url: Some(server.url()),
..Default::default()
})
.unwrap();
let file_paths = vec![file_path1.to_str().unwrap(), file_path2.to_str().unwrap()];
let result = client
.process_batch(
file_paths,
BatchOptions {
max_concurrency: Some(2),
max_attempts: None,
polling_interval: None,
},
)
.await;
assert!(result.is_ok());
let batch_results = result.unwrap();
assert_eq!(batch_results.len(), 2);
assert_eq!(batch_results[0].request_id, "test-request-id-1");
assert_eq!(batch_results[0].status, "PROCESSING");
assert_eq!(batch_results[1].request_id, "test-request-id-2");
assert_eq!(batch_results[1].status, "PROCESSING");
mock1.assert_async().await;
mock2.assert_async().await;
mock_upload1.assert_async().await;
mock_upload2.assert_async().await;
}
#[tokio::test]
async fn test_detect_file() {
let mut server = mockito::Server::new_async().await;
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.jpg");
let mut file = File::create(&file_path).unwrap();
file.write_all(b"test image data").unwrap();
let mock1 = server
.mock("POST", "/api/files/aws-presigned")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"code": "success",
"errno": 0,
"requestId": "test-request-id",
"mediaId": "test-media-id",
"response": {
"signedUrl": format!("{}/upload", server.url())
}
})
.to_string(),
)
.create_async()
.await;
let mock2 = server
.mock("PUT", "/upload")
.with_status(200)
.create_async()
.await;
let mock3 = server
.mock("GET", "/api/media/users/test-request-id")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(
json!({
"requestId": "test-request-id",
"overallStatus": "COMPLETED",
"finalScore": 75,
"models": []
})
.to_string(),
)
.create_async()
.await;
let client = Client::new(Config {
api_key: "test_api_key".to_string(),
base_url: Some(server.url()),
..Default::default()
})
.unwrap();
let result = client.detect_file(file_path.to_str().unwrap()).await;
assert!(result.is_ok());
let analysis_result = result.unwrap();
assert_eq!(analysis_result.request_id, "test-request-id");
assert_eq!(analysis_result.status, "COMPLETED");
assert_eq!(analysis_result.score, Some(0.75));
mock1.assert_async().await;
mock2.assert_async().await;
mock3.assert_async().await;
}
}