Skip to main content

punch_memory/
knowledge.rs

1use serde::{Deserialize, Serialize};
2
3use punch_types::{FighterId, PunchError, PunchResult};
4use tracing::debug;
5
6use crate::MemorySubstrate;
7
8/// An entity in a fighter's knowledge graph.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct KnowledgeEntity {
11    pub id: i64,
12    pub name: String,
13    pub entity_type: String,
14    pub properties: serde_json::Value,
15    pub created_at: String,
16}
17
18/// A directed relation between two entities in the knowledge graph.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct KnowledgeRelation {
21    pub id: i64,
22    pub from_entity: String,
23    pub relation: String,
24    pub to_entity: String,
25    pub properties: serde_json::Value,
26    pub created_at: String,
27}
28
29impl MemorySubstrate {
30    /// Add (or upsert) an entity to a fighter's knowledge graph.
31    pub async fn add_entity(
32        &self,
33        fighter_id: &FighterId,
34        name: &str,
35        entity_type: &str,
36        properties: &serde_json::Value,
37    ) -> PunchResult<()> {
38        let fighter_str = fighter_id.to_string();
39        let props_str = properties.to_string();
40
41        let conn = self.conn.lock().await;
42        conn.execute(
43            "INSERT INTO knowledge_entities (fighter_id, name, entity_type, properties)
44             VALUES (?1, ?2, ?3, ?4)
45             ON CONFLICT(fighter_id, name, entity_type) DO UPDATE SET
46                properties = excluded.properties",
47            rusqlite::params![fighter_str, name, entity_type, props_str],
48        )
49        .map_err(|e| PunchError::KnowledgeGraph(format!("failed to add entity: {e}")))?;
50
51        debug!(fighter_id = %fighter_id, name = name, "knowledge entity added");
52        Ok(())
53    }
54
55    /// Add (or upsert) a relation between two entities.
56    pub async fn add_relation(
57        &self,
58        fighter_id: &FighterId,
59        from: &str,
60        relation: &str,
61        to: &str,
62        properties: &serde_json::Value,
63    ) -> PunchResult<()> {
64        let fighter_str = fighter_id.to_string();
65        let props_str = properties.to_string();
66
67        let conn = self.conn.lock().await;
68        conn.execute(
69            "INSERT INTO knowledge_relations (fighter_id, from_entity, relation, to_entity, properties)
70             VALUES (?1, ?2, ?3, ?4, ?5)
71             ON CONFLICT(fighter_id, from_entity, relation, to_entity) DO UPDATE SET
72                properties = excluded.properties",
73            rusqlite::params![fighter_str, from, relation, to, props_str],
74        )
75        .map_err(|e| PunchError::KnowledgeGraph(format!("failed to add relation: {e}")))?;
76
77        debug!(fighter_id = %fighter_id, from = from, relation = relation, to = to, "knowledge relation added");
78        Ok(())
79    }
80
81    /// Query entities matching a name or type substring.
82    pub async fn query_entities(
83        &self,
84        fighter_id: &FighterId,
85        query: &str,
86    ) -> PunchResult<Vec<KnowledgeEntity>> {
87        let fighter_str = fighter_id.to_string();
88        let pattern = format!("%{query}%");
89
90        let conn = self.conn.lock().await;
91        let mut stmt = conn
92            .prepare(
93                "SELECT id, name, entity_type, properties, created_at
94                 FROM knowledge_entities
95                 WHERE fighter_id = ?1 AND (name LIKE ?2 OR entity_type LIKE ?2)
96                 ORDER BY name",
97            )
98            .map_err(|e| PunchError::KnowledgeGraph(format!("failed to query entities: {e}")))?;
99
100        let rows = stmt
101            .query_map(rusqlite::params![fighter_str, pattern], |row| {
102                let id: i64 = row.get(0)?;
103                let name: String = row.get(1)?;
104                let entity_type: String = row.get(2)?;
105                let props: String = row.get(3)?;
106                let created_at: String = row.get(4)?;
107                Ok((id, name, entity_type, props, created_at))
108            })
109            .map_err(|e| PunchError::KnowledgeGraph(format!("failed to query entities: {e}")))?;
110
111        let mut entities = Vec::new();
112        for row in rows {
113            let (id, name, entity_type, props, created_at) = row.map_err(|e| {
114                PunchError::KnowledgeGraph(format!("failed to read entity row: {e}"))
115            })?;
116
117            let properties: serde_json::Value = serde_json::from_str(&props).map_err(|e| {
118                PunchError::KnowledgeGraph(format!("corrupt entity properties: {e}"))
119            })?;
120
121            entities.push(KnowledgeEntity {
122                id,
123                name,
124                entity_type,
125                properties,
126                created_at,
127            });
128        }
129
130        Ok(entities)
131    }
132
133    /// Query all relations involving a given entity (as source or target).
134    pub async fn query_relations(
135        &self,
136        fighter_id: &FighterId,
137        entity: &str,
138    ) -> PunchResult<Vec<KnowledgeRelation>> {
139        let fighter_str = fighter_id.to_string();
140
141        let conn = self.conn.lock().await;
142        let mut stmt = conn
143            .prepare(
144                "SELECT id, from_entity, relation, to_entity, properties, created_at
145                 FROM knowledge_relations
146                 WHERE fighter_id = ?1 AND (from_entity = ?2 OR to_entity = ?2)
147                 ORDER BY relation",
148            )
149            .map_err(|e| PunchError::KnowledgeGraph(format!("failed to query relations: {e}")))?;
150
151        let rows = stmt
152            .query_map(rusqlite::params![fighter_str, entity], |row| {
153                let id: i64 = row.get(0)?;
154                let from_entity: String = row.get(1)?;
155                let relation: String = row.get(2)?;
156                let to_entity: String = row.get(3)?;
157                let props: String = row.get(4)?;
158                let created_at: String = row.get(5)?;
159                Ok((id, from_entity, relation, to_entity, props, created_at))
160            })
161            .map_err(|e| PunchError::KnowledgeGraph(format!("failed to query relations: {e}")))?;
162
163        let mut relations = Vec::new();
164        for row in rows {
165            let (id, from_entity, relation, to_entity, props, created_at) = row.map_err(|e| {
166                PunchError::KnowledgeGraph(format!("failed to read relation row: {e}"))
167            })?;
168
169            let properties: serde_json::Value = serde_json::from_str(&props).map_err(|e| {
170                PunchError::KnowledgeGraph(format!("corrupt relation properties: {e}"))
171            })?;
172
173            relations.push(KnowledgeRelation {
174                id,
175                from_entity,
176                relation,
177                to_entity,
178                properties,
179                created_at,
180            });
181        }
182
183        Ok(relations)
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use punch_types::{FighterManifest, FighterStatus, ModelConfig, Provider, WeightClass};
190
191    use crate::MemorySubstrate;
192
193    fn test_manifest() -> FighterManifest {
194        FighterManifest {
195            name: "KG Fighter".into(),
196            description: "knowledge graph test".into(),
197            model: ModelConfig {
198                provider: Provider::Anthropic,
199                model: "claude-sonnet-4-20250514".into(),
200                api_key_env: None,
201                base_url: None,
202                max_tokens: Some(4096),
203                temperature: Some(0.7),
204            },
205            system_prompt: "test".into(),
206            capabilities: Vec::new(),
207            weight_class: WeightClass::Featherweight,
208            tenant_id: None,
209        }
210    }
211
212    #[tokio::test]
213    async fn test_add_and_query_entities() {
214        let substrate = MemorySubstrate::in_memory().unwrap();
215        let fid = punch_types::FighterId::new();
216        substrate
217            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
218            .await
219            .unwrap();
220
221        substrate
222            .add_entity(&fid, "Rust", "language", &serde_json::json!({"year": 2010}))
223            .await
224            .unwrap();
225
226        let entities = substrate.query_entities(&fid, "Rust").await.unwrap();
227        assert_eq!(entities.len(), 1);
228        assert_eq!(entities[0].entity_type, "language");
229    }
230
231    #[tokio::test]
232    async fn test_add_and_query_relations() {
233        let substrate = MemorySubstrate::in_memory().unwrap();
234        let fid = punch_types::FighterId::new();
235        substrate
236            .save_fighter(&fid, &test_manifest(), FighterStatus::Idle)
237            .await
238            .unwrap();
239
240        substrate
241            .add_entity(&fid, "Alice", "person", &serde_json::json!({}))
242            .await
243            .unwrap();
244        substrate
245            .add_entity(&fid, "Bob", "person", &serde_json::json!({}))
246            .await
247            .unwrap();
248        substrate
249            .add_relation(
250                &fid,
251                "Alice",
252                "knows",
253                "Bob",
254                &serde_json::json!({"since": 2020}),
255            )
256            .await
257            .unwrap();
258
259        let relations = substrate.query_relations(&fid, "Alice").await.unwrap();
260        assert_eq!(relations.len(), 1);
261        assert_eq!(relations[0].relation, "knows");
262        assert_eq!(relations[0].to_entity, "Bob");
263    }
264}