use crate::config::Config;
use crate::core::Result;
use std::marker::PhantomData;
pub struct NoOutput;
pub struct HasOutput;
pub struct NoLlm;
pub struct HasLlm;
#[derive(Debug)]
pub struct TypedBuilder<Output = NoOutput, Llm = NoLlm> {
config: Config,
_output: PhantomData<Output>,
_llm: PhantomData<Llm>,
}
impl TypedBuilder<NoOutput, NoLlm> {
pub fn new() -> Self {
Self {
config: Config::default(),
_output: PhantomData,
_llm: PhantomData,
}
}
}
impl Default for TypedBuilder<NoOutput, NoLlm> {
fn default() -> Self {
Self::new()
}
}
impl<Llm> TypedBuilder<NoOutput, Llm> {
pub fn with_output_dir(mut self, dir: &str) -> TypedBuilder<HasOutput, Llm> {
self.config.output_dir = dir.to_string();
TypedBuilder {
config: self.config,
_output: PhantomData,
_llm: PhantomData,
}
}
}
impl<Output> TypedBuilder<Output, NoLlm> {
pub fn with_ollama(mut self) -> TypedBuilder<Output, HasLlm> {
self.config.ollama.enabled = true;
self.config.ollama.host = "localhost".to_string();
self.config.ollama.port = 11434;
self.config.embeddings.backend = "ollama".to_string();
TypedBuilder {
config: self.config,
_output: PhantomData,
_llm: PhantomData,
}
}
pub fn with_ollama_custom(
mut self,
host: &str,
port: u16,
chat_model: &str,
) -> TypedBuilder<Output, HasLlm> {
self.config.ollama.enabled = true;
self.config.ollama.host = host.to_string();
self.config.ollama.port = port;
self.config.ollama.chat_model = chat_model.to_string();
self.config.embeddings.backend = "ollama".to_string();
TypedBuilder {
config: self.config,
_output: PhantomData,
_llm: PhantomData,
}
}
pub fn with_hash_embeddings(mut self) -> TypedBuilder<Output, HasLlm> {
self.config.ollama.enabled = false;
self.config.embeddings.backend = "hash".to_string();
self.config.approach = "algorithmic".to_string();
TypedBuilder {
config: self.config,
_output: PhantomData,
_llm: PhantomData,
}
}
pub fn with_candle_embeddings(mut self) -> TypedBuilder<Output, HasLlm> {
self.config.embeddings.backend = "candle".to_string();
TypedBuilder {
config: self.config,
_output: PhantomData,
_llm: PhantomData,
}
}
}
impl<Output, Llm> TypedBuilder<Output, Llm> {
pub fn with_chunk_size(mut self, size: usize) -> Self {
self.config.chunk_size = size;
self.config.text.chunk_size = size;
self
}
pub fn with_chunk_overlap(mut self, overlap: usize) -> Self {
self.config.chunk_overlap = overlap;
self.config.text.chunk_overlap = overlap;
self
}
pub fn with_top_k(mut self, k: usize) -> Self {
self.config.top_k_results = Some(k);
self.config.retrieval.top_k = k;
self
}
pub fn with_similarity_threshold(mut self, threshold: f32) -> Self {
self.config.similarity_threshold = Some(threshold);
self.config.graph.similarity_threshold = threshold;
self
}
pub fn with_approach(mut self, approach: &str) -> Self {
self.config.approach = approach.to_string();
self
}
pub fn with_parallel(mut self, enabled: bool) -> Self {
self.config.parallel.enabled = enabled;
self
}
pub fn with_gleaning(mut self, max_rounds: usize) -> Self {
self.config.entities.use_gleaning = true;
self.config.entities.max_gleaning_rounds = max_rounds;
self
}
pub fn config(&self) -> &Config {
&self.config
}
}
impl TypedBuilder<HasOutput, HasLlm> {
pub fn build(self) -> Result<crate::GraphRAG> {
crate::GraphRAG::new(self.config)
}
pub fn build_and_init(self) -> Result<crate::GraphRAG> {
let mut graphrag = crate::GraphRAG::new(self.config)?;
graphrag.initialize()?;
Ok(graphrag)
}
}
#[derive(Debug, Clone)]
pub struct GraphRAGBuilder {
config: Config,
}
impl Default for GraphRAGBuilder {
fn default() -> Self {
Self::new()
}
}
impl GraphRAGBuilder {
pub fn new() -> Self {
Self {
config: Config::default(),
}
}
pub fn with_output_dir(mut self, dir: &str) -> Self {
self.config.output_dir = dir.to_string();
self
}
pub fn with_chunk_size(mut self, size: usize) -> Self {
self.config.chunk_size = size;
self.config.text.chunk_size = size;
self
}
pub fn with_chunk_overlap(mut self, overlap: usize) -> Self {
self.config.chunk_overlap = overlap;
self.config.text.chunk_overlap = overlap;
self
}
pub fn with_embedding_dimension(mut self, dimension: usize) -> Self {
self.config.embeddings.dimension = dimension;
self
}
pub fn with_embedding_model(mut self, model: &str) -> Self {
self.config.embeddings.model = Some(model.to_string());
self
}
pub fn with_embedding_backend(mut self, backend: &str) -> Self {
self.config.embeddings.backend = backend.to_string();
self
}
pub fn with_ollama_host(mut self, host: &str) -> Self {
self.config.ollama.host = host.to_string();
self
}
pub fn with_ollama_port(mut self, port: u16) -> Self {
self.config.ollama.port = port;
self
}
pub fn with_ollama_enabled(mut self, enabled: bool) -> Self {
self.config.ollama.enabled = enabled;
self
}
pub fn with_chat_model(mut self, model: &str) -> Self {
self.config.ollama.chat_model = model.to_string();
self
}
pub fn with_ollama_embedding_model(mut self, model: &str) -> Self {
self.config.ollama.embedding_model = model.to_string();
self
}
pub fn with_top_k(mut self, k: usize) -> Self {
self.config.top_k_results = Some(k);
self.config.retrieval.top_k = k;
self
}
pub fn with_similarity_threshold(mut self, threshold: f32) -> Self {
self.config.similarity_threshold = Some(threshold);
self.config.graph.similarity_threshold = threshold;
self
}
pub fn with_approach(mut self, approach: &str) -> Self {
self.config.approach = approach.to_string();
self
}
pub fn with_parallel_processing(mut self, enabled: bool) -> Self {
self.config.parallel.enabled = enabled;
self
}
pub fn with_num_threads(mut self, num_threads: usize) -> Self {
self.config.parallel.num_threads = num_threads;
self
}
pub fn with_auto_save(mut self, enabled: bool, interval_seconds: u64) -> Self {
self.config.auto_save.enabled = enabled;
self.config.auto_save.interval_seconds = interval_seconds;
self
}
pub fn with_auto_save_workspace(mut self, name: &str) -> Self {
self.config.auto_save.workspace_name = Some(name.to_string());
self
}
pub fn with_local_defaults(mut self) -> Self {
self.config.ollama.enabled = true;
self.config.ollama.host = "localhost".to_string();
self.config.ollama.port = 11434;
self.config.embeddings.backend = "candle".to_string();
self
}
pub fn build(self) -> Result<crate::GraphRAG> {
crate::GraphRAG::new(self.config)
}
pub fn config(&self) -> &Config {
&self.config
}
pub fn config_mut(&mut self) -> &mut Config {
&mut self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_default() {
let builder = GraphRAGBuilder::new();
assert_eq!(builder.config().output_dir, "./output");
}
#[test]
fn test_builder_with_output_dir() {
let builder = GraphRAGBuilder::new().with_output_dir("./custom");
assert_eq!(builder.config().output_dir, "./custom");
}
#[test]
fn test_builder_with_chunk_size() {
let builder = GraphRAGBuilder::new().with_chunk_size(512);
assert_eq!(builder.config().chunk_size, 512);
assert_eq!(builder.config().text.chunk_size, 512);
}
#[test]
fn test_builder_with_embedding_config() {
let builder = GraphRAGBuilder::new()
.with_embedding_dimension(384)
.with_embedding_model("test-model");
assert_eq!(builder.config().embeddings.dimension, 384);
assert_eq!(
builder.config().embeddings.model,
Some("test-model".to_string())
);
}
#[test]
fn test_builder_with_ollama() {
let builder = GraphRAGBuilder::new()
.with_ollama_enabled(true)
.with_ollama_host("custom-host")
.with_ollama_port(8080)
.with_chat_model("custom-model");
assert!(builder.config().ollama.enabled);
assert_eq!(builder.config().ollama.host, "custom-host");
assert_eq!(builder.config().ollama.port, 8080);
assert_eq!(builder.config().ollama.chat_model, "custom-model");
}
#[test]
fn test_builder_with_retrieval() {
let builder = GraphRAGBuilder::new()
.with_top_k(20)
.with_similarity_threshold(0.8);
assert_eq!(builder.config().top_k_results, Some(20));
assert_eq!(builder.config().retrieval.top_k, 20);
assert_eq!(builder.config().similarity_threshold, Some(0.8));
assert_eq!(builder.config().graph.similarity_threshold, 0.8);
}
#[test]
fn test_builder_with_parallel() {
let builder = GraphRAGBuilder::new()
.with_parallel_processing(false)
.with_num_threads(8);
assert!(!builder.config().parallel.enabled);
assert_eq!(builder.config().parallel.num_threads, 8);
}
#[test]
fn test_builder_with_auto_save() {
let builder = GraphRAGBuilder::new()
.with_auto_save(true, 600)
.with_auto_save_workspace("test");
assert!(builder.config().auto_save.enabled);
assert_eq!(builder.config().auto_save.interval_seconds, 600);
assert_eq!(
builder.config().auto_save.workspace_name,
Some("test".to_string())
);
}
#[test]
fn test_builder_local_defaults() {
let builder = GraphRAGBuilder::new().with_local_defaults();
assert!(builder.config().ollama.enabled);
assert_eq!(builder.config().ollama.host, "localhost");
assert_eq!(builder.config().ollama.port, 11434);
assert_eq!(builder.config().embeddings.backend, "candle");
}
#[test]
fn test_builder_fluent_api() {
let builder = GraphRAGBuilder::new()
.with_output_dir("./test")
.with_chunk_size(256)
.with_chunk_overlap(32)
.with_top_k(15)
.with_approach("hybrid");
assert_eq!(builder.config().output_dir, "./test");
assert_eq!(builder.config().chunk_size, 256);
assert_eq!(builder.config().chunk_overlap, 32);
assert_eq!(builder.config().top_k_results, Some(15));
assert_eq!(builder.config().approach, "hybrid");
}
#[test]
fn test_typed_builder_state_transitions() {
let builder = TypedBuilder::new();
let builder = builder.with_output_dir("./test_output");
assert_eq!(builder.config().output_dir, "./test_output");
let builder = builder.with_ollama();
assert!(builder.config().ollama.enabled);
assert_eq!(builder.config().ollama.host, "localhost");
assert_eq!(builder.config().ollama.port, 11434);
}
#[test]
fn test_typed_builder_with_hash_embeddings() {
let builder = TypedBuilder::new()
.with_output_dir("./test")
.with_hash_embeddings();
assert!(!builder.config().ollama.enabled);
assert_eq!(builder.config().embeddings.backend, "hash");
assert_eq!(builder.config().approach, "algorithmic");
}
#[test]
fn test_typed_builder_with_ollama_custom() {
let builder = TypedBuilder::new()
.with_output_dir("./test")
.with_ollama_custom("my-server", 8080, "mistral:latest");
assert!(builder.config().ollama.enabled);
assert_eq!(builder.config().ollama.host, "my-server");
assert_eq!(builder.config().ollama.port, 8080);
assert_eq!(builder.config().ollama.chat_model, "mistral:latest");
}
#[test]
fn test_typed_builder_with_candle() {
let builder = TypedBuilder::new()
.with_output_dir("./test")
.with_candle_embeddings();
assert_eq!(builder.config().embeddings.backend, "candle");
}
#[test]
fn test_typed_builder_optional_methods() {
let builder = TypedBuilder::new()
.with_chunk_size(512)
.with_chunk_overlap(64)
.with_top_k(20)
.with_similarity_threshold(0.75)
.with_approach("hybrid")
.with_parallel(true)
.with_gleaning(3);
assert_eq!(builder.config().chunk_size, 512);
assert_eq!(builder.config().chunk_overlap, 64);
assert_eq!(builder.config().top_k_results, Some(20));
assert_eq!(builder.config().similarity_threshold, Some(0.75));
assert_eq!(builder.config().approach, "hybrid");
assert!(builder.config().parallel.enabled);
assert!(builder.config().entities.use_gleaning);
assert_eq!(builder.config().entities.max_gleaning_rounds, 3);
}
#[test]
fn test_typed_builder_order_independence() {
let builder1 = TypedBuilder::new()
.with_chunk_size(512)
.with_output_dir("./test1")
.with_ollama();
let builder2 = TypedBuilder::new()
.with_output_dir("./test2")
.with_chunk_size(512)
.with_ollama();
assert_eq!(builder1.config().chunk_size, builder2.config().chunk_size);
}
#[test]
fn test_typed_builder_llm_before_output() {
let builder = TypedBuilder::new()
.with_ollama() .with_output_dir("./test");
assert!(builder.config().ollama.enabled);
assert_eq!(builder.config().output_dir, "./test");
}
}