use tracing::warn;
use crate::{
config::constants::{MAX_APEX_QUERY_CHARS, MAX_APEX_RESULTS, MAX_APEX_SNIPPET_CHARS},
integrations::{context::truncate_on_char_boundary, search_client::SearchClient},
};
#[derive(Debug, Clone, PartialEq)]
pub struct ApexContextResult {
pub file: String,
pub snippet: String,
pub score: f32,
pub start_line: Option<u32>,
}
pub async fn fetch_apex_context(
search: &dyn SearchClient,
apex_index: &str,
apex_path_prefixes: &[String],
cross_query: &str,
) -> Vec<ApexContextResult> {
if apex_index.is_empty() {
return Vec::new();
}
let query_full = cross_query.trim();
if query_full.is_empty() {
return Vec::new();
}
let query = truncate_on_char_boundary(query_full, MAX_APEX_QUERY_CHARS);
let top_k = (MAX_APEX_RESULTS * 4) as u32;
let raw_results = match search.search(apex_index, query, Some(top_k)).await {
Ok(r) => r,
Err(e) => {
warn!(
apex_index,
"APEX context search failed (fail-open, review continues): {e}"
);
return Vec::new();
}
};
let filtered = raw_results.into_iter().filter(|r| {
if apex_path_prefixes.is_empty() {
true } else {
apex_path_prefixes
.iter()
.any(|prefix| r.file.starts_with(prefix.as_str()))
}
});
filtered
.take(MAX_APEX_RESULTS)
.map(|r| {
let raw_snippet = r.snippet.unwrap_or_default();
let snippet =
truncate_on_char_boundary(&raw_snippet, MAX_APEX_SNIPPET_CHARS).to_string();
ApexContextResult {
file: r.file,
snippet,
score: r.score,
start_line: r.start_line,
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::integrations::search_client::{
EmbedderState, HealthResponse, IndexInfo, SearchClientError, SearchResult,
};
use async_trait::async_trait;
struct MockSearch {
results: Vec<SearchResult>,
error: Option<SearchClientError>,
called: std::sync::Mutex<bool>,
}
impl MockSearch {
fn with_results(results: Vec<SearchResult>) -> Self {
Self {
results,
error: None,
called: std::sync::Mutex::new(false),
}
}
fn with_error(err: SearchClientError) -> Self {
Self {
results: Vec::new(),
error: Some(err),
called: std::sync::Mutex::new(false),
}
}
fn was_called(&self) -> bool {
*self.called.lock().unwrap()
}
}
#[async_trait]
impl SearchClient for MockSearch {
async fn health(&self) -> Result<HealthResponse, SearchClientError> {
Ok(HealthResponse {
status: "ok".to_string(),
embedder: EmbedderState::Bool(true),
})
}
async fn list_indexes(&self) -> Result<Vec<IndexInfo>, SearchClientError> {
Ok(vec![])
}
async fn search(
&self,
_index_id: &str,
_query: &str,
_top_k: Option<u32>,
) -> Result<Vec<SearchResult>, SearchClientError> {
*self.called.lock().unwrap() = true;
if let Some(ref e) = self.error {
return Err(SearchClientError::Transport(e.to_string()));
}
Ok(self.results.clone())
}
}
fn make_result(file: &str, snippet: &str, score: f32) -> SearchResult {
SearchResult {
file: file.to_string(),
snippet: Some(snippet.to_string()),
score,
start_line: Some(1),
end_line: None,
}
}
#[tokio::test]
async fn apex_context_empty_index_returns_empty_no_search_call() {
let mock = MockSearch::with_results(vec![make_result("apex/spec.md", "content", 0.9)]);
let result = fetch_apex_context(&mock, "", &[], "PR title").await;
assert!(result.is_empty(), "empty index must return empty results");
assert!(
!mock.was_called(),
"search must not be called when apex_index is empty"
);
}
#[tokio::test]
async fn apex_context_empty_query_returns_empty_no_search_call() {
let mock = MockSearch::with_results(vec![make_result("apex/spec.md", "content", 0.9)]);
let result = fetch_apex_context(&mock, "my-index", &[], " \n\t ").await;
assert!(result.is_empty(), "blank query must return empty results");
assert!(
!mock.was_called(),
"search must not be called when query is blank"
);
}
#[tokio::test]
async fn apex_context_search_error_fail_open() {
let mock = MockSearch::with_error(SearchClientError::Transport("refused".to_string()));
let result = fetch_apex_context(&mock, "my-index", &[], "some PR query").await;
assert!(
result.is_empty(),
"search error must produce empty result (fail-open)"
);
}
#[tokio::test]
async fn apex_context_maps_result_fields_correctly() {
let mock = MockSearch::with_results(vec![SearchResult {
file: "apex/auth-spec.md".to_string(),
snippet: Some("The auth flow must verify token expiry.".to_string()),
score: 0.87,
start_line: Some(15),
end_line: None,
}]);
let result = fetch_apex_context(&mock, "my-index", &[], "auth token flow").await;
assert_eq!(result.len(), 1);
assert_eq!(result[0].file, "apex/auth-spec.md");
assert_eq!(result[0].snippet, "The auth flow must verify token expiry.");
assert!((result[0].score - 0.87_f32).abs() < 1e-4);
assert_eq!(result[0].start_line, Some(15));
}
#[tokio::test]
async fn apex_context_truncates_snippet_to_max_chars() {
let long_snippet = "x".repeat(MAX_APEX_SNIPPET_CHARS + 200);
let mock = MockSearch::with_results(vec![SearchResult {
file: "apex/spec.md".to_string(),
snippet: Some(long_snippet),
score: 0.5,
start_line: None,
end_line: None,
}]);
let result = fetch_apex_context(&mock, "idx", &[], "query").await;
assert_eq!(result.len(), 1);
let actual_chars = result[0].snippet.chars().count();
assert_eq!(
actual_chars, MAX_APEX_SNIPPET_CHARS,
"snippet must be exactly MAX_APEX_SNIPPET_CHARS chars"
);
}
#[tokio::test]
async fn apex_context_prefix_filter_retains_only_apex_hits() {
let mock = MockSearch::with_results(vec![
make_result("apex/spec.md", "spec content", 0.9),
make_result("src/main.rs", "fn main() {}", 0.85),
]);
let prefixes = vec!["apex/".to_string()];
let result = fetch_apex_context(&mock, "idx", &prefixes, "auth flow").await;
assert_eq!(result.len(), 1, "only apex/ hit must survive prefix filter");
assert_eq!(result[0].file, "apex/spec.md");
}
#[tokio::test]
async fn apex_context_no_prefix_all_results_are_apex() {
let mock = MockSearch::with_results(vec![
make_result("apex/a.md", "a", 0.9),
make_result("docs/b.md", "b", 0.8),
]);
let result = fetch_apex_context(&mock, "idx", &[], "query").await;
assert_eq!(
result.len(),
2,
"no prefix filter → all results returned (up to cap)"
);
}
#[tokio::test]
async fn apex_context_caps_at_max_results() {
let results: Vec<SearchResult> = (0..MAX_APEX_RESULTS + 2)
.map(|i| {
make_result(
&format!("apex/spec-{i}.md"),
"content",
0.9 - i as f32 * 0.01,
)
})
.collect();
let mock = MockSearch::with_results(results);
let prefixes = vec!["apex/".to_string()];
let result = fetch_apex_context(&mock, "idx", &prefixes, "query").await;
assert_eq!(
result.len(),
MAX_APEX_RESULTS,
"results must be capped at MAX_APEX_RESULTS"
);
}
#[tokio::test]
async fn apex_context_prefix_filter_no_match_returns_empty() {
let mock = MockSearch::with_results(vec![
make_result("src/a.rs", "code", 0.8),
make_result("tests/b.rs", "test", 0.7),
]);
let prefixes = vec!["apex/".to_string()];
let result = fetch_apex_context(&mock, "idx", &prefixes, "query").await;
assert!(
result.is_empty(),
"no matching prefix ⇒ empty (no code hits treated as APEX)"
);
}
#[tokio::test]
async fn apex_context_long_query_does_not_panic() {
let long_query = "a".repeat(MAX_APEX_QUERY_CHARS + 500);
let mock = MockSearch::with_results(vec![make_result("apex/spec.md", "content", 0.7)]);
let result = fetch_apex_context(&mock, "idx", &[], &long_query).await;
assert!(result.len() <= MAX_APEX_RESULTS);
assert!(
mock.was_called(),
"search must be called for a non-empty query"
);
}
}