use crate::core::providers::unified_provider::ProviderError;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct KnowledgeBaseRetrievalRequest {
pub retrieval_query: RetrievalQuery,
#[serde(skip_serializing_if = "Option::is_none")]
pub retrieval_configuration: Option<RetrievalConfiguration>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_token: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RetrievalQuery {
pub text: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrievalConfiguration {
pub vector_search_configuration: VectorSearchConfiguration,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct VectorSearchConfiguration {
pub number_of_results: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub override_search_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub filter: Option<RetrievalFilter>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RetrievalFilter {
pub equals: Option<Value>,
pub not_equals: Option<Value>,
pub greater_than: Option<Value>,
pub greater_than_or_equals: Option<Value>,
pub less_than: Option<Value>,
pub less_than_or_equals: Option<Value>,
#[serde(rename = "in")]
pub in_values: Option<Vec<Value>>,
pub not_in: Option<Vec<Value>>,
pub starts_with: Option<String>,
pub list_contains: Option<Value>,
pub string_contains: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct KnowledgeBaseRetrievalResponse {
pub retrieval_results: Vec<RetrievalResult>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_token: Option<String>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrievalResult {
pub content: RetrievalContent,
pub location: RetrievalLocation,
pub score: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<Value>,
}
#[derive(Debug, Deserialize)]
pub struct RetrievalContent {
pub text: String,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RetrievalLocation {
pub type_: String,
pub s3_location: Option<S3Location>,
pub web_location: Option<WebLocation>,
pub confluence_location: Option<ConfluenceLocation>,
pub salesforce_location: Option<SalesforceLocation>,
pub sharepoint_location: Option<SharepointLocation>,
}
#[derive(Debug, Deserialize)]
pub struct S3Location {
pub uri: String,
}
#[derive(Debug, Deserialize)]
pub struct WebLocation {
pub url: String,
}
#[derive(Debug, Deserialize)]
pub struct ConfluenceLocation {
pub url: String,
}
#[derive(Debug, Deserialize)]
pub struct SalesforceLocation {
pub url: String,
}
#[derive(Debug, Deserialize)]
pub struct SharepointLocation {
pub url: String,
}
pub struct KnowledgeBaseClient<'a> {
client: &'a crate::core::providers::bedrock::client::BedrockClient,
}
impl<'a> KnowledgeBaseClient<'a> {
pub fn new(client: &'a crate::core::providers::bedrock::client::BedrockClient) -> Self {
Self { client }
}
pub async fn retrieve(
&self,
knowledge_base_id: &str,
query: &str,
number_of_results: u32,
) -> Result<KnowledgeBaseRetrievalResponse, ProviderError> {
let request = KnowledgeBaseRetrievalRequest {
retrieval_query: RetrievalQuery {
text: query.to_string(),
},
retrieval_configuration: Some(RetrievalConfiguration {
vector_search_configuration: VectorSearchConfiguration {
number_of_results,
override_search_type: None,
filter: None,
},
}),
next_token: None,
};
let url = format!("knowledgebases/{}/retrieve", knowledge_base_id);
let response = self
.client
.send_request("", &url, &serde_json::to_value(request)?)
.await?;
let kb_response: KnowledgeBaseRetrievalResponse = response
.json()
.await
.map_err(|e| ProviderError::response_parsing("bedrock", e.to_string()))?;
Ok(kb_response)
}
}