use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use crate::{DomainInfo, PredicateInfo, SymbolTable, ValueRange};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct InferenceConfig {
pub min_confidence: f64,
pub infer_hierarchies: bool,
pub infer_constraints: bool,
pub infer_dependencies: bool,
pub cardinality_multiplier: f64,
pub max_nesting_depth: usize,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
min_confidence: 0.7,
infer_hierarchies: true,
infer_constraints: true,
infer_dependencies: true,
cardinality_multiplier: 2.0,
max_nesting_depth: 5,
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct ConfidenceScore {
pub score: f64,
pub evidence_count: usize,
pub reasoning: String,
}
impl ConfidenceScore {
pub fn new(score: f64, evidence_count: usize, reasoning: impl Into<String>) -> Self {
Self {
score: score.clamp(0.0, 1.0),
evidence_count,
reasoning: reasoning.into(),
}
}
pub fn is_confident(&self, threshold: f64) -> bool {
self.score >= threshold
}
}
#[derive(Clone, Debug)]
pub struct DataSample {
records: Vec<HashMap<String, Value>>,
}
impl DataSample {
pub fn from_json(json: &str) -> Result<Self> {
let value: Value = serde_json::from_str(json)?;
let records = match value {
Value::Array(arr) => arr
.into_iter()
.filter_map(|v| {
if let Value::Object(map) = v {
Some(map.into_iter().collect::<HashMap<_, _>>())
} else {
None
}
})
.collect(),
Value::Object(map) => {
vec![map.into_iter().collect()]
}
_ => return Err(anyhow!("Expected JSON array or object")),
};
Ok(Self { records })
}
pub fn from_csv(csv: &str) -> Result<Self> {
let mut lines = csv.lines();
let headers: Vec<String> = lines
.next()
.ok_or_else(|| anyhow!("Empty CSV"))?
.split(',')
.map(|s| s.trim().to_string())
.collect();
let records = lines
.map(|line| {
let values: Vec<&str> = line.split(',').map(|s| s.trim()).collect();
headers
.iter()
.zip(values.iter())
.map(|(k, v)| {
let json_val = if let Ok(num) = v.parse::<f64>() {
Value::Number(
serde_json::Number::from_f64(num)
.unwrap_or_else(|| serde_json::Number::from(0i64)),
)
} else if *v == "true" || *v == "false" {
Value::Bool(*v == "true")
} else {
Value::String(v.to_string())
};
(k.clone(), json_val)
})
.collect()
})
.collect();
Ok(Self { records })
}
pub fn field_names(&self) -> HashSet<String> {
self.records
.iter()
.flat_map(|record| record.keys().cloned())
.collect()
}
pub fn field_values(&self, field: &str) -> Vec<&Value> {
self.records
.iter()
.filter_map(|record| record.get(field))
.collect()
}
pub fn len(&self) -> usize {
self.records.len()
}
pub fn is_empty(&self) -> bool {
self.records.is_empty()
}
}
#[derive(Clone, Debug, Default)]
pub struct LearningStatistics {
pub domains_inferred: usize,
pub predicates_inferred: usize,
pub constraints_inferred: usize,
pub hierarchies_inferred: usize,
pub dependencies_inferred: usize,
pub total_samples_analyzed: usize,
pub inference_time_ms: u128,
}
pub struct SchemaLearner {
config: InferenceConfig,
statistics: LearningStatistics,
confidence_scores: HashMap<String, ConfidenceScore>,
}
impl SchemaLearner {
pub fn new(config: InferenceConfig) -> Self {
Self {
config,
statistics: LearningStatistics::default(),
confidence_scores: HashMap::new(),
}
}
pub fn learn_from_sample(&mut self, sample: &DataSample) -> Result<SymbolTable> {
let start = std::time::Instant::now();
let mut table = SymbolTable::new();
self.infer_domains(sample, &mut table)?;
self.infer_predicates(sample, &mut table)?;
if self.config.infer_constraints {
self.infer_constraints(sample, &mut table)?;
}
if self.config.infer_hierarchies {
self.infer_hierarchies(sample, &mut table)?;
}
self.statistics.total_samples_analyzed = sample.len();
self.statistics.inference_time_ms = start.elapsed().as_millis();
Ok(table)
}
fn infer_domains(&mut self, sample: &DataSample, table: &mut SymbolTable) -> Result<()> {
let mut domain_types: HashMap<String, HashSet<String>> = HashMap::new();
for field in sample.field_names() {
let values = sample.field_values(&field);
let types: HashSet<String> = values.iter().map(|v| self.infer_type(v)).collect();
domain_types.insert(field.clone(), types);
}
let mut inferred_types: HashSet<String> = HashSet::new();
for types in domain_types.values() {
inferred_types.extend(types.clone());
}
for type_name in inferred_types {
let cardinality = self.estimate_cardinality(sample, &type_name);
let domain = DomainInfo::new(&type_name, cardinality);
if table.add_domain(domain).is_ok() {
self.statistics.domains_inferred += 1;
self.confidence_scores.insert(
format!("domain:{}", type_name),
ConfidenceScore::new(
0.9,
sample.len(),
format!("Inferred from {} samples", sample.len()),
),
);
}
}
Ok(())
}
fn infer_predicates(&mut self, sample: &DataSample, table: &mut SymbolTable) -> Result<()> {
let fields: Vec<String> = sample.field_names().into_iter().collect();
for field in &fields {
let values = sample.field_values(field);
if values.is_empty() {
continue;
}
let type_name = self.infer_type(values[0]);
let predicate = PredicateInfo::new(field, vec![type_name.clone()]);
if table.add_predicate(predicate).is_ok() {
self.statistics.predicates_inferred += 1;
self.confidence_scores.insert(
format!("predicate:{}", field),
ConfidenceScore::new(
0.85,
values.len(),
format!("Inferred from {} values", values.len()),
),
);
}
}
for i in 0..fields.len() {
for j in (i + 1)..fields.len() {
let field1 = &fields[i];
let field2 = &fields[j];
if self.has_relationship(sample, field1, field2) {
let type1 = self.infer_type(sample.field_values(field1)[0]);
let type2 = self.infer_type(sample.field_values(field2)[0]);
let rel_name = format!("{}_{}", field1, field2);
let predicate = PredicateInfo::new(&rel_name, vec![type1, type2]);
if table.add_predicate(predicate).is_ok() {
self.statistics.predicates_inferred += 1;
}
}
}
}
Ok(())
}
fn infer_constraints(&mut self, sample: &DataSample, _table: &mut SymbolTable) -> Result<()> {
for field in sample.field_names() {
let values = sample.field_values(&field);
if let Some(range) = self.infer_value_range(&values) {
self.statistics.constraints_inferred += 1;
self.confidence_scores.insert(
format!("constraint:{}:range", field),
ConfidenceScore::new(
0.8,
values.len(),
"Inferred from numeric values".to_string(),
),
);
let _ = range; }
}
Ok(())
}
fn infer_hierarchies(&mut self, _sample: &DataSample, _table: &mut SymbolTable) -> Result<()> {
Ok(())
}
fn infer_type(&self, value: &Value) -> String {
match value {
Value::Number(_) => "Number".to_string(),
Value::String(_) => "String".to_string(),
Value::Bool(_) => "Boolean".to_string(),
Value::Array(_) => "Array".to_string(),
Value::Object(_) => "Object".to_string(),
Value::Null => "Unknown".to_string(),
}
}
fn estimate_cardinality(&self, sample: &DataSample, type_name: &str) -> usize {
let mut unique_values = HashSet::new();
for record in &sample.records {
for value in record.values() {
if self.infer_type(value) == type_name {
unique_values.insert(format!("{:?}", value));
}
}
}
((unique_values.len() as f64) * self.config.cardinality_multiplier).ceil() as usize
}
fn has_relationship(&self, sample: &DataSample, field1: &str, field2: &str) -> bool {
let values1 = sample.field_values(field1);
let values2 = sample.field_values(field2);
let co_occurrence = values1.len().min(values2.len());
let threshold = (sample.len() as f64 * 0.8).ceil() as usize;
co_occurrence >= threshold
}
fn infer_value_range(&self, values: &[&Value]) -> Option<ValueRange> {
let numbers: Vec<f64> = values.iter().filter_map(|v| v.as_f64()).collect();
if numbers.is_empty() {
return None;
}
let min = numbers.iter().copied().fold(f64::INFINITY, f64::min);
let max = numbers.iter().copied().fold(f64::NEG_INFINITY, f64::max);
Some(ValueRange::new().with_min(min, true).with_max(max, true))
}
pub fn statistics(&self) -> &LearningStatistics {
&self.statistics
}
pub fn confidence(&self, element: &str) -> Option<&ConfidenceScore> {
self.confidence_scores.get(element)
}
pub fn all_confidences(&self) -> &HashMap<String, ConfidenceScore> {
&self.confidence_scores
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_data_sample_from_json() {
let json = r#"[
{"id": 1, "name": "Alice"},
{"id": 2, "name": "Bob"}
]"#;
let sample = DataSample::from_json(json).expect("unwrap");
assert_eq!(sample.len(), 2);
assert_eq!(sample.field_names().len(), 2);
}
#[test]
fn test_data_sample_from_csv() {
let csv = "id,name,age\n1,Alice,30\n2,Bob,25";
let sample = DataSample::from_csv(csv).expect("unwrap");
assert_eq!(sample.len(), 2);
assert_eq!(sample.field_names().len(), 3);
}
#[test]
fn test_schema_learner_basic() {
let json = r#"[
{"id": 1, "name": "Alice", "age": 30},
{"id": 2, "name": "Bob", "age": 25}
]"#;
let sample = DataSample::from_json(json).expect("unwrap");
let config = InferenceConfig::default();
let mut learner = SchemaLearner::new(config);
let _schema = learner.learn_from_sample(&sample).expect("unwrap");
let stats = learner.statistics();
assert!(stats.domains_inferred > 0);
assert!(stats.predicates_inferred > 0);
assert_eq!(stats.total_samples_analyzed, 2);
}
#[test]
fn test_type_inference() {
let config = InferenceConfig::default();
let learner = SchemaLearner::new(config);
assert_eq!(learner.infer_type(&Value::Number(42.into())), "Number");
assert_eq!(learner.infer_type(&Value::String("test".into())), "String");
assert_eq!(learner.infer_type(&Value::Bool(true)), "Boolean");
}
#[test]
fn test_value_range_inference() {
let val1 = Value::Number(10.into());
let val2 = Value::Number(20.into());
let val3 = Value::Number(30.into());
let values = vec![&val1, &val2, &val3];
let config = InferenceConfig::default();
let learner = SchemaLearner::new(config);
let range = learner.infer_value_range(&values).expect("unwrap");
assert_eq!(range.min, Some(10.0));
assert_eq!(range.max, Some(30.0));
}
#[test]
fn test_confidence_score() {
let score = ConfidenceScore::new(0.85, 100, "High confidence");
assert_eq!(score.score, 0.85);
assert_eq!(score.evidence_count, 100);
assert!(score.is_confident(0.7));
assert!(!score.is_confident(0.9));
}
#[test]
fn test_inference_config_default() {
let config = InferenceConfig::default();
assert_eq!(config.min_confidence, 0.7);
assert!(config.infer_hierarchies);
assert!(config.infer_constraints);
}
#[test]
fn test_cardinality_estimation() {
let json = r#"[
{"id": 1, "type": "A"},
{"id": 2, "type": "B"},
{"id": 3, "type": "A"}
]"#;
let sample = DataSample::from_json(json).expect("unwrap");
let config = InferenceConfig::default();
let learner = SchemaLearner::new(config);
let cardinality = learner.estimate_cardinality(&sample, "Number");
assert!(cardinality > 0);
}
#[test]
fn test_field_values_extraction() {
let json = r#"[
{"name": "Alice", "age": 30},
{"name": "Bob", "age": 25}
]"#;
let sample = DataSample::from_json(json).expect("unwrap");
let names = sample.field_values("name");
assert_eq!(names.len(), 2);
}
#[test]
fn test_relationship_detection() {
let json = r#"[
{"person": "Alice", "city": "NYC"},
{"person": "Bob", "city": "LA"}
]"#;
let sample = DataSample::from_json(json).expect("unwrap");
let config = InferenceConfig::default();
let learner = SchemaLearner::new(config);
assert!(learner.has_relationship(&sample, "person", "city"));
}
#[test]
fn test_empty_sample() {
let json = "[]";
let sample = DataSample::from_json(json).expect("unwrap");
assert!(sample.is_empty());
assert_eq!(sample.len(), 0);
}
#[test]
fn test_single_object_json() {
let json = r#"{"id": 1, "name": "Alice"}"#;
let sample = DataSample::from_json(json).expect("unwrap");
assert_eq!(sample.len(), 1);
}
#[test]
fn test_csv_type_detection() {
let csv = "id,name,active\n1,Alice,true\n2,Bob,false";
let sample = DataSample::from_csv(csv).expect("unwrap");
let active_values = sample.field_values("active");
assert!(active_values.iter().all(|v| v.is_boolean()));
}
#[test]
fn test_confidence_scores_tracking() {
let json = r#"[{"id": 1, "name": "Alice"}]"#;
let sample = DataSample::from_json(json).expect("unwrap");
let config = InferenceConfig::default();
let mut learner = SchemaLearner::new(config);
learner.learn_from_sample(&sample).expect("unwrap");
assert!(!learner.all_confidences().is_empty());
}
#[test]
fn test_learning_statistics() {
let json = r#"[{"id": 1}, {"id": 2}, {"id": 3}]"#;
let sample = DataSample::from_json(json).expect("unwrap");
let config = InferenceConfig::default();
let mut learner = SchemaLearner::new(config);
learner.learn_from_sample(&sample).expect("unwrap");
let stats = learner.statistics();
assert_eq!(stats.total_samples_analyzed, 3);
assert!(stats.domains_inferred > 0 || stats.predicates_inferred > 0);
}
}