use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use super::auth::EndpointAuthConfig;
use crate::llm::ReasoningEffort;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum EmbeddingAdapter {
Tei,
OpenAi,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct EmbeddingConfig {
pub base_url: String,
#[serde(flatten)]
pub auth: EndpointAuthConfig,
pub adapter: EmbeddingAdapter,
pub model: String,
#[serde(default)]
pub document_model: Option<String>,
#[serde(default = "default_embedding_context_tokens")]
pub context_tokens: u32,
}
impl EmbeddingConfig {
#[must_use]
pub fn document_model(&self) -> &str {
self.document_model.as_deref().unwrap_or(&self.model)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum RerankAdapter {
Tei,
Minimal,
Cohere,
Jina,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum RerankScoreScale {
#[default]
Normalized,
Logits,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct RerankConfig {
pub base_url: String,
#[serde(flatten)]
pub auth: EndpointAuthConfig,
pub adapter: RerankAdapter,
pub model: String,
#[serde(default)]
pub score_scale: RerankScoreScale,
#[serde(default = "default_rerank_truncate")]
pub truncate: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "kebab-case")]
pub enum ChatAdapter {
#[default]
OpenAi,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ChatExpansionConfig {
pub base_url: String,
#[serde(flatten)]
pub auth: EndpointAuthConfig,
#[serde(default)]
pub adapter: ChatAdapter,
pub model: String,
#[serde(default = "default_chat_context_tokens")]
pub context_tokens: u32,
#[serde(default)]
pub max_output_tokens: Option<u32>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(deny_unknown_fields)]
pub struct ChatAskConfig {
#[serde(default)]
pub base_url: Option<String>,
#[serde(flatten)]
pub auth: EndpointAuthConfig,
#[serde(default)]
pub adapter: Option<ChatAdapter>,
#[serde(default)]
pub model: Option<String>,
#[serde(default = "default_ask_context_tokens")]
pub context_tokens: u32,
#[serde(default = "default_ask_max_output_tokens")]
pub max_output_tokens: u32,
#[serde(default)]
pub planning_enable_thinking: Option<bool>,
#[serde(default)]
pub synthesis_enable_thinking: Option<bool>,
#[serde(default)]
pub planning_reasoning_effort: Option<ReasoningEffort>,
#[serde(default)]
pub synthesis_reasoning_effort: Option<ReasoningEffort>,
#[serde(default)]
pub planning_chat_template_kwargs: Option<BTreeMap<String, serde_json::Value>>,
#[serde(default)]
pub synthesis_chat_template_kwargs: Option<BTreeMap<String, serde_json::Value>>,
}
impl ChatAskConfig {
#[must_use]
pub fn resolved_base_url<'a>(&'a self, expansion: &'a ChatExpansionConfig) -> &'a str {
self.base_url
.as_deref()
.filter(|url| !url.is_empty())
.unwrap_or(expansion.base_url.as_str())
}
#[must_use]
pub fn resolved_model<'a>(&'a self, expansion: &'a ChatExpansionConfig) -> &'a str {
self.model
.as_deref()
.filter(|model| !model.is_empty())
.unwrap_or(expansion.model.as_str())
}
#[must_use]
pub fn resolved_adapter(&self, expansion: &ChatExpansionConfig) -> ChatAdapter {
self.adapter.unwrap_or(expansion.adapter)
}
#[must_use]
pub fn resolved_auth(&self, expansion: &ChatExpansionConfig) -> EndpointAuthConfig {
EndpointAuthConfig {
credential: self
.auth
.credential
.clone()
.or_else(|| expansion.auth.credential.clone()),
api_key: self
.auth
.api_key
.clone()
.or_else(|| expansion.auth.api_key.clone()),
api_key_env: self
.auth
.api_key_env
.clone()
.or_else(|| expansion.auth.api_key_env.clone()),
extra_headers: merge_headers(&expansion.auth.extra_headers, &self.auth.extra_headers),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ChatSection {
pub expansion: ChatExpansionConfig,
#[serde(default)]
pub ask: ChatAskConfig,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct McpConfig {
#[serde(default)]
pub hooks: McpHooksConfig,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct McpHooksConfig {
#[serde(default = "default_recall_deadline_ms")]
pub recall_deadline_ms: u64,
}
impl Default for McpHooksConfig {
fn default() -> Self {
Self {
recall_deadline_ms: default_recall_deadline_ms(),
}
}
}
fn merge_headers(
base: &BTreeMap<String, String>,
override_headers: &BTreeMap<String, String>,
) -> BTreeMap<String, String> {
let mut merged = base.clone();
merged.extend(override_headers.iter().map(|(k, v)| (k.clone(), v.clone())));
merged
}
const fn default_embedding_context_tokens() -> u32 {
512
}
const fn default_rerank_truncate() -> bool {
true
}
const fn default_chat_context_tokens() -> u32 {
32_768
}
const fn default_ask_context_tokens() -> u32 {
65_536
}
const fn default_ask_max_output_tokens() -> u32 {
2_048
}
const fn default_recall_deadline_ms() -> u64 {
20_000
}