1use crate::persona::PersonaProfile;
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::sync::{Arc, RwLock};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct PersonaNode {
14 pub persona_id: String,
16 pub entity_type: String,
18 pub relationships: HashMap<String, Vec<String>>,
22 #[serde(default)]
24 pub metadata: HashMap<String, serde_json::Value>,
25}
26
27impl PersonaNode {
28 pub fn new(persona_id: String, entity_type: String) -> Self {
30 Self {
31 persona_id,
32 entity_type,
33 relationships: HashMap::new(),
34 metadata: HashMap::new(),
35 }
36 }
37
38 pub fn add_relationship(&mut self, relationship_type: String, related_persona_id: String) {
40 self.relationships
41 .entry(relationship_type)
42 .or_insert_with(Vec::new)
43 .push(related_persona_id);
44 }
45
46 pub fn get_related(&self, relationship_type: &str) -> Vec<String> {
48 self.relationships.get(relationship_type).cloned().unwrap_or_default()
49 }
50
51 pub fn get_relationship_types(&self) -> Vec<String> {
53 self.relationships.keys().cloned().collect()
54 }
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct Edge {
60 pub from: String,
62 pub to: String,
64 pub relationship_type: String,
66 #[serde(default = "default_edge_weight")]
68 pub weight: f64,
69}
70
71fn default_edge_weight() -> f64 {
72 1.0
73}
74
75#[derive(Debug, Clone)]
80pub struct PersonaGraph {
81 nodes: Arc<RwLock<HashMap<String, PersonaNode>>>,
83 edges: Arc<RwLock<HashMap<String, Vec<Edge>>>>,
85 reverse_edges: Arc<RwLock<HashMap<String, Vec<Edge>>>>,
87}
88
89impl PersonaGraph {
90 pub fn new() -> Self {
92 Self {
93 nodes: Arc::new(RwLock::new(HashMap::new())),
94 edges: Arc::new(RwLock::new(HashMap::new())),
95 reverse_edges: Arc::new(RwLock::new(HashMap::new())),
96 }
97 }
98
99 pub fn add_node(&self, node: PersonaNode) {
101 let mut nodes = self.nodes.write().unwrap();
102 nodes.insert(node.persona_id.clone(), node);
103 }
104
105 pub fn get_node(&self, persona_id: &str) -> Option<PersonaNode> {
107 let nodes = self.nodes.read().unwrap();
108 nodes.get(persona_id).cloned()
109 }
110
111 pub fn add_edge(&self, from: String, to: String, relationship_type: String) {
113 let to_clone = to.clone();
114 let edge = Edge {
115 from: from.clone(),
116 to: to_clone.clone(),
117 relationship_type: relationship_type.clone(),
118 weight: 1.0,
119 };
120
121 let mut edges = self.edges.write().unwrap();
123 edges.entry(from.clone()).or_insert_with(Vec::new).push(edge.clone());
124
125 let mut reverse_edges = self.reverse_edges.write().unwrap();
127 reverse_edges.entry(to_clone.clone()).or_insert_with(Vec::new).push(edge);
128
129 if let Some(node) = self.get_node(&from) {
131 let mut updated_node = node;
132 updated_node.add_relationship(relationship_type, to_clone);
133 self.add_node(updated_node);
134 }
135 }
136
137 pub fn get_edges_from(&self, persona_id: &str) -> Vec<Edge> {
139 let edges = self.edges.read().unwrap();
140 edges.get(persona_id).cloned().unwrap_or_default()
141 }
142
143 pub fn get_edges_to(&self, persona_id: &str) -> Vec<Edge> {
145 let reverse_edges = self.reverse_edges.read().unwrap();
146 reverse_edges.get(persona_id).cloned().unwrap_or_default()
147 }
148
149 pub fn find_related_bfs(
162 &self,
163 start_persona_id: &str,
164 relationship_types: Option<&[String]>,
165 max_depth: Option<usize>,
166 ) -> Vec<String> {
167 let mut visited = HashSet::new();
168 let mut queue = VecDeque::new();
169 let mut result = Vec::new();
170
171 queue.push_back((start_persona_id.to_string(), 0));
172 visited.insert(start_persona_id.to_string());
173
174 while let Some((current_id, depth)) = queue.pop_front() {
175 if let Some(max) = max_depth {
176 if depth >= max {
177 continue;
178 }
179 }
180
181 let edges = self.get_edges_from(¤t_id);
182 for edge in edges {
183 if let Some(types) = relationship_types {
185 if !types.contains(&edge.relationship_type) {
186 continue;
187 }
188 }
189
190 if !visited.contains(&edge.to) {
191 visited.insert(edge.to.clone());
192 result.push(edge.to.clone());
193 queue.push_back((edge.to.clone(), depth + 1));
194 }
195 }
196 }
197
198 result
199 }
200
201 pub fn find_related_dfs(
214 &self,
215 start_persona_id: &str,
216 relationship_types: Option<&[String]>,
217 max_depth: Option<usize>,
218 ) -> Vec<String> {
219 let mut visited = HashSet::new();
220 let mut result = Vec::new();
221
222 self.dfs_recursive(
223 start_persona_id,
224 relationship_types,
225 max_depth,
226 0,
227 &mut visited,
228 &mut result,
229 );
230
231 result
232 }
233
234 fn dfs_recursive(
236 &self,
237 current_id: &str,
238 relationship_types: Option<&[String]>,
239 max_depth: Option<usize>,
240 current_depth: usize,
241 visited: &mut HashSet<String>,
242 result: &mut Vec<String>,
243 ) {
244 if visited.contains(current_id) {
245 return;
246 }
247
248 if let Some(max) = max_depth {
249 if current_depth >= max {
250 return;
251 }
252 }
253
254 visited.insert(current_id.to_string());
255 if current_depth > 0 {
256 result.push(current_id.to_string());
258 }
259
260 let edges = self.get_edges_from(current_id);
261 for edge in edges {
262 if let Some(types) = relationship_types {
264 if !types.contains(&edge.relationship_type) {
265 continue;
266 }
267 }
268
269 self.dfs_recursive(
270 &edge.to,
271 relationship_types,
272 max_depth,
273 current_depth + 1,
274 visited,
275 result,
276 );
277 }
278 }
279
280 pub fn get_subgraph(&self, start_persona_id: &str) -> (Vec<PersonaNode>, Vec<Edge>) {
284 let related_ids = self.find_related_bfs(start_persona_id, None, None);
285 let mut all_ids = vec![start_persona_id.to_string()];
286 all_ids.extend(related_ids);
287
288 let nodes = self.nodes.read().unwrap();
289 let edges = self.edges.read().unwrap();
290
291 let subgraph_nodes: Vec<PersonaNode> =
292 all_ids.iter().filter_map(|id| nodes.get(id).cloned()).collect();
293
294 let subgraph_edges: Vec<Edge> = all_ids
295 .iter()
296 .flat_map(|id| edges.get(id).cloned().unwrap_or_default())
297 .filter(|edge| all_ids.contains(&edge.to))
298 .collect();
299
300 (subgraph_nodes, subgraph_edges)
301 }
302
303 pub fn get_all_nodes(&self) -> Vec<PersonaNode> {
305 let nodes = self.nodes.read().unwrap();
306 nodes.values().cloned().collect()
307 }
308
309 pub fn remove_node(&self, persona_id: &str) {
311 let mut nodes = self.nodes.write().unwrap();
312 nodes.remove(persona_id);
313
314 let mut edges = self.edges.write().unwrap();
316 edges.remove(persona_id);
317
318 let mut reverse_edges = self.reverse_edges.write().unwrap();
320 reverse_edges.remove(persona_id);
321
322 for edges_list in edges.values_mut() {
324 edges_list.retain(|e| e.to != persona_id);
325 }
326 for edges_list in reverse_edges.values_mut() {
327 edges_list.retain(|e| e.from != persona_id);
328 }
329 }
330
331 pub fn clear(&self) {
333 let mut nodes = self.nodes.write().unwrap();
334 nodes.clear();
335
336 let mut edges = self.edges.write().unwrap();
337 edges.clear();
338
339 let mut reverse_edges = self.reverse_edges.write().unwrap();
340 reverse_edges.clear();
341 }
342
343 pub fn get_stats(&self) -> GraphStats {
345 let nodes = self.nodes.read().unwrap();
346 let edges = self.edges.read().unwrap();
347
348 let mut relationship_type_counts = HashMap::new();
349 for edges_list in edges.values() {
350 for edge in edges_list {
351 *relationship_type_counts.entry(edge.relationship_type.clone()).or_insert(0) += 1;
352 }
353 }
354
355 GraphStats {
356 node_count: nodes.len(),
357 edge_count: edges.values().map(|e| e.len()).sum(),
358 relationship_types: relationship_type_counts,
359 }
360 }
361
362 pub fn link_entity_types(
377 &self,
378 from_persona_id: &str,
379 from_entity_type: &str,
380 to_persona_id: &str,
381 to_entity_type: &str,
382 ) {
383 let relationship_type: String = match (from_entity_type, to_entity_type) {
385 ("user", "order") | ("user", "orders") => "has_orders".to_string(),
386 ("user", "account") | ("user", "accounts") => "has_accounts".to_string(),
387 ("user", "webhook") | ("user", "webhooks") => "has_webhooks".to_string(),
388 ("user", "tcp_message") | ("user", "tcp_messages") => "has_tcp_messages".to_string(),
389 ("order", "payment") | ("order", "payments") => "has_payments".to_string(),
390 ("account", "order") | ("account", "orders") => "has_orders".to_string(),
391 ("account", "payment") | ("account", "payments") => "has_payments".to_string(),
392 _ => {
393 format!("has_{}", to_entity_type.to_lowercase().trim_end_matches('s'))
395 }
396 };
397
398 if self.get_node(from_persona_id).is_none() {
400 let node = PersonaNode::new(from_persona_id.to_string(), from_entity_type.to_string());
401 self.add_node(node);
402 }
403
404 if self.get_node(to_persona_id).is_none() {
405 let node = PersonaNode::new(to_persona_id.to_string(), to_entity_type.to_string());
406 self.add_node(node);
407 }
408
409 self.add_edge(
411 from_persona_id.to_string(),
412 to_persona_id.to_string(),
413 relationship_type.to_string(),
414 );
415 }
416
417 pub fn find_related_by_entity_type(
430 &self,
431 start_persona_id: &str,
432 target_entity_type: &str,
433 relationship_type: Option<&str>,
434 ) -> Vec<String> {
435 let related_ids = if let Some(rel_type) = relationship_type {
436 let rel_types = vec![rel_type.to_string()];
437 self.find_related_bfs(start_persona_id, Some(&rel_types), Some(2))
438 } else {
439 self.find_related_bfs(start_persona_id, None, Some(2))
440 };
441
442 related_ids
444 .into_iter()
445 .filter_map(|persona_id| {
446 if let Some(node) = self.get_node(&persona_id) {
447 if node.entity_type.to_lowercase() == target_entity_type.to_lowercase() {
448 Some(persona_id)
449 } else {
450 None
451 }
452 } else {
453 None
454 }
455 })
456 .collect()
457 }
458
459 pub fn get_or_create_node_with_links(
470 &self,
471 persona_id: &str,
472 entity_type: &str,
473 related_entity_id: Option<&str>,
474 related_entity_type: Option<&str>,
475 ) -> PersonaNode {
476 let node = if let Some(existing) = self.get_node(persona_id) {
478 existing
479 } else {
480 let new_node = PersonaNode::new(persona_id.to_string(), entity_type.to_string());
481 self.add_node(new_node.clone());
482 new_node
483 };
484
485 if let (Some(related_id), Some(related_type)) = (related_entity_id, related_entity_type) {
487 self.link_entity_types(persona_id, entity_type, related_id, related_type);
488 }
489
490 node
491 }
492}
493
494impl Default for PersonaGraph {
495 fn default() -> Self {
496 Self::new()
497 }
498}
499
500#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct GraphStats {
503 pub node_count: usize,
505 pub edge_count: usize,
507 pub relationship_types: HashMap<String, usize>,
509}
510
511#[derive(Debug, Clone, Serialize, Deserialize)]
513pub struct GraphVisualization {
514 pub nodes: Vec<VisualizationNode>,
516 pub edges: Vec<VisualizationEdge>,
518}
519
520#[derive(Debug, Clone, Serialize, Deserialize)]
522pub struct VisualizationNode {
523 pub id: String,
525 pub entity_type: String,
527 pub label: String,
529 #[serde(skip_serializing_if = "Option::is_none")]
531 pub position: Option<(f64, f64)>,
532}
533
534#[derive(Debug, Clone, Serialize, Deserialize)]
536pub struct VisualizationEdge {
537 pub from: String,
539 pub to: String,
541 pub relationship_type: String,
543 pub label: String,
545}
546
547impl PersonaGraph {
548 pub fn to_visualization(&self) -> GraphVisualization {
550 let nodes = self.get_all_nodes();
551 let edges = self.edges.read().unwrap();
552
553 let vis_nodes: Vec<VisualizationNode> = nodes
554 .iter()
555 .map(|node| VisualizationNode {
556 id: node.persona_id.clone(),
557 entity_type: node.entity_type.clone(),
558 label: format!("{} ({})", node.persona_id, node.entity_type),
559 position: None,
560 })
561 .collect();
562
563 let vis_edges: Vec<VisualizationEdge> = edges
564 .values()
565 .flatten()
566 .map(|edge| VisualizationEdge {
567 from: edge.from.clone(),
568 to: edge.to.clone(),
569 relationship_type: edge.relationship_type.clone(),
570 label: edge.relationship_type.clone(),
571 })
572 .collect();
573
574 GraphVisualization {
575 nodes: vis_nodes,
576 edges: vis_edges,
577 }
578 }
579}