use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::path::PathBuf;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ChunkerType {
#[default]
Text,
Markdown,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PostProcessorConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default)]
pub enabled_processors: Option<Vec<String>>,
#[serde(default)]
pub disabled_processors: Option<Vec<String>>,
#[serde(skip)]
pub enabled_set: Option<HashSet<String>>,
#[serde(skip)]
pub disabled_set: Option<HashSet<String>>,
}
impl PostProcessorConfig {
pub fn build_lookup_sets(&mut self) {
if let Some(ref enabled) = self.enabled_processors {
self.enabled_set = Some(enabled.iter().cloned().collect());
}
if let Some(ref disabled) = self.disabled_processors {
self.disabled_set = Some(disabled.iter().cloned().collect());
}
}
}
impl Default for PostProcessorConfig {
fn default() -> Self {
Self {
enabled: true,
enabled_processors: None,
disabled_processors: None,
enabled_set: None,
disabled_set: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkingConfig {
#[serde(default = "default_chunk_size", rename = "max_chars", alias = "max_characters")]
pub max_characters: usize,
#[serde(default = "default_chunk_overlap", rename = "max_overlap", alias = "overlap")]
pub overlap: usize,
#[serde(default = "default_trim")]
pub trim: bool,
#[serde(default = "default_chunker_type")]
pub chunker_type: ChunkerType,
#[serde(skip_serializing_if = "Option::is_none")]
pub embedding: Option<EmbeddingConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub preset: Option<String>,
}
impl Default for ChunkingConfig {
fn default() -> Self {
Self {
max_characters: 1000,
overlap: 200,
trim: true,
chunker_type: ChunkerType::Text,
embedding: None,
preset: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
#[serde(default = "default_model")]
pub model: EmbeddingModelType,
#[serde(default = "default_normalize")]
pub normalize: bool,
#[serde(default = "default_batch_size")]
pub batch_size: usize,
#[serde(default)]
pub show_download_progress: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_dir: Option<PathBuf>,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model: EmbeddingModelType::Preset {
name: "balanced".to_string(),
},
normalize: true,
batch_size: 32,
show_download_progress: false,
cache_dir: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum EmbeddingModelType {
Preset { name: String },
#[cfg(feature = "embeddings")]
FastEmbed { model: String, dimensions: usize },
Custom { model_id: String, dimensions: usize },
}
fn default_true() -> bool {
true
}
fn default_chunk_size() -> usize {
1000
}
fn default_chunk_overlap() -> usize {
200
}
fn default_trim() -> bool {
true
}
fn default_chunker_type() -> ChunkerType {
ChunkerType::Text
}
fn default_normalize() -> bool {
true
}
fn default_batch_size() -> usize {
32
}
fn default_model() -> EmbeddingModelType {
EmbeddingModelType::Preset {
name: "balanced".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_postprocessor_config_default() {
let config = PostProcessorConfig::default();
assert!(config.enabled);
assert!(config.enabled_processors.is_none());
assert!(config.disabled_processors.is_none());
}
#[test]
fn test_postprocessor_config_build_lookup_sets() {
let mut config = PostProcessorConfig {
enabled: true,
enabled_processors: Some(vec!["a".to_string(), "b".to_string()]),
disabled_processors: Some(vec!["c".to_string()]),
enabled_set: None,
disabled_set: None,
};
config.build_lookup_sets();
assert!(config.enabled_set.is_some());
assert!(config.disabled_set.is_some());
assert!(config.enabled_set.unwrap().contains("a"));
assert!(config.disabled_set.unwrap().contains("c"));
}
#[test]
fn test_chunking_config_defaults() {
let config = ChunkingConfig {
max_characters: 1000,
overlap: 200,
trim: true,
chunker_type: ChunkerType::Text,
embedding: None,
preset: None,
};
assert_eq!(config.max_characters, 1000);
assert_eq!(config.overlap, 200);
assert!(config.trim);
assert_eq!(config.chunker_type, ChunkerType::Text);
}
#[test]
fn test_embedding_config_default() {
let config = EmbeddingConfig::default();
assert!(config.normalize);
assert_eq!(config.batch_size, 32);
assert!(config.cache_dir.is_none());
}
#[test]
fn test_embedding_model_type_preset_serialization() {
let model = EmbeddingModelType::Preset {
name: "fast".to_string(),
};
let json = serde_json::to_string(&model).unwrap();
assert!(json.contains(r#""type":"preset""#), "Should contain type:preset field");
assert!(json.contains(r#""name":"fast""#), "Should contain name:fast field");
assert!(
!json.contains(r#"{"preset":"#),
"Should NOT use adjacently-tagged format"
);
}
#[test]
fn test_embedding_model_type_preset_deserialization() {
let json = r#"{"type": "preset", "name": "fast"}"#;
let model: EmbeddingModelType = serde_json::from_str(json).unwrap();
match model {
EmbeddingModelType::Preset { name } => {
assert_eq!(name, "fast");
}
_ => panic!("Expected Preset variant"),
}
}
#[test]
fn test_embedding_model_type_rejects_wrong_format() {
let wrong_json = r#"{"preset": {"name": "fast"}}"#;
let result: Result<EmbeddingModelType, _> = serde_json::from_str(wrong_json);
assert!(result.is_err(), "Should reject adjacently-tagged format");
}
#[test]
fn test_embedding_config_roundtrip() {
let config = EmbeddingConfig {
model: EmbeddingModelType::Preset {
name: "balanced".to_string(),
},
normalize: true,
batch_size: 64,
show_download_progress: false,
cache_dir: None,
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: EmbeddingConfig = serde_json::from_str(&json).unwrap();
match deserialized.model {
EmbeddingModelType::Preset { name } => {
assert_eq!(name, "balanced");
}
_ => panic!("Expected Preset variant"),
}
assert!(deserialized.normalize);
assert_eq!(deserialized.batch_size, 64);
}
#[test]
fn test_embedding_model_type_custom_serialization() {
let model = EmbeddingModelType::Custom {
model_id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
dimensions: 384,
};
let json = serde_json::to_string(&model).unwrap();
assert!(json.contains(r#""type":"custom""#), "Should contain type:custom field");
assert!(json.contains(r#""model_id":"#), "Should contain model_id field");
assert!(json.contains(r#""dimensions":384"#), "Should contain dimensions field");
}
#[test]
fn test_embedding_model_type_custom_deserialization() {
let json = r#"{"type": "custom", "model_id": "test/model", "dimensions": 512}"#;
let model: EmbeddingModelType = serde_json::from_str(json).unwrap();
match model {
EmbeddingModelType::Custom { model_id, dimensions } => {
assert_eq!(model_id, "test/model");
assert_eq!(dimensions, 512);
}
_ => panic!("Expected Custom variant"),
}
}
}