Skip to main content

datasynth_graph/models/
nodes.rs

1//! Node models for graph representation.
2
3use chrono::NaiveDate;
4use rust_decimal::Decimal;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// Unique identifier for a node.
9pub type NodeId = u64;
10
11/// Type of node in the graph.
12#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
13pub enum NodeType {
14    /// GL Account node.
15    Account,
16    /// Journal Entry document node.
17    JournalEntry,
18    /// Vendor node.
19    Vendor,
20    /// Customer node.
21    Customer,
22    /// User/Employee node.
23    User,
24    /// Company/Legal Entity node.
25    Company,
26    /// Cost Center node.
27    CostCenter,
28    /// Profit Center node.
29    ProfitCenter,
30    /// Material node.
31    Material,
32    /// Fixed Asset node.
33    FixedAsset,
34    /// Custom node type.
35    Custom(String),
36}
37
38impl NodeType {
39    /// Returns the type name as a string.
40    pub fn as_str(&self) -> &str {
41        match self {
42            NodeType::Account => "Account",
43            NodeType::JournalEntry => "JournalEntry",
44            NodeType::Vendor => "Vendor",
45            NodeType::Customer => "Customer",
46            NodeType::User => "User",
47            NodeType::Company => "Company",
48            NodeType::CostCenter => "CostCenter",
49            NodeType::ProfitCenter => "ProfitCenter",
50            NodeType::Material => "Material",
51            NodeType::FixedAsset => "FixedAsset",
52            NodeType::Custom(s) => s,
53        }
54    }
55}
56
57/// A node in the graph.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct GraphNode {
60    /// Unique node ID.
61    pub id: NodeId,
62    /// Node type.
63    pub node_type: NodeType,
64    /// External ID (e.g., account code, vendor ID).
65    pub external_id: String,
66    /// Node label for display.
67    pub label: String,
68    /// Numeric features for ML.
69    pub features: Vec<f64>,
70    /// Categorical features (will be one-hot encoded).
71    pub categorical_features: HashMap<String, String>,
72    /// Node properties.
73    pub properties: HashMap<String, NodeProperty>,
74    /// Labels for supervised learning.
75    pub labels: Vec<String>,
76    /// Is this node an anomaly?
77    pub is_anomaly: bool,
78    /// Anomaly type if anomalous.
79    pub anomaly_type: Option<String>,
80}
81
82impl GraphNode {
83    /// Creates a new graph node.
84    pub fn new(id: NodeId, node_type: NodeType, external_id: String, label: String) -> Self {
85        Self {
86            id,
87            node_type,
88            external_id,
89            label,
90            features: Vec::new(),
91            categorical_features: HashMap::new(),
92            properties: HashMap::new(),
93            labels: Vec::new(),
94            is_anomaly: false,
95            anomaly_type: None,
96        }
97    }
98
99    /// Adds a numeric feature.
100    pub fn with_feature(mut self, value: f64) -> Self {
101        self.features.push(value);
102        self
103    }
104
105    /// Adds multiple numeric features.
106    pub fn with_features(mut self, values: Vec<f64>) -> Self {
107        self.features.extend(values);
108        self
109    }
110
111    /// Adds a categorical feature.
112    pub fn with_categorical(mut self, name: &str, value: &str) -> Self {
113        self.categorical_features
114            .insert(name.to_string(), value.to_string());
115        self
116    }
117
118    /// Adds a property.
119    pub fn with_property(mut self, name: &str, value: NodeProperty) -> Self {
120        self.properties.insert(name.to_string(), value);
121        self
122    }
123
124    /// Marks the node as anomalous.
125    pub fn as_anomaly(mut self, anomaly_type: &str) -> Self {
126        self.is_anomaly = true;
127        self.anomaly_type = Some(anomaly_type.to_string());
128        self
129    }
130
131    /// Adds a label.
132    pub fn with_label(mut self, label: &str) -> Self {
133        self.labels.push(label.to_string());
134        self
135    }
136
137    /// Returns the feature vector dimension.
138    pub fn feature_dim(&self) -> usize {
139        self.features.len()
140    }
141
142    /// Create a graph node from any type implementing `ToNodeProperties`.
143    ///
144    /// This bridges the domain model structs (with `ToNodeProperties`) to the
145    /// graph export pipeline by converting all typed properties into `NodeProperty` values.
146    pub fn from_entity(id: NodeId, entity: &dyn datasynth_core::models::ToNodeProperties) -> Self {
147        let type_name = entity.node_type_name();
148        let mut node = GraphNode::new(
149            id,
150            NodeType::Custom(type_name.to_string()),
151            type_name.to_string(),
152            type_name.to_string(),
153        );
154        for (key, value) in entity.to_node_properties() {
155            node.properties.insert(key, NodeProperty::from(value));
156        }
157        node
158    }
159}
160
161/// Property value for a node.
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub enum NodeProperty {
164    /// String value.
165    String(String),
166    /// Integer value.
167    Int(i64),
168    /// Float value.
169    Float(f64),
170    /// Decimal value.
171    Decimal(Decimal),
172    /// Boolean value.
173    Bool(bool),
174    /// Date value.
175    Date(NaiveDate),
176    /// List of strings.
177    StringList(Vec<String>),
178}
179
180impl NodeProperty {
181    /// Converts to string representation.
182    pub fn to_string_value(&self) -> String {
183        match self {
184            NodeProperty::String(s) => s.clone(),
185            NodeProperty::Int(i) => i.to_string(),
186            NodeProperty::Float(f) => f.to_string(),
187            NodeProperty::Decimal(d) => d.to_string(),
188            NodeProperty::Bool(b) => b.to_string(),
189            NodeProperty::Date(d) => d.to_string(),
190            NodeProperty::StringList(v) => v.join(","),
191        }
192    }
193
194    /// Converts to numeric value (for features).
195    pub fn to_numeric(&self) -> Option<f64> {
196        match self {
197            NodeProperty::Int(i) => Some(*i as f64),
198            NodeProperty::Float(f) => Some(*f),
199            NodeProperty::Decimal(d) => (*d).try_into().ok(),
200            NodeProperty::Bool(b) => Some(if *b { 1.0 } else { 0.0 }),
201            _ => None,
202        }
203    }
204}
205
206impl From<datasynth_core::models::GraphPropertyValue> for NodeProperty {
207    fn from(v: datasynth_core::models::GraphPropertyValue) -> Self {
208        use datasynth_core::models::GraphPropertyValue;
209        match v {
210            GraphPropertyValue::String(s) => NodeProperty::String(s),
211            GraphPropertyValue::Int(i) => NodeProperty::Int(i),
212            GraphPropertyValue::Float(f) => NodeProperty::Float(f),
213            GraphPropertyValue::Decimal(d) => NodeProperty::Decimal(d),
214            GraphPropertyValue::Bool(b) => NodeProperty::Bool(b),
215            GraphPropertyValue::Date(d) => NodeProperty::Date(d),
216            GraphPropertyValue::StringList(v) => NodeProperty::StringList(v),
217        }
218    }
219}
220
221/// Account node with accounting-specific features.
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct AccountNode {
224    /// Base node.
225    pub node: GraphNode,
226    /// Account code.
227    pub account_code: String,
228    /// Account name.
229    pub account_name: String,
230    /// Account type (Asset, Liability, etc.).
231    pub account_type: String,
232    /// Account category.
233    pub account_category: Option<String>,
234    /// Is balance sheet account.
235    pub is_balance_sheet: bool,
236    /// Normal balance (Debit/Credit).
237    pub normal_balance: String,
238    /// Company code.
239    pub company_code: String,
240    /// Country code (ISO 3166-1 alpha-2) of the owning company.
241    pub country: Option<String>,
242}
243
244impl AccountNode {
245    /// Creates a new account node.
246    pub fn new(
247        id: NodeId,
248        account_code: String,
249        account_name: String,
250        account_type: String,
251        company_code: String,
252    ) -> Self {
253        let node = GraphNode::new(
254            id,
255            NodeType::Account,
256            account_code.clone(),
257            format!("{account_code} - {account_name}"),
258        );
259
260        Self {
261            node,
262            account_code,
263            account_name,
264            account_type,
265            account_category: None,
266            is_balance_sheet: false,
267            normal_balance: "Debit".to_string(),
268            company_code,
269            country: None,
270        }
271    }
272
273    /// Computes features for the account node.
274    pub fn compute_features(&mut self) {
275        // Account type encoding
276        let type_feature = match self.account_type.as_str() {
277            "Asset" => 0.0,
278            "Liability" => 1.0,
279            "Equity" => 2.0,
280            "Revenue" => 3.0,
281            "Expense" => 4.0,
282            _ => 5.0,
283        };
284        self.node.features.push(type_feature);
285
286        // Balance sheet indicator
287        self.node
288            .features
289            .push(if self.is_balance_sheet { 1.0 } else { 0.0 });
290
291        // Normal balance encoding
292        self.node.features.push(if self.normal_balance == "Debit" {
293            1.0
294        } else {
295            0.0
296        });
297
298        // Account code as normalized numeric feature [0, 1]
299        // Parse up to 4 leading digits and divide by 10000.
300        let code_prefix: String = self
301            .account_code
302            .chars()
303            .take(4)
304            .take_while(|c| c.is_ascii_digit())
305            .collect();
306        if let Ok(code_num) = code_prefix.parse::<f64>() {
307            self.node.features.push(code_num / 10000.0);
308        } else {
309            self.node.features.push(0.0);
310        }
311
312        // Add categorical features
313        self.node
314            .categorical_features
315            .insert("account_type".to_string(), self.account_type.clone());
316        self.node
317            .categorical_features
318            .insert("company_code".to_string(), self.company_code.clone());
319        if let Some(ref country) = self.country {
320            self.node
321                .categorical_features
322                .insert("country".to_string(), country.clone());
323        }
324    }
325}
326
327/// User node for approval networks.
328#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct UserNode {
330    /// Base node.
331    pub node: GraphNode,
332    /// User ID.
333    pub user_id: String,
334    /// User name.
335    pub user_name: String,
336    /// Department.
337    pub department: Option<String>,
338    /// Role.
339    pub role: Option<String>,
340    /// Manager ID.
341    pub manager_id: Option<String>,
342    /// Approval limit.
343    pub approval_limit: Option<Decimal>,
344    /// Is active.
345    pub is_active: bool,
346}
347
348impl UserNode {
349    /// Creates a new user node.
350    pub fn new(id: NodeId, user_id: String, user_name: String) -> Self {
351        let node = GraphNode::new(id, NodeType::User, user_id.clone(), user_name.clone());
352
353        Self {
354            node,
355            user_id,
356            user_name,
357            department: None,
358            role: None,
359            manager_id: None,
360            approval_limit: None,
361            is_active: true,
362        }
363    }
364
365    /// Computes features for the user node.
366    pub fn compute_features(&mut self) {
367        // Active status
368        self.node
369            .features
370            .push(if self.is_active { 1.0 } else { 0.0 });
371
372        // Approval limit (log-scaled)
373        if let Some(limit) = self.approval_limit {
374            let limit_f64: f64 = limit.try_into().unwrap_or(0.0);
375            self.node.features.push((limit_f64 + 1.0).ln());
376        } else {
377            self.node.features.push(0.0);
378        }
379
380        // Add categorical features
381        if let Some(ref dept) = self.department {
382            self.node
383                .categorical_features
384                .insert("department".to_string(), dept.clone());
385        }
386        if let Some(ref role) = self.role {
387            self.node
388                .categorical_features
389                .insert("role".to_string(), role.clone());
390        }
391    }
392}
393
394/// Company/Entity node for entity relationship graphs.
395#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct CompanyNode {
397    /// Base node.
398    pub node: GraphNode,
399    /// Company code.
400    pub company_code: String,
401    /// Company name.
402    pub company_name: String,
403    /// Country.
404    pub country: String,
405    /// Currency.
406    pub currency: String,
407    /// Is parent company.
408    pub is_parent: bool,
409    /// Parent company code.
410    pub parent_code: Option<String>,
411    /// Ownership percentage (if subsidiary).
412    pub ownership_percent: Option<Decimal>,
413}
414
415impl CompanyNode {
416    /// Creates a new company node.
417    pub fn new(id: NodeId, company_code: String, company_name: String) -> Self {
418        let node = GraphNode::new(
419            id,
420            NodeType::Company,
421            company_code.clone(),
422            company_name.clone(),
423        );
424
425        Self {
426            node,
427            company_code,
428            company_name,
429            country: "US".to_string(),
430            currency: "USD".to_string(),
431            is_parent: false,
432            parent_code: None,
433            ownership_percent: None,
434        }
435    }
436
437    /// Computes features for the company node.
438    pub fn compute_features(&mut self) {
439        // Is parent
440        self.node
441            .features
442            .push(if self.is_parent { 1.0 } else { 0.0 });
443
444        // Ownership percentage
445        if let Some(pct) = self.ownership_percent {
446            let pct_f64: f64 = pct.try_into().unwrap_or(0.0);
447            self.node.features.push(pct_f64 / 100.0);
448        } else {
449            self.node.features.push(1.0); // 100% for parent
450        }
451
452        // Add categorical features
453        self.node
454            .categorical_features
455            .insert("country".to_string(), self.country.clone());
456        self.node
457            .categorical_features
458            .insert("currency".to_string(), self.currency.clone());
459    }
460}
461
462#[cfg(test)]
463#[allow(clippy::unwrap_used)]
464mod tests {
465    use super::*;
466
467    #[test]
468    fn test_graph_node_creation() {
469        let node = GraphNode::new(1, NodeType::Account, "1000".to_string(), "Cash".to_string())
470            .with_feature(100.0)
471            .with_categorical("type", "Asset");
472
473        assert_eq!(node.id, 1);
474        assert_eq!(node.features.len(), 1);
475        assert!(node.categorical_features.contains_key("type"));
476    }
477
478    #[test]
479    fn test_account_node() {
480        let mut account = AccountNode::new(
481            1,
482            "1000".to_string(),
483            "Cash".to_string(),
484            "Asset".to_string(),
485            "1000".to_string(),
486        );
487        account.is_balance_sheet = true;
488        account.compute_features();
489
490        assert!(!account.node.features.is_empty());
491    }
492
493    #[test]
494    fn test_from_graph_property_value() {
495        use datasynth_core::models::GraphPropertyValue;
496
497        let prop: NodeProperty = GraphPropertyValue::Bool(true).into();
498        assert!(matches!(prop, NodeProperty::Bool(true)));
499
500        let prop: NodeProperty = GraphPropertyValue::Int(42).into();
501        assert!(matches!(prop, NodeProperty::Int(42)));
502
503        let prop: NodeProperty = GraphPropertyValue::String("hello".into()).into();
504        assert!(matches!(prop, NodeProperty::String(ref s) if s == "hello"));
505    }
506
507    #[test]
508    fn test_from_entity() {
509        use datasynth_core::models::{GraphPropertyValue, ToNodeProperties};
510        use std::collections::HashMap;
511
512        struct TestEntity;
513        impl ToNodeProperties for TestEntity {
514            fn node_type_name(&self) -> &'static str {
515                "test_entity"
516            }
517            fn node_type_code(&self) -> u16 {
518                999
519            }
520            fn to_node_properties(&self) -> HashMap<String, GraphPropertyValue> {
521                let mut p = HashMap::new();
522                p.insert("name".into(), GraphPropertyValue::String("Test".into()));
523                p.insert("active".into(), GraphPropertyValue::Bool(true));
524                p
525            }
526        }
527
528        let node = GraphNode::from_entity(42, &TestEntity);
529        assert_eq!(node.id, 42);
530        assert_eq!(node.node_type, NodeType::Custom("test_entity".into()));
531        assert!(node.properties.contains_key("name"));
532        assert!(node.properties.contains_key("active"));
533    }
534}