use async_trait::async_trait;
use reqwest::Client;
use serde_json::{Value, json};
use crate::mcp::error::McpError;
#[derive(Debug)]
pub struct RawHit {
pub id: String,
pub score: f64,
pub fields: Value,
pub snippet: Option<String>,
pub body: Option<String>,
}
#[derive(Debug)]
pub struct RawSearchResponse {
pub hits: Vec<RawHit>,
pub answer: Option<String>,
pub citations: Option<Vec<Value>>,
pub next_cursor: Option<String>,
pub total_estimate: u64,
}
#[async_trait]
#[allow(clippy::too_many_arguments)]
pub trait SearchApiAdapter: Send + Sync {
async fn search_knowledge_base(
&self,
knowledge_base_name: &str,
query: &str,
odata_filter: Option<&str>,
top: usize,
cursor: Option<&str>,
include_synthesis: bool,
include_full_body: bool,
) -> Result<RawSearchResponse, McpError>;
async fn search_index(
&self,
index_name: &str,
query: &str,
odata_filter: Option<&str>,
top: usize,
cursor: Option<&str>,
include_full_body: bool,
) -> Result<RawSearchResponse, McpError>;
}
pub struct AzureSearchAdapter {
http: Client,
service_url: String,
api_version: String,
token: String,
}
impl AzureSearchAdapter {
pub fn new(service_url: String, api_version: String) -> Result<Self, McpError> {
let auth = rigg_client::auth::get_auth_provider()
.map_err(|e| McpError::Unauthenticated(format!("auth: {e}")))?;
let token = auth
.get_token()
.map_err(|e| McpError::Unauthenticated(format!("token: {e}")))?;
let http = Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| McpError::Internal(format!("http client: {e}")))?;
Ok(Self {
http,
service_url,
api_version,
token,
})
}
fn bearer(&self) -> String {
format!("Bearer {}", self.token)
}
}
#[async_trait]
impl SearchApiAdapter for AzureSearchAdapter {
async fn search_knowledge_base(
&self,
knowledge_base_name: &str,
query: &str,
odata_filter: Option<&str>,
top: usize,
cursor: Option<&str>,
include_synthesis: bool,
include_full_body: bool,
) -> Result<RawSearchResponse, McpError> {
let url = format!(
"{}/knowledgebases/{}/retrieve?api-version={}",
self.service_url, knowledge_base_name, self.api_version,
);
let mut body = json!({
"search": query,
"top": top,
"includeSynthesis": include_synthesis,
"includeFullBody": include_full_body,
});
if let Some(filter) = odata_filter {
body["filter"] = json!(filter);
}
if let Some(c) = cursor {
body["continuationToken"] = json!(c);
}
let resp = self
.http
.post(&url)
.header("Authorization", self.bearer())
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| McpError::Unavailable(format!("KB search request failed: {e}")))?;
let status = resp.status();
let text = resp
.text()
.await
.map_err(|e| McpError::Internal(format!("read KB response: {e}")))?;
if !status.is_success() {
return Err(McpError::Unavailable(format!(
"KB search returned {status}: {text}"
)));
}
let val: Value = serde_json::from_str(&text)
.map_err(|e| McpError::Internal(format!("parse KB response: {e}")))?;
parse_kb_response(val)
}
async fn search_index(
&self,
index_name: &str,
query: &str,
odata_filter: Option<&str>,
top: usize,
cursor: Option<&str>,
include_full_body: bool,
) -> Result<RawSearchResponse, McpError> {
let url = format!(
"{}/indexes/{}/docs/search?api-version={}",
self.service_url, index_name, self.api_version,
);
let mut body = json!({
"search": query,
"top": top,
"select": "*",
"count": true,
"queryType": "full",
"searchMode": "any",
"vectorQueries": [],
});
if let Some(filter) = odata_filter {
body["filter"] = json!(filter);
}
if let Some(c) = cursor {
body["continuationToken"] = json!(c);
}
if !include_full_body {
body["highlight"] = json!("body");
body["highlightPreTag"] = json!("<mark>");
body["highlightPostTag"] = json!("</mark>");
}
let resp = self
.http
.post(&url)
.header("Authorization", self.bearer())
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.map_err(|e| McpError::Unavailable(format!("index search request failed: {e}")))?;
let status = resp.status();
let text = resp
.text()
.await
.map_err(|e| McpError::Internal(format!("read index response: {e}")))?;
if !status.is_success() {
return Err(McpError::Unavailable(format!(
"index search returned {status}: {text}"
)));
}
let val: Value = serde_json::from_str(&text)
.map_err(|e| McpError::Internal(format!("parse index response: {e}")))?;
parse_index_response(val, include_full_body)
}
}
fn parse_kb_response(val: Value) -> Result<RawSearchResponse, McpError> {
let hits = parse_hits(val.get("value"), true)?;
let answer = val
.pointer("/synthesis/answer")
.and_then(Value::as_str)
.map(String::from);
let citations = val
.pointer("/synthesis/citations")
.and_then(Value::as_array)
.cloned();
let next_cursor = extract_continuation(&val);
let total_estimate = hits.len() as u64;
Ok(RawSearchResponse {
hits,
answer,
citations,
next_cursor,
total_estimate,
})
}
fn parse_index_response(
val: Value,
include_full_body: bool,
) -> Result<RawSearchResponse, McpError> {
let total_estimate = val.get("@odata.count").and_then(Value::as_u64).unwrap_or(0);
let hits = parse_hits(val.get("value"), include_full_body)?;
let next_cursor = extract_continuation(&val);
Ok(RawSearchResponse {
hits,
answer: None,
citations: None,
next_cursor,
total_estimate,
})
}
fn parse_hits(value: Option<&Value>, include_full_body: bool) -> Result<Vec<RawHit>, McpError> {
let arr = match value.and_then(Value::as_array) {
Some(a) => a,
None => return Ok(vec![]),
};
let mut hits = Vec::with_capacity(arr.len());
for item in arr {
let id = item
.get("id")
.and_then(Value::as_str)
.unwrap_or("")
.to_string();
let score = item
.get("@search.score")
.and_then(Value::as_f64)
.unwrap_or(0.0);
let snippet = item
.pointer("/@search.highlights/body/0")
.and_then(Value::as_str)
.map(String::from);
let body = if include_full_body {
item.get("body").and_then(Value::as_str).map(String::from)
} else {
None
};
let fields = strip_system_fields(item);
hits.push(RawHit {
id,
score,
fields,
snippet,
body,
});
}
Ok(hits)
}
fn extract_continuation(val: &Value) -> Option<String> {
if let Some(npp) = val.get("@search.nextPageParameters")
&& let Ok(encoded) = serde_json::to_string(npp)
{
use base64::Engine;
return Some(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(encoded));
}
val.get("continuationToken")
.and_then(Value::as_str)
.map(String::from)
}
fn strip_system_fields(item: &Value) -> Value {
match item.as_object() {
None => item.clone(),
Some(map) => {
let filtered: serde_json::Map<String, Value> = map
.iter()
.filter(|(k, _)| !k.starts_with('@'))
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Value::Object(filtered)
}
}
}
pub struct NoOpSearch;
#[async_trait]
impl SearchApiAdapter for NoOpSearch {
async fn search_knowledge_base(
&self,
_knowledge_base_name: &str,
_query: &str,
_odata_filter: Option<&str>,
_top: usize,
_cursor: Option<&str>,
_include_synthesis: bool,
_include_full_body: bool,
) -> Result<RawSearchResponse, McpError> {
Ok(RawSearchResponse {
hits: vec![],
answer: None,
citations: None,
next_cursor: None,
total_estimate: 0,
})
}
async fn search_index(
&self,
_index_name: &str,
_query: &str,
_odata_filter: Option<&str>,
_top: usize,
_cursor: Option<&str>,
_include_full_body: bool,
) -> Result<RawSearchResponse, McpError> {
Ok(RawSearchResponse {
hits: vec![],
answer: None,
citations: None,
next_cursor: None,
total_estimate: 0,
})
}
}
#[cfg(test)]
pub mod mock {
use super::*;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub enum SearchCall {
KnowledgeBase {
knowledge_base_name: String,
query: String,
odata_filter: Option<String>,
top: usize,
cursor: Option<String>,
include_synthesis: bool,
include_full_body: bool,
},
Index {
index_name: String,
query: String,
odata_filter: Option<String>,
top: usize,
cursor: Option<String>,
include_full_body: bool,
},
}
#[derive(Default, Clone)]
pub struct MockSearchApi {
pub calls: Arc<Mutex<Vec<SearchCall>>>,
pub kb_response: Option<Arc<dyn Fn() -> RawSearchResponse + Send + Sync>>,
pub index_response: Option<Arc<dyn Fn() -> RawSearchResponse + Send + Sync>>,
}
impl MockSearchApi {
pub fn new() -> Self {
Self::default()
}
pub fn with_kb_response<F>(mut self, f: F) -> Self
where
F: Fn() -> RawSearchResponse + Send + Sync + 'static,
{
self.kb_response = Some(Arc::new(f));
self
}
pub fn with_index_response<F>(mut self, f: F) -> Self
where
F: Fn() -> RawSearchResponse + Send + Sync + 'static,
{
self.index_response = Some(Arc::new(f));
self
}
pub fn calls_snapshot(&self) -> Vec<SearchCall> {
self.calls.lock().unwrap().clone()
}
pub fn default_response() -> RawSearchResponse {
RawSearchResponse {
hits: vec![RawHit {
id: "hit-1".to_string(),
score: 0.9,
fields: serde_json::json!({
"source_link": "https://example.com/issue/1",
"summary": "Test hit",
}),
snippet: Some("…relevant snippet…".to_string()),
body: None,
}],
answer: None,
citations: None,
next_cursor: None,
total_estimate: 1,
}
}
pub fn agentic_response() -> RawSearchResponse {
RawSearchResponse {
hits: vec![RawHit {
id: "hit-a".to_string(),
score: 0.95,
fields: serde_json::json!({
"source_link": "https://example.com/page/1",
"title": "Relevant Page",
}),
snippet: Some("…agentic snippet…".to_string()),
body: None,
}],
answer: Some("The synthesised answer is: 42.".to_string()),
citations: Some(vec![
serde_json::json!({"url": "https://example.com/page/1"}),
]),
next_cursor: None,
total_estimate: 1,
}
}
pub fn full_body_response() -> RawSearchResponse {
RawSearchResponse {
hits: vec![RawHit {
id: "hit-b".to_string(),
score: 0.85,
fields: serde_json::json!({
"source_link": "https://example.com/issue/2",
"summary": "Full body test",
}),
snippet: Some("…snippet…".to_string()),
body: Some("Full body content here.".to_string()),
}],
answer: None,
citations: None,
next_cursor: None,
total_estimate: 1,
}
}
}
#[async_trait]
impl SearchApiAdapter for MockSearchApi {
async fn search_knowledge_base(
&self,
knowledge_base_name: &str,
query: &str,
odata_filter: Option<&str>,
top: usize,
cursor: Option<&str>,
include_synthesis: bool,
include_full_body: bool,
) -> Result<RawSearchResponse, McpError> {
self.calls.lock().unwrap().push(SearchCall::KnowledgeBase {
knowledge_base_name: knowledge_base_name.to_string(),
query: query.to_string(),
odata_filter: odata_filter.map(String::from),
top,
cursor: cursor.map(String::from),
include_synthesis,
include_full_body,
});
Ok(if let Some(f) = &self.kb_response {
f()
} else {
MockSearchApi::default_response()
})
}
async fn search_index(
&self,
index_name: &str,
query: &str,
odata_filter: Option<&str>,
top: usize,
cursor: Option<&str>,
include_full_body: bool,
) -> Result<RawSearchResponse, McpError> {
self.calls.lock().unwrap().push(SearchCall::Index {
index_name: index_name.to_string(),
query: query.to_string(),
odata_filter: odata_filter.map(String::from),
top,
cursor: cursor.map(String::from),
include_full_body,
});
Ok(if let Some(f) = &self.index_response {
f()
} else {
MockSearchApi::default_response()
})
}
}
}