1use std::{
4 collections::{HashMap, HashSet},
5 sync::Arc,
6 time::Instant,
7};
8
9use serde_json::Value;
10use tracing::info;
11use uuid::Uuid;
12
13use super::{
14 database_resolver::DatabaseEntityResolver,
15 logging::{FederationLogContext, FederationOperationType, ResolutionStrategy},
16 selection_parser::FieldSelection,
17 tracing::{FederationSpan, FederationTraceContext},
18 types::{EntityRepresentation, FederationResolver},
19};
20use crate::{db::traits::DatabaseAdapter, error::Result};
21
22#[derive(Debug)]
24pub struct EntityResolutionResult {
25 pub entities: Vec<Option<Value>>,
27
28 pub errors: Vec<String>,
30}
31
32#[derive(Debug)]
34pub struct EntityResolutionMetrics {
35 pub entities: Vec<Option<Value>>,
37 pub errors: Vec<String>,
39 pub duration_us: u64,
41 pub success: bool,
43}
44
45pub fn deduplicate_representations(reps: &[EntityRepresentation]) -> Vec<EntityRepresentation> {
47 let mut seen = HashSet::new();
48 let mut result = Vec::with_capacity(reps.len());
49
50 for rep in reps {
51 let key = format!("{}:{:?}", rep.typename, rep.key_fields);
53 if seen.insert(key) {
54 result.push(rep.clone());
55 }
56 }
57
58 result
59}
60
61pub fn group_entities_by_typename(
63 reps: &[EntityRepresentation],
64) -> HashMap<String, Vec<EntityRepresentation>> {
65 let mut groups: HashMap<String, Vec<EntityRepresentation>> = HashMap::new();
66
67 for rep in reps {
68 groups.entry(rep.typename.clone()).or_insert_with(Vec::new).push(rep.clone());
69 }
70
71 groups
72}
73
74pub fn construct_batch_where_clause(
76 representations: &[EntityRepresentation],
77 key_columns: &[String],
78) -> Result<String> {
79 if representations.is_empty() || key_columns.is_empty() {
80 return Ok(String::new());
81 }
82
83 let mut conditions = Vec::new();
84
85 for key_col in key_columns {
86 let values: Vec<String> = representations
87 .iter()
88 .filter_map(|rep| rep.key_fields.get(key_col))
89 .filter_map(|v| v.as_str())
90 .map(|s| format!("'{}'", s.replace('\'', "''")))
91 .collect();
92
93 if !values.is_empty() && !values.iter().all(|v| v == "''") {
94 conditions.push(format!("{} IN ({})", key_col, values.join(", ")));
95 }
96 }
97
98 if conditions.is_empty() {
99 Ok(String::new())
100 } else {
101 Ok(format!("WHERE {}", conditions.join(" AND ")))
102 }
103}
104
105pub async fn resolve_entities_from_db<A: DatabaseAdapter>(
107 representations: &[EntityRepresentation],
108 typename: &str,
109 adapter: Arc<A>,
110 fed_resolver: &FederationResolver,
111 selection: &FieldSelection,
112) -> EntityResolutionResult {
113 resolve_entities_from_db_with_tracing(
114 representations,
115 typename,
116 adapter,
117 fed_resolver,
118 selection,
119 None,
120 )
121 .await
122}
123
124pub async fn resolve_entities_from_db_with_tracing<A: DatabaseAdapter>(
126 representations: &[EntityRepresentation],
127 typename: &str,
128 adapter: Arc<A>,
129 fed_resolver: &FederationResolver,
130 selection: &FieldSelection,
131 trace_context: Option<FederationTraceContext>,
132) -> EntityResolutionResult {
133 if representations.is_empty() {
134 return EntityResolutionResult {
135 entities: Vec::new(),
136 errors: Vec::new(),
137 };
138 }
139
140 let db_resolver = DatabaseEntityResolver::new(adapter, fed_resolver.metadata.clone());
142
143 match db_resolver
145 .resolve_entities_from_db_with_tracing(typename, representations, selection, trace_context)
146 .await
147 {
148 Ok(entities) => EntityResolutionResult {
149 entities,
150 errors: Vec::new(),
151 },
152 Err(e) => EntityResolutionResult {
153 entities: vec![None; representations.len()],
154 errors: vec![e.to_string()],
155 },
156 }
157}
158
159pub async fn batch_load_entities<A: DatabaseAdapter>(
161 representations: &[EntityRepresentation],
162 fed_resolver: &FederationResolver,
163 adapter: Arc<A>,
164 selection: &FieldSelection,
165) -> Result<Vec<Option<Value>>> {
166 batch_load_entities_with_tracing(representations, fed_resolver, adapter, selection, None).await
167}
168
169pub async fn batch_load_entities_with_tracing<A: DatabaseAdapter>(
171 representations: &[EntityRepresentation],
172 fed_resolver: &FederationResolver,
173 adapter: Arc<A>,
174 selection: &FieldSelection,
175 trace_context: Option<FederationTraceContext>,
176) -> Result<Vec<Option<Value>>> {
177 let result = batch_load_entities_with_tracing_and_metrics(
178 representations,
179 fed_resolver,
180 adapter,
181 selection,
182 trace_context,
183 )
184 .await?;
185 Ok(result.entities)
186}
187
188pub async fn batch_load_entities_with_tracing_and_metrics<A: DatabaseAdapter>(
192 representations: &[EntityRepresentation],
193 fed_resolver: &FederationResolver,
194 adapter: Arc<A>,
195 selection: &FieldSelection,
196 trace_context: Option<FederationTraceContext>,
197) -> Result<EntityResolutionMetrics> {
198 let start_time = Instant::now();
199 let query_id = Uuid::new_v4().to_string();
200
201 if representations.is_empty() {
202 return Ok(EntityResolutionMetrics {
203 entities: Vec::new(),
204 errors: Vec::new(),
205 duration_us: 0,
206 success: true,
207 });
208 }
209
210 let trace_ctx = trace_context.unwrap_or_else(FederationTraceContext::new);
212
213 let span = FederationSpan::new("federation.entities.batch_load", trace_ctx.clone())
215 .with_attribute("entity_count", representations.len().to_string())
216 .with_attribute("typename_count", count_unique_typenames(representations).to_string());
217
218 let log_ctx = FederationLogContext::new(
220 FederationOperationType::EntityResolution,
221 query_id.clone(),
222 representations.len(),
223 )
224 .with_entity_count_unique(deduplicate_representations(representations).len())
225 .with_trace_id(trace_ctx.trace_id.clone());
226
227 info!(
228 query_id = %query_id,
229 entity_count = representations.len(),
230 operation_type = "entity_resolution",
231 status = "started",
232 context = ?serde_json::to_value(&log_ctx).unwrap_or_default(),
233 "Entity resolution operation started"
234 );
235
236 let grouped = group_entities_by_typename(representations);
238
239 let mut all_results: Vec<(usize, Option<Value>)> = Vec::new();
240 let mut current_index = 0;
241 let mut all_errors = Vec::new();
242
243 for (typename, reps) in grouped {
244 let batch_start = Instant::now();
245
246 let child_span = span
248 .create_child(format!("federation.entities.resolve.{}", typename))
249 .with_attribute("typename", typename.clone())
250 .with_attribute("batch_size", reps.len().to_string());
251
252 let result = resolve_entities_from_db_with_tracing(
254 &reps,
255 &typename,
256 Arc::clone(&adapter),
257 fed_resolver,
258 selection,
259 Some(trace_ctx.clone()),
260 )
261 .await;
262
263 let resolved_count = result.entities.iter().filter(|e| e.is_some()).count();
265 let error_count = result.errors.len();
266 let batch_duration_ms = batch_start.elapsed().as_secs_f64() * 1000.0;
267
268 let batch_log_ctx = FederationLogContext::new(
270 FederationOperationType::ResolveDb,
271 query_id.clone(),
272 reps.len(),
273 )
274 .with_typename(typename.clone())
275 .with_strategy(ResolutionStrategy::Db)
276 .with_entity_count_unique(reps.len())
277 .with_resolved_count(resolved_count)
278 .with_trace_id(trace_ctx.trace_id.clone())
279 .complete(batch_duration_ms);
280
281 if error_count > 0 {
282 info!(
283 query_id = %query_id,
284 typename = %typename,
285 batch_size = reps.len(),
286 resolved = resolved_count,
287 errors = error_count,
288 duration_ms = batch_duration_ms,
289 operation_type = "resolve_db",
290 status = "error",
291 context = ?serde_json::to_value(&batch_log_ctx).unwrap_or_default(),
292 "Entity batch resolution completed with errors"
293 );
294 } else {
295 info!(
296 query_id = %query_id,
297 typename = %typename,
298 batch_size = reps.len(),
299 resolved = resolved_count,
300 duration_ms = batch_duration_ms,
301 operation_type = "resolve_db",
302 status = "success",
303 context = ?serde_json::to_value(&batch_log_ctx).unwrap_or_default(),
304 "Entity batch resolution completed successfully"
305 );
306 }
307
308 for entity in result.entities {
310 all_results.push((current_index, entity));
311 current_index += 1;
312 }
313
314 all_errors.extend(result.errors.clone());
316
317 drop(child_span);
319 }
320
321 all_results.sort_by_key(|(idx, _)| *idx);
323
324 let _span_duration = span.duration_ms();
326 let resolved_count = all_results.iter().filter(|(_, e)| e.is_some()).count();
327
328 drop(span);
330
331 let duration_us = start_time.elapsed().as_micros() as u64;
332 let duration_ms = start_time.elapsed().as_secs_f64() * 1000.0;
333 let entities = all_results.into_iter().map(|(_, e)| e).collect();
334 let success = all_errors.is_empty();
335
336 let final_log_ctx = if success {
338 log_ctx.with_resolved_count(resolved_count).complete(duration_ms)
339 } else {
340 let error_message = if all_errors.is_empty() {
341 "Unknown error".to_string()
342 } else {
343 all_errors.join("; ")
344 };
345 log_ctx.with_resolved_count(resolved_count).fail(duration_ms, error_message)
346 };
347
348 info!(
349 query_id = %query_id,
350 entity_count = representations.len(),
351 resolved_count = resolved_count,
352 error_count = all_errors.len(),
353 duration_ms = duration_ms,
354 operation_type = "entity_resolution",
355 status = if success { "success" } else { "error" },
356 context = ?serde_json::to_value(&final_log_ctx).unwrap_or_default(),
357 "Entity resolution operation completed"
358 );
359
360 Ok(EntityResolutionMetrics {
361 entities,
362 errors: all_errors,
363 duration_us,
364 success,
365 })
366}
367
368fn count_unique_typenames(representations: &[EntityRepresentation]) -> usize {
370 let mut typenames = HashSet::new();
371 for rep in representations {
372 typenames.insert(&rep.typename);
373 }
374 typenames.len()
375}
376
377#[cfg(test)]
378mod tests {
379 use serde_json::json;
380
381 use super::*;
382
383 #[test]
384 fn test_deduplicate_representations() {
385 let reps = vec![
386 EntityRepresentation {
387 typename: "User".to_string(),
388 key_fields: {
389 let mut m = HashMap::new();
390 m.insert("id".to_string(), json!("123"));
391 m
392 },
393 all_fields: HashMap::new(),
394 },
395 EntityRepresentation {
396 typename: "User".to_string(),
397 key_fields: {
398 let mut m = HashMap::new();
399 m.insert("id".to_string(), json!("123"));
400 m
401 },
402 all_fields: HashMap::new(),
403 },
404 EntityRepresentation {
405 typename: "User".to_string(),
406 key_fields: {
407 let mut m = HashMap::new();
408 m.insert("id".to_string(), json!("456"));
409 m
410 },
411 all_fields: HashMap::new(),
412 },
413 ];
414
415 let deduped = deduplicate_representations(&reps);
416 assert_eq!(deduped.len(), 2);
417 }
418
419 #[test]
420 fn test_group_entities_by_typename() {
421 let reps = vec![
422 EntityRepresentation {
423 typename: "User".to_string(),
424 key_fields: HashMap::new(),
425 all_fields: HashMap::new(),
426 },
427 EntityRepresentation {
428 typename: "Order".to_string(),
429 key_fields: HashMap::new(),
430 all_fields: HashMap::new(),
431 },
432 EntityRepresentation {
433 typename: "User".to_string(),
434 key_fields: HashMap::new(),
435 all_fields: HashMap::new(),
436 },
437 ];
438
439 let grouped = group_entities_by_typename(&reps);
440 assert_eq!(grouped.len(), 2);
441 assert_eq!(grouped["User"].len(), 2);
442 assert_eq!(grouped["Order"].len(), 1);
443 }
444
445 #[test]
446 fn test_construct_batch_where_clause() {
447 let mut rep1 = EntityRepresentation {
448 typename: "User".to_string(),
449 key_fields: HashMap::new(),
450 all_fields: HashMap::new(),
451 };
452 rep1.key_fields.insert("id".to_string(), json!("123"));
453
454 let mut rep2 = EntityRepresentation {
455 typename: "User".to_string(),
456 key_fields: HashMap::new(),
457 all_fields: HashMap::new(),
458 };
459 rep2.key_fields.insert("id".to_string(), json!("456"));
460
461 let reps = vec![rep1, rep2];
462 let where_clause = construct_batch_where_clause(&reps, &["id".to_string()]).unwrap();
463
464 assert!(where_clause.contains("WHERE"));
465 assert!(where_clause.contains("id IN"));
466 assert!(where_clause.contains("123"));
467 assert!(where_clause.contains("456"));
468 }
469}