Skip to main content

datasynth_generators/relationships/
generator.rs

1//! Relationship generator implementation.
2//!
3//! Provides generation of relationships between entities based on
4//! cardinality rules and property generation configurations.
5
6use std::collections::{HashMap, HashSet};
7
8use chrono::{DateTime, Utc};
9use datasynth_core::utils::seeded_rng;
10use rand::prelude::*;
11use rand_chacha::ChaCha8Rng;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14
15use datasynth_core::uuid_factory::{DeterministicUuidFactory, GeneratorType};
16
17use super::rules::{
18    CardinalityRule, PropertyGenerator, PropertyValueType, RelationshipConfig,
19    RelationshipTypeConfig, RelationshipValidation,
20};
21
22/// Generated relationship output.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct GeneratedRelationship {
25    /// Relationship type name.
26    pub relationship_type: String,
27    /// Unique relationship ID.
28    pub id: String,
29    /// Source entity ID.
30    pub source_id: String,
31    /// Target entity ID.
32    pub target_id: String,
33    /// Relationship properties.
34    pub properties: HashMap<String, Value>,
35    /// Relationship metadata.
36    pub metadata: RelationshipMetadata,
37}
38
39/// Metadata for a generated relationship.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct RelationshipMetadata {
42    /// Data source.
43    pub source: String,
44    /// Generation timestamp.
45    pub generated_at: DateTime<Utc>,
46    /// Relationship weight.
47    pub weight: Option<f64>,
48    /// Valid from timestamp.
49    pub valid_from: Option<DateTime<Utc>>,
50    /// Valid to timestamp.
51    pub valid_to: Option<DateTime<Utc>>,
52    /// Custom labels.
53    pub labels: HashMap<String, String>,
54    /// Feature vector for ML.
55    pub features: Option<Vec<f64>>,
56    /// Whether the relationship is directed.
57    pub is_directed: bool,
58}
59
60impl Default for RelationshipMetadata {
61    fn default() -> Self {
62        Self {
63            source: "datasynth".to_string(),
64            generated_at: Utc::now(),
65            weight: None,
66            valid_from: None,
67            valid_to: None,
68            labels: HashMap::new(),
69            features: None,
70            is_directed: true,
71        }
72    }
73}
74
75/// Simple node representation for relationship generation.
76#[derive(Debug, Clone)]
77pub struct NodeRef {
78    /// Node ID.
79    pub id: String,
80    /// Node type.
81    pub node_type: String,
82    /// Node properties.
83    pub properties: HashMap<String, Value>,
84}
85
86impl NodeRef {
87    /// Creates a new node reference.
88    pub fn new(id: impl Into<String>, node_type: impl Into<String>) -> Self {
89        Self {
90            id: id.into(),
91            node_type: node_type.into(),
92            properties: HashMap::new(),
93        }
94    }
95
96    /// Adds a property.
97    pub fn with_property(mut self, key: impl Into<String>, value: Value) -> Self {
98        self.properties.insert(key.into(), value);
99        self
100    }
101}
102
103/// Generator for relationships between entities.
104pub struct RelationshipGenerator {
105    /// Configuration.
106    config: RelationshipConfig,
107    /// Random number generator.
108    rng: ChaCha8Rng,
109    /// Deterministic UUID factory.
110    uuid_factory: DeterministicUuidFactory,
111    /// Generation count.
112    count: u64,
113    /// Track relationships by source ID for cardinality validation.
114    relationships_by_source: HashMap<String, HashMap<String, Vec<String>>>,
115    /// Track relationships by target ID for cardinality validation.
116    relationships_by_target: HashMap<String, HashMap<String, Vec<String>>>,
117    /// Visited nodes for circular detection.
118    visited: HashSet<String>,
119}
120
121impl RelationshipGenerator {
122    /// Creates a new relationship generator.
123    pub fn new(config: RelationshipConfig, seed: u64) -> Self {
124        Self {
125            config,
126            rng: seeded_rng(seed, 0),
127            uuid_factory: DeterministicUuidFactory::new(seed, GeneratorType::Customer),
128            count: 0,
129            relationships_by_source: HashMap::new(),
130            relationships_by_target: HashMap::new(),
131            visited: HashSet::new(),
132        }
133    }
134
135    /// Creates a generator with default configuration.
136    pub fn with_defaults(seed: u64) -> Self {
137        Self::new(RelationshipConfig::default(), seed)
138    }
139
140    /// Generates relationships for a set of nodes.
141    pub fn generate_relationships(&mut self, nodes: &[NodeRef]) -> Vec<GeneratedRelationship> {
142        let mut relationships = Vec::new();
143
144        // Group nodes by type
145        let nodes_by_type = self.group_nodes_by_type(nodes);
146
147        // Clone relationship types to avoid borrow issues
148        let relationship_types = self.config.relationship_types.clone();
149
150        // For each relationship type, generate relationships
151        for rel_type in &relationship_types {
152            let rels = self.generate_for_type(rel_type, &nodes_by_type);
153            relationships.extend(rels);
154        }
155
156        relationships
157    }
158
159    /// Generates relationships for a single node.
160    pub fn generate_for_node(
161        &mut self,
162        node: &NodeRef,
163        available_targets: &HashMap<String, Vec<NodeRef>>,
164    ) -> Vec<GeneratedRelationship> {
165        // Check for orphan generation
166        if self.config.allow_orphans && self.rng.random_bool(self.config.orphan_probability) {
167            return Vec::new();
168        }
169
170        let mut relationships = Vec::new();
171
172        // Clone applicable relationship types to avoid borrow issues
173        let applicable_types: Vec<_> = self
174            .config
175            .relationship_types
176            .iter()
177            .filter(|rt| rt.source_type == node.node_type)
178            .cloned()
179            .collect();
180
181        for rel_type in &applicable_types {
182            if let Some(targets) = available_targets.get(&rel_type.target_type) {
183                let rels = self.generate_edges_for_node(node, targets, rel_type);
184                relationships.extend(rels);
185            }
186        }
187
188        relationships
189    }
190
191    /// Checks if a relationship would create a valid cardinality.
192    pub fn check_cardinality(
193        &self,
194        source_id: &str,
195        target_id: &str,
196        rel_type: &str,
197    ) -> RelationshipValidation {
198        // Find the relationship type config
199        let type_config = self
200            .config
201            .relationship_types
202            .iter()
203            .find(|rt| rt.name == rel_type);
204
205        let Some(type_config) = type_config else {
206            return RelationshipValidation::invalid(format!(
207                "Unknown relationship type: {rel_type}"
208            ));
209        };
210
211        let (_min, max) = type_config.cardinality.bounds();
212
213        // Check source-side cardinality
214        let current_count = self
215            .relationships_by_source
216            .get(source_id)
217            .and_then(|m| m.get(rel_type))
218            .map(std::vec::Vec::len)
219            .unwrap_or(0);
220
221        if current_count >= max as usize {
222            return RelationshipValidation::invalid(format!(
223                "Source {source_id} already has maximum {max} {rel_type} relationships"
224            ));
225        }
226
227        // For OneToOne and ManyToOne, check if target already has a relationship
228        if matches!(
229            type_config.cardinality,
230            CardinalityRule::OneToOne | CardinalityRule::ManyToOne { .. }
231        ) {
232            let target_count = self
233                .relationships_by_target
234                .get(target_id)
235                .and_then(|m| m.get(rel_type))
236                .map(std::vec::Vec::len)
237                .unwrap_or(0);
238
239            if target_count > 0 {
240                return RelationshipValidation::invalid(format!(
241                    "Target {target_id} already has a {rel_type} relationship"
242                ));
243            }
244        }
245
246        RelationshipValidation::valid()
247    }
248
249    /// Checks if a relationship would create a circular reference.
250    pub fn check_circular(&mut self, source_id: &str, target_id: &str) -> bool {
251        if !self.config.allow_circular {
252            // Simple check: direct circular reference
253            if source_id == target_id {
254                return true;
255            }
256
257            // DFS to check for circular paths
258            self.visited.clear();
259            self.visited.insert(source_id.to_string());
260
261            return self.has_path_to(target_id, source_id, 0);
262        }
263
264        false
265    }
266
267    /// Returns the number of relationships generated.
268    pub fn count(&self) -> u64 {
269        self.count
270    }
271
272    /// Resets the generator.
273    pub fn reset(&mut self, seed: u64) {
274        self.rng = seeded_rng(seed, 0);
275        self.uuid_factory = DeterministicUuidFactory::new(seed, GeneratorType::Customer);
276        self.count = 0;
277        self.relationships_by_source.clear();
278        self.relationships_by_target.clear();
279        self.visited.clear();
280    }
281
282    /// Returns the configuration.
283    pub fn config(&self) -> &RelationshipConfig {
284        &self.config
285    }
286
287    /// Groups nodes by their type.
288    fn group_nodes_by_type(&self, nodes: &[NodeRef]) -> HashMap<String, Vec<NodeRef>> {
289        let mut grouped: HashMap<String, Vec<NodeRef>> = HashMap::new();
290
291        for node in nodes {
292            grouped
293                .entry(node.node_type.clone())
294                .or_default()
295                .push(node.clone());
296        }
297
298        grouped
299    }
300
301    /// Generates relationships for a specific relationship type.
302    fn generate_for_type(
303        &mut self,
304        rel_type: &RelationshipTypeConfig,
305        nodes_by_type: &HashMap<String, Vec<NodeRef>>,
306    ) -> Vec<GeneratedRelationship> {
307        let mut relationships = Vec::new();
308
309        let Some(source_nodes) = nodes_by_type.get(&rel_type.source_type) else {
310            return relationships;
311        };
312
313        let Some(target_nodes) = nodes_by_type.get(&rel_type.target_type) else {
314            return relationships;
315        };
316
317        for source in source_nodes {
318            let rels = self.generate_edges_for_node(source, target_nodes, rel_type);
319            relationships.extend(rels);
320        }
321
322        relationships
323    }
324
325    /// Generates edges from a single source node.
326    fn generate_edges_for_node(
327        &mut self,
328        source: &NodeRef,
329        targets: &[NodeRef],
330        rel_type: &RelationshipTypeConfig,
331    ) -> Vec<GeneratedRelationship> {
332        let mut relationships = Vec::new();
333
334        if targets.is_empty() {
335            return relationships;
336        }
337
338        // Determine number of relationships based on cardinality
339        let (min, max) = rel_type.cardinality.bounds();
340        let count = if min == max {
341            min as usize
342        } else {
343            self.rng.random_range(min..=max) as usize
344        };
345
346        // Filter available targets
347        let available_targets: Vec<_> = targets
348            .iter()
349            .filter(|t| {
350                // Check if this relationship is valid
351                let validation = self.check_cardinality(&source.id, &t.id, &rel_type.name);
352                if !validation.valid {
353                    return false;
354                }
355
356                // Check for circular references
357                if self.check_circular(&source.id, &t.id) {
358                    return false;
359                }
360
361                true
362            })
363            .collect();
364
365        if available_targets.is_empty() && rel_type.required {
366            // Log warning or handle required relationship with no valid targets
367            return relationships;
368        }
369
370        // Select targets
371        let selected_count = count.min(available_targets.len());
372        let mut selected_indices: Vec<usize> = (0..available_targets.len()).collect();
373        selected_indices.shuffle(&mut self.rng);
374        selected_indices.truncate(selected_count);
375
376        for idx in selected_indices {
377            let target = available_targets[idx];
378            let relationship = self.create_relationship(source, target, rel_type);
379
380            // Track the relationship for cardinality validation
381            self.track_relationship(&source.id, &target.id, &rel_type.name);
382
383            relationships.push(relationship);
384        }
385
386        relationships
387    }
388
389    /// Creates a single relationship.
390    fn create_relationship(
391        &mut self,
392        source: &NodeRef,
393        target: &NodeRef,
394        rel_type: &RelationshipTypeConfig,
395    ) -> GeneratedRelationship {
396        self.count += 1;
397
398        let id = self.uuid_factory.next().to_string();
399        let properties = self.generate_properties(source, target, &rel_type.properties);
400
401        let metadata = RelationshipMetadata {
402            source: "datasynth".to_string(),
403            generated_at: Utc::now(),
404            weight: Some(rel_type.weight),
405            valid_from: None,
406            valid_to: None,
407            labels: HashMap::new(),
408            features: None,
409            is_directed: rel_type.directed,
410        };
411
412        GeneratedRelationship {
413            relationship_type: rel_type.name.clone(),
414            id,
415            source_id: source.id.clone(),
416            target_id: target.id.clone(),
417            properties,
418            metadata,
419        }
420    }
421
422    /// Generates properties for a relationship.
423    fn generate_properties(
424        &mut self,
425        source: &NodeRef,
426        target: &NodeRef,
427        rules: &[super::rules::PropertyGenerationRule],
428    ) -> HashMap<String, Value> {
429        let mut properties = HashMap::new();
430
431        for rule in rules {
432            let value =
433                self.generate_property_value(source, target, &rule.generator, &rule.value_type);
434            properties.insert(rule.name.clone(), value);
435        }
436
437        properties
438    }
439
440    /// Generates a single property value.
441    fn generate_property_value(
442        &mut self,
443        source: &NodeRef,
444        target: &NodeRef,
445        generator: &PropertyGenerator,
446        value_type: &PropertyValueType,
447    ) -> Value {
448        match generator {
449            PropertyGenerator::Constant(value) => value.clone(),
450
451            PropertyGenerator::RandomChoice(choices) => {
452                if choices.is_empty() {
453                    Value::Null
454                } else {
455                    let idx = self.rng.random_range(0..choices.len());
456                    choices[idx].clone()
457                }
458            }
459
460            PropertyGenerator::Range { min, max } => {
461                let value = self.rng.random_range(*min..=*max);
462                match value_type {
463                    PropertyValueType::Integer => {
464                        Value::Number(serde_json::Number::from(value as i64))
465                    }
466                    _ => Value::Number(
467                        serde_json::Number::from_f64(value)
468                            .unwrap_or_else(|| serde_json::Number::from(0)),
469                    ),
470                }
471            }
472
473            PropertyGenerator::FromSourceProperty(prop_name) => source
474                .properties
475                .get(prop_name)
476                .cloned()
477                .unwrap_or(Value::Null),
478
479            PropertyGenerator::FromTargetProperty(prop_name) => target
480                .properties
481                .get(prop_name)
482                .cloned()
483                .unwrap_or(Value::Null),
484
485            PropertyGenerator::Uuid => Value::String(self.uuid_factory.next().to_string()),
486
487            PropertyGenerator::Timestamp => Value::String(Utc::now().to_rfc3339()),
488        }
489    }
490
491    /// Tracks a relationship for cardinality validation.
492    fn track_relationship(&mut self, source_id: &str, target_id: &str, rel_type: &str) {
493        // Track by source
494        self.relationships_by_source
495            .entry(source_id.to_string())
496            .or_default()
497            .entry(rel_type.to_string())
498            .or_default()
499            .push(target_id.to_string());
500
501        // Track by target
502        self.relationships_by_target
503            .entry(target_id.to_string())
504            .or_default()
505            .entry(rel_type.to_string())
506            .or_default()
507            .push(source_id.to_string());
508    }
509
510    /// DFS to check if there's a path from current to target.
511    fn has_path_to(&mut self, current: &str, target: &str, depth: u32) -> bool {
512        if depth >= self.config.max_circular_depth {
513            return false;
514        }
515
516        if current == target {
517            return true;
518        }
519
520        if self.visited.contains(current) {
521            return false;
522        }
523
524        self.visited.insert(current.to_string());
525
526        // Collect all next nodes to avoid holding borrow during recursion
527        let next_nodes: Vec<String> = self
528            .relationships_by_source
529            .get(current)
530            .map(|outgoing| outgoing.values().flatten().cloned().collect())
531            .unwrap_or_default();
532
533        // Now check paths without holding the borrow
534        for next in next_nodes {
535            if self.has_path_to(&next, target, depth + 1) {
536                return true;
537            }
538        }
539
540        false
541    }
542}
543
544/// Builder for relationship configuration.
545pub struct RelationshipConfigBuilder {
546    config: RelationshipConfig,
547}
548
549impl RelationshipConfigBuilder {
550    /// Creates a new builder.
551    pub fn new() -> Self {
552        Self {
553            config: RelationshipConfig::default(),
554        }
555    }
556
557    /// Adds a relationship type.
558    pub fn add_type(mut self, type_config: RelationshipTypeConfig) -> Self {
559        self.config.relationship_types.push(type_config);
560        self
561    }
562
563    /// Sets whether orphans are allowed.
564    pub fn allow_orphans(mut self, allow: bool) -> Self {
565        self.config.allow_orphans = allow;
566        self
567    }
568
569    /// Sets the orphan probability.
570    pub fn orphan_probability(mut self, prob: f64) -> Self {
571        self.config.orphan_probability = prob.clamp(0.0, 1.0);
572        self
573    }
574
575    /// Sets whether circular relationships are allowed.
576    pub fn allow_circular(mut self, allow: bool) -> Self {
577        self.config.allow_circular = allow;
578        self
579    }
580
581    /// Sets the maximum circular depth.
582    pub fn max_circular_depth(mut self, depth: u32) -> Self {
583        self.config.max_circular_depth = depth;
584        self
585    }
586
587    /// Builds the configuration.
588    pub fn build(self) -> RelationshipConfig {
589        self.config
590    }
591}
592
593impl Default for RelationshipConfigBuilder {
594    fn default() -> Self {
595        Self::new()
596    }
597}
598
599#[cfg(test)]
600#[allow(clippy::unwrap_used)]
601mod tests {
602    use super::*;
603
604    fn create_test_nodes() -> Vec<NodeRef> {
605        vec![
606            NodeRef::new("je_1", "journal_entry"),
607            NodeRef::new("je_2", "journal_entry"),
608            NodeRef::new("acc_1", "account"),
609            NodeRef::new("acc_2", "account"),
610            NodeRef::new("acc_3", "account"),
611            NodeRef::new("user_1", "user"),
612        ]
613    }
614
615    #[test]
616    fn test_generate_relationships() {
617        let config = RelationshipConfig::with_types(vec![RelationshipTypeConfig::new(
618            "debits",
619            "journal_entry",
620            "account",
621        )
622        .with_cardinality(CardinalityRule::one_to_many(1, 2))]);
623
624        let mut generator = RelationshipGenerator::new(config, 42);
625        let nodes = create_test_nodes();
626        let relationships = generator.generate_relationships(&nodes);
627
628        assert!(!relationships.is_empty());
629        for rel in &relationships {
630            assert_eq!(rel.relationship_type, "debits");
631            assert!(rel.source_id.starts_with("je_"));
632            assert!(rel.target_id.starts_with("acc_"));
633        }
634    }
635
636    #[test]
637    fn test_cardinality_validation() {
638        let config = RelationshipConfig::with_types(vec![RelationshipTypeConfig::new(
639            "debits",
640            "journal_entry",
641            "account",
642        )
643        .with_cardinality(CardinalityRule::one_to_one())]);
644
645        let generator = RelationshipGenerator::new(config, 42);
646
647        let validation = generator.check_cardinality("je_1", "acc_1", "debits");
648        assert!(validation.valid);
649
650        let validation = generator.check_cardinality("je_1", "acc_1", "unknown");
651        assert!(!validation.valid);
652    }
653
654    #[test]
655    fn test_circular_detection() {
656        let config = RelationshipConfig::default()
657            .allow_circular(false)
658            .max_circular_depth(3);
659
660        let mut generator = RelationshipGenerator::new(config, 42);
661
662        // Direct circular
663        assert!(generator.check_circular("a", "a"));
664
665        // No circular (different nodes)
666        assert!(!generator.check_circular("a", "b"));
667    }
668
669    #[test]
670    fn test_property_generation() {
671        let config = RelationshipConfig::with_types(vec![RelationshipTypeConfig::new(
672            "test", "source", "target",
673        )
674        .with_property(super::super::rules::PropertyGenerationRule::range(
675            "amount", 100.0, 1000.0,
676        ))
677        .with_property(
678            super::super::rules::PropertyGenerationRule::constant_string("status", "active"),
679        )]);
680
681        let mut generator = RelationshipGenerator::new(config, 42);
682        let nodes = vec![NodeRef::new("s1", "source"), NodeRef::new("t1", "target")];
683
684        let relationships = generator.generate_relationships(&nodes);
685
686        assert!(!relationships.is_empty());
687        let rel = &relationships[0];
688        assert!(rel.properties.contains_key("amount"));
689        assert!(rel.properties.contains_key("status"));
690        assert_eq!(
691            rel.properties.get("status"),
692            Some(&Value::String("active".into()))
693        );
694    }
695
696    #[test]
697    fn test_orphan_generation() {
698        let config = RelationshipConfig::with_types(vec![RelationshipTypeConfig::new(
699            "test", "source", "target",
700        )
701        .with_cardinality(CardinalityRule::one_to_one())])
702        .allow_orphans(true)
703        .orphan_probability(1.0); // Always create orphans
704
705        let mut generator = RelationshipGenerator::new(config, 42);
706
707        let source = NodeRef::new("s1", "source");
708        let available: HashMap<String, Vec<NodeRef>> =
709            [("target".to_string(), vec![NodeRef::new("t1", "target")])]
710                .into_iter()
711                .collect();
712
713        let relationships = generator.generate_for_node(&source, &available);
714        assert!(relationships.is_empty());
715    }
716
717    #[test]
718    fn test_config_builder() {
719        let config = RelationshipConfigBuilder::new()
720            .add_type(RelationshipTypeConfig::new("test", "a", "b"))
721            .allow_orphans(false)
722            .orphan_probability(0.1)
723            .allow_circular(true)
724            .max_circular_depth(5)
725            .build();
726
727        assert_eq!(config.relationship_types.len(), 1);
728        assert!(!config.allow_orphans);
729        assert_eq!(config.orphan_probability, 0.1);
730        assert!(config.allow_circular);
731        assert_eq!(config.max_circular_depth, 5);
732    }
733
734    #[test]
735    fn test_generator_count_and_reset() {
736        let config = RelationshipConfig::with_types(vec![RelationshipTypeConfig::new(
737            "test", "source", "target",
738        )
739        .with_cardinality(CardinalityRule::one_to_one())]);
740
741        let mut generator = RelationshipGenerator::new(config, 42);
742        assert_eq!(generator.count(), 0);
743
744        let nodes = vec![NodeRef::new("s1", "source"), NodeRef::new("t1", "target")];
745        generator.generate_relationships(&nodes);
746
747        assert!(generator.count() > 0);
748
749        generator.reset(42);
750        assert_eq!(generator.count(), 0);
751    }
752}