1use argentor_core::{ArgentorResult, ToolCall, ToolResult};
7use argentor_memory::KnowledgeGraph;
8use argentor_security::Capability;
9use argentor_skills::skill::{Skill, SkillDescriptor};
10use async_trait::async_trait;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13
14pub struct KnowledgeGraphSkill {
16 descriptor: SkillDescriptor,
17 graph: Arc<RwLock<KnowledgeGraph>>,
18}
19
20impl KnowledgeGraphSkill {
21 pub fn new(graph: Arc<RwLock<KnowledgeGraph>>) -> Self {
23 Self {
24 descriptor: SkillDescriptor {
25 name: "knowledge_graph".to_string(),
26 description:
27 "Query and manipulate the knowledge graph of entities and relationships. \
28 Supports operations: add_entity, add_relationship, query_entity, \
29 find_related, context, summarize."
30 .to_string(),
31 parameters_schema: serde_json::json!({
32 "type": "object",
33 "properties": {
34 "operation": {
35 "type": "string",
36 "enum": ["add_entity", "add_relationship", "query_entity",
37 "find_related", "context", "summarize"],
38 "description": "The operation to perform"
39 },
40 "name": {
41 "type": "string",
42 "description": "Entity name (for add_entity, query_entity)"
43 },
44 "entity_type": {
45 "type": "string",
46 "description": "Entity type: Person, Organization, Concept, Tool, File, Location, Event, Fact"
47 },
48 "entity_id": {
49 "type": "string",
50 "description": "Entity ID (for context, find_related)"
51 },
52 "from_entity": {
53 "type": "string",
54 "description": "Source entity ID (for add_relationship)"
55 },
56 "to_entity": {
57 "type": "string",
58 "description": "Target entity ID (for add_relationship)"
59 },
60 "relation_type": {
61 "type": "string",
62 "description": "Relationship type: IsA, HasProperty, RelatedTo, DependsOn, CreatedBy, Contains, WorksWith, Mentions, UsedTool, ProducedOutput"
63 },
64 "properties": {
65 "type": "object",
66 "description": "Key-value properties for entity or relationship",
67 "additionalProperties": true
68 },
69 "depth": {
70 "type": "integer",
71 "description": "Traversal depth for context (default: 1)",
72 "default": 1
73 },
74 "source": {
75 "type": "string",
76 "description": "Origin of the data: user, agent, tool_result, extracted",
77 "default": "agent"
78 }
79 },
80 "required": ["operation"]
81 }),
82 required_capabilities: vec![Capability::DatabaseQuery],
83 requires_approval: false,
84 },
85 graph,
86 }
87 }
88}
89
90#[async_trait]
91impl Skill for KnowledgeGraphSkill {
92 fn descriptor(&self) -> &SkillDescriptor {
93 &self.descriptor
94 }
95
96 async fn execute(&self, call: ToolCall) -> ArgentorResult<ToolResult> {
97 let op = call.arguments["operation"].as_str().unwrap_or_default();
98
99 match op {
100 "add_entity" => self.op_add_entity(&call).await,
101 "add_relationship" => self.op_add_relationship(&call).await,
102 "query_entity" => self.op_query_entity(&call).await,
103 "find_related" => self.op_find_related(&call).await,
104 "context" => self.op_context(&call).await,
105 "summarize" => self.op_summarize(&call).await,
106 other => Ok(ToolResult::error(
107 &call.id,
108 format!("Unknown operation: '{other}'. Use one of: add_entity, add_relationship, query_entity, find_related, context, summarize"),
109 )),
110 }
111 }
112}
113
114impl KnowledgeGraphSkill {
115 async fn op_add_entity(&self, call: &ToolCall) -> ArgentorResult<ToolResult> {
116 let name = call.arguments["name"]
117 .as_str()
118 .unwrap_or_default()
119 .to_string();
120 if name.is_empty() {
121 return Ok(ToolResult::error(&call.id, "Entity 'name' is required"));
122 }
123
124 let entity_type = parse_entity_type(
125 call.arguments
126 .get("entity_type")
127 .and_then(|v| v.as_str())
128 .unwrap_or("Concept"),
129 );
130
131 let properties: std::collections::HashMap<String, serde_json::Value> = call
132 .arguments
133 .get("properties")
134 .and_then(|p| serde_json::from_value(p.clone()).ok())
135 .unwrap_or_default();
136
137 let source = call.arguments["source"]
138 .as_str()
139 .unwrap_or("agent")
140 .to_string();
141
142 let now = chrono::Utc::now();
143 let entity = argentor_memory::Entity {
144 id: String::new(),
145 name: name.clone(),
146 entity_type,
147 properties,
148 created_at: now,
149 updated_at: now,
150 confidence: 1.0,
151 source,
152 };
153
154 let mut graph = self.graph.write().await;
155 let id = graph.add_entity(entity);
156
157 let response = serde_json::json!({
158 "added": true,
159 "entity_id": id,
160 "name": name,
161 });
162 Ok(ToolResult::success(&call.id, response.to_string()))
163 }
164
165 async fn op_add_relationship(&self, call: &ToolCall) -> ArgentorResult<ToolResult> {
166 let from = call.arguments["from_entity"]
167 .as_str()
168 .unwrap_or_default()
169 .to_string();
170 let to = call.arguments["to_entity"]
171 .as_str()
172 .unwrap_or_default()
173 .to_string();
174
175 if from.is_empty() || to.is_empty() {
176 return Ok(ToolResult::error(
177 &call.id,
178 "'from_entity' and 'to_entity' are required",
179 ));
180 }
181
182 let relation_type = parse_relation_type(
183 call.arguments
184 .get("relation_type")
185 .and_then(|v| v.as_str())
186 .unwrap_or("RelatedTo"),
187 );
188
189 let properties: std::collections::HashMap<String, serde_json::Value> = call
190 .arguments
191 .get("properties")
192 .and_then(|p| serde_json::from_value(p.clone()).ok())
193 .unwrap_or_default();
194
195 let source = call.arguments["source"]
196 .as_str()
197 .unwrap_or("agent")
198 .to_string();
199
200 let rel = argentor_memory::Relationship {
201 id: String::new(),
202 from_entity: from.clone(),
203 to_entity: to.clone(),
204 relation_type,
205 properties,
206 weight: 1.0,
207 created_at: chrono::Utc::now(),
208 source,
209 };
210
211 let mut graph = self.graph.write().await;
212 let id = graph.add_relationship(rel);
213
214 let response = serde_json::json!({
215 "added": true,
216 "relationship_id": id,
217 "from": from,
218 "to": to,
219 });
220 Ok(ToolResult::success(&call.id, response.to_string()))
221 }
222
223 async fn op_query_entity(&self, call: &ToolCall) -> ArgentorResult<ToolResult> {
224 let name = call.arguments["name"]
225 .as_str()
226 .unwrap_or_default()
227 .to_string();
228 if name.is_empty() {
229 return Ok(ToolResult::error(
230 &call.id,
231 "Entity 'name' is required for query",
232 ));
233 }
234
235 let graph = self.graph.read().await;
236 let entities = graph.find_entities(&name);
237
238 let results: Vec<serde_json::Value> = entities
239 .iter()
240 .map(|e| {
241 serde_json::json!({
242 "id": e.id,
243 "name": e.name,
244 "entity_type": format!("{}", e.entity_type),
245 "properties": e.properties,
246 "confidence": e.confidence,
247 "source": e.source,
248 })
249 })
250 .collect();
251
252 let response = serde_json::json!({
253 "query": name,
254 "results": results,
255 "total": results.len(),
256 });
257 Ok(ToolResult::success(&call.id, response.to_string()))
258 }
259
260 async fn op_find_related(&self, call: &ToolCall) -> ArgentorResult<ToolResult> {
261 let entity_id = call.arguments["entity_id"]
262 .as_str()
263 .unwrap_or_default()
264 .to_string();
265 if entity_id.is_empty() {
266 return Ok(ToolResult::error(
267 &call.id,
268 "'entity_id' is required for find_related",
269 ));
270 }
271
272 let depth = call.arguments["depth"].as_u64().unwrap_or(1) as usize;
273
274 let graph = self.graph.read().await;
275 let neighbors = graph.neighbors(&entity_id, depth);
276
277 let results: Vec<serde_json::Value> = neighbors
278 .iter()
279 .map(|e| {
280 serde_json::json!({
281 "id": e.id,
282 "name": e.name,
283 "entity_type": format!("{}", e.entity_type),
284 })
285 })
286 .collect();
287
288 let response = serde_json::json!({
289 "entity_id": entity_id,
290 "depth": depth,
291 "related": results,
292 "total": results.len(),
293 });
294 Ok(ToolResult::success(&call.id, response.to_string()))
295 }
296
297 async fn op_context(&self, call: &ToolCall) -> ArgentorResult<ToolResult> {
298 let entity_id = call.arguments["entity_id"]
299 .as_str()
300 .unwrap_or_default()
301 .to_string();
302 if entity_id.is_empty() {
303 return Ok(ToolResult::error(
304 &call.id,
305 "'entity_id' is required for context",
306 ));
307 }
308 let depth = call.arguments["depth"].as_u64().unwrap_or(1) as usize;
309
310 let graph = self.graph.read().await;
311 let ctx = graph.to_context_string(&entity_id, depth);
312
313 Ok(ToolResult::success(&call.id, ctx))
314 }
315
316 async fn op_summarize(&self, call: &ToolCall) -> ArgentorResult<ToolResult> {
317 let graph = self.graph.read().await;
318 let summary = graph.summarize();
319
320 let response = serde_json::json!({
321 "entity_count": summary.entity_count,
322 "relationship_count": summary.relationship_count,
323 "entity_types": summary.entity_types,
324 "relationship_types": summary.relationship_types,
325 "most_connected": summary.most_connected,
326 });
327 Ok(ToolResult::success(&call.id, response.to_string()))
328 }
329}
330
331fn parse_entity_type(s: &str) -> argentor_memory::EntityType {
336 match s {
337 "Person" => argentor_memory::EntityType::Person,
338 "Organization" => argentor_memory::EntityType::Organization,
339 "Concept" => argentor_memory::EntityType::Concept,
340 "Tool" => argentor_memory::EntityType::Tool,
341 "File" => argentor_memory::EntityType::File,
342 "Location" => argentor_memory::EntityType::Location,
343 "Event" => argentor_memory::EntityType::Event,
344 "Fact" => argentor_memory::EntityType::Fact,
345 other => argentor_memory::EntityType::Custom(other.to_string()),
346 }
347}
348
349fn parse_relation_type(s: &str) -> argentor_memory::RelationType {
350 match s {
351 "IsA" => argentor_memory::RelationType::IsA,
352 "HasProperty" => argentor_memory::RelationType::HasProperty,
353 "RelatedTo" => argentor_memory::RelationType::RelatedTo,
354 "DependsOn" => argentor_memory::RelationType::DependsOn,
355 "CreatedBy" => argentor_memory::RelationType::CreatedBy,
356 "Contains" => argentor_memory::RelationType::Contains,
357 "WorksWith" => argentor_memory::RelationType::WorksWith,
358 "Mentions" => argentor_memory::RelationType::Mentions,
359 "UsedTool" => argentor_memory::RelationType::UsedTool,
360 "ProducedOutput" => argentor_memory::RelationType::ProducedOutput,
361 other => argentor_memory::RelationType::Custom(other.to_string()),
362 }
363}
364
365#[cfg(test)]
370#[allow(clippy::unwrap_used, clippy::expect_used)]
371mod tests {
372 use super::*;
373
374 fn make_skill() -> KnowledgeGraphSkill {
375 let graph = Arc::new(RwLock::new(KnowledgeGraph::new()));
376 KnowledgeGraphSkill::new(graph)
377 }
378
379 #[test]
380 fn test_descriptor() {
381 let skill = make_skill();
382 assert_eq!(skill.descriptor().name, "knowledge_graph");
383 }
384
385 #[tokio::test]
386 async fn test_add_entity_operation() {
387 let skill = make_skill();
388 let call = ToolCall {
389 id: "t1".to_string(),
390 name: "knowledge_graph".to_string(),
391 arguments: serde_json::json!({
392 "operation": "add_entity",
393 "name": "Alice",
394 "entity_type": "Person"
395 }),
396 };
397 let result = skill.execute(call).await.unwrap();
398 assert!(!result.is_error);
399 assert!(result.content.contains("\"added\":true"));
400 assert!(result.content.contains("Alice"));
401 }
402
403 #[tokio::test]
404 async fn test_add_entity_missing_name() {
405 let skill = make_skill();
406 let call = ToolCall {
407 id: "t2".to_string(),
408 name: "knowledge_graph".to_string(),
409 arguments: serde_json::json!({
410 "operation": "add_entity",
411 "entity_type": "Person"
412 }),
413 };
414 let result = skill.execute(call).await.unwrap();
415 assert!(result.is_error);
416 }
417
418 #[tokio::test]
419 async fn test_add_relationship_operation() {
420 let skill = make_skill();
421
422 let call_a = ToolCall {
424 id: "a".to_string(),
425 name: "knowledge_graph".to_string(),
426 arguments: serde_json::json!({"operation": "add_entity", "name": "A", "entity_type": "Concept"}),
427 };
428 let res_a = skill.execute(call_a).await.unwrap();
429 let parsed_a: serde_json::Value = serde_json::from_str(&res_a.content).unwrap();
430 let id_a = parsed_a["entity_id"].as_str().unwrap().to_string();
431
432 let call_b = ToolCall {
433 id: "b".to_string(),
434 name: "knowledge_graph".to_string(),
435 arguments: serde_json::json!({"operation": "add_entity", "name": "B", "entity_type": "Concept"}),
436 };
437 let res_b = skill.execute(call_b).await.unwrap();
438 let parsed_b: serde_json::Value = serde_json::from_str(&res_b.content).unwrap();
439 let id_b = parsed_b["entity_id"].as_str().unwrap().to_string();
440
441 let call_rel = ToolCall {
443 id: "r".to_string(),
444 name: "knowledge_graph".to_string(),
445 arguments: serde_json::json!({
446 "operation": "add_relationship",
447 "from_entity": id_a,
448 "to_entity": id_b,
449 "relation_type": "DependsOn"
450 }),
451 };
452 let result = skill.execute(call_rel).await.unwrap();
453 assert!(!result.is_error);
454 assert!(result.content.contains("\"added\":true"));
455 }
456
457 #[tokio::test]
458 async fn test_query_entity_operation() {
459 let skill = make_skill();
460
461 let call = ToolCall {
463 id: "a".to_string(),
464 name: "knowledge_graph".to_string(),
465 arguments: serde_json::json!({"operation": "add_entity", "name": "Rust", "entity_type": "Concept"}),
466 };
467 skill.execute(call).await.unwrap();
468
469 let call_q = ToolCall {
471 id: "q".to_string(),
472 name: "knowledge_graph".to_string(),
473 arguments: serde_json::json!({"operation": "query_entity", "name": "rust"}),
474 };
475 let result = skill.execute(call_q).await.unwrap();
476 assert!(!result.is_error);
477 let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
478 assert_eq!(parsed["total"].as_u64().unwrap(), 1);
479 }
480
481 #[tokio::test]
482 async fn test_summarize_operation() {
483 let skill = make_skill();
484 let call = ToolCall {
485 id: "s".to_string(),
486 name: "knowledge_graph".to_string(),
487 arguments: serde_json::json!({"operation": "summarize"}),
488 };
489 let result = skill.execute(call).await.unwrap();
490 assert!(!result.is_error);
491 let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
492 assert_eq!(parsed["entity_count"].as_u64().unwrap(), 0);
493 }
494
495 #[tokio::test]
496 async fn test_unknown_operation() {
497 let skill = make_skill();
498 let call = ToolCall {
499 id: "u".to_string(),
500 name: "knowledge_graph".to_string(),
501 arguments: serde_json::json!({"operation": "foobar"}),
502 };
503 let result = skill.execute(call).await.unwrap();
504 assert!(result.is_error);
505 assert!(result.content.contains("Unknown operation"));
506 }
507
508 #[tokio::test]
509 async fn test_context_operation() {
510 let skill = make_skill();
511
512 let call = ToolCall {
514 id: "a".to_string(),
515 name: "knowledge_graph".to_string(),
516 arguments: serde_json::json!({"operation": "add_entity", "name": "Alice", "entity_type": "Person"}),
517 };
518 let res = skill.execute(call).await.unwrap();
519 let parsed: serde_json::Value = serde_json::from_str(&res.content).unwrap();
520 let id = parsed["entity_id"].as_str().unwrap().to_string();
521
522 let call_ctx = ToolCall {
524 id: "c".to_string(),
525 name: "knowledge_graph".to_string(),
526 arguments: serde_json::json!({"operation": "context", "entity_id": id}),
527 };
528 let result = skill.execute(call_ctx).await.unwrap();
529 assert!(!result.is_error);
530 assert!(result.content.contains("Alice"));
531 }
532
533 #[tokio::test]
534 async fn test_find_related_operation() {
535 let skill = make_skill();
536
537 let call_a = ToolCall {
539 id: "a".to_string(),
540 name: "knowledge_graph".to_string(),
541 arguments: serde_json::json!({"operation": "add_entity", "name": "A", "entity_type": "Concept"}),
542 };
543 let res_a = skill.execute(call_a).await.unwrap();
544 let id_a: String = serde_json::from_str::<serde_json::Value>(&res_a.content).unwrap()
545 ["entity_id"]
546 .as_str()
547 .unwrap()
548 .to_string();
549
550 let call_b = ToolCall {
551 id: "b".to_string(),
552 name: "knowledge_graph".to_string(),
553 arguments: serde_json::json!({"operation": "add_entity", "name": "B", "entity_type": "Concept"}),
554 };
555 let res_b = skill.execute(call_b).await.unwrap();
556 let id_b: String = serde_json::from_str::<serde_json::Value>(&res_b.content).unwrap()
557 ["entity_id"]
558 .as_str()
559 .unwrap()
560 .to_string();
561
562 let call_rel = ToolCall {
564 id: "r".to_string(),
565 name: "knowledge_graph".to_string(),
566 arguments: serde_json::json!({
567 "operation": "add_relationship",
568 "from_entity": id_a,
569 "to_entity": id_b,
570 "relation_type": "RelatedTo"
571 }),
572 };
573 skill.execute(call_rel).await.unwrap();
574
575 let call_find = ToolCall {
577 id: "f".to_string(),
578 name: "knowledge_graph".to_string(),
579 arguments: serde_json::json!({"operation": "find_related", "entity_id": id_a}),
580 };
581 let result = skill.execute(call_find).await.unwrap();
582 assert!(!result.is_error);
583 let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
584 assert_eq!(parsed["total"].as_u64().unwrap(), 1);
585 }
586}