use std::sync::Arc;
use cognee_chunking::TokenCounterKind;
use cognee_embedding::engine::EmbeddingEngine;
use cognee_llm::{Llm, Transcriber};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CognifyConfig {
pub max_chunk_size: usize,
pub chunk_overlap: usize,
pub chunk_strategy: ChunkStrategy,
pub chunks_per_batch: usize,
pub max_parallel_extractions: usize,
pub custom_extraction_prompt: Option<String>,
pub enable_summarization: bool,
pub summarization_batch_size: usize,
pub embed_triplets: bool,
pub embedding_batch_size: usize,
pub vector_collection_prefix: String,
pub incremental_loading: bool,
pub use_pipeline_cache: bool,
pub temporal_cognify: bool,
pub create_web_page_nodes: bool,
pub data_per_batch: usize,
pub token_counter_kind: TokenCounterKind,
#[serde(skip)]
pub graph_schema: Option<serde_json::Value>,
#[serde(skip)]
pub summary_schema: Option<serde_json::Value>,
#[serde(skip)]
pub custom_chunker: Option<CustomChunker>,
#[serde(skip)]
pub transcriber: Option<TranscriberHandle>,
}
#[derive(Clone)]
#[allow(clippy::type_complexity)]
pub struct CustomChunker(pub Arc<dyn Fn(&str, usize) -> Vec<String> + Send + Sync>);
impl std::fmt::Debug for CustomChunker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("CustomChunker(…)")
}
}
#[derive(Clone)]
pub struct TranscriberHandle(pub Arc<dyn Transcriber>);
impl std::fmt::Debug for TranscriberHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("TranscriberHandle(…)")
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum ChunkStrategy {
Paragraph,
Recursive,
}
impl Default for CognifyConfig {
fn default() -> Self {
Self {
max_chunk_size: 1500,
chunk_overlap: 10,
chunk_strategy: ChunkStrategy::Paragraph,
chunks_per_batch: 100,
max_parallel_extractions: 20,
custom_extraction_prompt: None,
enable_summarization: true,
summarization_batch_size: 50,
embed_triplets: false,
embedding_batch_size: 100,
vector_collection_prefix: String::new(),
incremental_loading: true,
use_pipeline_cache: false,
temporal_cognify: false,
create_web_page_nodes: true,
data_per_batch: 20,
token_counter_kind: TokenCounterKind::from_env(),
graph_schema: None,
summary_schema: None,
custom_chunker: None,
transcriber: None,
}
}
}
impl CognifyConfig {
pub fn with_chunk_size(mut self, size: usize) -> Self {
self.max_chunk_size = size;
self
}
pub fn with_chunk_overlap(mut self, overlap: usize) -> Self {
self.chunk_overlap = overlap;
self
}
pub fn with_chunk_strategy(mut self, strategy: ChunkStrategy) -> Self {
self.chunk_strategy = strategy;
self
}
pub fn with_chunks_per_batch(mut self, batch_size: usize) -> Self {
self.chunks_per_batch = batch_size;
self
}
pub fn with_max_parallel_extractions(mut self, limit: usize) -> Self {
self.max_parallel_extractions = limit;
self
}
pub fn with_custom_prompt(mut self, prompt: String) -> Self {
self.custom_extraction_prompt = Some(prompt);
self
}
pub fn with_summarization(mut self, enable: bool) -> Self {
self.enable_summarization = enable;
self
}
pub fn with_summarization_batch_size(mut self, batch_size: usize) -> Self {
self.summarization_batch_size = batch_size;
self
}
pub fn with_triplet_embeddings(mut self, enable: bool) -> Self {
self.embed_triplets = enable;
self
}
pub fn with_embedding_batch_size(mut self, batch_size: usize) -> Self {
self.embedding_batch_size = batch_size;
self
}
pub fn with_collection_prefix(mut self, prefix: String) -> Self {
self.vector_collection_prefix = prefix;
self
}
pub fn with_incremental_loading(mut self, enable: bool) -> Self {
self.incremental_loading = enable;
self
}
pub fn with_pipeline_cache(mut self, enable: bool) -> Self {
self.use_pipeline_cache = enable;
self
}
pub fn with_temporal_cognify(mut self, enable: bool) -> Self {
self.temporal_cognify = enable;
self
}
pub fn with_web_page_nodes(mut self, enable: bool) -> Self {
self.create_web_page_nodes = enable;
self
}
pub fn with_data_per_batch(mut self, batch_size: usize) -> Self {
self.data_per_batch = batch_size;
self
}
pub fn with_token_counter(mut self, kind: TokenCounterKind) -> Self {
self.token_counter_kind = kind;
self
}
pub fn with_graph_schema(mut self, schema: serde_json::Value) -> Self {
self.graph_schema = Some(schema);
self
}
pub fn with_summary_schema(mut self, schema: serde_json::Value) -> Result<Self, ConfigError> {
validate_summary_schema(&schema)?;
self.summary_schema = Some(schema);
Ok(self)
}
#[allow(clippy::type_complexity)]
pub fn with_custom_chunker(
mut self,
chunker: Arc<dyn Fn(&str, usize) -> Vec<String> + Send + Sync>,
) -> Self {
self.custom_chunker = Some(CustomChunker(chunker));
self
}
pub fn with_transcriber(mut self, transcriber: Arc<dyn Transcriber>) -> Self {
self.transcriber = Some(TranscriberHandle(transcriber));
self
}
pub fn auto_chunk_size(embedding_engine: &dyn EmbeddingEngine, _llm: &dyn Llm) -> usize {
const PY_LLM_MAX_COMPLETION_TOKENS: usize = 16_384;
let llm_cutoff = PY_LLM_MAX_COMPLETION_TOKENS / 2; let embed_max = embedding_engine.max_sequence_length();
llm_cutoff.min(embed_max).max(1)
}
pub fn with_auto_chunk_size(
mut self,
embedding_engine: &dyn EmbeddingEngine,
llm: &dyn Llm,
) -> Self {
self.max_chunk_size = Self::auto_chunk_size(embedding_engine, llm);
self
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.max_chunk_size == 0 {
return Err(ConfigError::InvalidParameter(
"max_chunk_size must be greater than 0".to_string(),
));
}
if self.chunk_overlap >= self.max_chunk_size {
return Err(ConfigError::InvalidParameter(
"chunk_overlap must be less than max_chunk_size".to_string(),
));
}
if self.chunks_per_batch == 0 {
return Err(ConfigError::InvalidParameter(
"chunks_per_batch must be greater than 0".to_string(),
));
}
if self.max_parallel_extractions == 0 {
return Err(ConfigError::InvalidParameter(
"max_parallel_extractions must be greater than 0".to_string(),
));
}
if self.embedding_batch_size == 0 {
return Err(ConfigError::InvalidParameter(
"embedding_batch_size must be greater than 0".to_string(),
));
}
if self.summarization_batch_size == 0 {
return Err(ConfigError::InvalidParameter(
"summarization_batch_size must be greater than 0".to_string(),
));
}
if self.data_per_batch == 0 {
return Err(ConfigError::InvalidParameter(
"data_per_batch must be greater than 0".to_string(),
));
}
Ok(())
}
}
#[derive(Error, Debug)]
pub enum ConfigError {
#[error("Invalid configuration parameter: {0}")]
InvalidParameter(String),
#[error("Invalid summary schema: {0}")]
InvalidSummarySchema(String),
}
pub fn validate_summary_schema(schema: &serde_json::Value) -> Result<(), ConfigError> {
let obj = schema.as_object().ok_or_else(|| {
ConfigError::InvalidSummarySchema("schema must be a JSON object".to_string())
})?;
let props = obj
.get("properties")
.and_then(|p| p.as_object())
.ok_or_else(|| {
ConfigError::InvalidSummarySchema("schema must have a 'properties' object".to_string())
})?;
let summary_prop = props.get("summary").ok_or_else(|| {
ConfigError::InvalidSummarySchema(
"schema 'properties' must include a 'summary' field".to_string(),
)
})?;
if let Some(type_val) = summary_prop.get("type")
&& type_val.as_str() != Some("string")
{
return Err(ConfigError::InvalidSummarySchema(
"'summary' field must be of type 'string'".to_string(),
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use cognee_embedding::error::EmbeddingResult;
use cognee_llm::types::GenerationOptions;
struct MockEmbedding {
max_seq: usize,
}
#[async_trait]
impl EmbeddingEngine for MockEmbedding {
async fn embed(&self, _texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
Ok(vec![])
}
fn dimension(&self) -> usize {
384
}
fn batch_size(&self) -> usize {
32
}
fn max_sequence_length(&self) -> usize {
self.max_seq
}
}
struct MockLlm {
max_ctx: u32,
}
#[async_trait]
impl Llm for MockLlm {
async fn generate(
&self,
_messages: Vec<cognee_llm::Message>,
_options: Option<GenerationOptions>,
) -> cognee_llm::LlmResult<cognee_llm::GenerationResponse> {
unimplemented!()
}
async fn create_structured_output_with_messages_raw(
&self,
_messages: Vec<cognee_llm::Message>,
_json_schema: &serde_json::Value,
_options: Option<GenerationOptions>,
) -> cognee_llm::LlmResult<serde_json::Value> {
unimplemented!()
}
fn model(&self) -> &str {
"mock"
}
fn max_context_length(&self) -> u32 {
self.max_ctx
}
}
#[test]
fn test_default_config() {
let config = CognifyConfig::default();
assert_eq!(config.max_chunk_size, 1500);
assert_eq!(config.chunk_overlap, 10);
assert_eq!(config.chunk_strategy, ChunkStrategy::Paragraph);
assert_eq!(config.chunks_per_batch, 100);
assert_eq!(config.max_parallel_extractions, 20);
assert!(config.custom_extraction_prompt.is_none());
assert!(config.enable_summarization);
assert_eq!(config.summarization_batch_size, 50);
assert!(!config.embed_triplets);
assert_eq!(config.embedding_batch_size, 100);
assert_eq!(config.vector_collection_prefix, "");
assert!(config.incremental_loading);
assert!(!config.use_pipeline_cache);
assert!(!config.temporal_cognify);
assert_eq!(config.data_per_batch, 20);
}
#[test]
fn test_config_builder_chunking() {
let config = CognifyConfig::default()
.with_chunk_size(2000)
.with_chunk_overlap(50)
.with_chunk_strategy(ChunkStrategy::Recursive);
assert_eq!(config.max_chunk_size, 2000);
assert_eq!(config.chunk_overlap, 50);
assert_eq!(config.chunk_strategy, ChunkStrategy::Recursive);
}
#[test]
fn test_config_builder_graph_extraction() {
let config = CognifyConfig::default()
.with_chunks_per_batch(50)
.with_max_parallel_extractions(25)
.with_custom_prompt("Extract entities:".to_string());
assert_eq!(config.chunks_per_batch, 50);
assert_eq!(config.max_parallel_extractions, 25);
assert_eq!(
config.custom_extraction_prompt,
Some("Extract entities:".to_string())
);
}
#[test]
fn test_config_builder_all_features() {
let config = CognifyConfig::default()
.with_chunk_size(2000)
.with_triplet_embeddings(true)
.with_incremental_loading(false)
.with_summarization(false)
.with_temporal_cognify(true);
assert_eq!(config.max_chunk_size, 2000);
assert!(config.embed_triplets);
assert!(!config.incremental_loading);
assert!(!config.enable_summarization);
assert!(config.temporal_cognify);
}
#[test]
fn test_config_validation_success() {
let config = CognifyConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_config_validation_zero_chunk_size() {
let config = CognifyConfig {
max_chunk_size: 0,
..Default::default()
};
assert!(matches!(
config.validate(),
Err(ConfigError::InvalidParameter(_))
));
}
#[test]
fn test_config_validation_overlap_too_large() {
let config = CognifyConfig {
max_chunk_size: 100,
chunk_overlap: 100,
..Default::default()
};
assert!(matches!(
config.validate(),
Err(ConfigError::InvalidParameter(_))
));
}
#[test]
fn test_config_validation_zero_batch_sizes() {
let config1 = CognifyConfig {
chunks_per_batch: 0,
..Default::default()
};
assert!(config1.validate().is_err());
let config2 = CognifyConfig {
embedding_batch_size: 0,
..Default::default()
};
assert!(config2.validate().is_err());
let config3 = CognifyConfig {
summarization_batch_size: 0,
..Default::default()
};
assert!(config3.validate().is_err());
}
#[test]
fn auto_chunk_size_matches_python_default() {
let embed = MockEmbedding { max_seq: 512 };
let llm = MockLlm { max_ctx: 4096 };
assert_eq!(CognifyConfig::auto_chunk_size(&embed, &llm), 512);
}
#[test]
fn test_auto_chunk_size_embed_is_smaller() {
let embed = MockEmbedding { max_seq: 512 };
let llm = MockLlm { max_ctx: 4096 };
assert_eq!(CognifyConfig::auto_chunk_size(&embed, &llm), 512);
}
#[test]
fn test_auto_chunk_size_llm_cutoff_unused() {
let embed = MockEmbedding { max_seq: 512 };
let llm = MockLlm { max_ctx: 256 }; assert_eq!(CognifyConfig::auto_chunk_size(&embed, &llm), 512);
}
#[test]
fn test_auto_chunk_size_large_embedding() {
let embed = MockEmbedding { max_seq: 10_000 };
let llm = MockLlm { max_ctx: 4096 };
assert_eq!(CognifyConfig::auto_chunk_size(&embed, &llm), 8192);
}
#[test]
fn test_auto_chunk_size_equal_values() {
let embed = MockEmbedding { max_seq: 1024 };
let llm = MockLlm { max_ctx: 2048 };
assert_eq!(CognifyConfig::auto_chunk_size(&embed, &llm), 1024);
}
#[test]
fn test_auto_chunk_size_floor_at_one() {
let embed = MockEmbedding { max_seq: 0 };
let llm = MockLlm { max_ctx: 0 };
assert_eq!(CognifyConfig::auto_chunk_size(&embed, &llm), 1);
}
#[test]
fn test_auto_chunk_size_embed_exactly_at_llm_cutoff() {
let embed = MockEmbedding { max_seq: 8192 };
let llm = MockLlm { max_ctx: 4096 };
assert_eq!(CognifyConfig::auto_chunk_size(&embed, &llm), 8192);
}
#[test]
fn test_with_auto_chunk_size_builder() {
let embed = MockEmbedding { max_seq: 512 };
let llm = MockLlm { max_ctx: 4096 };
let config = CognifyConfig::default().with_auto_chunk_size(&embed, &llm);
assert_eq!(config.max_chunk_size, 512);
assert_eq!(config.chunk_overlap, 10);
assert_eq!(config.chunks_per_batch, 100);
}
}