Skip to main content

argentor_builtins/
knowledge_graph_skill.rs

1//! Knowledge graph skill — exposes the in-memory knowledge graph as a callable skill.
2//!
3//! Supported operations: `add_entity`, `add_relationship`, `query_entity`,
4//! `find_related`, `context`, `summarize`.
5
6use argentor_core::{ArgentorResult, ToolCall, ToolResult};
7use argentor_memory::KnowledgeGraph;
8use argentor_security::Capability;
9use argentor_skills::skill::{Skill, SkillDescriptor};
10use async_trait::async_trait;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14/// Skill that wraps a [`KnowledgeGraph`] and exposes entity-relationship operations.
15pub struct KnowledgeGraphSkill {
16    descriptor: SkillDescriptor,
17    graph: Arc<RwLock<KnowledgeGraph>>,
18}
19
20impl KnowledgeGraphSkill {
21    /// Create a new knowledge graph skill with the given shared graph.
22    pub fn new(graph: Arc<RwLock<KnowledgeGraph>>) -> Self {
23        Self {
24            descriptor: SkillDescriptor {
25                name: "knowledge_graph".to_string(),
26                description:
27                    "Query and manipulate the knowledge graph of entities and relationships. \
28                     Supports operations: add_entity, add_relationship, query_entity, \
29                     find_related, context, summarize."
30                        .to_string(),
31                parameters_schema: serde_json::json!({
32                    "type": "object",
33                    "properties": {
34                        "operation": {
35                            "type": "string",
36                            "enum": ["add_entity", "add_relationship", "query_entity",
37                                     "find_related", "context", "summarize"],
38                            "description": "The operation to perform"
39                        },
40                        "name": {
41                            "type": "string",
42                            "description": "Entity name (for add_entity, query_entity)"
43                        },
44                        "entity_type": {
45                            "type": "string",
46                            "description": "Entity type: Person, Organization, Concept, Tool, File, Location, Event, Fact"
47                        },
48                        "entity_id": {
49                            "type": "string",
50                            "description": "Entity ID (for context, find_related)"
51                        },
52                        "from_entity": {
53                            "type": "string",
54                            "description": "Source entity ID (for add_relationship)"
55                        },
56                        "to_entity": {
57                            "type": "string",
58                            "description": "Target entity ID (for add_relationship)"
59                        },
60                        "relation_type": {
61                            "type": "string",
62                            "description": "Relationship type: IsA, HasProperty, RelatedTo, DependsOn, CreatedBy, Contains, WorksWith, Mentions, UsedTool, ProducedOutput"
63                        },
64                        "properties": {
65                            "type": "object",
66                            "description": "Key-value properties for entity or relationship",
67                            "additionalProperties": true
68                        },
69                        "depth": {
70                            "type": "integer",
71                            "description": "Traversal depth for context (default: 1)",
72                            "default": 1
73                        },
74                        "source": {
75                            "type": "string",
76                            "description": "Origin of the data: user, agent, tool_result, extracted",
77                            "default": "agent"
78                        }
79                    },
80                    "required": ["operation"]
81                }),
82                required_capabilities: vec![Capability::DatabaseQuery],
83                requires_approval: false,
84            },
85            graph,
86        }
87    }
88}
89
90#[async_trait]
91impl Skill for KnowledgeGraphSkill {
92    fn descriptor(&self) -> &SkillDescriptor {
93        &self.descriptor
94    }
95
96    async fn execute(&self, call: ToolCall) -> ArgentorResult<ToolResult> {
97        let op = call.arguments["operation"].as_str().unwrap_or_default();
98
99        match op {
100            "add_entity" => self.op_add_entity(&call).await,
101            "add_relationship" => self.op_add_relationship(&call).await,
102            "query_entity" => self.op_query_entity(&call).await,
103            "find_related" => self.op_find_related(&call).await,
104            "context" => self.op_context(&call).await,
105            "summarize" => self.op_summarize(&call).await,
106            other => Ok(ToolResult::error(
107                &call.id,
108                format!("Unknown operation: '{other}'. Use one of: add_entity, add_relationship, query_entity, find_related, context, summarize"),
109            )),
110        }
111    }
112}
113
114impl KnowledgeGraphSkill {
115    async fn op_add_entity(&self, call: &ToolCall) -> ArgentorResult<ToolResult> {
116        let name = call.arguments["name"]
117            .as_str()
118            .unwrap_or_default()
119            .to_string();
120        if name.is_empty() {
121            return Ok(ToolResult::error(&call.id, "Entity 'name' is required"));
122        }
123
124        let entity_type = parse_entity_type(
125            call.arguments
126                .get("entity_type")
127                .and_then(|v| v.as_str())
128                .unwrap_or("Concept"),
129        );
130
131        let properties: std::collections::HashMap<String, serde_json::Value> = call
132            .arguments
133            .get("properties")
134            .and_then(|p| serde_json::from_value(p.clone()).ok())
135            .unwrap_or_default();
136
137        let source = call.arguments["source"]
138            .as_str()
139            .unwrap_or("agent")
140            .to_string();
141
142        let now = chrono::Utc::now();
143        let entity = argentor_memory::Entity {
144            id: String::new(),
145            name: name.clone(),
146            entity_type,
147            properties,
148            created_at: now,
149            updated_at: now,
150            confidence: 1.0,
151            source,
152        };
153
154        let mut graph = self.graph.write().await;
155        let id = graph.add_entity(entity);
156
157        let response = serde_json::json!({
158            "added": true,
159            "entity_id": id,
160            "name": name,
161        });
162        Ok(ToolResult::success(&call.id, response.to_string()))
163    }
164
165    async fn op_add_relationship(&self, call: &ToolCall) -> ArgentorResult<ToolResult> {
166        let from = call.arguments["from_entity"]
167            .as_str()
168            .unwrap_or_default()
169            .to_string();
170        let to = call.arguments["to_entity"]
171            .as_str()
172            .unwrap_or_default()
173            .to_string();
174
175        if from.is_empty() || to.is_empty() {
176            return Ok(ToolResult::error(
177                &call.id,
178                "'from_entity' and 'to_entity' are required",
179            ));
180        }
181
182        let relation_type = parse_relation_type(
183            call.arguments
184                .get("relation_type")
185                .and_then(|v| v.as_str())
186                .unwrap_or("RelatedTo"),
187        );
188
189        let properties: std::collections::HashMap<String, serde_json::Value> = call
190            .arguments
191            .get("properties")
192            .and_then(|p| serde_json::from_value(p.clone()).ok())
193            .unwrap_or_default();
194
195        let source = call.arguments["source"]
196            .as_str()
197            .unwrap_or("agent")
198            .to_string();
199
200        let rel = argentor_memory::Relationship {
201            id: String::new(),
202            from_entity: from.clone(),
203            to_entity: to.clone(),
204            relation_type,
205            properties,
206            weight: 1.0,
207            created_at: chrono::Utc::now(),
208            source,
209        };
210
211        let mut graph = self.graph.write().await;
212        let id = graph.add_relationship(rel);
213
214        let response = serde_json::json!({
215            "added": true,
216            "relationship_id": id,
217            "from": from,
218            "to": to,
219        });
220        Ok(ToolResult::success(&call.id, response.to_string()))
221    }
222
223    async fn op_query_entity(&self, call: &ToolCall) -> ArgentorResult<ToolResult> {
224        let name = call.arguments["name"]
225            .as_str()
226            .unwrap_or_default()
227            .to_string();
228        if name.is_empty() {
229            return Ok(ToolResult::error(
230                &call.id,
231                "Entity 'name' is required for query",
232            ));
233        }
234
235        let graph = self.graph.read().await;
236        let entities = graph.find_entities(&name);
237
238        let results: Vec<serde_json::Value> = entities
239            .iter()
240            .map(|e| {
241                serde_json::json!({
242                    "id": e.id,
243                    "name": e.name,
244                    "entity_type": format!("{}", e.entity_type),
245                    "properties": e.properties,
246                    "confidence": e.confidence,
247                    "source": e.source,
248                })
249            })
250            .collect();
251
252        let response = serde_json::json!({
253            "query": name,
254            "results": results,
255            "total": results.len(),
256        });
257        Ok(ToolResult::success(&call.id, response.to_string()))
258    }
259
260    async fn op_find_related(&self, call: &ToolCall) -> ArgentorResult<ToolResult> {
261        let entity_id = call.arguments["entity_id"]
262            .as_str()
263            .unwrap_or_default()
264            .to_string();
265        if entity_id.is_empty() {
266            return Ok(ToolResult::error(
267                &call.id,
268                "'entity_id' is required for find_related",
269            ));
270        }
271
272        let depth = call.arguments["depth"].as_u64().unwrap_or(1) as usize;
273
274        let graph = self.graph.read().await;
275        let neighbors = graph.neighbors(&entity_id, depth);
276
277        let results: Vec<serde_json::Value> = neighbors
278            .iter()
279            .map(|e| {
280                serde_json::json!({
281                    "id": e.id,
282                    "name": e.name,
283                    "entity_type": format!("{}", e.entity_type),
284                })
285            })
286            .collect();
287
288        let response = serde_json::json!({
289            "entity_id": entity_id,
290            "depth": depth,
291            "related": results,
292            "total": results.len(),
293        });
294        Ok(ToolResult::success(&call.id, response.to_string()))
295    }
296
297    async fn op_context(&self, call: &ToolCall) -> ArgentorResult<ToolResult> {
298        let entity_id = call.arguments["entity_id"]
299            .as_str()
300            .unwrap_or_default()
301            .to_string();
302        if entity_id.is_empty() {
303            return Ok(ToolResult::error(
304                &call.id,
305                "'entity_id' is required for context",
306            ));
307        }
308        let depth = call.arguments["depth"].as_u64().unwrap_or(1) as usize;
309
310        let graph = self.graph.read().await;
311        let ctx = graph.to_context_string(&entity_id, depth);
312
313        Ok(ToolResult::success(&call.id, ctx))
314    }
315
316    async fn op_summarize(&self, call: &ToolCall) -> ArgentorResult<ToolResult> {
317        let graph = self.graph.read().await;
318        let summary = graph.summarize();
319
320        let response = serde_json::json!({
321            "entity_count": summary.entity_count,
322            "relationship_count": summary.relationship_count,
323            "entity_types": summary.entity_types,
324            "relationship_types": summary.relationship_types,
325            "most_connected": summary.most_connected,
326        });
327        Ok(ToolResult::success(&call.id, response.to_string()))
328    }
329}
330
331// ---------------------------------------------------------------------------
332// Helpers
333// ---------------------------------------------------------------------------
334
335fn parse_entity_type(s: &str) -> argentor_memory::EntityType {
336    match s {
337        "Person" => argentor_memory::EntityType::Person,
338        "Organization" => argentor_memory::EntityType::Organization,
339        "Concept" => argentor_memory::EntityType::Concept,
340        "Tool" => argentor_memory::EntityType::Tool,
341        "File" => argentor_memory::EntityType::File,
342        "Location" => argentor_memory::EntityType::Location,
343        "Event" => argentor_memory::EntityType::Event,
344        "Fact" => argentor_memory::EntityType::Fact,
345        other => argentor_memory::EntityType::Custom(other.to_string()),
346    }
347}
348
349fn parse_relation_type(s: &str) -> argentor_memory::RelationType {
350    match s {
351        "IsA" => argentor_memory::RelationType::IsA,
352        "HasProperty" => argentor_memory::RelationType::HasProperty,
353        "RelatedTo" => argentor_memory::RelationType::RelatedTo,
354        "DependsOn" => argentor_memory::RelationType::DependsOn,
355        "CreatedBy" => argentor_memory::RelationType::CreatedBy,
356        "Contains" => argentor_memory::RelationType::Contains,
357        "WorksWith" => argentor_memory::RelationType::WorksWith,
358        "Mentions" => argentor_memory::RelationType::Mentions,
359        "UsedTool" => argentor_memory::RelationType::UsedTool,
360        "ProducedOutput" => argentor_memory::RelationType::ProducedOutput,
361        other => argentor_memory::RelationType::Custom(other.to_string()),
362    }
363}
364
365// ===========================================================================
366// Tests
367// ===========================================================================
368
369#[cfg(test)]
370#[allow(clippy::unwrap_used, clippy::expect_used)]
371mod tests {
372    use super::*;
373
374    fn make_skill() -> KnowledgeGraphSkill {
375        let graph = Arc::new(RwLock::new(KnowledgeGraph::new()));
376        KnowledgeGraphSkill::new(graph)
377    }
378
379    #[test]
380    fn test_descriptor() {
381        let skill = make_skill();
382        assert_eq!(skill.descriptor().name, "knowledge_graph");
383    }
384
385    #[tokio::test]
386    async fn test_add_entity_operation() {
387        let skill = make_skill();
388        let call = ToolCall {
389            id: "t1".to_string(),
390            name: "knowledge_graph".to_string(),
391            arguments: serde_json::json!({
392                "operation": "add_entity",
393                "name": "Alice",
394                "entity_type": "Person"
395            }),
396        };
397        let result = skill.execute(call).await.unwrap();
398        assert!(!result.is_error);
399        assert!(result.content.contains("\"added\":true"));
400        assert!(result.content.contains("Alice"));
401    }
402
403    #[tokio::test]
404    async fn test_add_entity_missing_name() {
405        let skill = make_skill();
406        let call = ToolCall {
407            id: "t2".to_string(),
408            name: "knowledge_graph".to_string(),
409            arguments: serde_json::json!({
410                "operation": "add_entity",
411                "entity_type": "Person"
412            }),
413        };
414        let result = skill.execute(call).await.unwrap();
415        assert!(result.is_error);
416    }
417
418    #[tokio::test]
419    async fn test_add_relationship_operation() {
420        let skill = make_skill();
421
422        // Add two entities first
423        let call_a = ToolCall {
424            id: "a".to_string(),
425            name: "knowledge_graph".to_string(),
426            arguments: serde_json::json!({"operation": "add_entity", "name": "A", "entity_type": "Concept"}),
427        };
428        let res_a = skill.execute(call_a).await.unwrap();
429        let parsed_a: serde_json::Value = serde_json::from_str(&res_a.content).unwrap();
430        let id_a = parsed_a["entity_id"].as_str().unwrap().to_string();
431
432        let call_b = ToolCall {
433            id: "b".to_string(),
434            name: "knowledge_graph".to_string(),
435            arguments: serde_json::json!({"operation": "add_entity", "name": "B", "entity_type": "Concept"}),
436        };
437        let res_b = skill.execute(call_b).await.unwrap();
438        let parsed_b: serde_json::Value = serde_json::from_str(&res_b.content).unwrap();
439        let id_b = parsed_b["entity_id"].as_str().unwrap().to_string();
440
441        // Add relationship
442        let call_rel = ToolCall {
443            id: "r".to_string(),
444            name: "knowledge_graph".to_string(),
445            arguments: serde_json::json!({
446                "operation": "add_relationship",
447                "from_entity": id_a,
448                "to_entity": id_b,
449                "relation_type": "DependsOn"
450            }),
451        };
452        let result = skill.execute(call_rel).await.unwrap();
453        assert!(!result.is_error);
454        assert!(result.content.contains("\"added\":true"));
455    }
456
457    #[tokio::test]
458    async fn test_query_entity_operation() {
459        let skill = make_skill();
460
461        // Add entity
462        let call = ToolCall {
463            id: "a".to_string(),
464            name: "knowledge_graph".to_string(),
465            arguments: serde_json::json!({"operation": "add_entity", "name": "Rust", "entity_type": "Concept"}),
466        };
467        skill.execute(call).await.unwrap();
468
469        // Query
470        let call_q = ToolCall {
471            id: "q".to_string(),
472            name: "knowledge_graph".to_string(),
473            arguments: serde_json::json!({"operation": "query_entity", "name": "rust"}),
474        };
475        let result = skill.execute(call_q).await.unwrap();
476        assert!(!result.is_error);
477        let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
478        assert_eq!(parsed["total"].as_u64().unwrap(), 1);
479    }
480
481    #[tokio::test]
482    async fn test_summarize_operation() {
483        let skill = make_skill();
484        let call = ToolCall {
485            id: "s".to_string(),
486            name: "knowledge_graph".to_string(),
487            arguments: serde_json::json!({"operation": "summarize"}),
488        };
489        let result = skill.execute(call).await.unwrap();
490        assert!(!result.is_error);
491        let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
492        assert_eq!(parsed["entity_count"].as_u64().unwrap(), 0);
493    }
494
495    #[tokio::test]
496    async fn test_unknown_operation() {
497        let skill = make_skill();
498        let call = ToolCall {
499            id: "u".to_string(),
500            name: "knowledge_graph".to_string(),
501            arguments: serde_json::json!({"operation": "foobar"}),
502        };
503        let result = skill.execute(call).await.unwrap();
504        assert!(result.is_error);
505        assert!(result.content.contains("Unknown operation"));
506    }
507
508    #[tokio::test]
509    async fn test_context_operation() {
510        let skill = make_skill();
511
512        // Add entity
513        let call = ToolCall {
514            id: "a".to_string(),
515            name: "knowledge_graph".to_string(),
516            arguments: serde_json::json!({"operation": "add_entity", "name": "Alice", "entity_type": "Person"}),
517        };
518        let res = skill.execute(call).await.unwrap();
519        let parsed: serde_json::Value = serde_json::from_str(&res.content).unwrap();
520        let id = parsed["entity_id"].as_str().unwrap().to_string();
521
522        // Context
523        let call_ctx = ToolCall {
524            id: "c".to_string(),
525            name: "knowledge_graph".to_string(),
526            arguments: serde_json::json!({"operation": "context", "entity_id": id}),
527        };
528        let result = skill.execute(call_ctx).await.unwrap();
529        assert!(!result.is_error);
530        assert!(result.content.contains("Alice"));
531    }
532
533    #[tokio::test]
534    async fn test_find_related_operation() {
535        let skill = make_skill();
536
537        // Add entities
538        let call_a = ToolCall {
539            id: "a".to_string(),
540            name: "knowledge_graph".to_string(),
541            arguments: serde_json::json!({"operation": "add_entity", "name": "A", "entity_type": "Concept"}),
542        };
543        let res_a = skill.execute(call_a).await.unwrap();
544        let id_a: String = serde_json::from_str::<serde_json::Value>(&res_a.content).unwrap()
545            ["entity_id"]
546            .as_str()
547            .unwrap()
548            .to_string();
549
550        let call_b = ToolCall {
551            id: "b".to_string(),
552            name: "knowledge_graph".to_string(),
553            arguments: serde_json::json!({"operation": "add_entity", "name": "B", "entity_type": "Concept"}),
554        };
555        let res_b = skill.execute(call_b).await.unwrap();
556        let id_b: String = serde_json::from_str::<serde_json::Value>(&res_b.content).unwrap()
557            ["entity_id"]
558            .as_str()
559            .unwrap()
560            .to_string();
561
562        // Add relationship
563        let call_rel = ToolCall {
564            id: "r".to_string(),
565            name: "knowledge_graph".to_string(),
566            arguments: serde_json::json!({
567                "operation": "add_relationship",
568                "from_entity": id_a,
569                "to_entity": id_b,
570                "relation_type": "RelatedTo"
571            }),
572        };
573        skill.execute(call_rel).await.unwrap();
574
575        // Find related
576        let call_find = ToolCall {
577            id: "f".to_string(),
578            name: "knowledge_graph".to_string(),
579            arguments: serde_json::json!({"operation": "find_related", "entity_id": id_a}),
580        };
581        let result = skill.execute(call_find).await.unwrap();
582        assert!(!result.is_error);
583        let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
584        assert_eq!(parsed["total"].as_u64().unwrap(), 1);
585    }
586}