ipfrs_tensorlogic/
storage.rs1use crate::ir::{KnowledgeBase, Predicate, Rule, Term};
6use crate::reasoning::{InferenceEngine, Proof, Substitution};
7use ipfrs_core::{Block, Cid, Result};
8use ipfrs_storage::traits::BlockStore;
9use serde_json;
10use std::sync::Arc;
11
12pub struct TensorLogicStore<S: BlockStore> {
16 store: Arc<S>,
18 knowledge_base: std::sync::RwLock<KnowledgeBase>,
20 engine: InferenceEngine,
22}
23
24impl<S: BlockStore> TensorLogicStore<S> {
25 pub fn new(store: Arc<S>) -> Result<Self> {
27 Ok(Self {
28 store,
29 knowledge_base: std::sync::RwLock::new(KnowledgeBase::new()),
30 engine: InferenceEngine::new(),
31 })
32 }
33
34 pub async fn store_term(&self, term: &Term) -> Result<Cid> {
36 let json = serde_json::to_vec(term)
37 .map_err(|e| ipfrs_core::Error::Serialization(format!("Term serialization: {}", e)))?;
38
39 let block = Block::new(json.into())?;
40 let cid = *block.cid();
41
42 self.store.put(&block).await?;
43
44 Ok(cid)
45 }
46
47 pub async fn get_term(&self, cid: &Cid) -> Result<Option<Term>> {
49 match self.store.get(cid).await? {
50 Some(block) => {
51 let term = serde_json::from_slice(block.data()).map_err(|e| {
52 ipfrs_core::Error::Deserialization(format!("Term deserialization: {}", e))
53 })?;
54 Ok(Some(term))
55 }
56 None => Ok(None),
57 }
58 }
59
60 pub async fn store_predicate(&self, predicate: &Predicate) -> Result<Cid> {
62 let json = serde_json::to_vec(predicate).map_err(|e| {
63 ipfrs_core::Error::Serialization(format!("Predicate serialization: {}", e))
64 })?;
65
66 let block = Block::new(json.into())?;
67 let cid = *block.cid();
68
69 self.store.put(&block).await?;
70
71 Ok(cid)
72 }
73
74 pub async fn get_predicate(&self, cid: &Cid) -> Result<Option<Predicate>> {
76 match self.store.get(cid).await? {
77 Some(block) => {
78 let predicate = serde_json::from_slice(block.data()).map_err(|e| {
79 ipfrs_core::Error::Deserialization(format!("Predicate deserialization: {}", e))
80 })?;
81 Ok(Some(predicate))
82 }
83 None => Ok(None),
84 }
85 }
86
87 pub async fn store_rule(&self, rule: &Rule) -> Result<Cid> {
89 let json = serde_json::to_vec(rule)
90 .map_err(|e| ipfrs_core::Error::Serialization(format!("Rule serialization: {}", e)))?;
91
92 let block = Block::new(json.into())?;
93 let cid = *block.cid();
94
95 self.store.put(&block).await?;
96
97 Ok(cid)
98 }
99
100 pub async fn get_rule(&self, cid: &Cid) -> Result<Option<Rule>> {
102 match self.store.get(cid).await? {
103 Some(block) => {
104 let rule = serde_json::from_slice(block.data()).map_err(|e| {
105 ipfrs_core::Error::Deserialization(format!("Rule deserialization: {}", e))
106 })?;
107 Ok(Some(rule))
108 }
109 None => Ok(None),
110 }
111 }
112
113 pub async fn has(&self, cid: &Cid) -> Result<bool> {
115 self.store.has(cid).await
116 }
117
118 pub async fn delete(&self, cid: &Cid) -> Result<()> {
120 self.store.delete(cid).await
121 }
122
123 pub fn add_fact(&self, fact: Predicate) -> Result<()> {
125 let mut kb = self.knowledge_base.write().unwrap();
126 kb.add_fact(fact);
127 Ok(())
128 }
129
130 pub fn add_rule(&self, rule: Rule) -> Result<()> {
132 let mut kb = self.knowledge_base.write().unwrap();
133 kb.add_rule(rule);
134 Ok(())
135 }
136
137 pub fn infer(&self, goal: &Predicate) -> Result<Vec<Substitution>> {
139 let kb = self.knowledge_base.read().unwrap();
140 self.engine.query(goal, &kb)
141 }
142
143 pub fn prove(&self, goal: &Predicate) -> Result<Option<Proof>> {
145 let kb = self.knowledge_base.read().unwrap();
146 self.engine.prove(goal, &kb)
147 }
148
149 pub async fn store_proof(&self, proof: &Proof) -> Result<Cid> {
151 let json = serde_json::to_vec(proof)
152 .map_err(|e| ipfrs_core::Error::Serialization(format!("Proof serialization: {}", e)))?;
153
154 let block = Block::new(json.into())?;
155 let cid = *block.cid();
156
157 self.store.put(&block).await?;
158
159 Ok(cid)
160 }
161
162 pub async fn get_proof(&self, cid: &Cid) -> Result<Option<Proof>> {
164 match self.store.get(cid).await? {
165 Some(block) => {
166 let proof = serde_json::from_slice(block.data()).map_err(|e| {
167 ipfrs_core::Error::Deserialization(format!("Proof deserialization: {}", e))
168 })?;
169 Ok(Some(proof))
170 }
171 None => Ok(None),
172 }
173 }
174
175 pub fn verify_proof(&self, proof: &Proof) -> Result<bool> {
177 let kb = self.knowledge_base.read().unwrap();
178 self.engine.verify(proof, &kb)
179 }
180
181 pub fn kb_stats(&self) -> crate::ir::KnowledgeBaseStats {
183 let kb = self.knowledge_base.read().unwrap();
184 kb.stats()
185 }
186
187 pub async fn save_kb<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
195 use std::fs::File;
196 use std::io::Write;
197
198 let kb = self.knowledge_base.read().unwrap();
199
200 let encoded =
202 oxicode::serde::encode_to_vec(&*kb, oxicode::config::standard()).map_err(|e| {
203 ipfrs_core::Error::Serialization(format!("Failed to serialize KB: {}", e))
204 })?;
205
206 let mut file = File::create(path.as_ref())
208 .map_err(|e| ipfrs_core::Error::Storage(format!("Failed to create KB file: {}", e)))?;
209
210 file.write_all(&encoded)
211 .map_err(|e| ipfrs_core::Error::Storage(format!("Failed to write KB file: {}", e)))?;
212
213 Ok(())
214 }
215
216 pub async fn load_kb<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
223 use std::fs::File;
224 use std::io::Read;
225
226 let mut file = File::open(path.as_ref())
228 .map_err(|e| ipfrs_core::Error::Storage(format!("Failed to open KB file: {}", e)))?;
229
230 let mut buffer = Vec::new();
231 file.read_to_end(&mut buffer)
232 .map_err(|e| ipfrs_core::Error::Storage(format!("Failed to read KB file: {}", e)))?;
233
234 let kb: KnowledgeBase =
236 oxicode::serde::decode_owned_from_slice(&buffer, oxicode::config::standard())
237 .map(|(v, _)| v)
238 .map_err(|e| {
239 ipfrs_core::Error::Deserialization(format!("Failed to deserialize KB: {}", e))
240 })?;
241
242 *self.knowledge_base.write().unwrap() = kb;
244
245 Ok(())
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use crate::ir::Constant;
253 use ipfrs_storage::{BlockStoreConfig, SledBlockStore};
254
255 #[tokio::test]
256 async fn test_term_storage() {
257 let config = BlockStoreConfig {
258 path: std::path::PathBuf::from("/tmp/ipfrs-test-tensorlogic-term"),
259 cache_size: 100 * 1024 * 1024,
260 };
261 let _ = std::fs::remove_dir_all(&config.path);
262 let store = Arc::new(SledBlockStore::new(config).unwrap());
263 let tl_store = TensorLogicStore::new(store).unwrap();
264
265 let term = Term::Const(Constant::String("Alice".to_string()));
266 let cid = tl_store.store_term(&term).await.unwrap();
267
268 let retrieved = tl_store.get_term(&cid).await.unwrap();
269 assert_eq!(retrieved, Some(term));
270 }
271
272 #[tokio::test]
273 async fn test_predicate_storage() {
274 let config = BlockStoreConfig {
275 path: std::path::PathBuf::from("/tmp/ipfrs-test-tensorlogic-pred"),
276 cache_size: 100 * 1024 * 1024,
277 };
278 let _ = std::fs::remove_dir_all(&config.path);
279 let store = Arc::new(SledBlockStore::new(config).unwrap());
280 let tl_store = TensorLogicStore::new(store).unwrap();
281
282 let predicate = Predicate::new(
283 "parent".to_string(),
284 vec![
285 Term::Const(Constant::String("Alice".to_string())),
286 Term::Const(Constant::String("Bob".to_string())),
287 ],
288 );
289
290 let cid = tl_store.store_predicate(&predicate).await.unwrap();
291 let retrieved = tl_store.get_predicate(&cid).await.unwrap();
292 assert_eq!(retrieved, Some(predicate));
293 }
294
295 #[tokio::test]
296 async fn test_rule_storage() {
297 let config = BlockStoreConfig {
298 path: std::path::PathBuf::from("/tmp/ipfrs-test-tensorlogic-rule"),
299 cache_size: 100 * 1024 * 1024,
300 };
301 let _ = std::fs::remove_dir_all(&config.path);
302 let store = Arc::new(SledBlockStore::new(config).unwrap());
303 let tl_store = TensorLogicStore::new(store).unwrap();
304
305 let rule = Rule::fact(Predicate::new(
306 "parent".to_string(),
307 vec![
308 Term::Const(Constant::String("Alice".to_string())),
309 Term::Const(Constant::String("Bob".to_string())),
310 ],
311 ));
312
313 let cid = tl_store.store_rule(&rule).await.unwrap();
314 let retrieved = tl_store.get_rule(&cid).await.unwrap();
315 assert!(retrieved.is_some());
316 }
317}