use genai::Client;
use genai::chat::{ChatMessage as GenAiMessage, ChatRequest};
use crate::graph::CodeGraph;
use crate::rag::embedding::EmbeddingEngine;
use crate::rag::retrieval::{Citation, retrieve, retrieve_structural};
use crate::rag::session::{ChatMessage, ChatRole, SessionStore};
use crate::rag::vector_store::VectorStore;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryKind {
Structural,
Conceptual,
Hybrid,
}
pub fn classify_query(query: &str) -> QueryKind {
let q = query.trim().to_lowercase();
let structural_prefixes = [
"where is",
"find ",
"what calls",
"references to",
"who uses",
"show me",
"locate",
];
for prefix in &structural_prefixes {
if q.starts_with(prefix) {
return QueryKind::Structural;
}
}
let conceptual_prefixes = [
"how ",
"how does",
"explain",
"what does",
"why ",
"why is",
"why does",
"describe",
"what is the purpose",
"summarize",
];
for prefix in &conceptual_prefixes {
if q.starts_with(prefix) {
return QueryKind::Conceptual;
}
}
QueryKind::Hybrid
}
#[derive(Debug, Clone)]
pub struct RagResponse {
pub answer: String,
pub citations: Vec<Citation>,
pub tools_used: Vec<String>,
}
pub fn build_system_prompt(project_stats: &str) -> String {
format!(
"You are a codebase expert assistant. You answer questions about a specific software project \
using the provided code context extracted from the project's dependency graph and symbol index.\n\n\
Project overview:\n{project_stats}\n\n\
Instructions:\n\
- Answer only from the provided code context. Do not speculate about code not shown.\n\
- When referencing a specific symbol, file, or code snippet, add a footnote citation like [1], [2], etc.\n\
- Keep answers concise and developer-focused.\n\
- If the context does not contain enough information to answer, say so clearly.\n\
- Use markdown for code snippets."
)
}
pub fn build_user_prompt(query: &str, retrieval_context: &str) -> String {
if retrieval_context.is_empty() {
query.to_string()
} else {
format!("Codebase context:\n{retrieval_context}\n\nQuestion: {query}")
}
}
pub struct RagAgent;
impl RagAgent {
#[allow(clippy::too_many_arguments)]
pub async fn ask(
graph: &CodeGraph,
vector_store: &VectorStore,
engine: &EmbeddingEngine,
session_store: &mut SessionStore,
session_id: &str,
query: &str,
llm_client: &Client,
model: &str,
) -> anyhow::Result<RagResponse> {
let kind = classify_query(query);
let retrieval = retrieve(graph, vector_store, engine, query, kind).await?;
let history: Vec<ChatMessage> = session_store
.peek_history(session_id)
.map(|h| h.to_vec())
.unwrap_or_default();
let system_prompt = build_system_prompt("(codebase stats not available)");
let user_message_content = build_user_prompt(query, &retrieval.context_text);
let mut messages: Vec<GenAiMessage> = Vec::new();
for msg in &history {
match msg.role {
ChatRole::User => messages.push(GenAiMessage::user(&msg.content)),
ChatRole::Assistant => messages.push(GenAiMessage::assistant(&msg.content)),
ChatRole::System => {} }
}
messages.push(GenAiMessage::user(&user_message_content));
let request = ChatRequest::new(messages).with_system(system_prompt);
let response = llm_client.exec_chat(model, request, None).await?;
let answer = response.first_text().unwrap_or_default().to_string();
session_store.add_message(session_id, ChatMessage::user(query))?;
session_store.add_message(session_id, ChatMessage::assistant(&answer))?;
Ok(RagResponse {
answer,
citations: retrieval.citations,
tools_used: retrieval.tools_used,
})
}
#[allow(clippy::too_many_arguments)]
pub async fn ask_structural(
graph: &CodeGraph,
session_store: &mut SessionStore,
session_id: &str,
query: &str,
llm_client: &Client,
model: &str,
) -> anyhow::Result<RagResponse> {
let (context_text, citations, mut tools_used) = retrieve_structural(graph, query);
tools_used.push("structural-only (no embeddings)".to_string());
let history: Vec<ChatMessage> = session_store
.peek_history(session_id)
.map(|h| h.to_vec())
.unwrap_or_default();
let system_prompt = build_system_prompt("(codebase stats not available)");
let user_message_content = build_user_prompt(query, &context_text);
let mut messages: Vec<GenAiMessage> = Vec::new();
for msg in &history {
match msg.role {
ChatRole::User => messages.push(GenAiMessage::user(&msg.content)),
ChatRole::Assistant => messages.push(GenAiMessage::assistant(&msg.content)),
ChatRole::System => {}
}
}
messages.push(GenAiMessage::user(&user_message_content));
let request = ChatRequest::new(messages).with_system(system_prompt);
let response = llm_client.exec_chat(model, request, None).await?;
let answer = response.first_text().unwrap_or_default().to_string();
session_store.add_message(session_id, ChatMessage::user(query))?;
session_store.add_message(session_id, ChatMessage::assistant(&answer))?;
Ok(RagResponse {
answer,
citations,
tools_used,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classify_query_structural_where_is() {
assert_eq!(classify_query("where is auth"), QueryKind::Structural);
assert_eq!(
classify_query("where is the authentication handler"),
QueryKind::Structural
);
}
#[test]
fn classify_query_structural_find() {
assert_eq!(classify_query("find UserService"), QueryKind::Structural);
assert_eq!(
classify_query("Find all instances of fetch_user"),
QueryKind::Structural
);
}
#[test]
fn classify_query_structural_what_calls() {
assert_eq!(
classify_query("what calls handleAuth"),
QueryKind::Structural
);
assert_eq!(
classify_query("What calls the login function"),
QueryKind::Structural
);
}
#[test]
fn classify_query_structural_references_to() {
assert_eq!(
classify_query("references to UserService"),
QueryKind::Structural
);
}
#[test]
fn classify_query_structural_who_uses() {
assert_eq!(
classify_query("who uses the cache module"),
QueryKind::Structural
);
}
#[test]
fn classify_query_conceptual_how_does() {
assert_eq!(
classify_query("how does the caching system work"),
QueryKind::Conceptual
);
assert_eq!(
classify_query("how does authentication work"),
QueryKind::Conceptual
);
}
#[test]
fn classify_query_conceptual_explain() {
assert_eq!(
classify_query("explain the error handling"),
QueryKind::Conceptual
);
assert_eq!(
classify_query("Explain the retry mechanism"),
QueryKind::Conceptual
);
}
#[test]
fn classify_query_conceptual_why() {
assert_eq!(
classify_query("why is this function slow"),
QueryKind::Conceptual
);
assert_eq!(
classify_query("why does the server crash on startup"),
QueryKind::Conceptual
);
}
#[test]
fn classify_query_conceptual_what_does() {
assert_eq!(
classify_query("what does the auth module do"),
QueryKind::Conceptual
);
}
#[test]
fn classify_query_hybrid_default() {
assert_eq!(
classify_query("database connection pool"),
QueryKind::Hybrid
);
assert_eq!(classify_query("UserService"), QueryKind::Hybrid);
assert_eq!(classify_query("caching layer"), QueryKind::Hybrid);
}
#[test]
fn build_system_prompt_includes_instructions() {
let prompt = build_system_prompt("10 files, 42 symbols");
assert!(
prompt.contains("codebase"),
"prompt should mention codebase"
);
assert!(
prompt.contains("[1]") || prompt.contains("[N]"),
"prompt should mention citation format"
);
assert!(
prompt.contains("10 files, 42 symbols"),
"prompt should embed project stats"
);
}
#[test]
fn build_user_prompt_wraps_context_and_query() {
let prompt = build_user_prompt("what is foo?", "[1] function foo in src/lib.rs:10");
assert!(
prompt.contains("what is foo?"),
"prompt should contain the query"
);
assert!(
prompt.contains("[1] function foo"),
"prompt should embed retrieval context"
);
assert!(
prompt.contains("Codebase context:"),
"prompt should have context header"
);
}
#[test]
fn build_user_prompt_without_context_returns_query() {
let prompt = build_user_prompt("what is foo?", "");
assert_eq!(prompt, "what is foo?");
}
#[test]
fn citation_has_required_fields() {
let c = Citation {
index: 1,
file_path: "src/auth.rs".to_string(),
line_start: 42,
symbol_name: "authenticate_user".to_string(),
};
assert_eq!(c.index, 1);
assert_eq!(c.file_path, "src/auth.rs");
assert_eq!(c.line_start, 42);
assert_eq!(c.symbol_name, "authenticate_user");
}
}