Skip to main content

datasynth_generators/relationships/
generator.rs

1//! Relationship generator implementation.
2//!
3//! Provides generation of relationships between entities based on
4//! cardinality rules and property generation configurations.
5
6use std::collections::{HashMap, HashSet};
7
8use chrono::{DateTime, Utc};
9use datasynth_core::utils::seeded_rng;
10use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14
15use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
16
17use super::rules::{
18    CardinalityRule, PropertyGenerator, PropertyValueType, RelationshipConfig,
19    RelationshipTypeConfig, RelationshipValidation,
20};
21
22/// Generated relationship output.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct GeneratedRelationship {
25    /// Relationship type name.
26    pub relationship_type: String,
27    /// Unique relationship ID.
28    pub id: String,
29    /// Source entity ID.
30    pub source_id: String,
31    /// Target entity ID.
32    pub target_id: String,
33    /// Relationship properties.
34    pub properties: HashMap<String, Value>,
35    /// Relationship metadata.
36    pub metadata: RelationshipMetadata,
37}
38
39/// Metadata for a generated relationship.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct RelationshipMetadata {
42    /// Data source.
43    pub source: String,
44    /// Generation timestamp.
45    pub generated_at: DateTime<Utc>,
46    /// Relationship weight.
47    pub weight: Option<f64>,
48    /// Valid from timestamp.
49    pub valid_from: Option<DateTime<Utc>>,
50    /// Valid to timestamp.
51    pub valid_to: Option<DateTime<Utc>>,
52    /// Custom labels.
53    pub labels: HashMap<String, String>,
54    /// Feature vector for ML.
55    pub features: Option<Vec<f64>>,
56    /// Whether the relationship is directed.
57    pub is_directed: bool,
58}
59
60impl Default for RelationshipMetadata {
61    fn default() -> Self {
62        Self {
63            source: "datasynth".to_string(),
64            generated_at: Utc::now(),
65            weight: None,
66            valid_from: None,
67            valid_to: None,
68            labels: HashMap::new(),
69            features: None,
70            is_directed: true,
71        }
72    }
73}
74
75/// Simple node representation for relationship generation.
76#[derive(Debug, Clone)]
77pub struct NodeRef {
78    /// Node ID.
79    pub id: String,
80    /// Node type.
81    pub node_type: String,
82    /// Node properties.
83    pub properties: HashMap<String, Value>,
84}
85
86impl NodeRef {
87    /// Creates a new node reference.
88    pub fn new(id: impl Into<String>, node_type: impl Into<String>) -> Self {
89        Self {
90            id: id.into(),
91            node_type: node_type.into(),
92            properties: HashMap::new(),
93        }
94    }
95
96    /// Adds a property.
97    pub fn with_property(mut self, key: impl Into<String>, value: Value) -> Self {
98        self.properties.insert(key.into(), value);
99        self
100    }
101}
102
103/// Generator for relationships between entities.
104pub struct RelationshipGenerator {
105    /// Configuration.
106    config: RelationshipConfig,
107    /// Random number generator.
108    rng: ChaCha8Rng,
109    /// Deterministic UUID factory.
110    uuid_factory: DeterministicUuidFactory,
111    /// Generation count.
112    count: u64,
113    /// Track relationships by source ID for cardinality validation.
114    relationships_by_source: HashMap<String, HashMap<String, Vec<String>>>,
115    /// Track relationships by target ID for cardinality validation.
116    relationships_by_target: HashMap<String, HashMap<String, Vec<String>>>,
117    /// Visited nodes for circular detection.
118    visited: HashSet<String>,
119}
120
121impl RelationshipGenerator {
122    /// Creates a new relationship generator.
123    pub fn new(config: RelationshipConfig, seed: u64) -> Self {
124        Self {
125            config,
126            rng: seeded_rng(seed, 0),
127            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::Customer),
128            count: 0,
129            relationships_by_source: HashMap::new(),
130            relationships_by_target: HashMap::new(),
131            visited: HashSet::new(),
132        }
133    }
134
135    /// Creates a generator with default configuration.
136    pub fn with_defaults(seed: u64) -> Self {
137        Self::new(RelationshipConfig::default(), seed)
138    }
139
140    /// Generates relationships for a set of nodes.
141    pub fn generate_relationships(&mut self, nodes: &[NodeRef]) -> Vec<GeneratedRelationship> {
142        let mut relationships = Vec::new();
143
144        // Group nodes by type
145        let nodes_by_type = self.group_nodes_by_type(nodes);
146
147        // Clone relationship types to avoid borrow issues
148        let relationship_types = self.config.relationship_types.clone();
149
150        // For each relationship type, generate relationships
151        for rel_type in &relationship_types {
152            let rels = self.generate_for_type(rel_type, &nodes_by_type);
153            relationships.extend(rels);
154        }
155
156        relationships
157    }
158
159    /// Generates relationships for a single node.
160    pub fn generate_for_node(
161        &mut self,
162        node: &NodeRef,
163        available_targets: &HashMap<String, Vec<NodeRef>>,
164    ) -> Vec<GeneratedRelationship> {
165        // Check for orphan generation
166        if self.config.allow_orphans && self.rng.gen_bool(self.config.orphan_probability) {
167            return Vec::new();
168        }
169
170        let mut relationships = Vec::new();
171
172        // Clone applicable relationship types to avoid borrow issues
173        let applicable_types: Vec<_> = self
174            .config
175            .relationship_types
176            .iter()
177            .filter(|rt| rt.source_type == node.node_type)
178            .cloned()
179            .collect();
180
181        for rel_type in &applicable_types {
182            if let Some(targets) = available_targets.get(&rel_type.target_type) {
183                let rels = self.generate_edges_for_node(node, targets, rel_type);
184                relationships.extend(rels);
185            }
186        }
187
188        relationships
189    }
190
191    /// Checks if a relationship would create a valid cardinality.
192    pub fn check_cardinality(
193        &self,
194        source_id: &str,
195        target_id: &str,
196        rel_type: &str,
197    ) -> RelationshipValidation {
198        // Find the relationship type config
199        let type_config = self
200            .config
201            .relationship_types
202            .iter()
203            .find(|rt| rt.name == rel_type);
204
205        let Some(type_config) = type_config else {
206            return RelationshipValidation::invalid(format!(
207                "Unknown relationship type: {}",
208                rel_type
209            ));
210        };
211
212        let (_min, max) = type_config.cardinality.bounds();
213
214        // Check source-side cardinality
215        let current_count = self
216            .relationships_by_source
217            .get(source_id)
218            .and_then(|m| m.get(rel_type))
219            .map(|v| v.len())
220            .unwrap_or(0);
221
222        if current_count >= max as usize {
223            return RelationshipValidation::invalid(format!(
224                "Source {} already has maximum {} {} relationships",
225                source_id, max, rel_type
226            ));
227        }
228
229        // For OneToOne and ManyToOne, check if target already has a relationship
230        if matches!(
231            type_config.cardinality,
232            CardinalityRule::OneToOne | CardinalityRule::ManyToOne { .. }
233        ) {
234            let target_count = self
235                .relationships_by_target
236                .get(target_id)
237                .and_then(|m| m.get(rel_type))
238                .map(|v| v.len())
239                .unwrap_or(0);
240
241            if target_count > 0 {
242                return RelationshipValidation::invalid(format!(
243                    "Target {} already has a {} relationship",
244                    target_id, rel_type
245                ));
246            }
247        }
248
249        RelationshipValidation::valid()
250    }
251
252    /// Checks if a relationship would create a circular reference.
253    pub fn check_circular(&mut self, source_id: &str, target_id: &str) -> bool {
254        if !self.config.allow_circular {
255            // Simple check: direct circular reference
256            if source_id == target_id {
257                return true;
258            }
259
260            // DFS to check for circular paths
261            self.visited.clear();
262            self.visited.insert(source_id.to_string());
263
264            return self.has_path_to(target_id, source_id, 0);
265        }
266
267        false
268    }
269
270    /// Returns the number of relationships generated.
271    pub fn count(&self) -> u64 {
272        self.count
273    }
274
275    /// Resets the generator.
276    pub fn reset(&mut self, seed: u64) {
277        self.rng = seeded_rng(seed, 0);
278        self.uuid_factory = DeterministicUuidFactory::new(seed, GeneratorType::Customer);
279        self.count = 0;
280        self.relationships_by_source.clear();
281        self.relationships_by_target.clear();
282        self.visited.clear();
283    }
284
285    /// Returns the configuration.
286    pub fn config(&self) -> &RelationshipConfig {
287        &self.config
288    }
289
290    /// Groups nodes by their type.
291    fn group_nodes_by_type(&self, nodes: &[NodeRef]) -> HashMap<String, Vec<NodeRef>> {
292        let mut grouped: HashMap<String, Vec<NodeRef>> = HashMap::new();
293
294        for node in nodes {
295            grouped
296                .entry(node.node_type.clone())
297                .or_default()
298                .push(node.clone());
299        }
300
301        grouped
302    }
303
304    /// Generates relationships for a specific relationship type.
305    fn generate_for_type(
306        &mut self,
307        rel_type: &RelationshipTypeConfig,
308        nodes_by_type: &HashMap<String, Vec<NodeRef>>,
309    ) -> Vec<GeneratedRelationship> {
310        let mut relationships = Vec::new();
311
312        let Some(source_nodes) = nodes_by_type.get(&rel_type.source_type) else {
313            return relationships;
314        };
315
316        let Some(target_nodes) = nodes_by_type.get(&rel_type.target_type) else {
317            return relationships;
318        };
319
320        for source in source_nodes {
321            let rels = self.generate_edges_for_node(source, target_nodes, rel_type);
322            relationships.extend(rels);
323        }
324
325        relationships
326    }
327
328    /// Generates edges from a single source node.
329    fn generate_edges_for_node(
330        &mut self,
331        source: &NodeRef,
332        targets: &[NodeRef],
333        rel_type: &RelationshipTypeConfig,
334    ) -> Vec<GeneratedRelationship> {
335        let mut relationships = Vec::new();
336
337        if targets.is_empty() {
338            return relationships;
339        }
340
341        // Determine number of relationships based on cardinality
342        let (min, max) = rel_type.cardinality.bounds();
343        let count = if min == max {
344            min as usize
345        } else {
346            self.rng.gen_range(min..=max) as usize
347        };
348
349        // Filter available targets
350        let available_targets: Vec<_> = targets
351            .iter()
352            .filter(|t| {
353                // Check if this relationship is valid
354                let validation = self.check_cardinality(&source.id, &t.id, &rel_type.name);
355                if !validation.valid {
356                    return false;
357                }
358
359                // Check for circular references
360                if self.check_circular(&source.id, &t.id) {
361                    return false;
362                }
363
364                true
365            })
366            .collect();
367
368        if available_targets.is_empty() && rel_type.required {
369            // Log warning or handle required relationship with no valid targets
370            return relationships;
371        }
372
373        // Select targets
374        let selected_count = count.min(available_targets.len());
375        let mut selected_indices: Vec<usize> = (0..available_targets.len()).collect();
376        selected_indices.shuffle(&mut self.rng);
377        selected_indices.truncate(selected_count);
378
379        for idx in selected_indices {
380            let target = available_targets[idx];
381            let relationship = self.create_relationship(source, target, rel_type);
382
383            // Track the relationship for cardinality validation
384            self.track_relationship(&source.id, &target.id, &rel_type.name);
385
386            relationships.push(relationship);
387        }
388
389        relationships
390    }
391
392    /// Creates a single relationship.
393    fn create_relationship(
394        &mut self,
395        source: &NodeRef,
396        target: &NodeRef,
397        rel_type: &RelationshipTypeConfig,
398    ) -> GeneratedRelationship {
399        self.count += 1;
400
401        let id = self.uuid_factory.next().to_string();
402        let properties = self.generate_properties(source, target, &rel_type.properties);
403
404        let metadata = RelationshipMetadata {
405            source: "datasynth".to_string(),
406            generated_at: Utc::now(),
407            weight: Some(rel_type.weight),
408            valid_from: None,
409            valid_to: None,
410            labels: HashMap::new(),
411            features: None,
412            is_directed: rel_type.directed,
413        };
414
415        GeneratedRelationship {
416            relationship_type: rel_type.name.clone(),
417            id,
418            source_id: source.id.clone(),
419            target_id: target.id.clone(),
420            properties,
421            metadata,
422        }
423    }
424
425    /// Generates properties for a relationship.
426    fn generate_properties(
427        &mut self,
428        source: &NodeRef,
429        target: &NodeRef,
430        rules: &[super::rules::PropertyGenerationRule],
431    ) -> HashMap<String, Value> {
432        let mut properties = HashMap::new();
433
434        for rule in rules {
435            let value =
436                self.generate_property_value(source, target, &rule.generator, &rule.value_type);
437            properties.insert(rule.name.clone(), value);
438        }
439
440        properties
441    }
442
443    /// Generates a single property value.
444    fn generate_property_value(
445        &mut self,
446        source: &NodeRef,
447        target: &NodeRef,
448        generator: &PropertyGenerator,
449        value_type: &PropertyValueType,
450    ) -> Value {
451        match generator {
452            PropertyGenerator::Constant(value) => value.clone(),
453
454            PropertyGenerator::RandomChoice(choices) => {
455                if choices.is_empty() {
456                    Value::Null
457                } else {
458                    let idx = self.rng.gen_range(0..choices.len());
459                    choices[idx].clone()
460                }
461            }
462
463            PropertyGenerator::Range { min, max } => {
464                let value = self.rng.gen_range(*min..=*max);
465                match value_type {
466                    PropertyValueType::Integer => {
467                        Value::Number(serde_json::Number::from(value as i64))
468                    }
469                    _ => Value::Number(
470                        serde_json::Number::from_f64(value)
471                            .unwrap_or_else(|| serde_json::Number::from(0)),
472                    ),
473                }
474            }
475
476            PropertyGenerator::FromSourceProperty(prop_name) => source
477                .properties
478                .get(prop_name)
479                .cloned()
480                .unwrap_or(Value::Null),
481
482            PropertyGenerator::FromTargetProperty(prop_name) => target
483                .properties
484                .get(prop_name)
485                .cloned()
486                .unwrap_or(Value::Null),
487
488            PropertyGenerator::Uuid => Value::String(self.uuid_factory.next().to_string()),
489
490            PropertyGenerator::Timestamp => Value::String(Utc::now().to_rfc3339()),
491        }
492    }
493
494    /// Tracks a relationship for cardinality validation.
495    fn track_relationship(&mut self, source_id: &str, target_id: &str, rel_type: &str) {
496        // Track by source
497        self.relationships_by_source
498            .entry(source_id.to_string())
499            .or_default()
500            .entry(rel_type.to_string())
501            .or_default()
502            .push(target_id.to_string());
503
504        // Track by target
505        self.relationships_by_target
506            .entry(target_id.to_string())
507            .or_default()
508            .entry(rel_type.to_string())
509            .or_default()
510            .push(source_id.to_string());
511    }
512
513    /// DFS to check if there's a path from current to target.
514    fn has_path_to(&mut self, current: &str, target: &str, depth: u32) -> bool {
515        if depth >= self.config.max_circular_depth {
516            return false;
517        }
518
519        if current == target {
520            return true;
521        }
522
523        if self.visited.contains(current) {
524            return false;
525        }
526
527        self.visited.insert(current.to_string());
528
529        // Collect all next nodes to avoid holding borrow during recursion
530        let next_nodes: Vec<String> = self
531            .relationships_by_source
532            .get(current)
533            .map(|outgoing| outgoing.values().flatten().cloned().collect())
534            .unwrap_or_default();
535
536        // Now check paths without holding the borrow
537        for next in next_nodes {
538            if self.has_path_to(&next, target, depth + 1) {
539                return true;
540            }
541        }
542
543        false
544    }
545}
546
547/// Builder for relationship configuration.
548pub struct RelationshipConfigBuilder {
549    config: RelationshipConfig,
550}
551
552impl RelationshipConfigBuilder {
553    /// Creates a new builder.
554    pub fn new() -> Self {
555        Self {
556            config: RelationshipConfig::default(),
557        }
558    }
559
560    /// Adds a relationship type.
561    pub fn add_type(mut self, type_config: RelationshipTypeConfig) -> Self {
562        self.config.relationship_types.push(type_config);
563        self
564    }
565
566    /// Sets whether orphans are allowed.
567    pub fn allow_orphans(mut self, allow: bool) -> Self {
568        self.config.allow_orphans = allow;
569        self
570    }
571
572    /// Sets the orphan probability.
573    pub fn orphan_probability(mut self, prob: f64) -> Self {
574        self.config.orphan_probability = prob.clamp(0.0, 1.0);
575        self
576    }
577
578    /// Sets whether circular relationships are allowed.
579    pub fn allow_circular(mut self, allow: bool) -> Self {
580        self.config.allow_circular = allow;
581        self
582    }
583
584    /// Sets the maximum circular depth.
585    pub fn max_circular_depth(mut self, depth: u32) -> Self {
586        self.config.max_circular_depth = depth;
587        self
588    }
589
590    /// Builds the configuration.
591    pub fn build(self) -> RelationshipConfig {
592        self.config
593    }
594}
595
596impl Default for RelationshipConfigBuilder {
597    fn default() -> Self {
598        Self::new()
599    }
600}
601
602#[cfg(test)]
603#[allow(clippy::unwrap_used)]
604mod tests {
605    use super::*;
606
607    fn create_test_nodes() -> Vec<NodeRef> {
608        vec![
609            NodeRef::new("je_1", "journal_entry"),
610            NodeRef::new("je_2", "journal_entry"),
611            NodeRef::new("acc_1", "account"),
612            NodeRef::new("acc_2", "account"),
613            NodeRef::new("acc_3", "account"),
614            NodeRef::new("user_1", "user"),
615        ]
616    }
617
618    #[test]
619    fn test_generate_relationships() {
620        let config = RelationshipConfig::with_types(vec![RelationshipTypeConfig::new(
621            "debits",
622            "journal_entry",
623            "account",
624        )
625        .with_cardinality(CardinalityRule::one_to_many(1, 2))]);
626
627        let mut generator = RelationshipGenerator::new(config, 42);
628        let nodes = create_test_nodes();
629        let relationships = generator.generate_relationships(&nodes);
630
631        assert!(!relationships.is_empty());
632        for rel in &relationships {
633            assert_eq!(rel.relationship_type, "debits");
634            assert!(rel.source_id.starts_with("je_"));
635            assert!(rel.target_id.starts_with("acc_"));
636        }
637    }
638
639    #[test]
640    fn test_cardinality_validation() {
641        let config = RelationshipConfig::with_types(vec![RelationshipTypeConfig::new(
642            "debits",
643            "journal_entry",
644            "account",
645        )
646        .with_cardinality(CardinalityRule::one_to_one())]);
647
648        let generator = RelationshipGenerator::new(config, 42);
649
650        let validation = generator.check_cardinality("je_1", "acc_1", "debits");
651        assert!(validation.valid);
652
653        let validation = generator.check_cardinality("je_1", "acc_1", "unknown");
654        assert!(!validation.valid);
655    }
656
657    #[test]
658    fn test_circular_detection() {
659        let config = RelationshipConfig::default()
660            .allow_circular(false)
661            .max_circular_depth(3);
662
663        let mut generator = RelationshipGenerator::new(config, 42);
664
665        // Direct circular
666        assert!(generator.check_circular("a", "a"));
667
668        // No circular (different nodes)
669        assert!(!generator.check_circular("a", "b"));
670    }
671
672    #[test]
673    fn test_property_generation() {
674        let config = RelationshipConfig::with_types(vec![RelationshipTypeConfig::new(
675            "test", "source", "target",
676        )
677        .with_property(super::super::rules::PropertyGenerationRule::range(
678            "amount", 100.0, 1000.0,
679        ))
680        .with_property(
681            super::super::rules::PropertyGenerationRule::constant_string("status", "active"),
682        )]);
683
684        let mut generator = RelationshipGenerator::new(config, 42);
685        let nodes = vec![NodeRef::new("s1", "source"), NodeRef::new("t1", "target")];
686
687        let relationships = generator.generate_relationships(&nodes);
688
689        assert!(!relationships.is_empty());
690        let rel = &relationships[0];
691        assert!(rel.properties.contains_key("amount"));
692        assert!(rel.properties.contains_key("status"));
693        assert_eq!(
694            rel.properties.get("status"),
695            Some(&Value::String("active".into()))
696        );
697    }
698
699    #[test]
700    fn test_orphan_generation() {
701        let config = RelationshipConfig::with_types(vec![RelationshipTypeConfig::new(
702            "test", "source", "target",
703        )
704        .with_cardinality(CardinalityRule::one_to_one())])
705        .allow_orphans(true)
706        .orphan_probability(1.0); // Always create orphans
707
708        let mut generator = RelationshipGenerator::new(config, 42);
709
710        let source = NodeRef::new("s1", "source");
711        let available: HashMap<String, Vec<NodeRef>> =
712            [("target".to_string(), vec![NodeRef::new("t1", "target")])]
713                .into_iter()
714                .collect();
715
716        let relationships = generator.generate_for_node(&source, &available);
717        assert!(relationships.is_empty());
718    }
719
720    #[test]
721    fn test_config_builder() {
722        let config = RelationshipConfigBuilder::new()
723            .add_type(RelationshipTypeConfig::new("test", "a", "b"))
724            .allow_orphans(false)
725            .orphan_probability(0.1)
726            .allow_circular(true)
727            .max_circular_depth(5)
728            .build();
729
730        assert_eq!(config.relationship_types.len(), 1);
731        assert!(!config.allow_orphans);
732        assert_eq!(config.orphan_probability, 0.1);
733        assert!(config.allow_circular);
734        assert_eq!(config.max_circular_depth, 5);
735    }
736
737    #[test]
738    fn test_generator_count_and_reset() {
739        let config = RelationshipConfig::with_types(vec![RelationshipTypeConfig::new(
740            "test", "source", "target",
741        )
742        .with_cardinality(CardinalityRule::one_to_one())]);
743
744        let mut generator = RelationshipGenerator::new(config, 42);
745        assert_eq!(generator.count(), 0);
746
747        let nodes = vec![NodeRef::new("s1", "source"), NodeRef::new("t1", "target")];
748        generator.generate_relationships(&nodes);
749
750        assert!(generator.count() > 0);
751
752        generator.reset(42);
753        assert_eq!(generator.count(), 0);
754    }
755}