1use serde::{Deserialize, Serialize};
2
3use punch_types::{FighterId, PunchError, PunchResult};
4use tracing::debug;
5
6use crate::MemorySubstrate;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct KnowledgeEntity {
11 pub id: i64,
12 pub name: String,
13 pub entity_type: String,
14 pub properties: serde_json::Value,
15 pub created_at: String,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct KnowledgeRelation {
21 pub id: i64,
22 pub from_entity: String,
23 pub relation: String,
24 pub to_entity: String,
25 pub properties: serde_json::Value,
26 pub created_at: String,
27}
28
29impl MemorySubstrate {
30 pub async fn add_entity(
32 &self,
33 fighter_id: &FighterId,
34 name: &str,
35 entity_type: &str,
36 properties: &serde_json::Value,
37 ) -> PunchResult<()> {
38 let fighter_str = fighter_id.to_string();
39 let props_str = properties.to_string();
40
41 let conn = self.conn.lock().await;
42 conn.execute(
43 "INSERT INTO knowledge_entities (fighter_id, name, entity_type, properties)
44 VALUES (?1, ?2, ?3, ?4)
45 ON CONFLICT(fighter_id, name, entity_type) DO UPDATE SET
46 properties = excluded.properties",
47 rusqlite::params![fighter_str, name, entity_type, props_str],
48 )
49 .map_err(|e| PunchError::KnowledgeGraph(format!("failed to add entity: {e}")))?;
50
51 debug!(fighter_id = %fighter_id, name = name, "knowledge entity added");
52 Ok(())
53 }
54
55 pub async fn add_relation(
57 &self,
58 fighter_id: &FighterId,
59 from: &str,
60 relation: &str,
61 to: &str,
62 properties: &serde_json::Value,
63 ) -> PunchResult<()> {
64 let fighter_str = fighter_id.to_string();
65 let props_str = properties.to_string();
66
67 let conn = self.conn.lock().await;
68 conn.execute(
69 "INSERT INTO knowledge_relations (fighter_id, from_entity, relation, to_entity, properties)
70 VALUES (?1, ?2, ?3, ?4, ?5)
71 ON CONFLICT(fighter_id, from_entity, relation, to_entity) DO UPDATE SET
72 properties = excluded.properties",
73 rusqlite::params![fighter_str, from, relation, to, props_str],
74 )
75 .map_err(|e| PunchError::KnowledgeGraph(format!("failed to add relation: {e}")))?;
76
77 debug!(fighter_id = %fighter_id, from = from, relation = relation, to = to, "knowledge relation added");
78 Ok(())
79 }
80
81 pub async fn query_entities(
83 &self,
84 fighter_id: &FighterId,
85 query: &str,
86 ) -> PunchResult<Vec<KnowledgeEntity>> {
87 let fighter_str = fighter_id.to_string();
88 let pattern = format!("%{query}%");
89
90 let conn = self.conn.lock().await;
91 let mut stmt = conn
92 .prepare(
93 "SELECT id, name, entity_type, properties, created_at
94 FROM knowledge_entities
95 WHERE fighter_id = ?1 AND (name LIKE ?2 OR entity_type LIKE ?2)
96 ORDER BY name",
97 )
98 .map_err(|e| PunchError::KnowledgeGraph(format!("failed to query entities: {e}")))?;
99
100 let rows = stmt
101 .query_map(rusqlite::params![fighter_str, pattern], |row| {
102 let id: i64 = row.get(0)?;
103 let name: String = row.get(1)?;
104 let entity_type: String = row.get(2)?;
105 let props: String = row.get(3)?;
106 let created_at: String = row.get(4)?;
107 Ok((id, name, entity_type, props, created_at))
108 })
109 .map_err(|e| PunchError::KnowledgeGraph(format!("failed to query entities: {e}")))?;
110
111 let mut entities = Vec::new();
112 for row in rows {
113 let (id, name, entity_type, props, created_at) = row.map_err(|e| {
114 PunchError::KnowledgeGraph(format!("failed to read entity row: {e}"))
115 })?;
116
117 let properties: serde_json::Value = serde_json::from_str(&props).map_err(|e| {
118 PunchError::KnowledgeGraph(format!("corrupt entity properties: {e}"))
119 })?;
120
121 entities.push(KnowledgeEntity {
122 id,
123 name,
124 entity_type,
125 properties,
126 created_at,
127 });
128 }
129
130 Ok(entities)
131 }
132
133 pub async fn query_relations(
135 &self,
136 fighter_id: &FighterId,
137 entity: &str,
138 ) -> PunchResult<Vec<KnowledgeRelation>> {
139 let fighter_str = fighter_id.to_string();
140
141 let conn = self.conn.lock().await;
142 let mut stmt = conn
143 .prepare(
144 "SELECT id, from_entity, relation, to_entity, properties, created_at
145 FROM knowledge_relations
146 WHERE fighter_id = ?1 AND (from_entity = ?2 OR to_entity = ?2)
147 ORDER BY relation",
148 )
149 .map_err(|e| PunchError::KnowledgeGraph(format!("failed to query relations: {e}")))?;
150
151 let rows = stmt
152 .query_map(rusqlite::params![fighter_str, entity], |row| {
153 let id: i64 = row.get(0)?;
154 let from_entity: String = row.get(1)?;
155 let relation: String = row.get(2)?;
156 let to_entity: String = row.get(3)?;
157 let props: String = row.get(4)?;
158 let created_at: String = row.get(5)?;
159 Ok((id, from_entity, relation, to_entity, props, created_at))
160 })
161 .map_err(|e| PunchError::KnowledgeGraph(format!("failed to query relations: {e}")))?;
162
163 let mut relations = Vec::new();
164 for row in rows {
165 let (id, from_entity, relation, to_entity, props, created_at) = row.map_err(|e| {
166 PunchError::KnowledgeGraph(format!("failed to read relation row: {e}"))
167 })?;
168
169 let properties: serde_json::Value = serde_json::from_str(&props).map_err(|e| {
170 PunchError::KnowledgeGraph(format!("corrupt relation properties: {e}"))
171 })?;
172
173 relations.push(KnowledgeRelation {
174 id,
175 from_entity,
176 relation,
177 to_entity,
178 properties,
179 created_at,
180 });
181 }
182
183 Ok(relations)
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use punch_types::{FighterManifest, FighterStatus, ModelConfig, Provider, WeightClass};
190
191 use crate::MemorySubstrate;
192
193 fn test_manifest() -> FighterManifest {
194 FighterManifest {
195 name: "KG Fighter".into(),
196 description: "knowledge graph test".into(),
197 model: ModelConfig {
198 provider: Provider::Anthropic,
199 model: "claude-sonnet-4-20250514".into(),
200 api_key_env: None,
201 base_url: None,
202 max_tokens: Some(4096),
203 temperature: Some(0.7),
204 },
205 system_prompt: "test".into(),
206 capabilities: Vec::new(),
207 weight_class: WeightClass::Featherweight,
208 tenant_id: None,
209 }
210 }
211
212 #[tokio::test]
213 async fn test_add_and_query_entities() {
214 let substrate = MemorySubstrate::in_memory().unwrap();
215 let fid = punch_types::FighterId::new();
216 substrate
217 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
218 .await
219 .unwrap();
220
221 substrate
222 .add_entity(&fid, "Rust", "language", &serde_json::json!({"year": 2010}))
223 .await
224 .unwrap();
225
226 let entities = substrate.query_entities(&fid, "Rust").await.unwrap();
227 assert_eq!(entities.len(), 1);
228 assert_eq!(entities[0].entity_type, "language");
229 }
230
231 #[tokio::test]
232 async fn test_add_and_query_relations() {
233 let substrate = MemorySubstrate::in_memory().unwrap();
234 let fid = punch_types::FighterId::new();
235 substrate
236 .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
237 .await
238 .unwrap();
239
240 substrate
241 .add_entity(&fid, "Alice", "person", &serde_json::json!({}))
242 .await
243 .unwrap();
244 substrate
245 .add_entity(&fid, "Bob", "person", &serde_json::json!({}))
246 .await
247 .unwrap();
248 substrate
249 .add_relation(
250 &fid,
251 "Alice",
252 "knows",
253 "Bob",
254 &serde_json::json!({"since": 2020}),
255 )
256 .await
257 .unwrap();
258
259 let relations = substrate.query_relations(&fid, "Alice").await.unwrap();
260 assert_eq!(relations.len(), 1);
261 assert_eq!(relations[0].relation, "knows");
262 assert_eq!(relations[0].to_entity, "Bob");
263 }
264}