Skip to main content

fraiseql_core/federation/
entity_resolver.rs

1//! Entity resolution for federation _entities query.
2
3use std::{
4    collections::{HashMap, HashSet},
5    sync::Arc,
6    time::Instant,
7};
8
9use serde_json::Value;
10use tracing::info;
11use uuid::Uuid;
12
13use super::{
14    database_resolver::DatabaseEntityResolver,
15    logging::{FederationLogContext, FederationOperationType, ResolutionStrategy},
16    selection_parser::FieldSelection,
17    tracing::{FederationSpan, FederationTraceContext},
18    types::{EntityRepresentation, FederationResolver},
19};
20use crate::{db::traits::DatabaseAdapter, error::Result};
21
22/// Result of entity resolution
23#[derive(Debug)]
24pub struct EntityResolutionResult {
25    /// Resolved entities in same order as input representations
26    pub entities: Vec<Option<Value>>,
27
28    /// Any errors encountered during resolution
29    pub errors: Vec<String>,
30}
31
32/// Result of batch entity resolution with timing information
33#[derive(Debug)]
34pub struct EntityResolutionMetrics {
35    /// Resolved entities in same order as input representations
36    pub entities:    Vec<Option<Value>>,
37    /// Any errors encountered during resolution
38    pub errors:      Vec<String>,
39    /// Duration of resolution in microseconds
40    pub duration_us: u64,
41    /// Whether resolution succeeded (no errors)
42    pub success:     bool,
43}
44
45/// Deduplicate entity representations while preserving order
46pub fn deduplicate_representations(reps: &[EntityRepresentation]) -> Vec<EntityRepresentation> {
47    let mut seen = HashSet::new();
48    let mut result = Vec::with_capacity(reps.len());
49
50    for rep in reps {
51        // Create a key from typename + key_fields
52        let key = format!("{}:{:?}", rep.typename, rep.key_fields);
53        if seen.insert(key) {
54            result.push(rep.clone());
55        }
56    }
57
58    result
59}
60
61/// Group entities by typename and strategy
62pub fn group_entities_by_typename(
63    reps: &[EntityRepresentation],
64) -> HashMap<String, Vec<EntityRepresentation>> {
65    let mut groups: HashMap<String, Vec<EntityRepresentation>> = HashMap::new();
66
67    for rep in reps {
68        groups.entry(rep.typename.clone()).or_insert_with(Vec::new).push(rep.clone());
69    }
70
71    groups
72}
73
74/// Construct WHERE clause for batch query
75pub fn construct_batch_where_clause(
76    representations: &[EntityRepresentation],
77    key_columns: &[String],
78) -> Result<String> {
79    if representations.is_empty() || key_columns.is_empty() {
80        return Ok(String::new());
81    }
82
83    let mut conditions = Vec::new();
84
85    for key_col in key_columns {
86        let values: Vec<String> = representations
87            .iter()
88            .filter_map(|rep| rep.key_fields.get(key_col))
89            .filter_map(|v| v.as_str())
90            .map(|s| format!("'{}'", s.replace('\'', "''")))
91            .collect();
92
93        if !values.is_empty() && !values.iter().all(|v| v == "''") {
94            conditions.push(format!("{} IN ({})", key_col, values.join(", ")));
95        }
96    }
97
98    if conditions.is_empty() {
99        Ok(String::new())
100    } else {
101        Ok(format!("WHERE {}", conditions.join(" AND ")))
102    }
103}
104
105/// Resolve entities for a specific typename from local database
106pub async fn resolve_entities_from_db<A: DatabaseAdapter>(
107    representations: &[EntityRepresentation],
108    typename: &str,
109    adapter: Arc<A>,
110    fed_resolver: &FederationResolver,
111    selection: &FieldSelection,
112) -> EntityResolutionResult {
113    resolve_entities_from_db_with_tracing(
114        representations,
115        typename,
116        adapter,
117        fed_resolver,
118        selection,
119        None,
120    )
121    .await
122}
123
124/// Resolve entities for a specific typename from local database with optional distributed tracing.
125pub async fn resolve_entities_from_db_with_tracing<A: DatabaseAdapter>(
126    representations: &[EntityRepresentation],
127    typename: &str,
128    adapter: Arc<A>,
129    fed_resolver: &FederationResolver,
130    selection: &FieldSelection,
131    trace_context: Option<FederationTraceContext>,
132) -> EntityResolutionResult {
133    if representations.is_empty() {
134        return EntityResolutionResult {
135            entities: Vec::new(),
136            errors:   Vec::new(),
137        };
138    }
139
140    // Create database entity resolver
141    let db_resolver = DatabaseEntityResolver::new(adapter, fed_resolver.metadata.clone());
142
143    // Resolve from database with tracing
144    match db_resolver
145        .resolve_entities_from_db_with_tracing(typename, representations, selection, trace_context)
146        .await
147    {
148        Ok(entities) => EntityResolutionResult {
149            entities,
150            errors: Vec::new(),
151        },
152        Err(e) => EntityResolutionResult {
153            entities: vec![None; representations.len()],
154            errors:   vec![e.to_string()],
155        },
156    }
157}
158
159/// Batch load entities from database
160pub async fn batch_load_entities<A: DatabaseAdapter>(
161    representations: &[EntityRepresentation],
162    fed_resolver: &FederationResolver,
163    adapter: Arc<A>,
164    selection: &FieldSelection,
165) -> Result<Vec<Option<Value>>> {
166    batch_load_entities_with_tracing(representations, fed_resolver, adapter, selection, None).await
167}
168
169/// Batch load entities from database with optional distributed tracing and metrics.
170pub async fn batch_load_entities_with_tracing<A: DatabaseAdapter>(
171    representations: &[EntityRepresentation],
172    fed_resolver: &FederationResolver,
173    adapter: Arc<A>,
174    selection: &FieldSelection,
175    trace_context: Option<FederationTraceContext>,
176) -> Result<Vec<Option<Value>>> {
177    let result = batch_load_entities_with_tracing_and_metrics(
178        representations,
179        fed_resolver,
180        adapter,
181        selection,
182        trace_context,
183    )
184    .await?;
185    Ok(result.entities)
186}
187
188/// Batch load entities with full metrics for observability.
189///
190/// Returns both entities and timing information for metrics recording.
191pub async fn batch_load_entities_with_tracing_and_metrics<A: DatabaseAdapter>(
192    representations: &[EntityRepresentation],
193    fed_resolver: &FederationResolver,
194    adapter: Arc<A>,
195    selection: &FieldSelection,
196    trace_context: Option<FederationTraceContext>,
197) -> Result<EntityResolutionMetrics> {
198    let start_time = Instant::now();
199    let query_id = Uuid::new_v4().to_string();
200
201    if representations.is_empty() {
202        return Ok(EntityResolutionMetrics {
203            entities:    Vec::new(),
204            errors:      Vec::new(),
205            duration_us: 0,
206            success:     true,
207        });
208    }
209
210    // Create or use provided trace context
211    let trace_ctx = trace_context.unwrap_or_else(FederationTraceContext::new);
212
213    // Create span for federation query
214    let span = FederationSpan::new("federation.entities.batch_load", trace_ctx.clone())
215        .with_attribute("entity_count", representations.len().to_string())
216        .with_attribute("typename_count", count_unique_typenames(representations).to_string());
217
218    // Log entity resolution start
219    let log_ctx = FederationLogContext::new(
220        FederationOperationType::EntityResolution,
221        query_id.clone(),
222        representations.len(),
223    )
224    .with_entity_count_unique(deduplicate_representations(representations).len())
225    .with_trace_id(trace_ctx.trace_id.clone());
226
227    info!(
228        query_id = %query_id,
229        entity_count = representations.len(),
230        operation_type = "entity_resolution",
231        status = "started",
232        context = ?serde_json::to_value(&log_ctx).unwrap_or_default(),
233        "Entity resolution operation started"
234    );
235
236    // Group by typename
237    let grouped = group_entities_by_typename(representations);
238
239    let mut all_results: Vec<(usize, Option<Value>)> = Vec::new();
240    let mut current_index = 0;
241    let mut all_errors = Vec::new();
242
243    for (typename, reps) in grouped {
244        let batch_start = Instant::now();
245
246        // Create child span for this typename batch
247        let child_span = span
248            .create_child(format!("federation.entities.resolve.{}", typename))
249            .with_attribute("typename", typename.clone())
250            .with_attribute("batch_size", reps.len().to_string());
251
252        // Resolve this batch using database with trace context
253        let result = resolve_entities_from_db_with_tracing(
254            &reps,
255            &typename,
256            Arc::clone(&adapter),
257            fed_resolver,
258            selection,
259            Some(trace_ctx.clone()),
260        )
261        .await;
262
263        // Record batch metrics
264        let resolved_count = result.entities.iter().filter(|e| e.is_some()).count();
265        let error_count = result.errors.len();
266        let batch_duration_ms = batch_start.elapsed().as_secs_f64() * 1000.0;
267
268        // Log batch completion
269        let batch_log_ctx = FederationLogContext::new(
270            FederationOperationType::ResolveDb,
271            query_id.clone(),
272            reps.len(),
273        )
274        .with_typename(typename.clone())
275        .with_strategy(ResolutionStrategy::Db)
276        .with_entity_count_unique(reps.len())
277        .with_resolved_count(resolved_count)
278        .with_trace_id(trace_ctx.trace_id.clone())
279        .complete(batch_duration_ms);
280
281        if error_count > 0 {
282            info!(
283                query_id = %query_id,
284                typename = %typename,
285                batch_size = reps.len(),
286                resolved = resolved_count,
287                errors = error_count,
288                duration_ms = batch_duration_ms,
289                operation_type = "resolve_db",
290                status = "error",
291                context = ?serde_json::to_value(&batch_log_ctx).unwrap_or_default(),
292                "Entity batch resolution completed with errors"
293            );
294        } else {
295            info!(
296                query_id = %query_id,
297                typename = %typename,
298                batch_size = reps.len(),
299                resolved = resolved_count,
300                duration_ms = batch_duration_ms,
301                operation_type = "resolve_db",
302                status = "success",
303                context = ?serde_json::to_value(&batch_log_ctx).unwrap_or_default(),
304                "Entity batch resolution completed successfully"
305            );
306        }
307
308        // Map results back to original indices with proper ordering
309        for entity in result.entities {
310            all_results.push((current_index, entity));
311            current_index += 1;
312        }
313
314        // Collect errors
315        all_errors.extend(result.errors.clone());
316
317        // Drop child span
318        drop(child_span);
319    }
320
321    // Sort by original index to preserve order
322    all_results.sort_by_key(|(idx, _)| *idx);
323
324    // Record final span attributes
325    let _span_duration = span.duration_ms();
326    let resolved_count = all_results.iter().filter(|(_, e)| e.is_some()).count();
327
328    // Keep span alive until function returns
329    drop(span);
330
331    let duration_us = start_time.elapsed().as_micros() as u64;
332    let duration_ms = start_time.elapsed().as_secs_f64() * 1000.0;
333    let entities = all_results.into_iter().map(|(_, e)| e).collect();
334    let success = all_errors.is_empty();
335
336    // Log overall completion
337    let final_log_ctx = if success {
338        log_ctx.with_resolved_count(resolved_count).complete(duration_ms)
339    } else {
340        let error_message = if all_errors.is_empty() {
341            "Unknown error".to_string()
342        } else {
343            all_errors.join("; ")
344        };
345        log_ctx.with_resolved_count(resolved_count).fail(duration_ms, error_message)
346    };
347
348    info!(
349        query_id = %query_id,
350        entity_count = representations.len(),
351        resolved_count = resolved_count,
352        error_count = all_errors.len(),
353        duration_ms = duration_ms,
354        operation_type = "entity_resolution",
355        status = if success { "success" } else { "error" },
356        context = ?serde_json::to_value(&final_log_ctx).unwrap_or_default(),
357        "Entity resolution operation completed"
358    );
359
360    Ok(EntityResolutionMetrics {
361        entities,
362        errors: all_errors,
363        duration_us,
364        success,
365    })
366}
367
368/// Count unique typenames in representations
369fn count_unique_typenames(representations: &[EntityRepresentation]) -> usize {
370    let mut typenames = HashSet::new();
371    for rep in representations {
372        typenames.insert(&rep.typename);
373    }
374    typenames.len()
375}
376
377#[cfg(test)]
378mod tests {
379    use serde_json::json;
380
381    use super::*;
382
383    #[test]
384    fn test_deduplicate_representations() {
385        let reps = vec![
386            EntityRepresentation {
387                typename:   "User".to_string(),
388                key_fields: {
389                    let mut m = HashMap::new();
390                    m.insert("id".to_string(), json!("123"));
391                    m
392                },
393                all_fields: HashMap::new(),
394            },
395            EntityRepresentation {
396                typename:   "User".to_string(),
397                key_fields: {
398                    let mut m = HashMap::new();
399                    m.insert("id".to_string(), json!("123"));
400                    m
401                },
402                all_fields: HashMap::new(),
403            },
404            EntityRepresentation {
405                typename:   "User".to_string(),
406                key_fields: {
407                    let mut m = HashMap::new();
408                    m.insert("id".to_string(), json!("456"));
409                    m
410                },
411                all_fields: HashMap::new(),
412            },
413        ];
414
415        let deduped = deduplicate_representations(&reps);
416        assert_eq!(deduped.len(), 2);
417    }
418
419    #[test]
420    fn test_group_entities_by_typename() {
421        let reps = vec![
422            EntityRepresentation {
423                typename:   "User".to_string(),
424                key_fields: HashMap::new(),
425                all_fields: HashMap::new(),
426            },
427            EntityRepresentation {
428                typename:   "Order".to_string(),
429                key_fields: HashMap::new(),
430                all_fields: HashMap::new(),
431            },
432            EntityRepresentation {
433                typename:   "User".to_string(),
434                key_fields: HashMap::new(),
435                all_fields: HashMap::new(),
436            },
437        ];
438
439        let grouped = group_entities_by_typename(&reps);
440        assert_eq!(grouped.len(), 2);
441        assert_eq!(grouped["User"].len(), 2);
442        assert_eq!(grouped["Order"].len(), 1);
443    }
444
445    #[test]
446    fn test_construct_batch_where_clause() {
447        let mut rep1 = EntityRepresentation {
448            typename:   "User".to_string(),
449            key_fields: HashMap::new(),
450            all_fields: HashMap::new(),
451        };
452        rep1.key_fields.insert("id".to_string(), json!("123"));
453
454        let mut rep2 = EntityRepresentation {
455            typename:   "User".to_string(),
456            key_fields: HashMap::new(),
457            all_fields: HashMap::new(),
458        };
459        rep2.key_fields.insert("id".to_string(), json!("456"));
460
461        let reps = vec![rep1, rep2];
462        let where_clause = construct_batch_where_clause(&reps, &["id".to_string()]).unwrap();
463
464        assert!(where_clause.contains("WHERE"));
465        assert!(where_clause.contains("id IN"));
466        assert!(where_clause.contains("123"));
467        assert!(where_clause.contains("456"));
468    }
469}