1use crate::hnsw::{DistanceMetric, VectorIndex};
11use ipfrs_core::{Cid, Error, Result};
12use ipfrs_tensorlogic::{
13 CycleDetector, GoalDecomposition, InferenceEngine, KnowledgeBase, Predicate, ProofRule,
14 Substitution, Term,
15};
16use parking_lot::RwLock;
17use serde::{Deserialize, Serialize};
18use std::collections::{HashMap, HashSet, VecDeque};
19use std::sync::Arc;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct SolverConfig {
24 pub max_depth: usize,
26 pub similarity_threshold: f32,
28 pub top_k_similar: usize,
30 pub embedding_dim: usize,
32 pub detect_cycles: bool,
34}
35
36impl Default for SolverConfig {
37 fn default() -> Self {
38 Self {
39 max_depth: 100,
40 similarity_threshold: 0.8,
41 top_k_similar: 10,
42 embedding_dim: 384, detect_cycles: true,
44 }
45 }
46}
47
48pub struct PredicateEmbedder {
50 dim: usize,
52 embeddings: Arc<RwLock<HashMap<String, Vec<f32>>>>,
54}
55
56impl PredicateEmbedder {
57 pub fn new(dim: usize) -> Self {
59 Self {
60 dim,
61 embeddings: Arc::new(RwLock::new(HashMap::new())),
62 }
63 }
64
65 pub fn embed_predicate(&self, pred: &Predicate) -> Vec<f32> {
72 let cached = self.embeddings.read().get(&pred.to_string()).cloned();
73 if let Some(emb) = cached {
74 return emb;
75 }
76
77 let mut embedding = vec![0.0; self.dim];
78
79 let name_hash = self.hash_string(&pred.name);
81 for (i, val) in embedding.iter_mut().enumerate() {
82 *val += (((name_hash + i) as f32).sin() * 0.5).abs();
83 }
84
85 for (idx, term) in pred.args.iter().enumerate() {
87 let term_emb = self.embed_term(term, idx);
88 for i in 0..self.dim {
89 embedding[i] += term_emb[i] * 0.3; }
91 }
92
93 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
95 if norm > 1e-6 {
96 for x in &mut embedding {
97 *x /= norm;
98 }
99 }
100
101 self.embeddings
103 .write()
104 .insert(pred.to_string(), embedding.clone());
105
106 embedding
107 }
108
109 fn embed_term(&self, term: &Term, position: usize) -> Vec<f32> {
111 let mut embedding = vec![0.0; self.dim];
112
113 match term {
114 Term::Var(name) => {
115 let hash = self.hash_string(name) + position;
117 for (i, val) in embedding.iter_mut().enumerate() {
118 *val = (((hash + i) as f32).sin() * 0.2).abs();
119 }
120 }
121 Term::Const(constant) => {
122 let hash = self.hash_string(&format!("{:?}", constant));
124 for (i, val) in embedding.iter_mut().enumerate() {
125 *val = (((hash + i) as f32).sin() * 0.8).abs();
126 }
127 }
128 Term::Fun(functor, args) => {
129 let hash = self.hash_string(functor);
131 for (i, val) in embedding.iter_mut().enumerate() {
132 *val = (((hash + i) as f32).sin() * 0.6).abs();
133 }
134
135 for (idx, arg) in args.iter().enumerate() {
136 let arg_emb = self.embed_term(arg, idx);
137 for i in 0..self.dim {
138 embedding[i] += arg_emb[i] * 0.2;
139 }
140 }
141 }
142 Term::Ref(_) => {
143 let hash = position;
145 for (i, val) in embedding.iter_mut().enumerate() {
146 *val = (((hash + i) as f32).sin() * 0.5).abs();
147 }
148 }
149 }
150
151 embedding
152 }
153
154 fn hash_string(&self, s: &str) -> usize {
156 s.bytes().fold(0usize, |acc, b| {
157 acc.wrapping_mul(31).wrapping_add(b as usize)
158 })
159 }
160
161 pub fn similarity(&self, pred1: &Predicate, pred2: &Predicate) -> f32 {
163 let emb1 = self.embed_predicate(pred1);
164 let emb2 = self.embed_predicate(pred2);
165
166 self.cosine_similarity(&emb1, &emb2)
167 }
168
169 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
171 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
172 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
173 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
174
175 if norm_a < 1e-6 || norm_b < 1e-6 {
176 0.0
177 } else {
178 dot / (norm_a * norm_b)
179 }
180 }
181}
182
183#[derive(Debug, Clone)]
185pub struct ProofTreeNode {
186 pub goal: Predicate,
188 pub substitution: Substitution,
190 pub parent: Option<usize>,
192 pub depth: usize,
194 pub relevance: f32,
196}
197
198pub struct LogicSolver {
200 config: SolverConfig,
202 kb: Arc<RwLock<KnowledgeBase>>,
204 engine: Arc<RwLock<InferenceEngine>>,
206 embedder: PredicateEmbedder,
208 predicate_index: Arc<RwLock<Option<VectorIndex>>>,
210 cycle_detector: Arc<RwLock<CycleDetector>>,
212 cid_to_predicate: Arc<RwLock<HashMap<Cid, Predicate>>>,
214}
215
216impl LogicSolver {
217 pub fn new(config: SolverConfig) -> Result<Self> {
219 let kb = KnowledgeBase::new();
220 let engine = InferenceEngine::new();
221
222 Ok(Self {
223 embedder: PredicateEmbedder::new(config.embedding_dim),
224 config,
225 kb: Arc::new(RwLock::new(kb)),
226 engine: Arc::new(RwLock::new(engine)),
227 predicate_index: Arc::new(RwLock::new(None)),
228 cycle_detector: Arc::new(RwLock::new(CycleDetector::new())),
229 cid_to_predicate: Arc::new(RwLock::new(HashMap::new())),
230 })
231 }
232
233 pub fn with_defaults() -> Result<Self> {
235 Self::new(SolverConfig::default())
236 }
237
238 pub fn add_fact(&mut self, fact: Predicate, cid: Cid) -> Result<()> {
240 self.kb.write().add_fact(fact.clone());
242
243 let embedding = self.embedder.embed_predicate(&fact);
245
246 {
248 let mut index_lock = self.predicate_index.write();
249 if index_lock.is_none() {
250 *index_lock = Some(VectorIndex::new(
251 self.config.embedding_dim,
252 DistanceMetric::Cosine,
253 32, 100, )?);
256 }
257
258 if let Some(ref mut index) = *index_lock {
259 index.insert(&cid, &embedding)?;
260 }
261 }
262
263 self.cid_to_predicate.write().insert(cid, fact);
265
266 Ok(())
267 }
268
269 pub fn add_rule(&mut self, head: Predicate, body: Vec<Predicate>) -> Result<()> {
271 use ipfrs_tensorlogic::Rule;
272 let rule = Rule { head, body };
273 self.kb.write().add_rule(rule);
274 Ok(())
275 }
276
277 pub fn find_similar_predicates(
279 &self,
280 query: &Predicate,
281 k: usize,
282 ) -> Result<Vec<(Cid, Predicate, f32)>> {
283 let embedding = self.embedder.embed_predicate(query);
284
285 let index_lock = self.predicate_index.read();
286 let index = index_lock
287 .as_ref()
288 .ok_or_else(|| Error::InvalidInput("Predicate index not initialized".to_string()))?;
289
290 let results = index.search(&embedding, k, 100)?; let cid_map = self.cid_to_predicate.read();
293 let mut similar = Vec::new();
294
295 for result in results {
296 if let Some(pred) = cid_map.get(&result.cid) {
297 similar.push((result.cid, pred.clone(), result.score));
298 }
299 }
300
301 Ok(similar)
302 }
303
304 pub fn query(&self, goal: &Predicate) -> Result<Vec<Substitution>> {
306 let engine = self.engine.write();
307 let kb = self.kb.read();
308 let substs = engine.query(goal, &kb)?;
309 Ok(substs)
310 }
311
312 pub fn query_with_depth(
314 &self,
315 goal: &Predicate,
316 max_depth: usize,
317 ) -> Result<Vec<Substitution>> {
318 let decomposition = GoalDecomposition::new(goal.clone(), max_depth);
320
321 let mut all_substs = Vec::new();
322
323 for subgoal in &decomposition.subgoals {
325 let engine = self.engine.write();
326 let kb = self.kb.read();
327 let substs = engine.query(subgoal, &kb)?;
328 all_substs.extend(substs);
329 }
330
331 Ok(all_substs)
332 }
333
334 pub fn backward_chain(&self, goal: &Predicate) -> Result<Vec<(Substitution, f32)>> {
341 let mut substs_with_scores = Vec::new();
342
343 let exact_substs = self.query(goal)?;
345 for subst in exact_substs {
346 substs_with_scores.push((subst, 1.0)); }
348
349 if substs_with_scores.is_empty() {
351 let similar = self.find_similar_predicates(goal, self.config.top_k_similar)?;
352
353 for (_, similar_pred, score) in similar {
354 let similarity = score;
356
357 if similarity >= self.config.similarity_threshold {
358 let substs = self.query(&similar_pred)?;
360 for subst in substs {
361 substs_with_scores.push((subst, similarity));
362 }
363 }
364 }
365 }
366
367 substs_with_scores
369 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
370
371 Ok(substs_with_scores)
372 }
373
374 pub fn would_cycle(&self, goal: &Predicate, _depth: usize) -> bool {
376 if !self.config.detect_cycles {
377 return false;
378 }
379
380 let detector = self.cycle_detector.read();
381 detector.would_cycle(goal)
382 }
383
384 pub fn stats(&self) -> SolverStats {
386 let kb = self.kb.read();
387 let kb_stats = kb.stats();
388
389 let index_lock = self.predicate_index.read();
390 let num_indexed = if let Some(ref index) = *index_lock {
391 index.len()
392 } else {
393 0
394 };
395
396 SolverStats {
397 num_facts: kb_stats.num_facts,
398 num_rules: kb_stats.num_rules,
399 num_indexed_predicates: num_indexed,
400 embedding_dim: self.config.embedding_dim,
401 }
402 }
403
404 pub fn clear(&mut self) {
406 let mut kb = self.kb.write();
407 kb.facts.clear();
408 kb.rules.clear();
409 *self.predicate_index.write() = None;
410 self.cid_to_predicate.write().clear();
411 *self.cycle_detector.write() = CycleDetector::new();
412 }
413}
414
415#[derive(Debug, Clone, Serialize, Deserialize)]
417pub struct SolverStats {
418 pub num_facts: usize,
420 pub num_rules: usize,
422 pub num_indexed_predicates: usize,
424 pub embedding_dim: usize,
426}
427
428#[allow(dead_code)]
430pub struct ProofSearch {
431 config: SolverConfig,
433 embedder: PredicateEmbedder,
435 proof_index: VectorIndex,
437 visited: HashSet<String>,
439}
440
441impl ProofSearch {
442 pub fn new(config: SolverConfig) -> Result<Self> {
444 Ok(Self {
445 embedder: PredicateEmbedder::new(config.embedding_dim),
446 proof_index: VectorIndex::new(
447 config.embedding_dim,
448 DistanceMetric::Cosine,
449 32, 100, )?,
452 visited: HashSet::new(),
453 config,
454 })
455 }
456
457 pub fn search_proof_tree(
459 &mut self,
460 goal: &Predicate,
461 kb: &KnowledgeBase,
462 ) -> Result<Vec<ProofTreeNode>> {
463 let mut queue: VecDeque<ProofTreeNode> = VecDeque::new();
464 let mut proof_tree = Vec::new();
465
466 let root = ProofTreeNode {
468 goal: goal.clone(),
469 substitution: HashMap::new(),
470 parent: None,
471 depth: 0,
472 relevance: 1.0,
473 };
474
475 queue.push_back(root);
476 self.visited.clear();
477
478 while let Some(node) = queue.pop_front() {
479 if node.depth >= self.config.max_depth {
481 continue;
482 }
483
484 let goal_str = node.goal.to_string();
486 if self.visited.contains(&goal_str) {
487 continue;
488 }
489 self.visited.insert(goal_str);
490
491 let node_id = proof_tree.len();
492 proof_tree.push(node.clone());
493
494 for fact in &kb.facts {
496 if node.goal.name == fact.name && node.goal.arity() == fact.arity() {
497 let child = ProofTreeNode {
498 goal: fact.clone(),
499 substitution: HashMap::new(),
500 parent: Some(node_id),
501 depth: node.depth + 1,
502 relevance: node.relevance * 1.0, };
504 queue.push_back(child);
505 }
506 }
507
508 for rule in &kb.rules {
510 if node.goal.name == rule.head.name && node.goal.arity() == rule.head.arity() {
511 for body_pred in &rule.body {
512 let child = ProofTreeNode {
513 goal: body_pred.clone(),
514 substitution: HashMap::new(),
515 parent: Some(node_id),
516 depth: node.depth + 1,
517 relevance: node.relevance * 0.9, };
519 queue.push_back(child);
520 }
521 }
522 }
523 }
524
525 Ok(proof_tree)
526 }
527
528 pub fn extract_proof(&self, tree: &[ProofTreeNode], leaf_idx: usize) -> Vec<ProofRule> {
530 let mut proof_rules = Vec::new();
531 let mut current_idx = Some(leaf_idx);
532
533 while let Some(idx) = current_idx {
534 if idx >= tree.len() {
535 break;
536 }
537
538 let node = &tree[idx];
539 proof_rules.push(ProofRule {
540 head: node.goal.clone(),
541 body: Vec::new(), is_fact: true,
543 });
544
545 current_idx = node.parent;
546 }
547
548 proof_rules.reverse();
549 proof_rules
550 }
551}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
556 use ipfrs_tensorlogic::Constant;
557
558 #[test]
559 fn test_predicate_embedder() {
560 let embedder = PredicateEmbedder::new(128);
561
562 let alice = Term::Const(Constant::String("Alice".to_string()));
563 let bob = Term::Const(Constant::String("Bob".to_string()));
564 let charlie = Term::Const(Constant::String("Charlie".to_string()));
565
566 let pred1 = Predicate::new("parent".to_string(), vec![alice.clone(), bob.clone()]);
567 let pred2 = Predicate::new("parent".to_string(), vec![alice.clone(), bob.clone()]);
568 let pred3 = Predicate::new("parent".to_string(), vec![alice.clone(), charlie.clone()]);
569
570 let sim_same = embedder.similarity(&pred1, &pred2);
572 assert!(
573 sim_same > 0.99,
574 "Expected sim_same > 0.99, got {}",
575 sim_same
576 );
577
578 let sim_diff_args = embedder.similarity(&pred1, &pred3);
580 assert!(
581 sim_diff_args < sim_same,
582 "Expected {} < {}",
583 sim_diff_args,
584 sim_same
585 );
586 assert!(
587 sim_diff_args > 0.8,
588 "Expected predicates with same name to have reasonable similarity, got {}",
589 sim_diff_args
590 );
591 }
592
593 #[test]
594 fn test_solver_creation() {
595 let solver = LogicSolver::with_defaults();
596 assert!(solver.is_ok());
597
598 let stats = solver.unwrap().stats();
599 assert_eq!(stats.num_facts, 0);
600 assert_eq!(stats.num_rules, 0);
601 }
602
603 #[test]
604 fn test_add_fact() {
605 let mut solver = LogicSolver::with_defaults().unwrap();
606
607 let alice = Term::Const(Constant::String("Alice".to_string()));
608 let bob = Term::Const(Constant::String("Bob".to_string()));
609 let fact = Predicate::new("parent".to_string(), vec![alice, bob]);
610
611 let cid: Cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
612 .parse()
613 .unwrap();
614
615 let result = solver.add_fact(fact, cid);
616 assert!(result.is_ok());
617
618 let stats = solver.stats();
619 assert_eq!(stats.num_facts, 1);
620 assert_eq!(stats.num_indexed_predicates, 1);
621 }
622
623 #[test]
624 fn test_add_rule() {
625 let mut solver = LogicSolver::with_defaults().unwrap();
626
627 let x = Term::Var("X".to_string());
628 let y = Term::Var("Y".to_string());
629 let z = Term::Var("Z".to_string());
630
631 let head = Predicate::new("ancestor".to_string(), vec![x.clone(), z.clone()]);
632 let body1 = Predicate::new("parent".to_string(), vec![x.clone(), y.clone()]);
633 let body2 = Predicate::new("ancestor".to_string(), vec![y.clone(), z.clone()]);
634
635 let result = solver.add_rule(head, vec![body1, body2]);
636 assert!(result.is_ok());
637
638 let stats = solver.stats();
639 assert_eq!(stats.num_rules, 1);
640 }
641
642 #[test]
643 fn test_query_empty() {
644 let solver = LogicSolver::with_defaults().unwrap();
645
646 let alice = Term::Const(Constant::String("Alice".to_string()));
647 let bob = Term::Const(Constant::String("Bob".to_string()));
648 let query = Predicate::new("parent".to_string(), vec![alice, bob]);
649
650 let result = solver.query(&query);
651 assert!(result.is_ok());
652 assert!(result.unwrap().is_empty());
653 }
654
655 #[test]
656 fn test_proof_search_creation() {
657 let config = SolverConfig::default();
658 let search = ProofSearch::new(config);
659 assert!(search.is_ok());
660 }
661
662 #[test]
663 fn test_solver_clear() {
664 let mut solver = LogicSolver::with_defaults().unwrap();
665
666 let alice = Term::Const(Constant::String("Alice".to_string()));
667 let bob = Term::Const(Constant::String("Bob".to_string()));
668 let fact = Predicate::new("parent".to_string(), vec![alice, bob]);
669
670 let cid: Cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
671 .parse()
672 .unwrap();
673
674 solver.add_fact(fact, cid).unwrap();
675 assert_eq!(solver.stats().num_facts, 1);
676
677 solver.clear();
678 assert_eq!(solver.stats().num_facts, 0);
679 }
680
681 #[test]
682 fn test_embedding_normalization() {
683 let embedder = PredicateEmbedder::new(64);
684
685 let alice = Term::Const(Constant::String("Alice".to_string()));
686 let pred = Predicate::new("person".to_string(), vec![alice]);
687
688 let embedding = embedder.embed_predicate(&pred);
689
690 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
692 assert!((norm - 1.0).abs() < 0.01);
693 }
694}