oxirs_embed/api/
graphql.rs

1//! GraphQL API implementation
2//!
3//! This module provides GraphQL schema and resolvers for embedding services.
4
5#[cfg(feature = "graphql")]
6use super::ApiState;
7#[cfg(feature = "graphql")]
8use async_graphql::{
9    Context, EmptyMutation, EmptySubscription, Object, Result as GraphQLResult, Schema,
10    SimpleObject,
11};
12#[cfg(feature = "graphql")]
13use std::sync::Arc;
14#[cfg(feature = "graphql")]
15use uuid::Uuid;
16
17/// GraphQL representation of model information
18#[cfg(feature = "graphql")]
19#[derive(SimpleObject)]
20pub struct ModelInfo {
21    pub model_id: String,
22    pub name: String,
23    pub model_type: String,
24    pub is_loaded: bool,
25    pub is_trained: bool,
26    pub num_entities: i32,
27    pub num_relations: i32,
28    pub num_triples: i32,
29    pub dimensions: i32,
30    pub created_at: String, // ISO 8601 timestamp
31}
32
33/// GraphQL representation of system health
34#[cfg(feature = "graphql")]
35#[derive(SimpleObject)]
36pub struct SystemHealth {
37    pub status: String,
38    pub models_loaded: i32,
39    pub cache_hit_rate: f64,
40    pub memory_usage_mb: f64,
41    pub total_requests: i64,
42}
43
44/// GraphQL representation of cache statistics
45#[cfg(feature = "graphql")]
46#[derive(SimpleObject)]
47pub struct CacheStatistics {
48    pub total_hits: i64,
49    pub total_misses: i64,
50    pub hit_rate: f64,
51    pub memory_usage_bytes: i64,
52    pub time_saved_seconds: f64,
53}
54
55/// GraphQL representation of prediction result
56#[cfg(feature = "graphql")]
57#[derive(SimpleObject)]
58pub struct PredictionResult {
59    pub entity: String,
60    pub score: f64,
61}
62
63/// GraphQL Query root
64#[cfg(feature = "graphql")]
65pub struct Query;
66
67#[cfg(feature = "graphql")]
68#[Object]
69impl Query {
70    /// Get API version
71    async fn version(&self) -> &str {
72        "1.0.0"
73    }
74
75    /// Health check
76    async fn health(&self) -> &str {
77        "OK"
78    }
79
80    /// Get system health status
81    async fn system_health(&self, ctx: &Context<'_>) -> GraphQLResult<SystemHealth> {
82        let state = ctx.data::<Arc<ApiState>>()?;
83
84        let models = state.models.read().await;
85        let model_count = models.len() as i32;
86
87        let cache_stats = state.cache_manager.get_stats();
88        let cache_hit_rate = if cache_stats.total_hits + cache_stats.total_misses > 0 {
89            cache_stats.total_hits as f64
90                / (cache_stats.total_hits + cache_stats.total_misses) as f64
91        } else {
92            0.0
93        };
94
95        let memory_usage_mb =
96            state.cache_manager.estimate_memory_usage() as f64 / (1024.0 * 1024.0);
97
98        let status = if model_count > 0 && cache_hit_rate > 0.5 {
99            "healthy"
100        } else if model_count > 0 {
101            "degraded"
102        } else {
103            "unhealthy"
104        };
105
106        Ok(SystemHealth {
107            status: status.to_string(),
108            models_loaded: model_count,
109            cache_hit_rate,
110            memory_usage_mb,
111            total_requests: (cache_stats.total_hits + cache_stats.total_misses) as i64,
112        })
113    }
114
115    /// Get cache statistics
116    async fn cache_stats(&self, ctx: &Context<'_>) -> GraphQLResult<CacheStatistics> {
117        let state = ctx.data::<Arc<ApiState>>()?;
118        let cache_stats = state.cache_manager.get_stats();
119
120        Ok(CacheStatistics {
121            total_hits: cache_stats.total_hits as i64,
122            total_misses: cache_stats.total_misses as i64,
123            hit_rate: cache_stats.hit_rate,
124            memory_usage_bytes: cache_stats.memory_usage_bytes as i64,
125            time_saved_seconds: cache_stats.total_time_saved_seconds,
126        })
127    }
128
129    /// List all models
130    async fn models(&self, ctx: &Context<'_>) -> GraphQLResult<Vec<ModelInfo>> {
131        let state = ctx.data::<Arc<ApiState>>()?;
132
133        let registry_models = state.registry.list_models().await;
134        let loaded_models = state.models.read().await;
135
136        let mut model_list = Vec::new();
137        for model_metadata in registry_models {
138            let is_loaded = loaded_models.contains_key(&model_metadata.model_id);
139
140            let (is_trained, stats) =
141                if let Some(model) = loaded_models.get(&model_metadata.model_id) {
142                    let stats = model.get_stats();
143                    (model.is_trained(), stats)
144                } else {
145                    (false, Default::default())
146                };
147
148            model_list.push(ModelInfo {
149                model_id: model_metadata.model_id.to_string(),
150                name: model_metadata.name,
151                model_type: model_metadata.model_type,
152                is_loaded,
153                is_trained,
154                num_entities: stats.num_entities as i32,
155                num_relations: stats.num_relations as i32,
156                num_triples: stats.num_triples as i32,
157                dimensions: stats.dimensions as i32,
158                created_at: model_metadata.created_at.to_rfc3339(),
159            });
160        }
161
162        Ok(model_list)
163    }
164
165    /// Get specific model information
166    async fn model(&self, ctx: &Context<'_>, model_id: String) -> GraphQLResult<Option<ModelInfo>> {
167        let state = ctx.data::<Arc<ApiState>>()?;
168
169        let model_uuid = Uuid::parse_str(&model_id)
170            .map_err(|_| async_graphql::Error::new("Invalid model ID format"))?;
171
172        let model_metadata = match state.registry.get_model(model_uuid).await {
173            Ok(metadata) => metadata,
174            Err(_) => return Ok(None),
175        };
176
177        let loaded_models = state.models.read().await;
178        let is_loaded = loaded_models.contains_key(&model_uuid);
179
180        let (is_trained, stats) = if let Some(model) = loaded_models.get(&model_uuid) {
181            let stats = model.get_stats();
182            (model.is_trained(), stats)
183        } else {
184            (false, Default::default())
185        };
186
187        Ok(Some(ModelInfo {
188            model_id: model_metadata.model_id.to_string(),
189            name: model_metadata.name,
190            model_type: model_metadata.model_type,
191            is_loaded,
192            is_trained,
193            num_entities: stats.num_entities as i32,
194            num_relations: stats.num_relations as i32,
195            num_triples: stats.num_triples as i32,
196            dimensions: stats.dimensions as i32,
197            created_at: model_metadata.created_at.to_rfc3339(),
198        }))
199    }
200
201    /// Predict objects for given subject and predicate
202    async fn predict_objects(
203        &self,
204        ctx: &Context<'_>,
205        subject: String,
206        predicate: String,
207        top_k: Option<i32>,
208    ) -> GraphQLResult<Vec<PredictionResult>> {
209        let state = ctx.data::<Arc<ApiState>>()?;
210
211        // Get production model (simplified - would use the helper function)
212        let models = state.models.read().await;
213        let model = models
214            .values()
215            .next()
216            .ok_or_else(|| async_graphql::Error::new("No models available"))?;
217
218        if !model.is_trained() {
219            return Err(async_graphql::Error::new("Model is not trained"));
220        }
221
222        let k = top_k.unwrap_or(10) as usize;
223        let predictions = model
224            .predict_objects(&subject, &predicate, k)
225            .map_err(|e| async_graphql::Error::new(format!("Prediction failed: {}", e)))?;
226
227        Ok(predictions
228            .into_iter()
229            .map(|(entity, score)| PredictionResult { entity, score })
230            .collect())
231    }
232
233    /// Predict subjects for given predicate and object
234    async fn predict_subjects(
235        &self,
236        ctx: &Context<'_>,
237        predicate: String,
238        object: String,
239        top_k: Option<i32>,
240    ) -> GraphQLResult<Vec<PredictionResult>> {
241        let state = ctx.data::<Arc<ApiState>>()?;
242
243        let models = state.models.read().await;
244        let model = models
245            .values()
246            .next()
247            .ok_or_else(|| async_graphql::Error::new("No models available"))?;
248
249        if !model.is_trained() {
250            return Err(async_graphql::Error::new("Model is not trained"));
251        }
252
253        let k = top_k.unwrap_or(10) as usize;
254        let predictions = model
255            .predict_subjects(&predicate, &object, k)
256            .map_err(|e| async_graphql::Error::new(format!("Prediction failed: {}", e)))?;
257
258        Ok(predictions
259            .into_iter()
260            .map(|(entity, score)| PredictionResult { entity, score })
261            .collect())
262    }
263
264    /// Predict relations for given subject and object
265    async fn predict_relations(
266        &self,
267        ctx: &Context<'_>,
268        subject: String,
269        object: String,
270        top_k: Option<i32>,
271    ) -> GraphQLResult<Vec<PredictionResult>> {
272        let state = ctx.data::<Arc<ApiState>>()?;
273
274        let models = state.models.read().await;
275        let model = models
276            .values()
277            .next()
278            .ok_or_else(|| async_graphql::Error::new("No models available"))?;
279
280        if !model.is_trained() {
281            return Err(async_graphql::Error::new("Model is not trained"));
282        }
283
284        let k = top_k.unwrap_or(10) as usize;
285        let predictions = model
286            .predict_relations(&subject, &object, k)
287            .map_err(|e| async_graphql::Error::new(format!("Prediction failed: {}", e)))?;
288
289        Ok(predictions
290            .into_iter()
291            .map(|(entity, score)| PredictionResult { entity, score })
292            .collect())
293    }
294
295    /// Score a triple (subject, predicate, object)
296    async fn score_triple(
297        &self,
298        ctx: &Context<'_>,
299        subject: String,
300        predicate: String,
301        object: String,
302    ) -> GraphQLResult<f64> {
303        let state = ctx.data::<Arc<ApiState>>()?;
304
305        let models = state.models.read().await;
306        let model = models
307            .values()
308            .next()
309            .ok_or_else(|| async_graphql::Error::new("No models available"))?;
310
311        if !model.is_trained() {
312            return Err(async_graphql::Error::new("Model is not trained"));
313        }
314
315        let score = model
316            .score_triple(&subject, &predicate, &object)
317            .map_err(|e| async_graphql::Error::new(format!("Scoring failed: {}", e)))?;
318
319        Ok(score)
320    }
321}
322
323/// Create GraphQL schema with API state
324#[cfg(feature = "graphql")]
325pub fn create_schema() -> Schema<Query, EmptyMutation, EmptySubscription> {
326    Schema::build(Query, EmptyMutation, EmptySubscription).finish()
327}
328
329/// GraphQL handler for Axum
330#[cfg(all(feature = "graphql", feature = "api-server"))]
331pub async fn graphql_handler(
332    schema: axum::extract::Extension<Schema<Query, EmptyMutation, EmptySubscription>>,
333    state: axum::extract::Extension<Arc<ApiState>>,
334    req: axum::extract::Json<async_graphql::Request>,
335) -> axum::Json<async_graphql::Response> {
336    let request = req.0.data(state.0.clone());
337    let response = schema.execute(request).await;
338    axum::Json(response)
339}
340
341/// GraphiQL playground handler
342#[cfg(all(feature = "graphql", feature = "api-server"))]
343pub async fn graphiql() -> impl axum::response::IntoResponse {
344    axum::response::Html(
345        async_graphql::http::GraphiQLSource::build()
346            .endpoint("/graphql")
347            .finish(),
348    )
349}