use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ExpandStrategy {
Basic,
Detailed,
Structured,
Creative,
#[default]
Auto,
}
impl std::fmt::Display for ExpandStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExpandStrategy::Basic => write!(f, "basic"),
ExpandStrategy::Detailed => write!(f, "detailed"),
ExpandStrategy::Structured => write!(f, "structured"),
ExpandStrategy::Creative => write!(f, "creative"),
ExpandStrategy::Auto => write!(f, "auto"),
}
}
}
impl ExpandStrategy {
pub fn all() -> Vec<ExpandStrategy> {
vec![
ExpandStrategy::Basic,
ExpandStrategy::Detailed,
ExpandStrategy::Structured,
ExpandStrategy::Creative,
ExpandStrategy::Auto,
]
}
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"basic" => Some(ExpandStrategy::Basic),
"detailed" => Some(ExpandStrategy::Detailed),
"structured" => Some(ExpandStrategy::Structured),
"creative" => Some(ExpandStrategy::Creative),
"auto" => Some(ExpandStrategy::Auto),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpandResult {
pub original_prompt: String,
pub expanded_prompt: String,
pub strategy_used: ExpandStrategy,
pub metadata: HashMap<String, serde_json::Value>,
}
impl ExpandResult {
pub fn new(
original: impl Into<String>,
expanded: impl Into<String>,
strategy: ExpandStrategy,
) -> Self {
Self {
original_prompt: original.into(),
expanded_prompt: expanded.into(),
strategy_used: strategy,
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn expansion_ratio(&self) -> f64 {
if self.original_prompt.is_empty() {
0.0
} else {
self.expanded_prompt.len() as f64 / self.original_prompt.len() as f64
}
}
}
pub struct ExpandPrompts;
impl ExpandPrompts {
pub const BASIC: &'static str = r#"Expand the following prompt with additional context and clarity while maintaining the original intent. Add relevant details that would help an AI assistant understand and respond better.
Original prompt: {prompt}
Expanded prompt:"#;
pub const DETAILED: &'static str = r#"Transform the following prompt into a comprehensive, detailed request. Include:
- Clear objectives and expected outcomes
- Relevant context and background information
- Specific requirements or constraints
- Quality criteria for the response
- Any relevant examples or references
Original prompt: {prompt}
Detailed expanded prompt:"#;
pub const STRUCTURED: &'static str = r#"Restructure the following prompt into a well-organized format with clear sections:
1. **Objective**: What needs to be accomplished
2. **Context**: Background information
3. **Requirements**: Specific needs and constraints
4. **Output Format**: Expected format of the response
5. **Success Criteria**: How to evaluate the response
Original prompt: {prompt}
Structured expanded prompt:"#;
pub const CREATIVE: &'static str = r#"Expand the following prompt with creative and imaginative elements while maintaining the core intent. Add engaging context, vivid descriptions, and innovative angles that could inspire a more creative response.
Original prompt: {prompt}
Creative expanded prompt:"#;
pub const AUTO: &'static str = r#"Analyze the following prompt and expand it using the most appropriate strategy. Consider:
- The type of task (creative, technical, analytical, etc.)
- The level of detail needed
- The expected output format
First, briefly identify the best expansion approach, then provide the expanded prompt.
Original prompt: {prompt}
Analysis and expanded prompt:"#;
pub fn get(strategy: ExpandStrategy) -> &'static str {
match strategy {
ExpandStrategy::Basic => Self::BASIC,
ExpandStrategy::Detailed => Self::DETAILED,
ExpandStrategy::Structured => Self::STRUCTURED,
ExpandStrategy::Creative => Self::CREATIVE,
ExpandStrategy::Auto => Self::AUTO,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptExpanderConfig {
pub name: String,
pub model: String,
pub instructions: Option<String>,
pub verbose: bool,
pub temperature: f32,
pub max_tokens: usize,
}
impl Default for PromptExpanderConfig {
fn default() -> Self {
Self {
name: "PromptExpanderAgent".to_string(),
model: "gpt-4o-mini".to_string(),
instructions: None,
verbose: false,
temperature: 0.7,
max_tokens: 1000,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PromptExpanderAgentBuilder {
config: PromptExpanderConfig,
}
impl PromptExpanderAgentBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.config.name = name.into();
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.config.model = model.into();
self
}
pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
self.config.instructions = Some(instructions.into());
self
}
pub fn verbose(mut self) -> Self {
self.config.verbose = true;
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.config.temperature = temp;
self
}
pub fn max_tokens(mut self, tokens: usize) -> Self {
self.config.max_tokens = tokens;
self
}
pub fn build(self) -> PromptExpanderAgent {
PromptExpanderAgent {
config: self.config,
}
}
}
#[derive(Debug, Clone)]
pub struct PromptExpanderAgent {
config: PromptExpanderConfig,
}
impl PromptExpanderAgent {
pub fn new() -> PromptExpanderAgentBuilder {
PromptExpanderAgentBuilder::new()
}
pub fn name(&self) -> &str {
&self.config.name
}
pub fn model(&self) -> &str {
&self.config.model
}
pub fn detect_strategy(&self, prompt: &str) -> ExpandStrategy {
let prompt_lower = prompt.to_lowercase();
let word_count = prompt.split_whitespace().count();
if prompt_lower.contains("creative")
|| prompt_lower.contains("story")
|| prompt_lower.contains("imagine")
|| prompt_lower.contains("write a poem")
|| prompt_lower.contains("fiction")
{
return ExpandStrategy::Creative;
}
if prompt_lower.contains("analyze")
|| prompt_lower.contains("compare")
|| prompt_lower.contains("evaluate")
|| prompt_lower.contains("report")
|| prompt_lower.contains("document")
{
return ExpandStrategy::Structured;
}
if prompt_lower.contains("explain")
|| prompt_lower.contains("describe in detail")
|| prompt_lower.contains("comprehensive")
|| prompt_lower.contains("thorough")
{
return ExpandStrategy::Detailed;
}
if word_count < 10 {
return ExpandStrategy::Detailed;
}
ExpandStrategy::Basic
}
pub fn expand_sync(
&self,
prompt: &str,
strategy: ExpandStrategy,
context: Option<&str>,
) -> ExpandResult {
let actual_strategy = if strategy == ExpandStrategy::Auto {
self.detect_strategy(prompt)
} else {
strategy
};
let expansion_prompt = ExpandPrompts::get(actual_strategy).replace("{prompt}", prompt);
let full_prompt = if let Some(ctx) = context {
format!("{}\n\nAdditional context: {}", expansion_prompt, ctx)
} else {
expansion_prompt
};
let expanded = format!(
"## Expanded Prompt\n\n{}\n\n### Original Intent\n{}\n\n### Additional Context\nThis prompt has been expanded using the {} strategy.",
prompt,
prompt,
actual_strategy
);
ExpandResult::new(prompt, expanded, actual_strategy)
.with_metadata("expansion_prompt", full_prompt)
.with_metadata("model", self.config.model.clone())
}
}
impl Default for PromptExpanderAgent {
fn default() -> Self {
Self::new().build()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RewriteStrategy {
Basic,
Hyde,
StepBack,
SubQueries,
MultiQuery,
Contextual,
#[default]
Auto,
}
impl std::fmt::Display for RewriteStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RewriteStrategy::Basic => write!(f, "basic"),
RewriteStrategy::Hyde => write!(f, "hyde"),
RewriteStrategy::StepBack => write!(f, "step_back"),
RewriteStrategy::SubQueries => write!(f, "sub_queries"),
RewriteStrategy::MultiQuery => write!(f, "multi_query"),
RewriteStrategy::Contextual => write!(f, "contextual"),
RewriteStrategy::Auto => write!(f, "auto"),
}
}
}
impl RewriteStrategy {
pub fn all() -> Vec<RewriteStrategy> {
vec![
RewriteStrategy::Basic,
RewriteStrategy::Hyde,
RewriteStrategy::StepBack,
RewriteStrategy::SubQueries,
RewriteStrategy::MultiQuery,
RewriteStrategy::Contextual,
RewriteStrategy::Auto,
]
}
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"basic" => Some(RewriteStrategy::Basic),
"hyde" => Some(RewriteStrategy::Hyde),
"step_back" | "stepback" => Some(RewriteStrategy::StepBack),
"sub_queries" | "subqueries" => Some(RewriteStrategy::SubQueries),
"multi_query" | "multiquery" => Some(RewriteStrategy::MultiQuery),
"contextual" => Some(RewriteStrategy::Contextual),
"auto" => Some(RewriteStrategy::Auto),
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RewriteResult {
pub original_query: String,
pub rewritten_queries: Vec<String>,
pub strategy_used: RewriteStrategy,
pub hypothetical_document: Option<String>,
pub step_back_question: Option<String>,
pub sub_queries: Option<Vec<String>>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl RewriteResult {
pub fn new(
original: impl Into<String>,
rewritten: Vec<String>,
strategy: RewriteStrategy,
) -> Self {
Self {
original_query: original.into(),
rewritten_queries: rewritten,
strategy_used: strategy,
hypothetical_document: None,
step_back_question: None,
sub_queries: None,
metadata: HashMap::new(),
}
}
pub fn with_hypothetical_document(mut self, doc: impl Into<String>) -> Self {
self.hypothetical_document = Some(doc.into());
self
}
pub fn with_step_back_question(mut self, question: impl Into<String>) -> Self {
self.step_back_question = Some(question.into());
self
}
pub fn with_sub_queries(mut self, queries: Vec<String>) -> Self {
self.sub_queries = Some(queries);
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn primary_query(&self) -> Option<&str> {
self.rewritten_queries.first().map(|s| s.as_str())
}
pub fn all_queries(&self) -> Vec<&str> {
let mut queries: Vec<&str> = self.rewritten_queries.iter().map(|s| s.as_str()).collect();
if let Some(sub) = &self.sub_queries {
queries.extend(sub.iter().map(|s| s.as_str()));
}
queries
}
}
pub struct RewritePrompts;
impl RewritePrompts {
pub const BASIC: &'static str = r#"Rewrite the following query to be clearer and more specific for information retrieval. Maintain the original intent while improving clarity.
Original query: {query}
Rewritten query:"#;
pub const HYDE: &'static str = r#"Given the following query, write a hypothetical document that would perfectly answer this query. This document will be used to find similar real documents.
Query: {query}
Hypothetical document that answers this query:"#;
pub const STEP_BACK: &'static str = r#"Given the following specific query, generate a more general "step-back" question that would provide broader context helpful for answering the original query.
Original query: {query}
Step-back question (more general):"#;
pub const SUB_QUERIES: &'static str = r#"Break down the following complex query into simpler sub-queries that together would help answer the original question. Generate 2-4 sub-queries.
Original query: {query}
Sub-queries (one per line):
1."#;
pub const MULTI_QUERY: &'static str = r#"Generate {num_queries} different variations of the following query. Each variation should approach the question from a different angle while maintaining the same intent.
Original query: {query}
Query variations:
1."#;
pub const CONTEXTUAL: &'static str = r#"Given the chat history and the current query, rewrite the query to be self-contained and clear without requiring the chat history for context.
Chat history:
{chat_history}
Current query: {query}
Self-contained rewritten query:"#;
pub const AUTO: &'static str = r#"Analyze the following query and determine the best rewriting strategy, then apply it.
Query: {query}
First identify if this query would benefit from:
- Basic rewriting (simple clarification)
- HyDE (generating a hypothetical answer document)
- Step-back (asking a broader question first)
- Sub-queries (breaking into smaller questions)
- Multi-query (generating variations)
Then provide the rewritten query/queries."#;
pub fn get(strategy: RewriteStrategy) -> &'static str {
match strategy {
RewriteStrategy::Basic => Self::BASIC,
RewriteStrategy::Hyde => Self::HYDE,
RewriteStrategy::StepBack => Self::STEP_BACK,
RewriteStrategy::SubQueries => Self::SUB_QUERIES,
RewriteStrategy::MultiQuery => Self::MULTI_QUERY,
RewriteStrategy::Contextual => Self::CONTEXTUAL,
RewriteStrategy::Auto => Self::AUTO,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryRewriterConfig {
pub name: String,
pub model: String,
pub instructions: Option<String>,
pub verbose: bool,
pub max_queries: usize,
pub abbreviations: HashMap<String, String>,
pub temperature: f32,
pub max_tokens: usize,
}
impl Default for QueryRewriterConfig {
fn default() -> Self {
Self {
name: "QueryRewriterAgent".to_string(),
model: "gpt-4o-mini".to_string(),
instructions: None,
verbose: false,
max_queries: 5,
abbreviations: HashMap::new(),
temperature: 0.3,
max_tokens: 500,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct QueryRewriterAgentBuilder {
config: QueryRewriterConfig,
}
impl QueryRewriterAgentBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn name(mut self, name: impl Into<String>) -> Self {
self.config.name = name.into();
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.config.model = model.into();
self
}
pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
self.config.instructions = Some(instructions.into());
self
}
pub fn verbose(mut self) -> Self {
self.config.verbose = true;
self
}
pub fn max_queries(mut self, max: usize) -> Self {
self.config.max_queries = max;
self
}
pub fn abbreviation(mut self, abbrev: impl Into<String>, expansion: impl Into<String>) -> Self {
self.config.abbreviations.insert(abbrev.into(), expansion.into());
self
}
pub fn abbreviations(mut self, abbrevs: HashMap<String, String>) -> Self {
self.config.abbreviations = abbrevs;
self
}
pub fn temperature(mut self, temp: f32) -> Self {
self.config.temperature = temp;
self
}
pub fn max_tokens(mut self, tokens: usize) -> Self {
self.config.max_tokens = tokens;
self
}
pub fn build(self) -> QueryRewriterAgent {
QueryRewriterAgent {
config: self.config,
}
}
}
#[derive(Debug, Clone)]
pub struct QueryRewriterAgent {
config: QueryRewriterConfig,
}
impl QueryRewriterAgent {
pub fn new() -> QueryRewriterAgentBuilder {
QueryRewriterAgentBuilder::new()
}
pub fn name(&self) -> &str {
&self.config.name
}
pub fn model(&self) -> &str {
&self.config.model
}
pub fn expand_abbreviations(&self, query: &str) -> String {
let mut result = query.to_string();
for (abbrev, expansion) in &self.config.abbreviations {
result = result.replace(abbrev, expansion);
}
result
}
pub fn detect_strategy(&self, query: &str, has_chat_history: bool) -> RewriteStrategy {
let query_lower = query.to_lowercase();
let word_count = query.split_whitespace().count();
if has_chat_history {
return RewriteStrategy::Contextual;
}
if query_lower.contains(" and ")
|| query_lower.contains(" or ")
|| word_count > 20
|| query.contains("?") && query.matches("?").count() > 1
{
return RewriteStrategy::SubQueries;
}
if query_lower.starts_with("what is")
|| query_lower.starts_with("how does")
|| query_lower.starts_with("why")
|| query_lower.contains("explain")
{
return RewriteStrategy::StepBack;
}
if query_lower.starts_with("who")
|| query_lower.starts_with("when")
|| query_lower.starts_with("where")
|| query_lower.contains("specific")
{
return RewriteStrategy::Hyde;
}
if word_count < 5 {
return RewriteStrategy::MultiQuery;
}
RewriteStrategy::Basic
}
pub fn rewrite_sync(
&self,
query: &str,
strategy: RewriteStrategy,
chat_history: Option<&[HashMap<String, String>]>,
context: Option<&str>,
num_queries: Option<usize>,
) -> RewriteResult {
let expanded_query = self.expand_abbreviations(query);
let has_history = chat_history.map(|h| !h.is_empty()).unwrap_or(false);
let actual_strategy = if strategy == RewriteStrategy::Auto {
self.detect_strategy(&expanded_query, has_history)
} else {
strategy
};
let num = num_queries.unwrap_or(self.config.max_queries);
let mut result = match actual_strategy {
RewriteStrategy::Basic => {
let rewritten = format!("{} (clarified and optimized for retrieval)", expanded_query);
RewriteResult::new(query, vec![rewritten], actual_strategy)
}
RewriteStrategy::Hyde => {
let hypothetical = format!(
"This document discusses {}. It provides detailed information about the topic, \
including key concepts, examples, and practical applications.",
expanded_query
);
RewriteResult::new(query, vec![expanded_query.clone()], actual_strategy)
.with_hypothetical_document(hypothetical)
}
RewriteStrategy::StepBack => {
let step_back = format!("What are the fundamental concepts related to: {}", expanded_query);
RewriteResult::new(query, vec![expanded_query.clone(), step_back.clone()], actual_strategy)
.with_step_back_question(step_back)
}
RewriteStrategy::SubQueries => {
let sub = vec![
format!("What is {}?", expanded_query.split_whitespace().take(3).collect::<Vec<_>>().join(" ")),
format!("How does {} work?", expanded_query.split_whitespace().take(3).collect::<Vec<_>>().join(" ")),
format!("Examples of {}", expanded_query.split_whitespace().take(3).collect::<Vec<_>>().join(" ")),
];
RewriteResult::new(query, vec![expanded_query.clone()], actual_strategy)
.with_sub_queries(sub)
}
RewriteStrategy::MultiQuery => {
let variations: Vec<String> = (0..num.min(5))
.map(|i| format!("{} (variation {})", expanded_query, i + 1))
.collect();
RewriteResult::new(query, variations, actual_strategy)
}
RewriteStrategy::Contextual => {
let rewritten = format!("{} (contextualized from chat history)", expanded_query);
RewriteResult::new(query, vec![rewritten], actual_strategy)
}
RewriteStrategy::Auto => {
RewriteResult::new(query, vec![expanded_query], actual_strategy)
}
};
if let Some(ctx) = context {
result = result.with_metadata("context", ctx.to_string());
}
result.with_metadata("model", self.config.model.clone())
}
}
impl Default for QueryRewriterAgent {
fn default() -> Self {
Self::new().build()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expand_strategy_display() {
assert_eq!(ExpandStrategy::Basic.to_string(), "basic");
assert_eq!(ExpandStrategy::Detailed.to_string(), "detailed");
assert_eq!(ExpandStrategy::Auto.to_string(), "auto");
}
#[test]
fn test_expand_strategy_from_str() {
assert_eq!(ExpandStrategy::from_str("basic"), Some(ExpandStrategy::Basic));
assert_eq!(ExpandStrategy::from_str("DETAILED"), Some(ExpandStrategy::Detailed));
assert_eq!(ExpandStrategy::from_str("invalid"), None);
}
#[test]
fn test_expand_result() {
let result = ExpandResult::new("test", "expanded test", ExpandStrategy::Basic)
.with_metadata("key", "value");
assert_eq!(result.original_prompt, "test");
assert_eq!(result.expanded_prompt, "expanded test");
assert_eq!(result.strategy_used, ExpandStrategy::Basic);
assert!(result.metadata.contains_key("key"));
}
#[test]
fn test_expand_result_ratio() {
let result = ExpandResult::new("test", "expanded test content", ExpandStrategy::Basic);
assert!(result.expansion_ratio() > 1.0);
let empty = ExpandResult::new("", "expanded", ExpandStrategy::Basic);
assert_eq!(empty.expansion_ratio(), 0.0);
}
#[test]
fn test_prompt_expander_builder() {
let agent = PromptExpanderAgent::new()
.name("TestExpander")
.model("gpt-4")
.temperature(0.5)
.verbose()
.build();
assert_eq!(agent.name(), "TestExpander");
assert_eq!(agent.model(), "gpt-4");
}
#[test]
fn test_prompt_expander_detect_strategy() {
let agent = PromptExpanderAgent::default();
assert_eq!(
agent.detect_strategy("Write a creative story about dragons"),
ExpandStrategy::Creative
);
assert_eq!(
agent.detect_strategy("Analyze the market trends"),
ExpandStrategy::Structured
);
assert_eq!(
agent.detect_strategy("Hello"),
ExpandStrategy::Detailed
);
}
#[test]
fn test_prompt_expander_expand_sync() {
let agent = PromptExpanderAgent::default();
let result = agent.expand_sync("Write a blog post", ExpandStrategy::Basic, None);
assert_eq!(result.original_prompt, "Write a blog post");
assert!(!result.expanded_prompt.is_empty());
assert_eq!(result.strategy_used, ExpandStrategy::Basic);
}
#[test]
fn test_rewrite_strategy_display() {
assert_eq!(RewriteStrategy::Basic.to_string(), "basic");
assert_eq!(RewriteStrategy::Hyde.to_string(), "hyde");
assert_eq!(RewriteStrategy::StepBack.to_string(), "step_back");
}
#[test]
fn test_rewrite_strategy_from_str() {
assert_eq!(RewriteStrategy::from_str("basic"), Some(RewriteStrategy::Basic));
assert_eq!(RewriteStrategy::from_str("hyde"), Some(RewriteStrategy::Hyde));
assert_eq!(RewriteStrategy::from_str("step_back"), Some(RewriteStrategy::StepBack));
assert_eq!(RewriteStrategy::from_str("stepback"), Some(RewriteStrategy::StepBack));
assert_eq!(RewriteStrategy::from_str("invalid"), None);
}
#[test]
fn test_rewrite_result() {
let result = RewriteResult::new("test query", vec!["rewritten".to_string()], RewriteStrategy::Basic)
.with_hypothetical_document("hypothetical doc")
.with_step_back_question("broader question")
.with_sub_queries(vec!["sub1".to_string(), "sub2".to_string()])
.with_metadata("key", "value");
assert_eq!(result.original_query, "test query");
assert_eq!(result.primary_query(), Some("rewritten"));
assert!(result.hypothetical_document.is_some());
assert!(result.step_back_question.is_some());
assert!(result.sub_queries.is_some());
}
#[test]
fn test_rewrite_result_all_queries() {
let result = RewriteResult::new("test", vec!["q1".to_string(), "q2".to_string()], RewriteStrategy::Basic)
.with_sub_queries(vec!["sub1".to_string(), "sub2".to_string()]);
let all = result.all_queries();
assert_eq!(all.len(), 4);
}
#[test]
fn test_query_rewriter_builder() {
let agent = QueryRewriterAgent::new()
.name("TestRewriter")
.model("gpt-4")
.max_queries(3)
.abbreviation("ML", "Machine Learning")
.temperature(0.2)
.verbose()
.build();
assert_eq!(agent.name(), "TestRewriter");
assert_eq!(agent.model(), "gpt-4");
}
#[test]
fn test_query_rewriter_expand_abbreviations() {
let agent = QueryRewriterAgent::new()
.abbreviation("ML", "Machine Learning")
.abbreviation("AI", "Artificial Intelligence")
.build();
let expanded = agent.expand_abbreviations("What is ML and AI?");
assert!(expanded.contains("Machine Learning"));
assert!(expanded.contains("Artificial Intelligence"));
}
#[test]
fn test_query_rewriter_detect_strategy() {
let agent = QueryRewriterAgent::default();
assert_eq!(
agent.detect_strategy("What is X and how does Y relate to Z?", false),
RewriteStrategy::SubQueries
);
assert_eq!(
agent.detect_strategy("What is machine learning?", false),
RewriteStrategy::StepBack
);
assert_eq!(
agent.detect_strategy("Who invented the telephone?", false),
RewriteStrategy::Hyde
);
assert_eq!(
agent.detect_strategy("Tell me more", true),
RewriteStrategy::Contextual
);
assert_eq!(
agent.detect_strategy("Python", false),
RewriteStrategy::MultiQuery
);
}
#[test]
fn test_query_rewriter_rewrite_sync() {
let agent = QueryRewriterAgent::default();
let result = agent.rewrite_sync("What is Rust?", RewriteStrategy::Basic, None, None, None);
assert_eq!(result.original_query, "What is Rust?");
assert!(!result.rewritten_queries.is_empty());
assert_eq!(result.strategy_used, RewriteStrategy::Basic);
}
#[test]
fn test_query_rewriter_hyde_strategy() {
let agent = QueryRewriterAgent::default();
let result = agent.rewrite_sync("What is Rust?", RewriteStrategy::Hyde, None, None, None);
assert!(result.hypothetical_document.is_some());
}
#[test]
fn test_query_rewriter_sub_queries_strategy() {
let agent = QueryRewriterAgent::default();
let result = agent.rewrite_sync("Complex query", RewriteStrategy::SubQueries, None, None, None);
assert!(result.sub_queries.is_some());
assert!(!result.sub_queries.as_ref().unwrap().is_empty());
}
}