ipfrs_semantic/
solver.rs

1//! Logic solver for reasoning queries with semantic integration
2//!
3//! This module provides integration between semantic search and logic reasoning:
4//! - Predicate-to-embedding mapping for similarity-based matching
5//! - Logic term similarity for fuzzy unification
6//! - Proof tree search using vector indices
7//! - Backward chaining with semantic relevance
8//! - Subgoal decomposition and dependency tracking
9
10use crate::hnsw::{DistanceMetric, VectorIndex};
11use ipfrs_core::{Cid, Error, Result};
12use ipfrs_tensorlogic::{
13    CycleDetector, GoalDecomposition, InferenceEngine, KnowledgeBase, Predicate, ProofRule,
14    Substitution, Term,
15};
16use parking_lot::RwLock;
17use serde::{Deserialize, Serialize};
18use std::collections::{HashMap, HashSet, VecDeque};
19use std::sync::Arc;
20
21/// Configuration for the logic solver
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct SolverConfig {
24    /// Maximum recursion depth for backward chaining
25    pub max_depth: usize,
26    /// Similarity threshold for fuzzy matching (0.0 to 1.0)
27    pub similarity_threshold: f32,
28    /// Number of similar predicates to consider
29    pub top_k_similar: usize,
30    /// Embedding dimension for predicates
31    pub embedding_dim: usize,
32    /// Whether to use cycle detection
33    pub detect_cycles: bool,
34}
35
36impl Default for SolverConfig {
37    fn default() -> Self {
38        Self {
39            max_depth: 100,
40            similarity_threshold: 0.8,
41            top_k_similar: 10,
42            embedding_dim: 384, // Standard embedding size
43            detect_cycles: true,
44        }
45    }
46}
47
48/// Maps logic predicates to vector embeddings for similarity search
49pub struct PredicateEmbedder {
50    /// Embedding dimension
51    dim: usize,
52    /// Cache of predicate embeddings
53    embeddings: Arc<RwLock<HashMap<String, Vec<f32>>>>,
54}
55
56impl PredicateEmbedder {
57    /// Create a new predicate embedder
58    pub fn new(dim: usize) -> Self {
59        Self {
60            dim,
61            embeddings: Arc::new(RwLock::new(HashMap::new())),
62        }
63    }
64
65    /// Generate embedding for a predicate
66    ///
67    /// This uses a simple compositional approach:
68    /// - Predicate name contributes to the embedding
69    /// - Each term contributes based on its structure
70    /// - Ground terms have higher weight than variables
71    pub fn embed_predicate(&self, pred: &Predicate) -> Vec<f32> {
72        let cached = self.embeddings.read().get(&pred.to_string()).cloned();
73        if let Some(emb) = cached {
74            return emb;
75        }
76
77        let mut embedding = vec![0.0; self.dim];
78
79        // Hash predicate name to embedding space
80        let name_hash = self.hash_string(&pred.name);
81        for (i, val) in embedding.iter_mut().enumerate() {
82            *val += (((name_hash + i) as f32).sin() * 0.5).abs();
83        }
84
85        // Add contribution from each argument
86        for (idx, term) in pred.args.iter().enumerate() {
87            let term_emb = self.embed_term(term, idx);
88            for i in 0..self.dim {
89                embedding[i] += term_emb[i] * 0.3; // Weight terms less than predicate name
90            }
91        }
92
93        // Normalize
94        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
95        if norm > 1e-6 {
96            for x in &mut embedding {
97                *x /= norm;
98            }
99        }
100
101        // Cache the embedding
102        self.embeddings
103            .write()
104            .insert(pred.to_string(), embedding.clone());
105
106        embedding
107    }
108
109    /// Embed a logic term
110    fn embed_term(&self, term: &Term, position: usize) -> Vec<f32> {
111        let mut embedding = vec![0.0; self.dim];
112
113        match term {
114            Term::Var(name) => {
115                // Variables get a low-weight embedding based on position
116                let hash = self.hash_string(name) + position;
117                for (i, val) in embedding.iter_mut().enumerate() {
118                    *val = (((hash + i) as f32).sin() * 0.2).abs();
119                }
120            }
121            Term::Const(constant) => {
122                // Constants get higher weight
123                let hash = self.hash_string(&format!("{:?}", constant));
124                for (i, val) in embedding.iter_mut().enumerate() {
125                    *val = (((hash + i) as f32).sin() * 0.8).abs();
126                }
127            }
128            Term::Fun(functor, args) => {
129                // Function terms combine functor and arg embeddings
130                let hash = self.hash_string(functor);
131                for (i, val) in embedding.iter_mut().enumerate() {
132                    *val = (((hash + i) as f32).sin() * 0.6).abs();
133                }
134
135                for (idx, arg) in args.iter().enumerate() {
136                    let arg_emb = self.embed_term(arg, idx);
137                    for i in 0..self.dim {
138                        embedding[i] += arg_emb[i] * 0.2;
139                    }
140                }
141            }
142            Term::Ref(_) => {
143                // References get medium weight based on position
144                let hash = position;
145                for (i, val) in embedding.iter_mut().enumerate() {
146                    *val = (((hash + i) as f32).sin() * 0.5).abs();
147                }
148            }
149        }
150
151        embedding
152    }
153
154    /// Simple string hash function
155    fn hash_string(&self, s: &str) -> usize {
156        s.bytes().fold(0usize, |acc, b| {
157            acc.wrapping_mul(31).wrapping_add(b as usize)
158        })
159    }
160
161    /// Compute similarity between two predicates (cosine similarity)
162    pub fn similarity(&self, pred1: &Predicate, pred2: &Predicate) -> f32 {
163        let emb1 = self.embed_predicate(pred1);
164        let emb2 = self.embed_predicate(pred2);
165
166        self.cosine_similarity(&emb1, &emb2)
167    }
168
169    /// Cosine similarity between two vectors
170    fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
171        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
172        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
173        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
174
175        if norm_a < 1e-6 || norm_b < 1e-6 {
176            0.0
177        } else {
178            dot / (norm_a * norm_b)
179        }
180    }
181}
182
183/// Proof tree node for search
184#[derive(Debug, Clone)]
185pub struct ProofTreeNode {
186    /// Current goal predicate
187    pub goal: Predicate,
188    /// Substitutions applied so far
189    pub substitution: Substitution,
190    /// Parent node ID (None for root)
191    pub parent: Option<usize>,
192    /// Depth in the proof tree
193    pub depth: usize,
194    /// Relevance score (from semantic search)
195    pub relevance: f32,
196}
197
198/// Logic solver with semantic integration
199pub struct LogicSolver {
200    /// Configuration
201    config: SolverConfig,
202    /// Knowledge base for reasoning
203    kb: Arc<RwLock<KnowledgeBase>>,
204    /// Inference engine
205    engine: Arc<RwLock<InferenceEngine>>,
206    /// Predicate embedder for similarity search
207    embedder: PredicateEmbedder,
208    /// Vector index for predicate search
209    predicate_index: Arc<RwLock<Option<VectorIndex>>>,
210    /// Cycle detector
211    cycle_detector: Arc<RwLock<CycleDetector>>,
212    /// Cache of CID to predicate mapping
213    cid_to_predicate: Arc<RwLock<HashMap<Cid, Predicate>>>,
214}
215
216impl LogicSolver {
217    /// Create a new logic solver
218    pub fn new(config: SolverConfig) -> Result<Self> {
219        let kb = KnowledgeBase::new();
220        let engine = InferenceEngine::new();
221
222        Ok(Self {
223            embedder: PredicateEmbedder::new(config.embedding_dim),
224            config,
225            kb: Arc::new(RwLock::new(kb)),
226            engine: Arc::new(RwLock::new(engine)),
227            predicate_index: Arc::new(RwLock::new(None)),
228            cycle_detector: Arc::new(RwLock::new(CycleDetector::new())),
229            cid_to_predicate: Arc::new(RwLock::new(HashMap::new())),
230        })
231    }
232
233    /// Create with default configuration
234    pub fn with_defaults() -> Result<Self> {
235        Self::new(SolverConfig::default())
236    }
237
238    /// Add a fact to the knowledge base and index it
239    pub fn add_fact(&mut self, fact: Predicate, cid: Cid) -> Result<()> {
240        // Add to knowledge base
241        self.kb.write().add_fact(fact.clone());
242
243        // Generate embedding
244        let embedding = self.embedder.embed_predicate(&fact);
245
246        // Add to vector index (create if needed)
247        {
248            let mut index_lock = self.predicate_index.write();
249            if index_lock.is_none() {
250                *index_lock = Some(VectorIndex::new(
251                    self.config.embedding_dim,
252                    DistanceMetric::Cosine,
253                    32,  // max_nb_connection
254                    100, // ef_construction
255                )?);
256            }
257
258            if let Some(ref mut index) = *index_lock {
259                index.insert(&cid, &embedding)?;
260            }
261        }
262
263        // Store CID mapping
264        self.cid_to_predicate.write().insert(cid, fact);
265
266        Ok(())
267    }
268
269    /// Add a rule to the knowledge base
270    pub fn add_rule(&mut self, head: Predicate, body: Vec<Predicate>) -> Result<()> {
271        use ipfrs_tensorlogic::Rule;
272        let rule = Rule { head, body };
273        self.kb.write().add_rule(rule);
274        Ok(())
275    }
276
277    /// Find similar predicates using semantic search
278    pub fn find_similar_predicates(
279        &self,
280        query: &Predicate,
281        k: usize,
282    ) -> Result<Vec<(Cid, Predicate, f32)>> {
283        let embedding = self.embedder.embed_predicate(query);
284
285        let index_lock = self.predicate_index.read();
286        let index = index_lock
287            .as_ref()
288            .ok_or_else(|| Error::InvalidInput("Predicate index not initialized".to_string()))?;
289
290        let results = index.search(&embedding, k, 100)?; // ef_search = 100
291
292        let cid_map = self.cid_to_predicate.read();
293        let mut similar = Vec::new();
294
295        for result in results {
296            if let Some(pred) = cid_map.get(&result.cid) {
297                similar.push((result.cid, pred.clone(), result.score));
298            }
299        }
300
301        Ok(similar)
302    }
303
304    /// Query using backward chaining with semantic relevance
305    pub fn query(&self, goal: &Predicate) -> Result<Vec<Substitution>> {
306        let engine = self.engine.write();
307        let kb = self.kb.read();
308        let substs = engine.query(goal, &kb)?;
309        Ok(substs)
310    }
311
312    /// Query with depth limit and semantic guidance
313    pub fn query_with_depth(
314        &self,
315        goal: &Predicate,
316        max_depth: usize,
317    ) -> Result<Vec<Substitution>> {
318        // Use goal decomposition for complex queries
319        let decomposition = GoalDecomposition::new(goal.clone(), max_depth);
320
321        let mut all_substs = Vec::new();
322
323        // Solve each subgoal
324        for subgoal in &decomposition.subgoals {
325            let engine = self.engine.write();
326            let kb = self.kb.read();
327            let substs = engine.query(subgoal, &kb)?;
328            all_substs.extend(substs);
329        }
330
331        Ok(all_substs)
332    }
333
334    /// Perform backward chaining with semantic search fallback
335    ///
336    /// This combines traditional backward chaining with semantic search:
337    /// 1. Try exact unification first
338    /// 2. If that fails, search for similar predicates
339    /// 3. Use similarity scores to rank results
340    pub fn backward_chain(&self, goal: &Predicate) -> Result<Vec<(Substitution, f32)>> {
341        let mut substs_with_scores = Vec::new();
342
343        // Try exact backward chaining first
344        let exact_substs = self.query(goal)?;
345        for subst in exact_substs {
346            substs_with_scores.push((subst, 1.0)); // Exact match = score 1.0
347        }
348
349        // If no exact results and similarity search enabled, try semantic search
350        if substs_with_scores.is_empty() {
351            let similar = self.find_similar_predicates(goal, self.config.top_k_similar)?;
352
353            for (_, similar_pred, score) in similar {
354                // Score from vector search is already a similarity score
355                let similarity = score;
356
357                if similarity >= self.config.similarity_threshold {
358                    // Try to prove the similar predicate
359                    let substs = self.query(&similar_pred)?;
360                    for subst in substs {
361                        substs_with_scores.push((subst, similarity));
362                    }
363                }
364            }
365        }
366
367        // Sort by score (descending)
368        substs_with_scores
369            .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
370
371        Ok(substs_with_scores)
372    }
373
374    /// Check if a query would create a cycle
375    pub fn would_cycle(&self, goal: &Predicate, _depth: usize) -> bool {
376        if !self.config.detect_cycles {
377            return false;
378        }
379
380        let detector = self.cycle_detector.read();
381        detector.would_cycle(goal)
382    }
383
384    /// Get knowledge base statistics
385    pub fn stats(&self) -> SolverStats {
386        let kb = self.kb.read();
387        let kb_stats = kb.stats();
388
389        let index_lock = self.predicate_index.read();
390        let num_indexed = if let Some(ref index) = *index_lock {
391            index.len()
392        } else {
393            0
394        };
395
396        SolverStats {
397            num_facts: kb_stats.num_facts,
398            num_rules: kb_stats.num_rules,
399            num_indexed_predicates: num_indexed,
400            embedding_dim: self.config.embedding_dim,
401        }
402    }
403
404    /// Clear all data
405    pub fn clear(&mut self) {
406        let mut kb = self.kb.write();
407        kb.facts.clear();
408        kb.rules.clear();
409        *self.predicate_index.write() = None;
410        self.cid_to_predicate.write().clear();
411        *self.cycle_detector.write() = CycleDetector::new();
412    }
413}
414
415/// Statistics for the logic solver
416#[derive(Debug, Clone, Serialize, Deserialize)]
417pub struct SolverStats {
418    /// Number of facts in knowledge base
419    pub num_facts: usize,
420    /// Number of rules in knowledge base
421    pub num_rules: usize,
422    /// Number of indexed predicates
423    pub num_indexed_predicates: usize,
424    /// Embedding dimension
425    pub embedding_dim: usize,
426}
427
428/// Proof search with semantic guidance
429#[allow(dead_code)]
430pub struct ProofSearch {
431    /// Configuration
432    config: SolverConfig,
433    /// Embedder for similarity
434    embedder: PredicateEmbedder,
435    /// Vector index for proof fragments
436    proof_index: VectorIndex,
437    /// Visited goals (for cycle detection)
438    visited: HashSet<String>,
439}
440
441impl ProofSearch {
442    /// Create a new proof search
443    pub fn new(config: SolverConfig) -> Result<Self> {
444        Ok(Self {
445            embedder: PredicateEmbedder::new(config.embedding_dim),
446            proof_index: VectorIndex::new(
447                config.embedding_dim,
448                DistanceMetric::Cosine,
449                32,  // max_nb_connection
450                100, // ef_construction
451            )?,
452            visited: HashSet::new(),
453            config,
454        })
455    }
456
457    /// Search for proof trees using BFS with semantic guidance
458    pub fn search_proof_tree(
459        &mut self,
460        goal: &Predicate,
461        kb: &KnowledgeBase,
462    ) -> Result<Vec<ProofTreeNode>> {
463        let mut queue: VecDeque<ProofTreeNode> = VecDeque::new();
464        let mut proof_tree = Vec::new();
465
466        // Initialize with root goal
467        let root = ProofTreeNode {
468            goal: goal.clone(),
469            substitution: HashMap::new(),
470            parent: None,
471            depth: 0,
472            relevance: 1.0,
473        };
474
475        queue.push_back(root);
476        self.visited.clear();
477
478        while let Some(node) = queue.pop_front() {
479            // Check depth limit
480            if node.depth >= self.config.max_depth {
481                continue;
482            }
483
484            // Check if already visited (cycle detection)
485            let goal_str = node.goal.to_string();
486            if self.visited.contains(&goal_str) {
487                continue;
488            }
489            self.visited.insert(goal_str);
490
491            let node_id = proof_tree.len();
492            proof_tree.push(node.clone());
493
494            // Try to match with facts (simplified - exact name match)
495            for fact in &kb.facts {
496                if node.goal.name == fact.name && node.goal.arity() == fact.arity() {
497                    let child = ProofTreeNode {
498                        goal: fact.clone(),
499                        substitution: HashMap::new(),
500                        parent: Some(node_id),
501                        depth: node.depth + 1,
502                        relevance: node.relevance * 1.0, // Exact match
503                    };
504                    queue.push_back(child);
505                }
506            }
507
508            // Try to match with rules (simplified - exact name match)
509            for rule in &kb.rules {
510                if node.goal.name == rule.head.name && node.goal.arity() == rule.head.arity() {
511                    for body_pred in &rule.body {
512                        let child = ProofTreeNode {
513                            goal: body_pred.clone(),
514                            substitution: HashMap::new(),
515                            parent: Some(node_id),
516                            depth: node.depth + 1,
517                            relevance: node.relevance * 0.9, // Rule inference slightly lower
518                        };
519                        queue.push_back(child);
520                    }
521                }
522            }
523        }
524
525        Ok(proof_tree)
526    }
527
528    /// Extract proof from proof tree
529    pub fn extract_proof(&self, tree: &[ProofTreeNode], leaf_idx: usize) -> Vec<ProofRule> {
530        let mut proof_rules = Vec::new();
531        let mut current_idx = Some(leaf_idx);
532
533        while let Some(idx) = current_idx {
534            if idx >= tree.len() {
535                break;
536            }
537
538            let node = &tree[idx];
539            proof_rules.push(ProofRule {
540                head: node.goal.clone(),
541                body: Vec::new(), // Simplified for now
542                is_fact: true,
543            });
544
545            current_idx = node.parent;
546        }
547
548        proof_rules.reverse();
549        proof_rules
550    }
551}
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556    use ipfrs_tensorlogic::Constant;
557
558    #[test]
559    fn test_predicate_embedder() {
560        let embedder = PredicateEmbedder::new(128);
561
562        let alice = Term::Const(Constant::String("Alice".to_string()));
563        let bob = Term::Const(Constant::String("Bob".to_string()));
564        let charlie = Term::Const(Constant::String("Charlie".to_string()));
565
566        let pred1 = Predicate::new("parent".to_string(), vec![alice.clone(), bob.clone()]);
567        let pred2 = Predicate::new("parent".to_string(), vec![alice.clone(), bob.clone()]);
568        let pred3 = Predicate::new("parent".to_string(), vec![alice.clone(), charlie.clone()]);
569
570        // Same predicates should have high similarity
571        let sim_same = embedder.similarity(&pred1, &pred2);
572        assert!(
573            sim_same > 0.99,
574            "Expected sim_same > 0.99, got {}",
575            sim_same
576        );
577
578        // Different arguments should have lower similarity
579        let sim_diff_args = embedder.similarity(&pred1, &pred3);
580        assert!(
581            sim_diff_args < sim_same,
582            "Expected {} < {}",
583            sim_diff_args,
584            sim_same
585        );
586        assert!(
587            sim_diff_args > 0.8,
588            "Expected predicates with same name to have reasonable similarity, got {}",
589            sim_diff_args
590        );
591    }
592
593    #[test]
594    fn test_solver_creation() {
595        let solver = LogicSolver::with_defaults();
596        assert!(solver.is_ok());
597
598        let stats = solver.unwrap().stats();
599        assert_eq!(stats.num_facts, 0);
600        assert_eq!(stats.num_rules, 0);
601    }
602
603    #[test]
604    fn test_add_fact() {
605        let mut solver = LogicSolver::with_defaults().unwrap();
606
607        let alice = Term::Const(Constant::String("Alice".to_string()));
608        let bob = Term::Const(Constant::String("Bob".to_string()));
609        let fact = Predicate::new("parent".to_string(), vec![alice, bob]);
610
611        let cid: Cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
612            .parse()
613            .unwrap();
614
615        let result = solver.add_fact(fact, cid);
616        assert!(result.is_ok());
617
618        let stats = solver.stats();
619        assert_eq!(stats.num_facts, 1);
620        assert_eq!(stats.num_indexed_predicates, 1);
621    }
622
623    #[test]
624    fn test_add_rule() {
625        let mut solver = LogicSolver::with_defaults().unwrap();
626
627        let x = Term::Var("X".to_string());
628        let y = Term::Var("Y".to_string());
629        let z = Term::Var("Z".to_string());
630
631        let head = Predicate::new("ancestor".to_string(), vec![x.clone(), z.clone()]);
632        let body1 = Predicate::new("parent".to_string(), vec![x.clone(), y.clone()]);
633        let body2 = Predicate::new("ancestor".to_string(), vec![y.clone(), z.clone()]);
634
635        let result = solver.add_rule(head, vec![body1, body2]);
636        assert!(result.is_ok());
637
638        let stats = solver.stats();
639        assert_eq!(stats.num_rules, 1);
640    }
641
642    #[test]
643    fn test_query_empty() {
644        let solver = LogicSolver::with_defaults().unwrap();
645
646        let alice = Term::Const(Constant::String("Alice".to_string()));
647        let bob = Term::Const(Constant::String("Bob".to_string()));
648        let query = Predicate::new("parent".to_string(), vec![alice, bob]);
649
650        let result = solver.query(&query);
651        assert!(result.is_ok());
652        assert!(result.unwrap().is_empty());
653    }
654
655    #[test]
656    fn test_proof_search_creation() {
657        let config = SolverConfig::default();
658        let search = ProofSearch::new(config);
659        assert!(search.is_ok());
660    }
661
662    #[test]
663    fn test_solver_clear() {
664        let mut solver = LogicSolver::with_defaults().unwrap();
665
666        let alice = Term::Const(Constant::String("Alice".to_string()));
667        let bob = Term::Const(Constant::String("Bob".to_string()));
668        let fact = Predicate::new("parent".to_string(), vec![alice, bob]);
669
670        let cid: Cid = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
671            .parse()
672            .unwrap();
673
674        solver.add_fact(fact, cid).unwrap();
675        assert_eq!(solver.stats().num_facts, 1);
676
677        solver.clear();
678        assert_eq!(solver.stats().num_facts, 0);
679    }
680
681    #[test]
682    fn test_embedding_normalization() {
683        let embedder = PredicateEmbedder::new(64);
684
685        let alice = Term::Const(Constant::String("Alice".to_string()));
686        let pred = Predicate::new("person".to_string(), vec![alice]);
687
688        let embedding = embedder.embed_predicate(&pred);
689
690        // Check that embedding is normalized
691        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
692        assert!((norm - 1.0).abs() < 0.01);
693    }
694}