1use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet, VecDeque};
8use std::sync::{Arc, RwLock};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PersonaNode {
13 pub persona_id: String,
15 pub entity_type: String,
17 pub relationships: HashMap<String, Vec<String>>,
21 #[serde(default)]
23 pub metadata: HashMap<String, serde_json::Value>,
24}
25
26impl PersonaNode {
27 pub fn new(persona_id: String, entity_type: String) -> Self {
29 Self {
30 persona_id,
31 entity_type,
32 relationships: HashMap::new(),
33 metadata: HashMap::new(),
34 }
35 }
36
37 pub fn add_relationship(&mut self, relationship_type: String, related_persona_id: String) {
39 self.relationships
40 .entry(relationship_type)
41 .or_default()
42 .push(related_persona_id);
43 }
44
45 pub fn get_related(&self, relationship_type: &str) -> Vec<String> {
47 self.relationships.get(relationship_type).cloned().unwrap_or_default()
48 }
49
50 pub fn get_relationship_types(&self) -> Vec<String> {
52 self.relationships.keys().cloned().collect()
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Edge {
59 pub from: String,
61 pub to: String,
63 pub relationship_type: String,
65 #[serde(default = "default_edge_weight")]
67 pub weight: f64,
68}
69
70fn default_edge_weight() -> f64 {
71 1.0
72}
73
74#[derive(Debug, Clone)]
79pub struct PersonaGraph {
80 nodes: Arc<RwLock<HashMap<String, PersonaNode>>>,
82 edges: Arc<RwLock<HashMap<String, Vec<Edge>>>>,
84 reverse_edges: Arc<RwLock<HashMap<String, Vec<Edge>>>>,
86}
87
88impl PersonaGraph {
89 pub fn new() -> Self {
91 Self {
92 nodes: Arc::new(RwLock::new(HashMap::new())),
93 edges: Arc::new(RwLock::new(HashMap::new())),
94 reverse_edges: Arc::new(RwLock::new(HashMap::new())),
95 }
96 }
97
98 pub fn add_node(&self, node: PersonaNode) {
100 let mut nodes = self.nodes.write().unwrap();
101 nodes.insert(node.persona_id.clone(), node);
102 }
103
104 pub fn get_node(&self, persona_id: &str) -> Option<PersonaNode> {
106 let nodes = self.nodes.read().unwrap();
107 nodes.get(persona_id).cloned()
108 }
109
110 pub fn add_edge(&self, from: String, to: String, relationship_type: String) {
112 let to_clone = to.clone();
113 let edge = Edge {
114 from: from.clone(),
115 to: to_clone.clone(),
116 relationship_type: relationship_type.clone(),
117 weight: 1.0,
118 };
119
120 let mut edges = self.edges.write().unwrap();
122 edges.entry(from.clone()).or_default().push(edge.clone());
123
124 let mut reverse_edges = self.reverse_edges.write().unwrap();
126 reverse_edges.entry(to_clone.clone()).or_default().push(edge);
127
128 if let Some(node) = self.get_node(&from) {
130 let mut updated_node = node;
131 updated_node.add_relationship(relationship_type, to_clone);
132 self.add_node(updated_node);
133 }
134 }
135
136 pub fn get_edges_from(&self, persona_id: &str) -> Vec<Edge> {
138 let edges = self.edges.read().unwrap();
139 edges.get(persona_id).cloned().unwrap_or_default()
140 }
141
142 pub fn get_edges_to(&self, persona_id: &str) -> Vec<Edge> {
144 let reverse_edges = self.reverse_edges.read().unwrap();
145 reverse_edges.get(persona_id).cloned().unwrap_or_default()
146 }
147
148 pub fn find_related_bfs(
161 &self,
162 start_persona_id: &str,
163 relationship_types: Option<&[String]>,
164 max_depth: Option<usize>,
165 ) -> Vec<String> {
166 let mut visited = HashSet::new();
167 let mut queue = VecDeque::new();
168 let mut result = Vec::new();
169
170 queue.push_back((start_persona_id.to_string(), 0));
171 visited.insert(start_persona_id.to_string());
172
173 while let Some((current_id, depth)) = queue.pop_front() {
174 if let Some(max) = max_depth {
175 if depth >= max {
176 continue;
177 }
178 }
179
180 let edges = self.get_edges_from(¤t_id);
181 for edge in edges {
182 if let Some(types) = relationship_types {
184 if !types.contains(&edge.relationship_type) {
185 continue;
186 }
187 }
188
189 if !visited.contains(&edge.to) {
190 visited.insert(edge.to.clone());
191 result.push(edge.to.clone());
192 queue.push_back((edge.to.clone(), depth + 1));
193 }
194 }
195 }
196
197 result
198 }
199
200 pub fn find_related_dfs(
213 &self,
214 start_persona_id: &str,
215 relationship_types: Option<&[String]>,
216 max_depth: Option<usize>,
217 ) -> Vec<String> {
218 let mut visited = HashSet::new();
219 let mut result = Vec::new();
220
221 self.dfs_recursive(
222 start_persona_id,
223 relationship_types,
224 max_depth,
225 0,
226 &mut visited,
227 &mut result,
228 );
229
230 result
231 }
232
233 fn dfs_recursive(
235 &self,
236 current_id: &str,
237 relationship_types: Option<&[String]>,
238 max_depth: Option<usize>,
239 current_depth: usize,
240 visited: &mut HashSet<String>,
241 result: &mut Vec<String>,
242 ) {
243 if visited.contains(current_id) {
244 return;
245 }
246
247 if let Some(max) = max_depth {
248 if current_depth >= max {
249 return;
250 }
251 }
252
253 visited.insert(current_id.to_string());
254 if current_depth > 0 {
255 result.push(current_id.to_string());
257 }
258
259 let edges = self.get_edges_from(current_id);
260 for edge in edges {
261 if let Some(types) = relationship_types {
263 if !types.contains(&edge.relationship_type) {
264 continue;
265 }
266 }
267
268 self.dfs_recursive(
269 &edge.to,
270 relationship_types,
271 max_depth,
272 current_depth + 1,
273 visited,
274 result,
275 );
276 }
277 }
278
279 pub fn get_subgraph(&self, start_persona_id: &str) -> (Vec<PersonaNode>, Vec<Edge>) {
283 let related_ids = self.find_related_bfs(start_persona_id, None, None);
284 let mut all_ids = vec![start_persona_id.to_string()];
285 all_ids.extend(related_ids);
286
287 let nodes = self.nodes.read().unwrap();
288 let edges = self.edges.read().unwrap();
289
290 let subgraph_nodes: Vec<PersonaNode> =
291 all_ids.iter().filter_map(|id| nodes.get(id).cloned()).collect();
292
293 let subgraph_edges: Vec<Edge> = all_ids
294 .iter()
295 .flat_map(|id| edges.get(id).cloned().unwrap_or_default())
296 .filter(|edge| all_ids.contains(&edge.to))
297 .collect();
298
299 (subgraph_nodes, subgraph_edges)
300 }
301
302 pub fn get_all_nodes(&self) -> Vec<PersonaNode> {
304 let nodes = self.nodes.read().unwrap();
305 nodes.values().cloned().collect()
306 }
307
308 pub fn remove_node(&self, persona_id: &str) {
310 let mut nodes = self.nodes.write().unwrap();
311 nodes.remove(persona_id);
312
313 let mut edges = self.edges.write().unwrap();
315 edges.remove(persona_id);
316
317 let mut reverse_edges = self.reverse_edges.write().unwrap();
319 reverse_edges.remove(persona_id);
320
321 for edges_list in edges.values_mut() {
323 edges_list.retain(|e| e.to != persona_id);
324 }
325 for edges_list in reverse_edges.values_mut() {
326 edges_list.retain(|e| e.from != persona_id);
327 }
328 }
329
330 pub fn clear(&self) {
332 let mut nodes = self.nodes.write().unwrap();
333 nodes.clear();
334
335 let mut edges = self.edges.write().unwrap();
336 edges.clear();
337
338 let mut reverse_edges = self.reverse_edges.write().unwrap();
339 reverse_edges.clear();
340 }
341
342 pub fn get_stats(&self) -> GraphStats {
344 let nodes = self.nodes.read().unwrap();
345 let edges = self.edges.read().unwrap();
346
347 let mut relationship_type_counts = HashMap::new();
348 for edges_list in edges.values() {
349 for edge in edges_list {
350 *relationship_type_counts.entry(edge.relationship_type.clone()).or_insert(0) += 1;
351 }
352 }
353
354 GraphStats {
355 node_count: nodes.len(),
356 edge_count: edges.values().map(|e| e.len()).sum(),
357 relationship_types: relationship_type_counts,
358 }
359 }
360
361 pub fn link_entity_types(
376 &self,
377 from_persona_id: &str,
378 from_entity_type: &str,
379 to_persona_id: &str,
380 to_entity_type: &str,
381 ) {
382 let relationship_type: String = match (from_entity_type, to_entity_type) {
384 ("user", "order") | ("user", "orders") => "has_orders".to_string(),
385 ("user", "account") | ("user", "accounts") => "has_accounts".to_string(),
386 ("user", "webhook") | ("user", "webhooks") => "has_webhooks".to_string(),
387 ("user", "tcp_message") | ("user", "tcp_messages") => "has_tcp_messages".to_string(),
388 ("order", "payment") | ("order", "payments") => "has_payments".to_string(),
389 ("account", "order") | ("account", "orders") => "has_orders".to_string(),
390 ("account", "payment") | ("account", "payments") => "has_payments".to_string(),
391 _ => {
392 format!("has_{}", to_entity_type.to_lowercase().trim_end_matches('s'))
394 }
395 };
396
397 if self.get_node(from_persona_id).is_none() {
399 let node = PersonaNode::new(from_persona_id.to_string(), from_entity_type.to_string());
400 self.add_node(node);
401 }
402
403 if self.get_node(to_persona_id).is_none() {
404 let node = PersonaNode::new(to_persona_id.to_string(), to_entity_type.to_string());
405 self.add_node(node);
406 }
407
408 self.add_edge(
410 from_persona_id.to_string(),
411 to_persona_id.to_string(),
412 relationship_type.to_string(),
413 );
414 }
415
416 pub fn find_related_by_entity_type(
429 &self,
430 start_persona_id: &str,
431 target_entity_type: &str,
432 relationship_type: Option<&str>,
433 ) -> Vec<String> {
434 let related_ids = if let Some(rel_type) = relationship_type {
435 let rel_types = vec![rel_type.to_string()];
436 self.find_related_bfs(start_persona_id, Some(&rel_types), Some(2))
437 } else {
438 self.find_related_bfs(start_persona_id, None, Some(2))
439 };
440
441 related_ids
443 .into_iter()
444 .filter_map(|persona_id| {
445 if let Some(node) = self.get_node(&persona_id) {
446 if node.entity_type.to_lowercase() == target_entity_type.to_lowercase() {
447 Some(persona_id)
448 } else {
449 None
450 }
451 } else {
452 None
453 }
454 })
455 .collect()
456 }
457
458 pub fn get_or_create_node_with_links(
469 &self,
470 persona_id: &str,
471 entity_type: &str,
472 related_entity_id: Option<&str>,
473 related_entity_type: Option<&str>,
474 ) -> PersonaNode {
475 let node = if let Some(existing) = self.get_node(persona_id) {
477 existing
478 } else {
479 let new_node = PersonaNode::new(persona_id.to_string(), entity_type.to_string());
480 self.add_node(new_node.clone());
481 new_node
482 };
483
484 if let (Some(related_id), Some(related_type)) = (related_entity_id, related_entity_type) {
486 self.link_entity_types(persona_id, entity_type, related_id, related_type);
487 }
488
489 node
490 }
491}
492
493impl Default for PersonaGraph {
494 fn default() -> Self {
495 Self::new()
496 }
497}
498
499#[derive(Debug, Clone, Serialize, Deserialize)]
501pub struct GraphStats {
502 pub node_count: usize,
504 pub edge_count: usize,
506 pub relationship_types: HashMap<String, usize>,
508}
509
510#[derive(Debug, Clone, Serialize, Deserialize)]
512pub struct GraphVisualization {
513 pub nodes: Vec<VisualizationNode>,
515 pub edges: Vec<VisualizationEdge>,
517}
518
519#[derive(Debug, Clone, Serialize, Deserialize)]
521pub struct VisualizationNode {
522 pub id: String,
524 pub entity_type: String,
526 pub label: String,
528 #[serde(skip_serializing_if = "Option::is_none")]
530 pub position: Option<(f64, f64)>,
531}
532
533#[derive(Debug, Clone, Serialize, Deserialize)]
535pub struct VisualizationEdge {
536 pub from: String,
538 pub to: String,
540 pub relationship_type: String,
542 pub label: String,
544}
545
546impl PersonaGraph {
547 pub fn to_visualization(&self) -> GraphVisualization {
549 let nodes = self.get_all_nodes();
550 let edges = self.edges.read().unwrap();
551
552 let vis_nodes: Vec<VisualizationNode> = nodes
553 .iter()
554 .map(|node| VisualizationNode {
555 id: node.persona_id.clone(),
556 entity_type: node.entity_type.clone(),
557 label: format!("{} ({})", node.persona_id, node.entity_type),
558 position: None,
559 })
560 .collect();
561
562 let vis_edges: Vec<VisualizationEdge> = edges
563 .values()
564 .flatten()
565 .map(|edge| VisualizationEdge {
566 from: edge.from.clone(),
567 to: edge.to.clone(),
568 relationship_type: edge.relationship_type.clone(),
569 label: edge.relationship_type.clone(),
570 })
571 .collect();
572
573 GraphVisualization {
574 nodes: vis_nodes,
575 edges: vis_edges,
576 }
577 }
578}