use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InferenceConfig {
pub sample_size: usize,
pub min_field_frequency: f64,
pub detect_formats: bool,
pub max_depth: usize,
pub collect_examples: bool,
pub max_examples: usize,
pub assume_nullable: bool,
pub format_confidence_threshold: f64,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
sample_size: 0, min_field_frequency: 0.0, detect_formats: true,
max_depth: 10,
collect_examples: true,
max_examples: 5,
assume_nullable: false,
format_confidence_threshold: 0.9,
}
}
}
impl InferenceConfig {
pub fn new() -> Self {
Self::default()
}
pub fn builder() -> InferenceConfigBuilder {
InferenceConfigBuilder::default()
}
}
#[derive(Debug, Default)]
pub struct InferenceConfigBuilder {
config: InferenceConfig,
}
impl InferenceConfigBuilder {
pub fn sample_size(mut self, size: usize) -> Self {
self.config.sample_size = size;
self
}
pub fn min_field_frequency(mut self, freq: f64) -> Self {
self.config.min_field_frequency = freq.clamp(0.0, 1.0);
self
}
pub fn detect_formats(mut self, detect: bool) -> Self {
self.config.detect_formats = detect;
self
}
pub fn max_depth(mut self, depth: usize) -> Self {
self.config.max_depth = depth;
self
}
pub fn collect_examples(mut self, collect: bool) -> Self {
self.config.collect_examples = collect;
self
}
pub fn max_examples(mut self, max: usize) -> Self {
self.config.max_examples = max;
self
}
pub fn assume_nullable(mut self, nullable: bool) -> Self {
self.config.assume_nullable = nullable;
self
}
pub fn format_confidence_threshold(mut self, threshold: f64) -> Self {
self.config.format_confidence_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn build(self) -> InferenceConfig {
self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = InferenceConfig::default();
assert_eq!(config.sample_size, 0);
assert!(config.detect_formats);
assert_eq!(config.max_depth, 10);
}
#[test]
fn test_builder() {
let config = InferenceConfig::builder()
.sample_size(1000)
.min_field_frequency(0.5)
.detect_formats(false)
.max_depth(5)
.build();
assert_eq!(config.sample_size, 1000);
assert_eq!(config.min_field_frequency, 0.5);
assert!(!config.detect_formats);
assert_eq!(config.max_depth, 5);
}
#[test]
fn test_frequency_clamping() {
let config = InferenceConfig::builder()
.min_field_frequency(1.5) .build();
assert_eq!(config.min_field_frequency, 1.0);
}
}