use std::collections::HashMap;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use crate::error::{OxiRagError, PipelineError};
use crate::layer1_echo::filter::MetadataFilter;
use crate::types::Query;
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[allow(clippy::struct_excessive_bools)]
pub struct LayerHints {
pub use_echo: bool,
pub use_speculator: bool,
pub use_judge: bool,
#[cfg(feature = "graphrag")]
pub use_graph: bool,
}
impl LayerHints {
#[must_use]
pub fn all() -> Self {
Self {
use_echo: true,
use_speculator: true,
use_judge: true,
#[cfg(feature = "graphrag")]
use_graph: true,
}
}
#[must_use]
pub fn echo_only() -> Self {
Self {
use_echo: true,
use_speculator: false,
use_judge: false,
#[cfg(feature = "graphrag")]
use_graph: false,
}
}
#[must_use]
pub fn echo_and_speculator() -> Self {
Self {
use_echo: true,
use_speculator: true,
use_judge: false,
#[cfg(feature = "graphrag")]
use_graph: false,
}
}
#[must_use]
pub fn with_echo(mut self) -> Self {
self.use_echo = true;
self
}
#[must_use]
pub fn with_speculator(mut self) -> Self {
self.use_speculator = true;
self
}
#[must_use]
pub fn with_judge(mut self) -> Self {
self.use_judge = true;
self
}
#[cfg(feature = "graphrag")]
#[must_use]
pub fn with_graph(mut self) -> Self {
self.use_graph = true;
self
}
#[must_use]
pub fn without_echo(mut self) -> Self {
self.use_echo = false;
self
}
#[must_use]
pub fn without_speculator(mut self) -> Self {
self.use_speculator = false;
self
}
#[must_use]
pub fn without_judge(mut self) -> Self {
self.use_judge = false;
self
}
#[cfg(feature = "graphrag")]
#[must_use]
pub fn without_graph(mut self) -> Self {
self.use_graph = false;
self
}
}
#[cfg(feature = "graphrag")]
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct GraphContext {
pub start_entities: Vec<String>,
pub max_hops: Option<usize>,
pub min_confidence: Option<f32>,
pub relationship_types: Option<Vec<crate::layer4_graph::types::RelationshipType>>,
pub entity_types: Option<Vec<crate::layer4_graph::types::EntityType>>,
pub direction: Option<crate::layer4_graph::types::Direction>,
}
#[cfg(feature = "graphrag")]
impl GraphContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_start_entities(mut self, entities: Vec<String>) -> Self {
self.start_entities = entities;
self
}
#[must_use]
pub fn with_start_entity(mut self, entity: impl Into<String>) -> Self {
self.start_entities.push(entity.into());
self
}
#[must_use]
pub fn with_max_hops(mut self, max_hops: usize) -> Self {
self.max_hops = Some(max_hops);
self
}
#[must_use]
pub fn with_min_confidence(mut self, min_confidence: f32) -> Self {
self.min_confidence = Some(min_confidence);
self
}
#[must_use]
pub fn with_relationship_types(
mut self,
types: Vec<crate::layer4_graph::types::RelationshipType>,
) -> Self {
self.relationship_types = Some(types);
self
}
#[must_use]
pub fn with_entity_types(mut self, types: Vec<crate::layer4_graph::types::EntityType>) -> Self {
self.entity_types = Some(types);
self
}
#[must_use]
pub fn with_direction(mut self, direction: crate::layer4_graph::types::Direction) -> Self {
self.direction = Some(direction);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtendedQuery {
pub query: Query,
pub timeout: Option<Duration>,
pub cache_key: Option<String>,
pub layer_hints: Option<LayerHints>,
#[cfg(feature = "graphrag")]
pub graph_context: Option<GraphContext>,
}
impl ExtendedQuery {
#[must_use]
pub fn new(query: Query) -> Self {
Self {
query,
timeout: None,
cache_key: None,
layer_hints: None,
#[cfg(feature = "graphrag")]
graph_context: None,
}
}
#[must_use]
pub fn text(&self) -> &str {
&self.query.text
}
#[must_use]
pub fn top_k(&self) -> usize {
self.query.top_k
}
#[must_use]
pub fn min_score(&self) -> Option<f32> {
self.query.min_score
}
#[must_use]
pub fn metadata_filter(&self) -> Option<&MetadataFilter> {
self.query.metadata_filter.as_ref()
}
#[must_use]
pub fn timeout(&self) -> Option<Duration> {
self.timeout
}
#[must_use]
pub fn cache_key(&self) -> Option<&str> {
self.cache_key.as_deref()
}
#[must_use]
pub fn layer_hints(&self) -> Option<&LayerHints> {
self.layer_hints.as_ref()
}
#[cfg(feature = "graphrag")]
#[must_use]
pub fn graph_context(&self) -> Option<&GraphContext> {
self.graph_context.as_ref()
}
#[must_use]
pub fn into_query(self) -> Query {
self.query
}
}
impl From<ExtendedQuery> for Query {
fn from(extended: ExtendedQuery) -> Self {
extended.query
}
}
impl From<Query> for ExtendedQuery {
fn from(query: Query) -> Self {
Self::new(query)
}
}
#[derive(Debug, Clone, Default)]
pub struct QueryBuilder {
text: Option<String>,
top_k: Option<usize>,
min_score: Option<f32>,
filters: HashMap<String, String>,
metadata_filter: Option<MetadataFilter>,
timeout: Option<Duration>,
cache_key: Option<String>,
layer_hints: Option<LayerHints>,
#[cfg(feature = "graphrag")]
graph_context: Option<GraphContext>,
}
impl QueryBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn text(mut self, text: impl Into<String>) -> Self {
self.text = Some(text.into());
self
}
#[must_use]
pub fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = Some(top_k);
self
}
#[must_use]
pub fn with_min_score(mut self, min_score: f32) -> Self {
self.min_score = Some(min_score);
self
}
#[must_use]
pub fn with_filter(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.filters.insert(key.into(), value.into());
self
}
#[must_use]
pub fn with_metadata(mut self, filter: MetadataFilter) -> Self {
self.metadata_filter = Some(filter);
self
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn with_cache_key(mut self, key: impl Into<String>) -> Self {
self.cache_key = Some(key.into());
self
}
#[must_use]
pub fn with_layer_hints(mut self, hints: LayerHints) -> Self {
self.layer_hints = Some(hints);
self
}
#[cfg(feature = "graphrag")]
#[must_use]
pub fn with_graph_context(mut self, context: GraphContext) -> Self {
self.graph_context = Some(context);
self
}
pub fn build(self) -> Result<Query, OxiRagError> {
let text = self.text.ok_or_else(|| {
OxiRagError::Pipeline(PipelineError::BuildError(
"Query text is required".to_string(),
))
})?;
let mut query = Query::new(text);
if let Some(top_k) = self.top_k {
query = query.with_top_k(top_k);
}
if let Some(min_score) = self.min_score {
query = query.with_min_score(min_score);
}
for (key, value) in self.filters {
query = query.with_filter(key, value);
}
if let Some(filter) = self.metadata_filter {
query = query.with_metadata_filter(filter);
}
Ok(query)
}
pub fn build_extended(self) -> Result<ExtendedQuery, OxiRagError> {
let timeout = self.timeout;
let cache_key = self.cache_key.clone();
let layer_hints = self.layer_hints.clone();
#[cfg(feature = "graphrag")]
let graph_context = self.graph_context.clone();
let query = self.build()?;
Ok(ExtendedQuery {
query,
timeout,
cache_key,
layer_hints,
#[cfg(feature = "graphrag")]
graph_context,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_simple_query() {
let query = QueryBuilder::new()
.text("What is Rust?")
.build()
.expect("Failed to build query");
assert_eq!(query.text, "What is Rust?");
assert_eq!(query.top_k, 10); assert!(query.min_score.is_none());
}
#[test]
fn test_builder_with_top_k() {
let query = QueryBuilder::new()
.text("query")
.with_top_k(5)
.build()
.expect("Failed to build query");
assert_eq!(query.top_k, 5);
}
#[test]
fn test_builder_with_min_score() {
let query = QueryBuilder::new()
.text("query")
.with_min_score(0.7)
.build()
.expect("Failed to build query");
assert!((query.min_score.unwrap() - 0.7).abs() < f32::EPSILON);
}
#[test]
fn test_builder_with_simple_filters() {
let query = QueryBuilder::new()
.text("query")
.with_filter("category", "science")
.with_filter("status", "published")
.build()
.expect("Failed to build query");
assert_eq!(query.filters.get("category"), Some(&"science".to_string()));
assert_eq!(query.filters.get("status"), Some(&"published".to_string()));
}
#[test]
fn test_builder_with_metadata_filter() {
let filter = MetadataFilter::and(vec![
MetadataFilter::eq("status", "published"),
MetadataFilter::or(vec![
MetadataFilter::eq("category", "science"),
MetadataFilter::eq("category", "tech"),
]),
]);
let query = QueryBuilder::new()
.text("query")
.with_metadata(filter.clone())
.build()
.expect("Failed to build query");
assert_eq!(query.metadata_filter, Some(filter));
}
#[test]
fn test_builder_missing_text_error() {
let result = QueryBuilder::new().build();
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = err.to_string();
assert!(err_msg.contains("Query text is required"));
}
#[test]
fn test_builder_extended_query() {
let extended = QueryBuilder::new()
.text("query")
.with_top_k(5)
.with_timeout(Duration::from_secs(30))
.with_cache_key("test-key")
.with_layer_hints(LayerHints::echo_only())
.build_extended()
.expect("Failed to build extended query");
assert_eq!(extended.query.text, "query");
assert_eq!(extended.query.top_k, 5);
assert_eq!(extended.timeout, Some(Duration::from_secs(30)));
assert_eq!(extended.cache_key, Some("test-key".to_string()));
let hints = extended.layer_hints.unwrap();
assert!(hints.use_echo);
assert!(!hints.use_speculator);
assert!(!hints.use_judge);
}
#[test]
fn test_extended_query_accessors() {
let extended = QueryBuilder::new()
.text("test query")
.with_top_k(20)
.with_min_score(0.5)
.with_metadata(MetadataFilter::eq("key", "value"))
.with_timeout(Duration::from_secs(60))
.with_cache_key("cache-key")
.with_layer_hints(LayerHints::all())
.build_extended()
.expect("Failed to build");
assert_eq!(extended.text(), "test query");
assert_eq!(extended.top_k(), 20);
assert!((extended.min_score().unwrap() - 0.5).abs() < f32::EPSILON);
assert!(extended.metadata_filter().is_some());
assert_eq!(extended.timeout(), Some(Duration::from_secs(60)));
assert_eq!(extended.cache_key(), Some("cache-key"));
assert!(extended.layer_hints().is_some());
}
#[test]
fn test_extended_query_into_query() {
let extended = QueryBuilder::new()
.text("test")
.with_top_k(15)
.build_extended()
.expect("Failed to build");
let query = extended.into_query();
assert_eq!(query.text, "test");
assert_eq!(query.top_k, 15);
}
#[test]
fn test_extended_query_from_query() {
let query = Query::new("original").with_top_k(8);
let extended: ExtendedQuery = query.into();
assert_eq!(extended.query.text, "original");
assert_eq!(extended.query.top_k, 8);
assert!(extended.timeout.is_none());
assert!(extended.cache_key.is_none());
}
#[test]
fn test_layer_hints_all() {
let hints = LayerHints::all();
assert!(hints.use_echo);
assert!(hints.use_speculator);
assert!(hints.use_judge);
}
#[test]
fn test_layer_hints_echo_only() {
let hints = LayerHints::echo_only();
assert!(hints.use_echo);
assert!(!hints.use_speculator);
assert!(!hints.use_judge);
}
#[test]
fn test_layer_hints_echo_and_speculator() {
let hints = LayerHints::echo_and_speculator();
assert!(hints.use_echo);
assert!(hints.use_speculator);
assert!(!hints.use_judge);
}
#[test]
fn test_layer_hints_builder_methods() {
let hints = LayerHints::default()
.with_echo()
.with_speculator()
.with_judge();
assert!(hints.use_echo);
assert!(hints.use_speculator);
assert!(hints.use_judge);
let hints2 = hints.without_speculator().without_judge();
assert!(hints2.use_echo);
assert!(!hints2.use_speculator);
assert!(!hints2.use_judge);
}
#[test]
fn test_query_builder_clone() {
let builder = QueryBuilder::new()
.text("query")
.with_top_k(5)
.with_min_score(0.6);
let cloned = builder.clone();
let query = cloned.build().expect("Failed to build");
assert_eq!(query.text, "query");
assert_eq!(query.top_k, 5);
}
#[test]
fn test_query_builder_default() {
let builder = QueryBuilder::default();
let result = builder.text("test").build();
assert!(result.is_ok());
}
#[test]
fn test_layer_hints_equality() {
let hints1 = LayerHints::all();
let hints2 = LayerHints::all();
let hints3 = LayerHints::echo_only();
assert_eq!(hints1, hints2);
assert_ne!(hints1, hints3);
}
#[cfg(feature = "graphrag")]
mod graphrag_tests {
use super::*;
use crate::layer4_graph::types::{Direction, EntityType, RelationshipType};
#[test]
fn test_graph_context_new() {
let ctx = GraphContext::new();
assert!(ctx.start_entities.is_empty());
assert!(ctx.max_hops.is_none());
assert!(ctx.min_confidence.is_none());
}
#[test]
fn test_graph_context_builder() {
let ctx = GraphContext::new()
.with_start_entity("rust")
.with_start_entity("programming")
.with_max_hops(3)
.with_min_confidence(0.5)
.with_direction(Direction::Outgoing)
.with_relationship_types(vec![RelationshipType::RelatedTo])
.with_entity_types(vec![EntityType::Technology]);
assert_eq!(ctx.start_entities.len(), 2);
assert_eq!(ctx.max_hops, Some(3));
assert!((ctx.min_confidence.unwrap() - 0.5).abs() < f32::EPSILON);
assert_eq!(ctx.direction, Some(Direction::Outgoing));
assert!(ctx.relationship_types.is_some());
assert!(ctx.entity_types.is_some());
}
#[test]
fn test_graph_context_with_start_entities() {
let ctx =
GraphContext::new().with_start_entities(vec!["a".to_string(), "b".to_string()]);
assert_eq!(ctx.start_entities, vec!["a", "b"]);
}
#[test]
fn test_builder_with_graph_context() {
let ctx = GraphContext::new()
.with_start_entity("test")
.with_max_hops(2);
let extended = QueryBuilder::new()
.text("query")
.with_graph_context(ctx)
.build_extended()
.expect("Failed to build");
let graph_ctx = extended.graph_context().unwrap();
assert_eq!(graph_ctx.start_entities, vec!["test"]);
assert_eq!(graph_ctx.max_hops, Some(2));
}
#[test]
fn test_layer_hints_with_graph() {
let hints = LayerHints::all();
assert!(hints.use_graph);
let hints2 = hints.without_graph();
assert!(!hints2.use_graph);
let hints3 = LayerHints::default().with_graph();
assert!(hints3.use_graph);
}
}
#[test]
fn test_query_builder_full_chain() {
let query = QueryBuilder::new()
.text("Find documents about machine learning")
.with_top_k(20)
.with_min_score(0.6)
.with_filter("language", "en")
.with_filter("year", "2024")
.with_metadata(MetadataFilter::exists("author"))
.build()
.expect("Failed to build");
assert_eq!(query.text, "Find documents about machine learning");
assert_eq!(query.top_k, 20);
assert!((query.min_score.unwrap() - 0.6).abs() < f32::EPSILON);
assert_eq!(query.filters.get("language"), Some(&"en".to_string()));
assert_eq!(query.filters.get("year"), Some(&"2024".to_string()));
assert!(query.metadata_filter.is_some());
}
#[test]
fn test_extended_query_serialization() {
let extended = QueryBuilder::new()
.text("test")
.with_top_k(5)
.with_cache_key("key")
.with_layer_hints(LayerHints::echo_only())
.build_extended()
.expect("Failed to build");
let json = serde_json::to_string(&extended).expect("Failed to serialize");
let parsed: ExtendedQuery = serde_json::from_str(&json).expect("Failed to deserialize");
assert_eq!(parsed.query.text, "test");
assert_eq!(parsed.query.top_k, 5);
assert_eq!(parsed.cache_key, Some("key".to_string()));
}
}