1use crate::{CacheManager, ModelRegistry};
8use async_graphql::{
9 Context, Enum, FieldResult, InputObject, Object, Schema, SimpleObject, Subscription, Union, ID,
10};
11use chrono::Utc;
12use futures_util::Stream;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::pin::Pin;
16use std::sync::Arc;
17use tokio::sync::RwLock;
18use tokio_stream::{wrappers::BroadcastStream, StreamExt};
19use uuid::Uuid;
20
21pub type EmbeddingSchema = Schema<QueryRoot, MutationRoot, SubscriptionRoot>;
23
24pub struct QueryRoot;
26
27pub struct MutationRoot;
29
30pub struct SubscriptionRoot;
32
33pub struct GraphQLContext {
35 pub model_registry: Arc<ModelRegistry>,
36 pub cache_manager: Arc<CacheManager>,
37 pub event_broadcaster: Arc<RwLock<tokio::sync::broadcast::Sender<EmbeddingEvent>>>,
38}
39
40impl GraphQLContext {
41 pub fn new(model_registry: Arc<ModelRegistry>, cache_manager: Arc<CacheManager>) -> Self {
42 let (tx, _) = tokio::sync::broadcast::channel(1000);
43 Self {
44 model_registry,
45 cache_manager,
46 event_broadcaster: Arc::new(RwLock::new(tx)),
47 }
48 }
49}
50
51#[derive(SimpleObject)]
53pub struct EmbeddingResult {
54 pub entity_id: String,
55 pub embedding: Vec<f32>,
56 pub dimensions: i32,
57 pub model_name: String,
58 pub confidence: Option<f64>,
59 pub metadata: Option<HashMap<String, String>>,
60 pub timestamp: String, }
62
63#[derive(SimpleObject)]
65pub struct SimilarityResult {
66 pub entity_id: String,
67 pub similarity_score: f64,
68 pub embedding: Option<Vec<f32>>,
69 pub metadata: Option<HashMap<String, String>>,
70 pub distance_metric: String,
71}
72
73#[derive(SimpleObject)]
75pub struct BatchEmbeddingResult {
76 pub job_id: ID,
77 pub status: BatchStatus,
78 pub progress: f64,
79 pub total_entities: i32,
80 pub processed_entities: i32,
81 pub estimated_completion: Option<String>,
82 pub results: Vec<EmbeddingResult>,
83 pub errors: Vec<String>,
84}
85
86#[derive(SimpleObject)]
88pub struct ModelInfo {
89 pub id: ID,
90 pub name: String,
91 pub version: String,
92 pub model_type: ModelType,
93 pub dimensions: i32,
94 pub parameters: HashMap<String, String>,
95 pub performance_metrics: Option<PerformanceMetrics>,
96 pub created_at: String,
97 pub updated_at: String,
98}
99
100#[derive(SimpleObject)]
102pub struct PerformanceMetrics {
103 pub inference_latency_ms: f64,
104 pub throughput_embeddings_per_sec: f64,
105 pub memory_usage_mb: f64,
106 pub accuracy_score: Option<f64>,
107 pub quality_metrics: HashMap<String, f64>,
108}
109
110#[derive(SimpleObject)]
112pub struct AggregationResult {
113 pub field: String,
114 pub aggregation_type: AggregationType,
115 pub value: f64,
116 pub count: i32,
117 pub metadata: HashMap<String, String>,
118}
119
120#[derive(SimpleObject)]
122pub struct ClusteringResult {
123 pub cluster_id: i32,
124 pub centroid: Vec<f32>,
125 pub entities: Vec<String>,
126 pub cohesion_score: f64,
127 pub metadata: HashMap<String, String>,
128}
129
130#[derive(SimpleObject)]
132pub struct EmbeddingAnalytics {
133 pub total_embeddings: i32,
134 pub dimensions_distribution: Vec<DimensionStat>,
135 pub model_usage: Vec<ModelUsageStat>,
136 pub quality_trends: Vec<QualityTrend>,
137 pub performance_summary: PerformanceMetrics,
138 pub cache_statistics: CacheStats,
139}
140
141#[derive(SimpleObject)]
143pub struct DimensionStat {
144 pub dimensions: i32,
145 pub count: i32,
146 pub percentage: f64,
147}
148
149#[derive(SimpleObject)]
151pub struct ModelUsageStat {
152 pub model_name: String,
153 pub usage_count: i32,
154 pub success_rate: f64,
155 pub average_latency_ms: f64,
156}
157
158#[derive(SimpleObject)]
160pub struct QualityTrend {
161 pub timestamp: String,
162 pub quality_score: f64,
163 pub metric_name: String,
164}
165
166#[derive(SimpleObject)]
168pub struct CacheStats {
169 pub hit_rate: f64,
170 pub total_requests: i32,
171 pub cache_size_mb: f64,
172 pub evictions: i32,
173}
174
175#[derive(InputObject)]
178pub struct EmbeddingQueryInput {
179 pub entity_ids: Option<Vec<String>>,
180 pub model_name: Option<String>,
181 pub include_metadata: Option<bool>,
182 pub format: Option<EmbeddingFormat>,
183 pub filters: Option<EmbeddingFilters>,
184}
185
186#[derive(InputObject)]
188pub struct SimilaritySearchInput {
189 pub query_embedding: Option<Vec<f32>>,
190 pub query_entity_id: Option<String>,
191 pub model_name: String,
192 pub top_k: Option<i32>,
193 pub threshold: Option<f64>,
194 pub distance_metric: Option<DistanceMetric>,
195 pub filters: Option<SimilarityFilters>,
196}
197
198#[derive(InputObject)]
200pub struct BatchEmbeddingInput {
201 pub entity_ids: Vec<String>,
202 pub model_name: String,
203 pub chunk_size: Option<i32>,
204 pub priority: Option<BatchPriority>,
205 pub callback_url: Option<String>,
206 pub metadata: Option<HashMap<String, String>>,
207}
208
209#[derive(InputObject)]
211pub struct EmbeddingFilters {
212 pub dimensions: Option<IntRange>,
213 pub confidence: Option<FloatRange>,
214 pub created_after: Option<String>,
215 pub created_before: Option<String>,
216 pub has_metadata: Option<bool>,
217 pub metadata_filters: Option<HashMap<String, String>>,
218}
219
220#[derive(InputObject)]
222pub struct SimilarityFilters {
223 pub entity_types: Option<Vec<String>>,
224 pub exclude_entities: Option<Vec<String>>,
225 pub metadata_filters: Option<HashMap<String, String>>,
226 pub confidence_threshold: Option<f64>,
227}
228
229#[derive(InputObject)]
231pub struct AggregationInput {
232 pub field: String,
233 pub aggregation_type: AggregationType,
234 pub group_by: Option<Vec<String>>,
235 pub filters: Option<EmbeddingFilters>,
236}
237
238#[derive(InputObject)]
240pub struct ClusteringInput {
241 pub entity_ids: Option<Vec<String>>,
242 pub model_name: String,
243 pub num_clusters: Option<i32>,
244 pub algorithm: Option<ClusteringAlgorithm>,
245 pub distance_metric: Option<DistanceMetric>,
246}
247
248#[derive(InputObject)]
250pub struct TimeRange {
251 pub start: String,
252 pub end: String,
253}
254
255#[derive(InputObject)]
257pub struct IntRange {
258 pub min: Option<i32>,
259 pub max: Option<i32>,
260}
261
262#[derive(InputObject)]
263pub struct FloatRange {
264 pub min: Option<f64>,
265 pub max: Option<f64>,
266}
267
268#[derive(Enum, Copy, Clone, Eq, PartialEq)]
271pub enum ModelType {
272 Transformer,
273 TransE,
274 DistMult,
275 ComplEx,
276 RotatE,
277 QuatE,
278 GNN,
279 Custom,
280}
281
282#[derive(Enum, Copy, Clone, Eq, PartialEq)]
283pub enum EmbeddingFormat {
284 Dense,
285 Sparse,
286 Compressed,
287 Quantized,
288}
289
290#[derive(Enum, Copy, Clone, Eq, PartialEq)]
291pub enum DistanceMetric {
292 Cosine,
293 Euclidean,
294 Manhattan,
295 Jaccard,
296 Hamming,
297}
298
299#[derive(Enum, Copy, Clone, Eq, PartialEq, Serialize, Deserialize)]
300pub enum BatchStatus {
301 Pending,
302 Running,
303 Completed,
304 Failed,
305 Cancelled,
306}
307
308#[derive(Enum, Copy, Clone, Eq, PartialEq)]
309pub enum BatchPriority {
310 Low,
311 Normal,
312 High,
313 Critical,
314}
315
316#[derive(Enum, Copy, Clone, Eq, PartialEq)]
317pub enum AggregationType {
318 Count,
319 Sum,
320 Average,
321 Min,
322 Max,
323 StdDev,
324 Percentile,
325}
326
327#[derive(Enum, Copy, Clone, Eq, PartialEq)]
328pub enum ClusteringAlgorithm {
329 KMeans,
330 DBSCAN,
331 Hierarchical,
332 SpectralClustering,
333}
334
335#[derive(Clone, Serialize, Deserialize, Union)]
337pub enum EmbeddingEvent {
338 EmbeddingGenerated(EmbeddingGeneratedEvent),
339 BatchCompleted(BatchCompletedEvent),
340 ModelUpdated(ModelUpdatedEvent),
341 QualityAlert(QualityAlertEvent),
342}
343
344#[derive(Clone, Serialize, Deserialize, SimpleObject)]
345pub struct EmbeddingGeneratedEvent {
346 pub entity_id: String,
347 pub model_name: String,
348 pub timestamp: String,
349 pub quality_score: Option<f64>,
350}
351
352#[derive(Clone, Serialize, Deserialize, SimpleObject)]
353pub struct BatchCompletedEvent {
354 pub job_id: String,
355 pub status: BatchStatus,
356 pub processed_count: i32,
357 pub error_count: i32,
358 pub completion_time: String,
359}
360
361#[derive(Clone, Serialize, Deserialize, SimpleObject)]
362pub struct ModelUpdatedEvent {
363 pub model_name: String,
364 pub version: String,
365 pub update_type: String,
366 pub timestamp: String,
367}
368
369#[derive(Clone, Serialize, Deserialize, SimpleObject)]
370pub struct QualityAlertEvent {
371 pub alert_type: String,
372 pub severity: String,
373 pub message: String,
374 pub affected_entities: Vec<String>,
375 pub timestamp: String,
376}
377
378#[Object]
381impl QueryRoot {
382 async fn embeddings(
384 &self,
385 ctx: &Context<'_>,
386 input: EmbeddingQueryInput,
387 ) -> FieldResult<Vec<EmbeddingResult>> {
388 let _context = ctx.data::<GraphQLContext>()?;
389
390 let mut results = Vec::new();
392
393 if let Some(entity_ids) = input.entity_ids {
394 for entity_id in entity_ids {
395 results.push(EmbeddingResult {
397 entity_id: entity_id.clone(),
398 embedding: vec![0.1, 0.2, 0.3], dimensions: 3,
400 model_name: input
401 .model_name
402 .clone()
403 .unwrap_or_else(|| "default".to_string()),
404 confidence: Some(0.95),
405 metadata: None,
406 timestamp: Utc::now().to_rfc3339(),
407 });
408 }
409 }
410
411 Ok(results)
412 }
413
414 async fn similarity_search(
416 &self,
417 ctx: &Context<'_>,
418 _input: SimilaritySearchInput,
419 ) -> FieldResult<Vec<SimilarityResult>> {
420 let _context = ctx.data::<GraphQLContext>()?;
421
422 let results = vec![SimilarityResult {
424 entity_id: "similar_entity_1".to_string(),
425 similarity_score: 0.92,
426 embedding: Some(vec![0.1, 0.2, 0.3]),
427 metadata: None,
428 distance_metric: "cosine".to_string(),
429 }];
430
431 Ok(results)
432 }
433
434 async fn models(
436 &self,
437 ctx: &Context<'_>,
438 _names: Option<Vec<String>>,
439 ) -> FieldResult<Vec<ModelInfo>> {
440 let _context = ctx.data::<GraphQLContext>()?;
441
442 let models = vec![ModelInfo {
444 id: ID::from("model_1"),
445 name: "TransE".to_string(),
446 version: "1.0.0".to_string(),
447 model_type: ModelType::TransE,
448 dimensions: 128,
449 parameters: HashMap::new(),
450 performance_metrics: None,
451 created_at: Utc::now().to_rfc3339(),
452 updated_at: Utc::now().to_rfc3339(),
453 }];
454
455 Ok(models)
456 }
457
458 async fn aggregation(
460 &self,
461 ctx: &Context<'_>,
462 input: AggregationInput,
463 ) -> FieldResult<AggregationResult> {
464 let _context = ctx.data::<GraphQLContext>()?;
465
466 Ok(AggregationResult {
468 field: input.field,
469 aggregation_type: input.aggregation_type,
470 value: 42.0,
471 count: 100,
472 metadata: HashMap::new(),
473 })
474 }
475
476 async fn clustering(
478 &self,
479 ctx: &Context<'_>,
480 _input: ClusteringInput,
481 ) -> FieldResult<Vec<ClusteringResult>> {
482 let _context = ctx.data::<GraphQLContext>()?;
483
484 let results = vec![ClusteringResult {
486 cluster_id: 0,
487 centroid: vec![0.1, 0.2, 0.3],
488 entities: vec!["entity1".to_string(), "entity2".to_string()],
489 cohesion_score: 0.85,
490 metadata: HashMap::new(),
491 }];
492
493 Ok(results)
494 }
495
496 async fn analytics(
498 &self,
499 ctx: &Context<'_>,
500 _time_range: Option<TimeRange>,
501 ) -> FieldResult<EmbeddingAnalytics> {
502 let _context = ctx.data::<GraphQLContext>()?;
503
504 Ok(EmbeddingAnalytics {
506 total_embeddings: 10000,
507 dimensions_distribution: vec![
508 DimensionStat {
509 dimensions: 128,
510 count: 7000,
511 percentage: 70.0,
512 },
513 DimensionStat {
514 dimensions: 256,
515 count: 3000,
516 percentage: 30.0,
517 },
518 ],
519 model_usage: vec![],
520 quality_trends: vec![],
521 performance_summary: PerformanceMetrics {
522 inference_latency_ms: 25.5,
523 throughput_embeddings_per_sec: 1000.0,
524 memory_usage_mb: 512.0,
525 accuracy_score: Some(0.95),
526 quality_metrics: HashMap::new(),
527 },
528 cache_statistics: CacheStats {
529 hit_rate: 0.85,
530 total_requests: 50000,
531 cache_size_mb: 256.0,
532 evictions: 100,
533 },
534 })
535 }
536}
537
538#[Object]
539impl MutationRoot {
540 async fn start_batch_embedding(
542 &self,
543 ctx: &Context<'_>,
544 input: BatchEmbeddingInput,
545 ) -> FieldResult<BatchEmbeddingResult> {
546 let _context = ctx.data::<GraphQLContext>()?;
547
548 let job_id = Uuid::new_v4();
549
550 Ok(BatchEmbeddingResult {
552 job_id: ID::from(job_id.to_string()),
553 status: BatchStatus::Pending,
554 progress: 0.0,
555 total_entities: input.entity_ids.len() as i32,
556 processed_entities: 0,
557 estimated_completion: Some((Utc::now() + chrono::Duration::minutes(10)).to_rfc3339()),
558 results: vec![],
559 errors: vec![],
560 })
561 }
562
563 async fn cancel_batch_job(&self, ctx: &Context<'_>, _job_id: ID) -> FieldResult<bool> {
565 let _context = ctx.data::<GraphQLContext>()?;
566
567 Ok(true)
569 }
570
571 async fn update_model(
573 &self,
574 ctx: &Context<'_>,
575 model_name: String,
576 parameters: HashMap<String, String>,
577 ) -> FieldResult<ModelInfo> {
578 let _context = ctx.data::<GraphQLContext>()?;
579
580 Ok(ModelInfo {
582 id: ID::from("model_1"),
583 name: model_name,
584 version: "1.1.0".to_string(),
585 model_type: ModelType::TransE,
586 dimensions: 128,
587 parameters,
588 performance_metrics: None,
589 created_at: Utc::now().to_rfc3339(),
590 updated_at: Utc::now().to_rfc3339(),
591 })
592 }
593}
594
595#[Subscription]
596impl SubscriptionRoot {
597 async fn embedding_events(
599 &self,
600 ctx: &Context<'_>,
601 _entity_filter: Option<Vec<String>>,
602 ) -> Pin<Box<dyn Stream<Item = EmbeddingEvent> + Send>> {
603 let context = ctx.data::<GraphQLContext>().unwrap();
604 let rx = context.event_broadcaster.read().await.subscribe();
605
606 let stream = BroadcastStream::new(rx).filter_map(|result| result.ok());
607
608 Box::pin(stream)
609 }
610
611 async fn batch_updates(
613 &self,
614 ctx: &Context<'_>,
615 _job_id: Option<ID>,
616 ) -> Pin<Box<dyn Stream<Item = BatchCompletedEvent> + Send>> {
617 let context = ctx.data::<GraphQLContext>().unwrap();
618 let rx = context.event_broadcaster.read().await.subscribe();
619
620 let stream = BroadcastStream::new(rx).filter_map(|result| match result {
621 Ok(EmbeddingEvent::BatchCompleted(event)) => Some(event),
622 _ => None,
623 });
624
625 Box::pin(stream)
626 }
627
628 async fn quality_alerts(
630 &self,
631 ctx: &Context<'_>,
632 _severity_filter: Option<Vec<String>>,
633 ) -> Pin<Box<dyn Stream<Item = QualityAlertEvent> + Send>> {
634 let context = ctx.data::<GraphQLContext>().unwrap();
635 let rx = context.event_broadcaster.read().await.subscribe();
636
637 let stream = BroadcastStream::new(rx).filter_map(|result| match result {
638 Ok(EmbeddingEvent::QualityAlert(event)) => Some(event),
639 _ => None,
640 });
641
642 Box::pin(stream)
643 }
644}
645
646pub fn create_schema(context: GraphQLContext) -> EmbeddingSchema {
648 Schema::build(QueryRoot, MutationRoot, SubscriptionRoot)
649 .data(context)
650 .finish()
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656 use crate::ModelRegistry;
657
658 #[tokio::test]
659 async fn test_graphql_context_creation() {
660 let storage_path = tempfile::tempdir().unwrap().path().to_path_buf();
661 let model_registry = Arc::new(ModelRegistry::new(storage_path));
662 let cache_config = crate::caching::CacheConfig::default();
663 let cache_manager = Arc::new(CacheManager::new(cache_config));
664
665 let context = GraphQLContext::new(model_registry, cache_manager);
666 assert!(context.event_broadcaster.read().await.receiver_count() == 0);
667 }
668
669 #[tokio::test]
670 async fn test_schema_creation() {
671 let storage_path = tempfile::tempdir().unwrap().path().to_path_buf();
672 let model_registry = Arc::new(ModelRegistry::new(storage_path));
673 let cache_config = crate::caching::CacheConfig::default();
674 let cache_manager = Arc::new(CacheManager::new(cache_config));
675 let context = GraphQLContext::new(model_registry, cache_manager);
676
677 let schema = create_schema(context);
678 assert!(!schema.sdl().is_empty());
681 }
682}