use std::sync::Arc;
mod alcove;
mod bridge;
pub use alcove::AlcoveRag;
pub use bridge::BridgeRag;
use crate::config::RagConfig;
pub struct RagChunk {
pub adapter: &'static str,
pub source: String,
pub content: String,
pub score: f32,
}
#[async_trait::async_trait]
pub trait RagAdapter: Send + Sync {
async fn retrieve(&self, query: &str, project: Option<&str>) -> Vec<RagChunk>;
fn name(&self) -> &'static str;
}
pub struct RagManager {
adapters: Vec<Box<dyn RagAdapter>>,
max_chars: usize,
}
impl RagManager {
pub fn from_config(cfg: &RagConfig) -> Option<Arc<Self>> {
let mut adapters: Vec<Box<dyn RagAdapter>> = Vec::new();
#[cfg(feature = "rag-alcove")]
if cfg.alcove.as_ref().map(|a| a.enabled).unwrap_or(true)
&& let Some(adapter) = AlcoveRag::from_config(cfg)
{
adapters.push(Box::new(adapter));
}
if let Some(bridge_cfg) = &cfg.bridge
&& bridge_cfg.enabled
{
adapters.push(Box::new(BridgeRag::new(bridge_cfg)));
}
if adapters.is_empty() {
return None;
}
let max_chars = cfg.max_tokens.unwrap_or(2000) * 4;
Some(Arc::new(Self {
adapters,
max_chars,
}))
}
const CONSTRAINT_KEYWORDS: &'static [&'static str] = &[
"must not",
"should not",
"must be",
"must have",
"required",
"constraint",
"important",
"critical",
"mandatory",
"prohibited",
"forbidden",
"never",
"always",
"ensure that",
"do not",
];
pub async fn retrieve(
&self,
query: &str,
project: Option<&str>,
max_results: usize,
) -> Vec<RagChunk> {
let mut chunks: Vec<RagChunk> = Vec::new();
for adapter in &self.adapters {
let before = chunks.len();
chunks.extend(adapter.retrieve(query, project).await);
let found = chunks.len() - before;
if found > 0 {
tracing::debug!(
adapter = adapter.name(),
found,
"RAG adapter returned chunks"
);
}
}
for chunk in &mut chunks {
let lower = chunk.content.to_lowercase();
if Self::CONSTRAINT_KEYWORDS
.iter()
.any(|kw| lower.contains(kw))
{
chunk.score += 0.15;
}
}
chunks.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut used = 0usize;
chunks
.into_iter()
.take(max_results)
.take_while(|chunk| {
let len = chunk.content.len() + chunk.source.len() + 50;
if used + len > self.max_chars {
return false;
}
used += len;
true
})
.collect()
}
pub async fn retrieve_and_format(&self, query: &str, project: Option<&str>) -> String {
let chunks = self.retrieve(query, project, usize::MAX).await;
if chunks.is_empty() {
return String::new();
}
let mut out = String::from("## RAG Context\n\n");
for chunk in chunks {
let entry = format!(
"[{}] {} (score: {:.2})\n> {}\n\n",
chunk.adapter,
chunk.source,
chunk.score,
chunk.content.replace('\n', "\n> ")
);
out.push_str(&entry);
}
out
}
}