use ahash::AHashSet;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum ChunkerType {
#[default]
Text,
Markdown,
Yaml,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ChunkSizing {
#[default]
Characters,
#[cfg(feature = "chunking-tokenizers")]
Tokenizer {
model: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
cache_dir: Option<std::path::PathBuf>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PostProcessorConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default)]
pub enabled_processors: Option<Vec<String>>,
#[serde(default)]
pub disabled_processors: Option<Vec<String>>,
#[serde(skip)]
pub enabled_set: Option<AHashSet<String>>,
#[serde(skip)]
pub disabled_set: Option<AHashSet<String>>,
}
impl PostProcessorConfig {
pub fn build_lookup_sets(&mut self) {
if let Some(ref enabled) = self.enabled_processors {
self.enabled_set = Some(enabled.iter().cloned().collect());
}
if let Some(ref disabled) = self.disabled_processors {
self.disabled_set = Some(disabled.iter().cloned().collect());
}
}
}
impl Default for PostProcessorConfig {
fn default() -> Self {
Self {
enabled: true,
enabled_processors: None,
disabled_processors: None,
enabled_set: None,
disabled_set: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkingConfig {
#[serde(default = "default_chunk_size", rename = "max_chars", alias = "max_characters")]
pub max_characters: usize,
#[serde(default = "default_chunk_overlap", rename = "max_overlap", alias = "overlap")]
pub overlap: usize,
#[serde(default = "default_trim")]
pub trim: bool,
#[serde(default = "default_chunker_type")]
pub chunker_type: ChunkerType,
#[serde(skip_serializing_if = "Option::is_none")]
pub embedding: Option<EmbeddingConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub preset: Option<String>,
#[serde(default)]
pub sizing: ChunkSizing,
#[serde(default)]
pub prepend_heading_context: bool,
}
impl ChunkingConfig {
pub fn new(max_characters: usize, overlap: usize, trim: bool) -> Self {
Self {
max_characters,
overlap,
trim,
chunker_type: ChunkerType::Text,
embedding: None,
preset: None,
sizing: ChunkSizing::default(),
prepend_heading_context: false,
}
}
pub fn with_chunker_type(mut self, chunker_type: ChunkerType) -> Self {
self.chunker_type = chunker_type;
self
}
pub fn with_sizing(mut self, sizing: ChunkSizing) -> Self {
self.sizing = sizing;
self
}
pub fn with_prepend_heading_context(mut self, prepend: bool) -> Self {
self.prepend_heading_context = prepend;
self
}
#[cfg(feature = "embeddings")]
pub fn resolve_preset(&self) -> Self {
let preset_name = match &self.preset {
Some(name) => name,
None => return self.clone(),
};
let preset = match crate::embeddings::get_preset(preset_name) {
Some(p) => p,
None => {
tracing::warn!(
"Unknown chunking preset '{}', using manual config. Available: {:?}",
preset_name,
crate::embeddings::list_presets()
);
return self.clone();
}
};
let embedding = match &self.embedding {
Some(existing) => Some(existing.clone()),
None => Some(EmbeddingConfig {
model: EmbeddingModelType::Preset {
name: preset_name.clone(),
},
..EmbeddingConfig::default()
}),
};
Self {
max_characters: preset.chunk_size,
overlap: preset.overlap,
embedding,
trim: self.trim,
chunker_type: self.chunker_type,
preset: self.preset.clone(),
sizing: self.sizing.clone(),
prepend_heading_context: self.prepend_heading_context,
}
}
#[cfg(not(feature = "embeddings"))]
pub fn resolve_preset(&self) -> Self {
if self.preset.is_some() {
tracing::warn!("Chunking presets require the 'embeddings' feature");
}
self.clone()
}
}
impl Default for ChunkingConfig {
fn default() -> Self {
Self {
max_characters: 1000,
overlap: 200,
trim: true,
chunker_type: ChunkerType::Text,
embedding: None,
preset: None,
sizing: ChunkSizing::default(),
prepend_heading_context: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
#[serde(default = "default_model")]
pub model: EmbeddingModelType,
#[serde(default = "default_normalize")]
pub normalize: bool,
#[serde(default = "default_batch_size")]
pub batch_size: usize,
#[serde(default)]
pub show_download_progress: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub cache_dir: Option<PathBuf>,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model: EmbeddingModelType::Preset {
name: "balanced".to_string(),
},
normalize: true,
batch_size: 32,
show_download_progress: false,
cache_dir: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum EmbeddingModelType {
Preset { name: String },
Custom { model_id: String, dimensions: usize },
Llm { llm: super::llm::LlmConfig },
}
fn default_true() -> bool {
true
}
fn default_chunk_size() -> usize {
1000
}
fn default_chunk_overlap() -> usize {
200
}
fn default_trim() -> bool {
true
}
fn default_chunker_type() -> ChunkerType {
ChunkerType::Text
}
fn default_normalize() -> bool {
true
}
fn default_batch_size() -> usize {
32
}
fn default_model() -> EmbeddingModelType {
EmbeddingModelType::Preset {
name: "balanced".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_postprocessor_config_default() {
let config = PostProcessorConfig::default();
assert!(config.enabled);
assert!(config.enabled_processors.is_none());
assert!(config.disabled_processors.is_none());
}
#[test]
fn test_postprocessor_config_build_lookup_sets() {
let mut config = PostProcessorConfig {
enabled: true,
enabled_processors: Some(vec!["a".to_string(), "b".to_string()]),
disabled_processors: Some(vec!["c".to_string()]),
enabled_set: None,
disabled_set: None,
};
config.build_lookup_sets();
assert!(config.enabled_set.is_some());
assert!(config.disabled_set.is_some());
assert!(config.enabled_set.unwrap().contains("a"));
assert!(config.disabled_set.unwrap().contains("c"));
}
#[test]
fn test_chunking_config_defaults() {
let config = ChunkingConfig::default();
assert_eq!(config.max_characters, 1000);
assert_eq!(config.overlap, 200);
assert!(config.trim);
assert_eq!(config.chunker_type, ChunkerType::Text);
assert!(matches!(config.sizing, ChunkSizing::Characters));
}
#[test]
fn test_embedding_config_default() {
let config = EmbeddingConfig::default();
assert!(config.normalize);
assert_eq!(config.batch_size, 32);
assert!(config.cache_dir.is_none());
}
#[test]
fn test_embedding_model_type_preset_serialization() {
let model = EmbeddingModelType::Preset {
name: "fast".to_string(),
};
let json = serde_json::to_string(&model).unwrap();
assert!(json.contains(r#""type":"preset""#), "Should contain type:preset field");
assert!(json.contains(r#""name":"fast""#), "Should contain name:fast field");
assert!(
!json.contains(r#"{"preset":"#),
"Should NOT use adjacently-tagged format"
);
}
#[test]
fn test_embedding_model_type_preset_deserialization() {
let json = r#"{"type": "preset", "name": "fast"}"#;
let model: EmbeddingModelType = serde_json::from_str(json).unwrap();
match model {
EmbeddingModelType::Preset { name } => {
assert_eq!(name, "fast");
}
_ => panic!("Expected Preset variant"),
}
}
#[test]
fn test_embedding_model_type_rejects_wrong_format() {
let wrong_json = r#"{"preset": {"name": "fast"}}"#;
let result: Result<EmbeddingModelType, _> = serde_json::from_str(wrong_json);
assert!(result.is_err(), "Should reject adjacently-tagged format");
}
#[test]
fn test_embedding_config_roundtrip() {
let config = EmbeddingConfig {
model: EmbeddingModelType::Preset {
name: "balanced".to_string(),
},
normalize: true,
batch_size: 64,
show_download_progress: false,
cache_dir: None,
};
let json = serde_json::to_string(&config).unwrap();
let deserialized: EmbeddingConfig = serde_json::from_str(&json).unwrap();
match deserialized.model {
EmbeddingModelType::Preset { name } => {
assert_eq!(name, "balanced");
}
_ => panic!("Expected Preset variant"),
}
assert!(deserialized.normalize);
assert_eq!(deserialized.batch_size, 64);
}
#[test]
fn test_embedding_model_type_custom_serialization() {
let model = EmbeddingModelType::Custom {
model_id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
dimensions: 384,
};
let json = serde_json::to_string(&model).unwrap();
assert!(json.contains(r#""type":"custom""#), "Should contain type:custom field");
assert!(json.contains(r#""model_id":"#), "Should contain model_id field");
assert!(json.contains(r#""dimensions":384"#), "Should contain dimensions field");
}
#[test]
#[cfg(feature = "embeddings")]
fn test_resolve_preset_balanced() {
let config = ChunkingConfig {
preset: Some("balanced".to_string()),
..Default::default()
};
let resolved = config.resolve_preset();
assert_eq!(resolved.max_characters, 1024);
assert_eq!(resolved.overlap, 100);
assert!(resolved.embedding.is_some());
match &resolved.embedding.unwrap().model {
EmbeddingModelType::Preset { name } => assert_eq!(name, "balanced"),
_ => panic!("Expected Preset model type"),
}
}
#[test]
#[cfg(feature = "embeddings")]
fn test_resolve_preset_preserves_explicit_embedding() {
let explicit_embedding = EmbeddingConfig {
model: EmbeddingModelType::Custom {
model_id: "custom/model".to_string(),
dimensions: 512,
},
batch_size: 64,
..Default::default()
};
let config = ChunkingConfig {
preset: Some("fast".to_string()),
embedding: Some(explicit_embedding),
..Default::default()
};
let resolved = config.resolve_preset();
assert_eq!(resolved.max_characters, 512);
assert_eq!(resolved.overlap, 50);
match &resolved.embedding.unwrap().model {
EmbeddingModelType::Custom { model_id, .. } => assert_eq!(model_id, "custom/model"),
_ => panic!("Expected Custom model type to be preserved"),
}
}
#[test]
fn test_resolve_preset_no_preset_returns_unchanged() {
let config = ChunkingConfig {
max_characters: 500,
overlap: 50,
..Default::default()
};
let resolved = config.resolve_preset();
assert_eq!(resolved.max_characters, 500);
assert_eq!(resolved.overlap, 50);
assert!(resolved.embedding.is_none());
}
#[test]
fn test_resolve_preset_unknown_name_returns_unchanged() {
let config = ChunkingConfig {
max_characters: 500,
preset: Some("nonexistent".to_string()),
..Default::default()
};
let resolved = config.resolve_preset();
assert_eq!(resolved.max_characters, 500);
}
#[test]
fn test_embedding_model_type_llm_roundtrip() {
let model_type = EmbeddingModelType::Llm {
llm: crate::core::config::llm::LlmConfig {
model: "openai/text-embedding-3-small".to_string(),
api_key: None,
base_url: None,
timeout_secs: None,
max_retries: None,
temperature: None,
max_tokens: None,
},
};
let json = serde_json::to_string(&model_type).unwrap();
assert!(json.contains("\"type\":\"llm\""));
assert!(json.contains("openai/text-embedding-3-small"));
let deserialized: EmbeddingModelType = serde_json::from_str(&json).unwrap();
match deserialized {
EmbeddingModelType::Llm { llm } => {
assert_eq!(llm.model, "openai/text-embedding-3-small");
}
_ => panic!("Expected Llm variant"),
}
}
#[test]
fn test_embedding_model_type_custom_deserialization() {
let json = r#"{"type": "custom", "model_id": "test/model", "dimensions": 512}"#;
let model: EmbeddingModelType = serde_json::from_str(json).unwrap();
match model {
EmbeddingModelType::Custom { model_id, dimensions } => {
assert_eq!(model_id, "test/model");
assert_eq!(dimensions, 512);
}
_ => panic!("Expected Custom variant"),
}
}
}