use std::collections::HashMap;
use std::sync::Arc;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::auth::{ApiKeyAuth, Auth, AuthConfig, JwtAuth, Target};
use crate::client::{ApiKeyPosition, ClientRequest, OramaClient};
use crate::error::Result;
use crate::stream_manager::OramaCoreStream;
use crate::types::*;
use crate::utils::{current_time_millis, format_duration};
const DEFAULT_READER_URL: &str = "https://collections.orama.com";
const DEFAULT_JWT_URL: &str = "https://app.orama.com/api/user/jwt";
#[derive(Debug, Clone)]
pub struct CollectionManagerConfig {
pub collection_id: String,
pub api_key: String,
pub cluster: Option<ClusterConfig>,
pub auth_jwt_url: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ClusterConfig {
pub writer_url: Option<String>,
pub read_url: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct NlpSearchParams {
pub query: String,
#[serde(rename = "LLMConfig", skip_serializing_if = "Option::is_none")]
pub llm_config: Option<LlmConfig>,
#[serde(rename = "userID", skip_serializing_if = "Option::is_none")]
pub user_id: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct CreateIndexParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub embeddings: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize)]
pub struct AddHookConfig {
pub name: Hook,
pub code: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct NewHookResponse {
#[serde(rename = "hookID")]
pub hook_id: String,
pub code: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct ExecuteToolsBody {
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_ids: Option<Vec<String>>,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub llm_config: Option<LlmConfig>,
}
#[derive(Debug, Clone)]
pub struct AiNamespace {
client: OramaClient,
collection_id: String,
}
impl AiNamespace {
pub(crate) fn new(client: OramaClient, collection_id: String) -> Self {
Self {
client,
collection_id,
}
}
pub async fn nlp_search<T>(&self, params: NlpSearchParams) -> Result<Vec<NlpSearchResult<T>>>
where
T: for<'de> serde::Deserialize<'de>,
{
let request = ClientRequest::post(
format!("/v1/collections/{}/nlp_search", self.collection_id),
Target::Reader,
ApiKeyPosition::QueryParams,
params,
);
self.client.request(request).await
}
pub async fn create_ai_session(&self) -> Result<OramaCoreStream> {
OramaCoreStream::new(self.collection_id.clone(), self.client.clone()).await
}
}
#[derive(Debug, Clone)]
pub struct CollectionsNamespace {
client: OramaClient,
collection_id: String,
}
impl CollectionsNamespace {
pub(crate) fn new(client: OramaClient, collection_id: String) -> Self {
Self {
client,
collection_id,
}
}
pub async fn get_stats(&self) -> Result<serde_json::Value> {
let request = ClientRequest::<()>::get(
format!("/v1/collections/{}/stats", self.collection_id),
Target::Reader,
ApiKeyPosition::QueryParams,
);
self.client.request(request).await
}
pub async fn get_all_docs<T>(&self, id: &str) -> Result<Vec<T>>
where
T: for<'de> serde::Deserialize<'de>,
{
let body = serde_json::json!({ "id": id });
let request = ClientRequest::post(
"/v1/collections/list".to_string(),
Target::Writer,
ApiKeyPosition::Header,
body,
);
self.client.request(request).await
}
}
#[derive(Debug, Clone)]
pub struct IndexNamespace {
client: OramaClient,
collection_id: String,
}
impl IndexNamespace {
pub(crate) fn new(client: OramaClient, collection_id: String) -> Self {
Self {
client,
collection_id,
}
}
pub async fn create(&self, config: CreateIndexParams) -> Result<()> {
let body = serde_json::json!({
"id": config.id,
"embedding": config.embeddings
});
let request = ClientRequest::post(
format!("/v1/collections/{}/indexes/create", self.collection_id),
Target::Writer,
ApiKeyPosition::Header,
body,
);
let _: serde_json::Value = self.client.request(request).await?;
Ok(())
}
pub async fn delete(&self, index_id: &str) -> Result<()> {
let body = serde_json::json!({
"index_id_to_delete": index_id
});
let request = ClientRequest::post(
format!("/v1/collections/{}/indexes/delete", self.collection_id),
Target::Writer,
ApiKeyPosition::Header,
body,
);
let _: serde_json::Value = self.client.request(request).await?;
Ok(())
}
pub fn set(&self, id: String) -> Index {
Index::new(self.client.clone(), self.collection_id.clone(), id)
}
}
#[derive(Debug, Clone)]
pub struct HooksNamespace {
client: OramaClient,
collection_id: String,
}
impl HooksNamespace {
pub(crate) fn new(client: OramaClient, collection_id: String) -> Self {
Self {
client,
collection_id,
}
}
pub async fn insert(&self, config: AddHookConfig) -> Result<NewHookResponse> {
let body = serde_json::json!({
"name": config.name,
"code": config.code
});
let request = ClientRequest::post(
format!("/v1/collections/{}/hooks/set", self.collection_id),
Target::Writer,
ApiKeyPosition::Header,
body,
);
let _: serde_json::Value = self.client.request(request).await?;
Ok(NewHookResponse {
hook_id: serde_json::to_string(&config.name)?,
code: config.code,
})
}
pub async fn list(&self) -> Result<HashMap<String, Option<String>>> {
let request = ClientRequest::<()>::get(
format!("/v1/collections/{}/hooks/list", self.collection_id),
Target::Writer,
ApiKeyPosition::Header,
);
let response: serde_json::Value = self.client.request(request).await?;
let empty_map = serde_json::Map::new();
let hooks = response["hooks"].as_object().unwrap_or(&empty_map);
let mut result = HashMap::new();
for (key, value) in hooks {
let val = value.as_str().map(|s| s.to_string());
result.insert(key.clone(), val);
}
Ok(result)
}
pub async fn delete(&self, hook: Hook) -> Result<()> {
let body = serde_json::json!({
"name_to_delete": hook
});
let request = ClientRequest::post(
format!("/v1/collections/{}/hooks/delete", self.collection_id),
Target::Writer,
ApiKeyPosition::Header,
body,
);
let _: serde_json::Value = self.client.request(request).await?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SystemPromptsNamespace {
client: OramaClient,
collection_id: String,
}
impl SystemPromptsNamespace {
pub(crate) fn new(client: OramaClient, collection_id: String) -> Self {
Self {
client,
collection_id,
}
}
pub async fn insert(&self, system_prompt: InsertSystemPromptBody) -> Result<serde_json::Value> {
let request = ClientRequest::post(
format!(
"/v1/collections/{}/system_prompts/insert",
self.collection_id
),
Target::Writer,
ApiKeyPosition::Header,
system_prompt,
);
self.client.request(request).await
}
pub async fn get(&self, id: &str) -> Result<SystemPrompt> {
let request = ClientRequest::<()>::get(
format!("/v1/collections/{}/system_prompts/get", self.collection_id),
Target::Reader,
ApiKeyPosition::QueryParams,
)
.with_param("system_prompt_id", id);
let response: serde_json::Value = self.client.request(request).await?;
let prompt = response["system_prompt"].clone();
Ok(serde_json::from_value(prompt)?)
}
pub async fn get_all(&self) -> Result<Vec<SystemPrompt>> {
let request = ClientRequest::<()>::get(
format!("/v1/collections/{}/system_prompts/all", self.collection_id),
Target::Reader,
ApiKeyPosition::QueryParams,
);
let response: serde_json::Value = self.client.request(request).await?;
let prompts = response["system_prompts"].clone();
Ok(serde_json::from_value(prompts)?)
}
pub async fn delete(&self, id: &str) -> Result<serde_json::Value> {
let body = serde_json::json!({ "id": id });
let request = ClientRequest::post(
format!(
"/v1/collections/{}/system_prompts/delete",
self.collection_id
),
Target::Writer,
ApiKeyPosition::Header,
body,
);
self.client.request(request).await
}
pub async fn update(&self, system_prompt: SystemPrompt) -> Result<serde_json::Value> {
let request = ClientRequest::post(
format!(
"/v1/collections/{}/system_prompts/update",
self.collection_id
),
Target::Writer,
ApiKeyPosition::Header,
system_prompt,
);
self.client.request(request).await
}
pub async fn validate(
&self,
system_prompt: SystemPrompt,
) -> Result<SystemPromptValidationResponse> {
let request = ClientRequest::post(
format!(
"/v1/collections/{}/system_prompts/validate",
self.collection_id
),
Target::Writer,
ApiKeyPosition::Header,
system_prompt,
);
let response: serde_json::Value = self.client.request(request).await?;
let result = response["result"].clone();
Ok(serde_json::from_value(result)?)
}
}
#[derive(Debug, Clone)]
pub struct ToolsNamespace {
client: OramaClient,
collection_id: String,
}
impl ToolsNamespace {
pub(crate) fn new(client: OramaClient, collection_id: String) -> Self {
Self {
client,
collection_id,
}
}
pub async fn insert(&self, tool: InsertToolBody) -> Result<()> {
let request = ClientRequest::post(
format!("/v1/collections/{}/tools/insert", self.collection_id),
Target::Writer,
ApiKeyPosition::Header,
tool,
);
let _: serde_json::Value = self.client.request(request).await?;
Ok(())
}
pub async fn get(&self, id: &str) -> Result<Tool> {
let request = ClientRequest::<()>::get(
format!("/v1/collections/{}/tools/get", self.collection_id),
Target::Reader,
ApiKeyPosition::QueryParams,
)
.with_param("tool_id", id);
let response: serde_json::Value = self.client.request(request).await?;
let tool = response["tool"].clone();
Ok(serde_json::from_value(tool)?)
}
pub async fn get_all(&self) -> Result<Vec<Tool>> {
let request = ClientRequest::<()>::get(
format!("/v1/collections/{}/tools/all", self.collection_id),
Target::Reader,
ApiKeyPosition::QueryParams,
);
let response: serde_json::Value = self.client.request(request).await?;
let tools = response["tools"].clone();
Ok(serde_json::from_value(tools)?)
}
pub async fn delete(&self, id: &str) -> Result<serde_json::Value> {
let body = serde_json::json!({ "id": id });
let request = ClientRequest::post(
format!("/v1/collections/{}/tools/delete", self.collection_id),
Target::Writer,
ApiKeyPosition::Header,
body,
);
self.client.request(request).await
}
pub async fn update(&self, tool: UpdateToolBody) -> Result<serde_json::Value> {
let request = ClientRequest::post(
format!("/v1/collections/{}/tools/update", self.collection_id),
Target::Writer,
ApiKeyPosition::Header,
tool,
);
self.client.request(request).await
}
pub async fn execute<T>(&self, tools: ExecuteToolsBody) -> Result<ExecuteToolsParsedResponse<T>>
where
T: for<'de> serde::Deserialize<'de>,
{
let request = ClientRequest::post(
format!("/v1/collections/{}/tools/run", self.collection_id),
Target::Reader,
ApiKeyPosition::QueryParams,
tools,
);
self.client.request(request).await
}
}
#[derive(Debug, Clone)]
pub struct Index {
client: OramaClient,
collection_id: String,
index_id: String,
}
impl Index {
pub(crate) fn new(client: OramaClient, collection_id: String, index_id: String) -> Self {
Self {
client,
collection_id,
index_id,
}
}
pub async fn reindex(&self) -> Result<()> {
let request = ClientRequest::<()>::post(
format!(
"/v1/collections/{}/indexes/{}/reindex",
self.collection_id, self.index_id
),
Target::Writer,
ApiKeyPosition::Header,
(),
);
let _: serde_json::Value = self.client.request(request).await?;
Ok(())
}
pub async fn insert_documents<T>(&self, documents: Vec<T>) -> Result<()>
where
T: Serialize,
{
let body = serde_json::json!({
"documents": documents
});
let request = ClientRequest::post(
format!(
"/v1/collections/{}/indexes/{}/documents/insert",
self.collection_id, self.index_id
),
Target::Writer,
ApiKeyPosition::Header,
body,
);
let _: serde_json::Value = self.client.request(request).await?;
Ok(())
}
pub async fn delete_documents(&self, document_ids: Vec<String>) -> Result<()> {
let body = serde_json::json!({
"document_ids": document_ids
});
let request = ClientRequest::post(
format!(
"/v1/collections/{}/indexes/{}/documents/delete",
self.collection_id, self.index_id
),
Target::Writer,
ApiKeyPosition::Header,
body,
);
let _: serde_json::Value = self.client.request(request).await?;
Ok(())
}
pub async fn upsert_documents<T>(&self, documents: Vec<T>) -> Result<()>
where
T: Serialize,
{
let body = serde_json::json!({
"documents": documents
});
let request = ClientRequest::post(
format!(
"/v1/collections/{}/indexes/{}/documents/upsert",
self.collection_id, self.index_id
),
Target::Writer,
ApiKeyPosition::Header,
body,
);
let _: serde_json::Value = self.client.request(request).await?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct CollectionManager {
client: OramaClient,
collection_id: String,
pub ai: AiNamespace,
pub collections: CollectionsNamespace,
pub index: IndexNamespace,
pub hooks: HooksNamespace,
pub system_prompts: SystemPromptsNamespace,
pub tools: ToolsNamespace,
}
impl CollectionManager {
pub async fn new(config: CollectionManagerConfig) -> Result<Self> {
let auth_config = if config.api_key.starts_with("p_") {
AuthConfig::Jwt(
JwtAuth::new(
config.auth_jwt_url.as_deref().unwrap_or(DEFAULT_JWT_URL),
&config.collection_id,
&config.api_key,
)
.with_reader_url(
config
.cluster
.as_ref()
.and_then(|c| c.read_url.as_deref())
.unwrap_or(DEFAULT_READER_URL),
)
.with_writer_url(
config
.cluster
.as_ref()
.and_then(|c| c.writer_url.as_deref())
.unwrap_or(""),
),
)
} else {
AuthConfig::ApiKey(
ApiKeyAuth::new(&config.api_key)
.with_reader_url(
config
.cluster
.as_ref()
.and_then(|c| c.read_url.as_deref())
.unwrap_or(DEFAULT_READER_URL),
)
.with_writer_url(
config
.cluster
.as_ref()
.and_then(|c| c.writer_url.as_deref())
.unwrap_or(""),
),
)
};
let client = Client::new();
let auth = Auth::new(auth_config, Arc::new(client));
let orama_client = OramaClient::new(auth)?;
let collection_id = config.collection_id.clone();
Ok(Self {
ai: AiNamespace::new(orama_client.clone(), collection_id.clone()),
collections: CollectionsNamespace::new(orama_client.clone(), collection_id.clone()),
index: IndexNamespace::new(orama_client.clone(), collection_id.clone()),
hooks: HooksNamespace::new(orama_client.clone(), collection_id.clone()),
system_prompts: SystemPromptsNamespace::new(
orama_client.clone(),
collection_id.clone(),
),
tools: ToolsNamespace::new(orama_client.clone(), collection_id.clone()),
client: orama_client,
collection_id,
})
}
pub async fn search<T>(&self, query: &SearchParams) -> Result<SearchResult<T>>
where
T: for<'de> serde::Deserialize<'de>,
{
let start_time = current_time_millis();
let request = ClientRequest::post(
format!("/v1/collections/{}/search", self.collection_id),
Target::Reader,
ApiKeyPosition::QueryParams,
query,
);
let mut result: SearchResult<T> = self.client.request(request).await?;
let elapsed_time = current_time_millis() - start_time;
result.elapsed = Some(Elapsed {
raw: elapsed_time,
formatted: format_duration(elapsed_time),
});
Ok(result)
}
}
impl CollectionManagerConfig {
pub fn new<S: Into<String>>(collection_id: S, api_key: S) -> Self {
Self {
collection_id: collection_id.into(),
api_key: api_key.into(),
cluster: None,
auth_jwt_url: None,
}
}
pub fn with_cluster(mut self, cluster: ClusterConfig) -> Self {
self.cluster = Some(cluster);
self
}
pub fn with_auth_jwt_url<S: Into<String>>(mut self, url: S) -> Self {
self.auth_jwt_url = Some(url.into());
self
}
}
impl ClusterConfig {
pub fn new() -> Self {
Self {
writer_url: None,
read_url: None,
}
}
pub fn with_writer_url<S: Into<String>>(mut self, url: S) -> Self {
self.writer_url = Some(url.into());
self
}
pub fn with_read_url<S: Into<String>>(mut self, url: S) -> Self {
self.read_url = Some(url.into());
self
}
}
impl Default for ClusterConfig {
fn default() -> Self {
Self::new()
}
}