1use async_trait::async_trait;
51use serde::{Deserialize, Serialize};
52use std::collections::HashMap;
53use std::sync::Arc;
54use thiserror::Error;
55use tokio::sync::RwLock;
56use voirs_sdk::{AudioBuffer, VoirsError};
57
58use crate::quality::QualityEvaluator;
59use crate::traits::{
60 QualityEvaluationConfig, QualityEvaluator as QualityEvaluatorTrait, QualityScore,
61};
62
63#[derive(Error, Debug)]
65pub enum GraphQLError {
66 #[error("Query execution error: {message}")]
68 QueryError {
69 message: String,
71 },
72
73 #[error("Invalid query syntax: {message}")]
75 SyntaxError {
76 message: String,
78 },
79
80 #[error("Authorization error: {message}")]
82 AuthError {
83 message: String,
85 },
86
87 #[error("Resource not found: {resource_type} with id '{id}'")]
89 NotFound {
90 resource_type: String,
92 id: String,
94 },
95
96 #[error("VoiRS error: {0}")]
98 VoirsError(#[from] VoirsError),
99
100 #[error("Evaluation error: {0}")]
102 EvaluationError(#[from] crate::EvaluationError),
103
104 #[error("Serialization error: {0}")]
106 SerializationError(#[from] serde_json::Error),
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct EvaluationResult {
112 pub id: String,
114 pub quality_score: f32,
116 pub status: EvaluationStatus,
118 pub timestamp: u64,
120 pub metrics: Option<MetricsData>,
122 pub dataset: Option<DatasetInfo>,
124 pub model: Option<ModelInfo>,
126}
127
128#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
130pub enum EvaluationStatus {
131 Pending,
133 Running,
135 Completed,
137 Failed,
139 Cancelled,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct MetricsData {
146 pub pesq: Option<f32>,
148 pub stoi: Option<f32>,
150 pub mcd: Option<f32>,
152 pub custom: HashMap<String, f32>,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct DatasetInfo {
159 pub id: String,
161 pub name: String,
163 pub sample_count: usize,
165 pub language: String,
167}
168
169#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct ModelInfo {
172 pub id: String,
174 pub name: String,
176 pub version: String,
178 pub architecture: String,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct EvaluationFilter {
185 pub min_quality: Option<f32>,
187 pub max_quality: Option<f32>,
189 pub language: Option<String>,
191 pub status: Option<EvaluationStatus>,
193 pub dataset_id: Option<String>,
195 pub model_id: Option<String>,
197 pub from_timestamp: Option<u64>,
199 pub to_timestamp: Option<u64>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct Pagination {
206 pub limit: Option<usize>,
208 pub offset: Option<usize>,
210 pub sort_by: Option<String>,
212 pub sort_desc: Option<bool>,
214}
215
216impl Default for Pagination {
217 fn default() -> Self {
218 Self {
219 limit: Some(100),
220 offset: Some(0),
221 sort_by: None,
222 sort_desc: Some(false),
223 }
224 }
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct PaginatedResults<T> {
230 pub items: Vec<T>,
232 pub total_count: usize,
234 pub has_next_page: bool,
236 pub page_info: PageInfo,
238}
239
240#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct PageInfo {
243 pub offset: usize,
245 pub limit: usize,
247 pub total: usize,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct EvaluationInput {
254 pub audio_data: String,
256 pub reference_id: Option<String>,
258 pub language: Option<String>,
260 pub dataset_id: Option<String>,
262 pub model_id: Option<String>,
264 pub parameters: Option<HashMap<String, serde_json::Value>>,
266}
267
268#[derive(Debug, Clone, Serialize, Deserialize)]
270pub struct BatchEvaluationInput {
271 pub evaluations: Vec<EvaluationInput>,
273 pub parallel: Option<usize>,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct AggregationResult {
280 pub avg_quality: f32,
282 pub min_quality: f32,
284 pub max_quality: f32,
286 pub std_dev: f32,
288 pub count: usize,
290 pub percentiles: HashMap<String, f32>,
292}
293
294#[async_trait]
296pub trait GraphQLSchema: Send + Sync {
297 async fn query_evaluations(
299 &self,
300 filter: Option<EvaluationFilter>,
301 pagination: Option<Pagination>,
302 ) -> Result<PaginatedResults<EvaluationResult>, GraphQLError>;
303
304 async fn get_evaluation(&self, id: &str) -> Result<EvaluationResult, GraphQLError>;
306
307 async fn evaluate_audio(
309 &self,
310 input: EvaluationInput,
311 ) -> Result<EvaluationResult, GraphQLError>;
312
313 async fn evaluate_batch(
315 &self,
316 input: BatchEvaluationInput,
317 ) -> Result<Vec<EvaluationResult>, GraphQLError>;
318
319 async fn get_aggregations(
321 &self,
322 filter: Option<EvaluationFilter>,
323 ) -> Result<AggregationResult, GraphQLError>;
324
325 async fn cancel_evaluation(&self, id: &str) -> Result<bool, GraphQLError>;
327
328 async fn delete_evaluation(&self, id: &str) -> Result<bool, GraphQLError>;
330}
331
332pub struct GraphQLService {
334 evaluator: Arc<RwLock<QualityEvaluator>>,
335 evaluations: Arc<RwLock<HashMap<String, EvaluationResult>>>,
336}
337
338impl GraphQLService {
339 pub async fn new() -> Result<Self, GraphQLError> {
341 let evaluator = QualityEvaluator::new().await?;
342
343 Ok(Self {
344 evaluator: Arc::new(RwLock::new(evaluator)),
345 evaluations: Arc::new(RwLock::new(HashMap::new())),
346 })
347 }
348
349 fn apply_filter(
351 &self,
352 results: &[EvaluationResult],
353 filter: &EvaluationFilter,
354 ) -> Vec<EvaluationResult> {
355 results
356 .iter()
357 .filter(|eval| {
358 if let Some(min_q) = filter.min_quality {
360 if eval.quality_score < min_q {
361 return false;
362 }
363 }
364
365 if let Some(max_q) = filter.max_quality {
367 if eval.quality_score > max_q {
368 return false;
369 }
370 }
371
372 if let Some(ref status) = filter.status {
374 if &eval.status != status {
375 return false;
376 }
377 }
378
379 if let Some(ref lang) = filter.language {
381 if let Some(ref dataset) = eval.dataset {
382 if &dataset.language != lang {
383 return false;
384 }
385 } else {
386 return false;
387 }
388 }
389
390 if let Some(ref dataset_id) = filter.dataset_id {
392 if let Some(ref dataset) = eval.dataset {
393 if &dataset.id != dataset_id {
394 return false;
395 }
396 } else {
397 return false;
398 }
399 }
400
401 if let Some(from_ts) = filter.from_timestamp {
403 if eval.timestamp < from_ts {
404 return false;
405 }
406 }
407
408 if let Some(to_ts) = filter.to_timestamp {
409 if eval.timestamp > to_ts {
410 return false;
411 }
412 }
413
414 true
415 })
416 .cloned()
417 .collect()
418 }
419
420 fn apply_pagination(
422 &self,
423 mut results: Vec<EvaluationResult>,
424 pagination: &Pagination,
425 ) -> (Vec<EvaluationResult>, PageInfo) {
426 let total = results.len();
427
428 if let Some(ref sort_by) = pagination.sort_by {
430 let sort_desc = pagination.sort_desc.unwrap_or(false);
431 match sort_by.as_str() {
432 "quality_score" => {
433 results.sort_by(|a, b| {
434 if sort_desc {
435 b.quality_score.partial_cmp(&a.quality_score).unwrap()
436 } else {
437 a.quality_score.partial_cmp(&b.quality_score).unwrap()
438 }
439 });
440 }
441 "timestamp" => {
442 results.sort_by(|a, b| {
443 if sort_desc {
444 b.timestamp.cmp(&a.timestamp)
445 } else {
446 a.timestamp.cmp(&b.timestamp)
447 }
448 });
449 }
450 _ => {}
451 }
452 }
453
454 let offset = pagination.offset.unwrap_or(0);
456 let limit = pagination.limit.unwrap_or(100);
457
458 let items: Vec<EvaluationResult> = results.into_iter().skip(offset).take(limit).collect();
459
460 let page_info = PageInfo {
461 offset,
462 limit,
463 total,
464 };
465
466 (items, page_info)
467 }
468
469 fn calculate_aggregations(
471 &self,
472 results: &[EvaluationResult],
473 ) -> Result<AggregationResult, GraphQLError> {
474 if results.is_empty() {
475 return Ok(AggregationResult {
476 avg_quality: 0.0,
477 min_quality: 0.0,
478 max_quality: 0.0,
479 std_dev: 0.0,
480 count: 0,
481 percentiles: HashMap::new(),
482 });
483 }
484
485 let scores: Vec<f32> = results.iter().map(|r| r.quality_score).collect();
486
487 let avg_quality = scores.iter().sum::<f32>() / scores.len() as f32;
488 let min_quality = scores.iter().fold(f32::INFINITY, |a, &b| a.min(b));
489 let max_quality = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
490
491 let variance = scores
492 .iter()
493 .map(|s| (s - avg_quality).powi(2))
494 .sum::<f32>()
495 / scores.len() as f32;
496 let std_dev = variance.sqrt();
497
498 let mut sorted_scores = scores.clone();
500 sorted_scores.sort_by(|a, b| a.partial_cmp(b).unwrap());
501
502 let mut percentiles = HashMap::new();
503 for p in [25, 50, 75, 90, 95, 99] {
504 let idx = ((p as f32 / 100.0) * sorted_scores.len() as f32) as usize;
505 let idx = idx.min(sorted_scores.len() - 1);
506 percentiles.insert(format!("p{}", p), sorted_scores[idx]);
507 }
508
509 Ok(AggregationResult {
510 avg_quality,
511 min_quality,
512 max_quality,
513 std_dev,
514 count: results.len(),
515 percentiles,
516 })
517 }
518}
519
520#[async_trait]
521impl GraphQLSchema for GraphQLService {
522 async fn query_evaluations(
523 &self,
524 filter: Option<EvaluationFilter>,
525 pagination: Option<Pagination>,
526 ) -> Result<PaginatedResults<EvaluationResult>, GraphQLError> {
527 let evaluations = self.evaluations.read().await;
528 let all_results: Vec<EvaluationResult> = evaluations.values().cloned().collect();
529
530 let filtered_results = if let Some(f) = filter {
532 self.apply_filter(&all_results, &f)
533 } else {
534 all_results
535 };
536
537 let pagination = pagination.unwrap_or_default();
539 let (items, page_info) = self.apply_pagination(filtered_results, &pagination);
540
541 let has_next_page = page_info.offset + items.len() < page_info.total;
542
543 Ok(PaginatedResults {
544 items,
545 total_count: page_info.total,
546 has_next_page,
547 page_info,
548 })
549 }
550
551 async fn get_evaluation(&self, id: &str) -> Result<EvaluationResult, GraphQLError> {
552 let evaluations = self.evaluations.read().await;
553 evaluations
554 .get(id)
555 .cloned()
556 .ok_or_else(|| GraphQLError::NotFound {
557 resource_type: "Evaluation".to_string(),
558 id: id.to_string(),
559 })
560 }
561
562 async fn evaluate_audio(
563 &self,
564 input: EvaluationInput,
565 ) -> Result<EvaluationResult, GraphQLError> {
566 let audio_bytes =
568 base64::decode(&input.audio_data).map_err(|e| GraphQLError::QueryError {
569 message: format!("Invalid base64 audio data: {}", e),
570 })?;
571
572 let audio = AudioBuffer::new(vec![0.1; 16000], 16000, 1);
574
575 let evaluator = self.evaluator.read().await;
577 let config = QualityEvaluationConfig::default();
578 let quality = evaluator
579 .evaluate_quality(&audio, None, Some(&config))
580 .await?;
581
582 let id = uuid::Uuid::new_v4().to_string();
584 let timestamp = std::time::SystemTime::now()
585 .duration_since(std::time::UNIX_EPOCH)
586 .unwrap()
587 .as_secs();
588
589 let result = EvaluationResult {
590 id: id.clone(),
591 quality_score: quality.overall_score,
592 status: EvaluationStatus::Completed,
593 timestamp,
594 metrics: Some(MetricsData {
595 pesq: None,
596 stoi: None,
597 mcd: None,
598 custom: HashMap::new(),
599 }),
600 dataset: input.dataset_id.map(|dataset_id| DatasetInfo {
601 id: dataset_id,
602 name: "Sample Dataset".to_string(),
603 sample_count: 100,
604 language: input.language.unwrap_or_else(|| "en-US".to_string()),
605 }),
606 model: input.model_id.map(|model_id| ModelInfo {
607 id: model_id,
608 name: "Sample Model".to_string(),
609 version: "1.0.0".to_string(),
610 architecture: "VITS".to_string(),
611 }),
612 };
613
614 let mut evaluations = self.evaluations.write().await;
616 evaluations.insert(id, result.clone());
617
618 Ok(result)
619 }
620
621 async fn evaluate_batch(
622 &self,
623 input: BatchEvaluationInput,
624 ) -> Result<Vec<EvaluationResult>, GraphQLError> {
625 let mut results = Vec::new();
626
627 for eval_input in input.evaluations {
628 let result = self.evaluate_audio(eval_input).await?;
629 results.push(result);
630 }
631
632 Ok(results)
633 }
634
635 async fn get_aggregations(
636 &self,
637 filter: Option<EvaluationFilter>,
638 ) -> Result<AggregationResult, GraphQLError> {
639 let evaluations = self.evaluations.read().await;
640 let all_results: Vec<EvaluationResult> = evaluations.values().cloned().collect();
641
642 let filtered_results = if let Some(f) = filter {
644 self.apply_filter(&all_results, &f)
645 } else {
646 all_results
647 };
648
649 self.calculate_aggregations(&filtered_results)
650 }
651
652 async fn cancel_evaluation(&self, id: &str) -> Result<bool, GraphQLError> {
653 let mut evaluations = self.evaluations.write().await;
654 if let Some(eval) = evaluations.get_mut(id) {
655 if eval.status == EvaluationStatus::Pending || eval.status == EvaluationStatus::Running
656 {
657 eval.status = EvaluationStatus::Cancelled;
658 Ok(true)
659 } else {
660 Ok(false)
661 }
662 } else {
663 Err(GraphQLError::NotFound {
664 resource_type: "Evaluation".to_string(),
665 id: id.to_string(),
666 })
667 }
668 }
669
670 async fn delete_evaluation(&self, id: &str) -> Result<bool, GraphQLError> {
671 let mut evaluations = self.evaluations.write().await;
672 Ok(evaluations.remove(id).is_some())
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679
680 #[tokio::test]
681 async fn test_graphql_service_creation() {
682 let service = GraphQLService::new().await;
683 assert!(service.is_ok());
684 }
685
686 #[tokio::test]
687 async fn test_query_evaluations_empty() {
688 let service = GraphQLService::new().await.unwrap();
689 let results = service.query_evaluations(None, None).await.unwrap();
690 assert_eq!(results.total_count, 0);
691 assert!(results.items.is_empty());
692 }
693
694 #[tokio::test]
695 async fn test_evaluate_audio() {
696 let service = GraphQLService::new().await.unwrap();
697 let input = EvaluationInput {
698 audio_data: base64::encode("test_audio_data"),
699 reference_id: None,
700 language: Some("en-US".to_string()),
701 dataset_id: Some("dataset1".to_string()),
702 model_id: Some("model1".to_string()),
703 parameters: None,
704 };
705
706 let result = service.evaluate_audio(input).await;
707 assert!(result.is_ok());
708
709 let eval = result.unwrap();
710 assert_eq!(eval.status, EvaluationStatus::Completed);
711 assert!(eval.dataset.is_some());
712 assert!(eval.model.is_some());
713 }
714
715 #[tokio::test]
716 async fn test_get_evaluation() {
717 let service = GraphQLService::new().await.unwrap();
718 let input = EvaluationInput {
719 audio_data: base64::encode("test"),
720 reference_id: None,
721 language: Some("en-US".to_string()),
722 dataset_id: None,
723 model_id: None,
724 parameters: None,
725 };
726
727 let created = service.evaluate_audio(input).await.unwrap();
728 let retrieved = service.get_evaluation(&created.id).await.unwrap();
729
730 assert_eq!(created.id, retrieved.id);
731 assert_eq!(created.quality_score, retrieved.quality_score);
732 }
733
734 #[tokio::test]
735 async fn test_pagination() {
736 let service = GraphQLService::new().await.unwrap();
737
738 for _ in 0..5 {
740 let input = EvaluationInput {
741 audio_data: base64::encode("test"),
742 reference_id: None,
743 language: None,
744 dataset_id: None,
745 model_id: None,
746 parameters: None,
747 };
748 service.evaluate_audio(input).await.unwrap();
749 }
750
751 let pagination = Pagination {
752 limit: Some(2),
753 offset: Some(0),
754 sort_by: None,
755 sort_desc: None,
756 };
757
758 let results = service
759 .query_evaluations(None, Some(pagination))
760 .await
761 .unwrap();
762 assert_eq!(results.items.len(), 2);
763 assert_eq!(results.total_count, 5);
764 assert!(results.has_next_page);
765 }
766
767 #[tokio::test]
768 async fn test_filter_by_quality() {
769 let service = GraphQLService::new().await.unwrap();
770
771 for _ in 0..3 {
773 let input = EvaluationInput {
774 audio_data: base64::encode("test"),
775 reference_id: None,
776 language: None,
777 dataset_id: None,
778 model_id: None,
779 parameters: None,
780 };
781 service.evaluate_audio(input).await.unwrap();
782 }
783
784 let filter = EvaluationFilter {
785 min_quality: Some(0.0),
786 max_quality: Some(5.0),
787 language: None,
788 status: None,
789 dataset_id: None,
790 model_id: None,
791 from_timestamp: None,
792 to_timestamp: None,
793 };
794
795 let results = service.query_evaluations(Some(filter), None).await.unwrap();
796 assert_eq!(results.total_count, 3);
797 }
798
799 #[tokio::test]
800 async fn test_aggregations() {
801 let service = GraphQLService::new().await.unwrap();
802
803 for _ in 0..10 {
805 let input = EvaluationInput {
806 audio_data: base64::encode("test"),
807 reference_id: None,
808 language: None,
809 dataset_id: None,
810 model_id: None,
811 parameters: None,
812 };
813 service.evaluate_audio(input).await.unwrap();
814 }
815
816 let agg = service.get_aggregations(None).await.unwrap();
817 assert_eq!(agg.count, 10);
818 assert!(agg.avg_quality >= 0.0);
819 assert!(agg.std_dev >= 0.0);
820 assert!(!agg.percentiles.is_empty());
821 }
822
823 #[tokio::test]
824 async fn test_cancel_evaluation() {
825 let service = GraphQLService::new().await.unwrap();
826 let input = EvaluationInput {
827 audio_data: base64::encode("test"),
828 reference_id: None,
829 language: None,
830 dataset_id: None,
831 model_id: None,
832 parameters: None,
833 };
834
835 let eval = service.evaluate_audio(input).await.unwrap();
836
837 let cancelled = service.cancel_evaluation(&eval.id).await.unwrap();
839 assert!(!cancelled);
840 }
841
842 #[tokio::test]
843 async fn test_delete_evaluation() {
844 let service = GraphQLService::new().await.unwrap();
845 let input = EvaluationInput {
846 audio_data: base64::encode("test"),
847 reference_id: None,
848 language: None,
849 dataset_id: None,
850 model_id: None,
851 parameters: None,
852 };
853
854 let eval = service.evaluate_audio(input).await.unwrap();
855 let deleted = service.delete_evaluation(&eval.id).await.unwrap();
856 assert!(deleted);
857
858 let result = service.get_evaluation(&eval.id).await;
860 assert!(result.is_err());
861 }
862}