1#[cfg(feature = "graphql")]
6use super::ApiState;
7#[cfg(feature = "graphql")]
8use async_graphql::{
9 Context, EmptyMutation, EmptySubscription, Object, Result as GraphQLResult, Schema,
10 SimpleObject,
11};
12#[cfg(feature = "graphql")]
13use std::sync::Arc;
14#[cfg(feature = "graphql")]
15use uuid::Uuid;
16
17#[cfg(feature = "graphql")]
19#[derive(SimpleObject)]
20pub struct ModelInfo {
21 pub model_id: String,
22 pub name: String,
23 pub model_type: String,
24 pub is_loaded: bool,
25 pub is_trained: bool,
26 pub num_entities: i32,
27 pub num_relations: i32,
28 pub num_triples: i32,
29 pub dimensions: i32,
30 pub created_at: String, }
32
33#[cfg(feature = "graphql")]
35#[derive(SimpleObject)]
36pub struct SystemHealth {
37 pub status: String,
38 pub models_loaded: i32,
39 pub cache_hit_rate: f64,
40 pub memory_usage_mb: f64,
41 pub total_requests: i64,
42}
43
44#[cfg(feature = "graphql")]
46#[derive(SimpleObject)]
47pub struct CacheStatistics {
48 pub total_hits: i64,
49 pub total_misses: i64,
50 pub hit_rate: f64,
51 pub memory_usage_bytes: i64,
52 pub time_saved_seconds: f64,
53}
54
55#[cfg(feature = "graphql")]
57#[derive(SimpleObject)]
58pub struct PredictionResult {
59 pub entity: String,
60 pub score: f64,
61}
62
63#[cfg(feature = "graphql")]
65pub struct Query;
66
67#[cfg(feature = "graphql")]
68#[Object]
69impl Query {
70 async fn version(&self) -> &str {
72 "1.0.0"
73 }
74
75 async fn health(&self) -> &str {
77 "OK"
78 }
79
80 async fn system_health(&self, ctx: &Context<'_>) -> GraphQLResult<SystemHealth> {
82 let state = ctx.data::<Arc<ApiState>>()?;
83
84 let models = state.models.read().await;
85 let model_count = models.len() as i32;
86
87 let cache_stats = state.cache_manager.get_stats();
88 let cache_hit_rate = if cache_stats.total_hits + cache_stats.total_misses > 0 {
89 cache_stats.total_hits as f64
90 / (cache_stats.total_hits + cache_stats.total_misses) as f64
91 } else {
92 0.0
93 };
94
95 let memory_usage_mb =
96 state.cache_manager.estimate_memory_usage() as f64 / (1024.0 * 1024.0);
97
98 let status = if model_count > 0 && cache_hit_rate > 0.5 {
99 "healthy"
100 } else if model_count > 0 {
101 "degraded"
102 } else {
103 "unhealthy"
104 };
105
106 Ok(SystemHealth {
107 status: status.to_string(),
108 models_loaded: model_count,
109 cache_hit_rate,
110 memory_usage_mb,
111 total_requests: (cache_stats.total_hits + cache_stats.total_misses) as i64,
112 })
113 }
114
115 async fn cache_stats(&self, ctx: &Context<'_>) -> GraphQLResult<CacheStatistics> {
117 let state = ctx.data::<Arc<ApiState>>()?;
118 let cache_stats = state.cache_manager.get_stats();
119
120 Ok(CacheStatistics {
121 total_hits: cache_stats.total_hits as i64,
122 total_misses: cache_stats.total_misses as i64,
123 hit_rate: cache_stats.hit_rate,
124 memory_usage_bytes: cache_stats.memory_usage_bytes as i64,
125 time_saved_seconds: cache_stats.total_time_saved_seconds,
126 })
127 }
128
129 async fn models(&self, ctx: &Context<'_>) -> GraphQLResult<Vec<ModelInfo>> {
131 let state = ctx.data::<Arc<ApiState>>()?;
132
133 let registry_models = state.registry.list_models().await;
134 let loaded_models = state.models.read().await;
135
136 let mut model_list = Vec::new();
137 for model_metadata in registry_models {
138 let is_loaded = loaded_models.contains_key(&model_metadata.model_id);
139
140 let (is_trained, stats) =
141 if let Some(model) = loaded_models.get(&model_metadata.model_id) {
142 let stats = model.get_stats();
143 (model.is_trained(), stats)
144 } else {
145 (false, Default::default())
146 };
147
148 model_list.push(ModelInfo {
149 model_id: model_metadata.model_id.to_string(),
150 name: model_metadata.name,
151 model_type: model_metadata.model_type,
152 is_loaded,
153 is_trained,
154 num_entities: stats.num_entities as i32,
155 num_relations: stats.num_relations as i32,
156 num_triples: stats.num_triples as i32,
157 dimensions: stats.dimensions as i32,
158 created_at: model_metadata.created_at.to_rfc3339(),
159 });
160 }
161
162 Ok(model_list)
163 }
164
165 async fn model(&self, ctx: &Context<'_>, model_id: String) -> GraphQLResult<Option<ModelInfo>> {
167 let state = ctx.data::<Arc<ApiState>>()?;
168
169 let model_uuid = Uuid::parse_str(&model_id)
170 .map_err(|_| async_graphql::Error::new("Invalid model ID format"))?;
171
172 let model_metadata = match state.registry.get_model(model_uuid).await {
173 Ok(metadata) => metadata,
174 Err(_) => return Ok(None),
175 };
176
177 let loaded_models = state.models.read().await;
178 let is_loaded = loaded_models.contains_key(&model_uuid);
179
180 let (is_trained, stats) = if let Some(model) = loaded_models.get(&model_uuid) {
181 let stats = model.get_stats();
182 (model.is_trained(), stats)
183 } else {
184 (false, Default::default())
185 };
186
187 Ok(Some(ModelInfo {
188 model_id: model_metadata.model_id.to_string(),
189 name: model_metadata.name,
190 model_type: model_metadata.model_type,
191 is_loaded,
192 is_trained,
193 num_entities: stats.num_entities as i32,
194 num_relations: stats.num_relations as i32,
195 num_triples: stats.num_triples as i32,
196 dimensions: stats.dimensions as i32,
197 created_at: model_metadata.created_at.to_rfc3339(),
198 }))
199 }
200
201 async fn predict_objects(
203 &self,
204 ctx: &Context<'_>,
205 subject: String,
206 predicate: String,
207 top_k: Option<i32>,
208 ) -> GraphQLResult<Vec<PredictionResult>> {
209 let state = ctx.data::<Arc<ApiState>>()?;
210
211 let models = state.models.read().await;
213 let model = models
214 .values()
215 .next()
216 .ok_or_else(|| async_graphql::Error::new("No models available"))?;
217
218 if !model.is_trained() {
219 return Err(async_graphql::Error::new("Model is not trained"));
220 }
221
222 let k = top_k.unwrap_or(10) as usize;
223 let predictions = model
224 .predict_objects(&subject, &predicate, k)
225 .map_err(|e| async_graphql::Error::new(format!("Prediction failed: {}", e)))?;
226
227 Ok(predictions
228 .into_iter()
229 .map(|(entity, score)| PredictionResult { entity, score })
230 .collect())
231 }
232
233 async fn predict_subjects(
235 &self,
236 ctx: &Context<'_>,
237 predicate: String,
238 object: String,
239 top_k: Option<i32>,
240 ) -> GraphQLResult<Vec<PredictionResult>> {
241 let state = ctx.data::<Arc<ApiState>>()?;
242
243 let models = state.models.read().await;
244 let model = models
245 .values()
246 .next()
247 .ok_or_else(|| async_graphql::Error::new("No models available"))?;
248
249 if !model.is_trained() {
250 return Err(async_graphql::Error::new("Model is not trained"));
251 }
252
253 let k = top_k.unwrap_or(10) as usize;
254 let predictions = model
255 .predict_subjects(&predicate, &object, k)
256 .map_err(|e| async_graphql::Error::new(format!("Prediction failed: {}", e)))?;
257
258 Ok(predictions
259 .into_iter()
260 .map(|(entity, score)| PredictionResult { entity, score })
261 .collect())
262 }
263
264 async fn predict_relations(
266 &self,
267 ctx: &Context<'_>,
268 subject: String,
269 object: String,
270 top_k: Option<i32>,
271 ) -> GraphQLResult<Vec<PredictionResult>> {
272 let state = ctx.data::<Arc<ApiState>>()?;
273
274 let models = state.models.read().await;
275 let model = models
276 .values()
277 .next()
278 .ok_or_else(|| async_graphql::Error::new("No models available"))?;
279
280 if !model.is_trained() {
281 return Err(async_graphql::Error::new("Model is not trained"));
282 }
283
284 let k = top_k.unwrap_or(10) as usize;
285 let predictions = model
286 .predict_relations(&subject, &object, k)
287 .map_err(|e| async_graphql::Error::new(format!("Prediction failed: {}", e)))?;
288
289 Ok(predictions
290 .into_iter()
291 .map(|(entity, score)| PredictionResult { entity, score })
292 .collect())
293 }
294
295 async fn score_triple(
297 &self,
298 ctx: &Context<'_>,
299 subject: String,
300 predicate: String,
301 object: String,
302 ) -> GraphQLResult<f64> {
303 let state = ctx.data::<Arc<ApiState>>()?;
304
305 let models = state.models.read().await;
306 let model = models
307 .values()
308 .next()
309 .ok_or_else(|| async_graphql::Error::new("No models available"))?;
310
311 if !model.is_trained() {
312 return Err(async_graphql::Error::new("Model is not trained"));
313 }
314
315 let score = model
316 .score_triple(&subject, &predicate, &object)
317 .map_err(|e| async_graphql::Error::new(format!("Scoring failed: {}", e)))?;
318
319 Ok(score)
320 }
321}
322
323#[cfg(feature = "graphql")]
325pub fn create_schema() -> Schema<Query, EmptyMutation, EmptySubscription> {
326 Schema::build(Query, EmptyMutation, EmptySubscription).finish()
327}
328
329#[cfg(all(feature = "graphql", feature = "api-server"))]
331pub async fn graphql_handler(
332 schema: axum::extract::Extension<Schema<Query, EmptyMutation, EmptySubscription>>,
333 state: axum::extract::Extension<Arc<ApiState>>,
334 req: axum::extract::Json<async_graphql::Request>,
335) -> axum::Json<async_graphql::Response> {
336 let request = req.0.data(state.0.clone());
337 let response = schema.execute(request).await;
338 axum::Json(response)
339}
340
341#[cfg(all(feature = "graphql", feature = "api-server"))]
343pub async fn graphiql() -> impl axum::response::IntoResponse {
344 axum::response::Html(
345 async_graphql::http::GraphiQLSource::build()
346 .endpoint("/graphql")
347 .finish(),
348 )
349}