use crate::ir::{KnowledgeBase, Predicate, Rule, Term};
use crate::proof_storage::{ProofFragment, ProofFragmentRef};
use crate::reasoning::{Proof, Substitution};
use async_trait::async_trait;
use ipfrs_core::Cid;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum RemoteReasoningError {
#[error("Network error: {0}")]
NetworkError(String),
#[error("Timeout waiting for remote response")]
Timeout,
#[error("Invalid response from peer: {0}")]
InvalidResponse(String),
#[error("Peer not found: {0}")]
PeerNotFound(String),
#[error("No peers available for query")]
NoPeersAvailable,
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("Remote query failed: {0}")]
QueryFailed(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryRequest {
pub predicate_name: String,
pub ground_args: Vec<String>,
pub max_results: usize,
pub max_depth: usize,
pub request_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryResponse {
pub request_id: String,
pub predicates: Vec<Predicate>,
pub rules: Vec<Rule>,
pub proof_fragments: Vec<ProofFragmentRef>,
pub peer_id: String,
pub has_more: bool,
pub continuation_token: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FactDiscoveryRequest {
pub predicate_name: String,
pub arg_patterns: Vec<Option<String>>,
pub max_hops: usize,
pub ttl: u32,
pub exclude_peers: HashSet<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FactDiscoveryResponse {
pub facts: Vec<Predicate>,
pub sources: HashMap<usize, String>,
pub peers_queried: usize,
pub hops: HashMap<usize, usize>, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IncrementalLoadRequest {
pub predicate_name: String,
pub batch_size: usize,
pub offset: usize,
pub filter: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IncrementalLoadResponse {
pub batch: Vec<Predicate>,
pub total_count: usize,
pub next_offset: Option<usize>,
pub is_last: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GoalResolutionRequest {
pub goal: Predicate,
pub substitution: HashMap<String, Term>,
pub depth: usize,
pub requester: String,
pub request_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GoalResolutionResponse {
pub request_id: String,
pub solved: bool,
pub solutions: Vec<HashMap<String, Term>>,
pub proof: Option<Proof>,
pub proof_fragments: Vec<ProofFragmentRef>,
}
#[async_trait]
pub trait RemoteKnowledgeProvider: Send + Sync {
async fn query_predicate(
&self,
request: QueryRequest,
) -> Result<QueryResponse, RemoteReasoningError>;
async fn discover_facts(
&self,
request: FactDiscoveryRequest,
) -> Result<FactDiscoveryResponse, RemoteReasoningError>;
async fn load_incremental(
&self,
request: IncrementalLoadRequest,
) -> Result<IncrementalLoadResponse, RemoteReasoningError>;
async fn resolve_goal(
&self,
request: GoalResolutionRequest,
) -> Result<GoalResolutionResponse, RemoteReasoningError>;
async fn get_available_peers(&self) -> Result<Vec<String>, RemoteReasoningError>;
}
pub struct DistributedGoalResolver {
local_kb: Arc<KnowledgeBase>,
remote_provider: Option<Arc<dyn RemoteKnowledgeProvider>>,
max_depth: usize,
timeout_ms: u64,
remote_fact_cache: HashMap<String, Vec<Predicate>>,
}
impl DistributedGoalResolver {
pub fn new(local_kb: Arc<KnowledgeBase>) -> Self {
Self {
local_kb,
remote_provider: None,
max_depth: 10,
timeout_ms: 5000,
remote_fact_cache: HashMap::new(),
}
}
pub fn with_provider(mut self, provider: Arc<dyn RemoteKnowledgeProvider>) -> Self {
self.remote_provider = Some(provider);
self
}
pub fn with_max_depth(mut self, max_depth: usize) -> Self {
self.max_depth = max_depth;
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
pub async fn resolve(
&mut self,
goal: &Predicate,
substitution: &Substitution,
) -> Result<Vec<Substitution>, RemoteReasoningError> {
let local_solutions = self.resolve_local(goal, substitution);
if !local_solutions.is_empty() {
return Ok(local_solutions);
}
if let Some(provider) = self.remote_provider.clone() {
let remote_solutions = self.resolve_remote(goal, substitution, &provider).await?;
Ok(remote_solutions)
} else {
Ok(Vec::new())
}
}
fn resolve_local(&self, goal: &Predicate, _substitution: &Substitution) -> Vec<Substitution> {
let facts = self.local_kb.get_predicates(&goal.name);
let mut solutions = Vec::new();
for fact in facts {
if let Some(subst) =
crate::reasoning::unify_predicates(goal, fact, &Substitution::new())
{
solutions.push(subst);
}
}
solutions
}
async fn resolve_remote(
&mut self,
goal: &Predicate,
substitution: &Substitution,
provider: &Arc<dyn RemoteKnowledgeProvider>,
) -> Result<Vec<Substitution>, RemoteReasoningError> {
let request = GoalResolutionRequest {
goal: goal.clone(),
substitution: substitution.clone(),
depth: 0,
requester: "local".to_string(),
request_id: uuid::Uuid::new_v4().to_string(),
};
let response = provider.resolve_goal(request).await?;
Ok(response.solutions)
}
pub async fn prefetch_facts(
&mut self,
predicate_name: &str,
) -> Result<usize, RemoteReasoningError> {
let Some(provider) = &self.remote_provider else {
return Ok(0);
};
let request = FactDiscoveryRequest {
predicate_name: predicate_name.to_string(),
arg_patterns: Vec::new(),
max_hops: 3,
ttl: 30,
exclude_peers: HashSet::new(),
};
let response = provider.discover_facts(request).await?;
let count = response.facts.len();
self.remote_fact_cache
.insert(predicate_name.to_string(), response.facts);
Ok(count)
}
pub fn get_cached_facts(&self, predicate_name: &str) -> Option<&[Predicate]> {
self.remote_fact_cache
.get(predicate_name)
.map(|v| v.as_slice())
}
pub fn clear_cache(&mut self) {
self.remote_fact_cache.clear();
}
}
pub struct DistributedProofAssembler {
remote_provider: Arc<dyn RemoteKnowledgeProvider>,
fragment_cache: HashMap<Cid, ProofFragment>,
#[allow(dead_code)]
max_depth: usize,
}
impl DistributedProofAssembler {
pub fn new(remote_provider: Arc<dyn RemoteKnowledgeProvider>) -> Self {
Self {
remote_provider,
fragment_cache: HashMap::new(),
max_depth: 100,
}
}
pub async fn assemble_proof(
&mut self,
goal: &Predicate,
) -> Result<Option<Proof>, RemoteReasoningError> {
let request = GoalResolutionRequest {
goal: goal.clone(),
substitution: HashMap::new(),
depth: 0,
requester: "local".to_string(),
request_id: uuid::Uuid::new_v4().to_string(),
};
let response = self.remote_provider.resolve_goal(request).await?;
if response.solved {
Ok(response.proof)
} else {
Ok(None)
}
}
pub async fn fetch_fragment(
&mut self,
cid: Cid,
) -> Result<ProofFragment, RemoteReasoningError> {
if let Some(fragment) = self.fragment_cache.get(&cid) {
return Ok(fragment.clone());
}
Err(RemoteReasoningError::NetworkError(
"Fragment fetch not yet implemented".to_string(),
))
}
}
pub struct MockRemoteKnowledgeProvider {
mock_kb: Arc<KnowledgeBase>,
}
impl MockRemoteKnowledgeProvider {
pub fn new(mock_kb: Arc<KnowledgeBase>) -> Self {
Self { mock_kb }
}
}
#[async_trait]
impl RemoteKnowledgeProvider for MockRemoteKnowledgeProvider {
async fn query_predicate(
&self,
request: QueryRequest,
) -> Result<QueryResponse, RemoteReasoningError> {
let predicates = self
.mock_kb
.get_predicates(&request.predicate_name)
.into_iter()
.take(request.max_results)
.cloned()
.collect();
let rules = self
.mock_kb
.get_rules(&request.predicate_name)
.into_iter()
.take(request.max_results)
.cloned()
.collect();
Ok(QueryResponse {
request_id: request.request_id,
predicates,
rules,
proof_fragments: Vec::new(),
peer_id: "mock_peer".to_string(),
has_more: false,
continuation_token: None,
})
}
async fn discover_facts(
&self,
request: FactDiscoveryRequest,
) -> Result<FactDiscoveryResponse, RemoteReasoningError> {
let facts: Vec<Predicate> = self
.mock_kb
.get_predicates(&request.predicate_name)
.into_iter()
.cloned()
.collect();
let sources: HashMap<usize, String> = (0..facts.len())
.map(|i| (i, "mock_peer".to_string()))
.collect();
let hops: HashMap<usize, usize> = (0..facts.len()).map(|i| (i, 0)).collect();
Ok(FactDiscoveryResponse {
facts,
sources,
peers_queried: 1,
hops,
})
}
async fn load_incremental(
&self,
request: IncrementalLoadRequest,
) -> Result<IncrementalLoadResponse, RemoteReasoningError> {
let all_facts: Vec<Predicate> = self
.mock_kb
.get_predicates(&request.predicate_name)
.into_iter()
.cloned()
.collect();
let total_count = all_facts.len();
let start = request.offset;
let end = (start + request.batch_size).min(total_count);
let batch = all_facts[start..end].to_vec();
let is_last = end >= total_count;
let next_offset = if is_last { None } else { Some(end) };
Ok(IncrementalLoadResponse {
batch,
total_count,
next_offset,
is_last,
})
}
async fn resolve_goal(
&self,
request: GoalResolutionRequest,
) -> Result<GoalResolutionResponse, RemoteReasoningError> {
let facts = self.mock_kb.get_predicates(&request.goal.name);
let mut solutions = Vec::new();
for fact in facts {
if let Some(subst) =
crate::reasoning::unify_predicates(&request.goal, fact, &Substitution::new())
{
solutions.push(subst);
}
}
let solved = !solutions.is_empty();
let proof = if solved {
Some(Proof::fact(request.goal.clone()))
} else {
None
};
Ok(GoalResolutionResponse {
request_id: request.request_id,
solved,
solutions,
proof,
proof_fragments: Vec::new(),
})
}
async fn get_available_peers(&self) -> Result<Vec<String>, RemoteReasoningError> {
Ok(vec!["mock_peer".to_string()])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::Constant;
#[tokio::test]
async fn test_query_request_serialization() {
let request = QueryRequest {
predicate_name: "parent".to_string(),
ground_args: vec!["Alice".to_string()],
max_results: 10,
max_depth: 5,
request_id: "test_123".to_string(),
};
let json = serde_json::to_string(&request).unwrap();
let decoded: QueryRequest = serde_json::from_str(&json).unwrap();
assert_eq!(request.predicate_name, decoded.predicate_name);
assert_eq!(request.ground_args, decoded.ground_args);
}
#[tokio::test]
async fn test_mock_provider_query() {
let mut kb = KnowledgeBase::new();
kb.add_fact(Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("Alice".to_string())),
Term::Const(Constant::String("Bob".to_string())),
],
));
let provider = MockRemoteKnowledgeProvider::new(Arc::new(kb));
let request = QueryRequest {
predicate_name: "parent".to_string(),
ground_args: vec![],
max_results: 10,
max_depth: 5,
request_id: "test_123".to_string(),
};
let response = provider.query_predicate(request).await.unwrap();
assert_eq!(response.predicates.len(), 1);
assert_eq!(response.predicates[0].name, "parent");
}
#[tokio::test]
async fn test_distributed_resolver() {
let mut local_kb = KnowledgeBase::new();
local_kb.add_fact(Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("Alice".to_string())),
Term::Const(Constant::String("Bob".to_string())),
],
));
let mut resolver = DistributedGoalResolver::new(Arc::new(local_kb));
let goal = Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("Alice".to_string())),
Term::Var("X".to_string()),
],
);
let solutions = resolver.resolve(&goal, &Substitution::new()).await.unwrap();
assert!(!solutions.is_empty());
}
#[tokio::test]
async fn test_fact_discovery() {
let mut kb = KnowledgeBase::new();
kb.add_fact(Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("Alice".to_string())),
Term::Const(Constant::String("Bob".to_string())),
],
));
kb.add_fact(Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("Bob".to_string())),
Term::Const(Constant::String("Charlie".to_string())),
],
));
let provider = MockRemoteKnowledgeProvider::new(Arc::new(kb));
let request = FactDiscoveryRequest {
predicate_name: "parent".to_string(),
arg_patterns: vec![],
max_hops: 3,
ttl: 30,
exclude_peers: HashSet::new(),
};
let response = provider.discover_facts(request).await.unwrap();
assert_eq!(response.facts.len(), 2);
assert_eq!(response.peers_queried, 1);
}
#[tokio::test]
async fn test_incremental_loading() {
let mut kb = KnowledgeBase::new();
for i in 0..10 {
kb.add_fact(Predicate::new(
"number".to_string(),
vec![Term::Const(Constant::Int(i))],
));
}
let provider = MockRemoteKnowledgeProvider::new(Arc::new(kb));
let request = IncrementalLoadRequest {
predicate_name: "number".to_string(),
batch_size: 3,
offset: 0,
filter: None,
};
let response = provider.load_incremental(request).await.unwrap();
assert_eq!(response.batch.len(), 3);
assert_eq!(response.total_count, 10);
assert!(!response.is_last);
assert_eq!(response.next_offset, Some(3));
}
#[tokio::test]
async fn test_goal_resolution() {
let mut kb = KnowledgeBase::new();
kb.add_fact(Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("Alice".to_string())),
Term::Const(Constant::String("Bob".to_string())),
],
));
let provider = MockRemoteKnowledgeProvider::new(Arc::new(kb));
let goal = Predicate::new(
"parent".to_string(),
vec![
Term::Const(Constant::String("Alice".to_string())),
Term::Var("X".to_string()),
],
);
let request = GoalResolutionRequest {
goal,
substitution: HashMap::new(),
depth: 0,
requester: "test".to_string(),
request_id: "test_123".to_string(),
};
let response = provider.resolve_goal(request).await.unwrap();
assert!(response.solved);
assert!(!response.solutions.is_empty());
assert!(response.proof.is_some());
}
}