use std::time::Duration;
use secrecy::SecretString;
use serde::{Deserialize, Serialize};
use crate::embedding::EmbedderConfig;
use crate::reranker::RerankerConfig;
use crate::{Dialect, ModelFamily, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppConfig {
pub embedding: Embedding,
pub reranker: Reranker,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Embedding {
#[serde(default)]
pub url: String,
#[serde(skip_serializing)]
pub api_key: Option<SecretString>,
pub model: String,
#[serde(default)]
pub tokenizer: String,
pub dialect: Dialect,
#[serde(default)]
pub model_family: ModelFamily,
#[serde(skip_serializing_if = "Option::is_none")]
pub query_instruction: Option<String>,
pub timeout_seconds: u64,
pub embedding_dim: usize,
#[serde(default = "default_context_length")]
pub context_length: usize,
#[serde(default = "default_max_batch_size")]
pub max_batch_size: usize,
#[serde(default = "default_embedding_workers")]
pub workers: usize,
pub requests_per_minute: usize,
pub max_concurrent_requests: usize,
pub tokens_per_minute: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Reranker {
#[serde(default)]
pub url: String,
#[serde(skip_serializing)]
pub api_key: Option<SecretString>,
pub model: String,
pub dialect: Dialect,
#[serde(default)]
pub model_family: ModelFamily,
pub timeout_seconds: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub instruction: Option<String>,
#[serde(default = "default_reranker_requests_per_minute")]
pub requests_per_minute: usize,
#[serde(default = "default_reranker_max_concurrent_requests")]
pub max_concurrent_requests: usize,
#[serde(default = "default_reranker_tokens_per_minute")]
pub tokens_per_minute: u32,
}
impl Embedding {
pub fn to_embedder_config(&self) -> Result<EmbedderConfig> {
validate_positive_u64(self.timeout_seconds, "embedding.timeout_seconds")?;
validate_positive_usize(self.embedding_dim, "embedding.embedding_dim")?;
validate_positive_usize(self.context_length, "embedding.context_length")?;
validate_positive_usize(self.max_batch_size, "embedding.max_batch_size")?;
validate_positive_usize(self.workers, "embedding.workers")?;
validate_positive_usize(self.requests_per_minute, "embedding.requests_per_minute")?;
validate_positive_usize(
self.max_concurrent_requests,
"embedding.max_concurrent_requests",
)?;
validate_positive_u32(self.tokens_per_minute, "embedding.tokens_per_minute")?;
Ok(EmbedderConfig {
api_key: self.api_key.clone(),
base_url: self.url.clone(),
timeout: Duration::from_secs(self.timeout_seconds),
dialect: self.dialect,
model_family: self.model_family,
model: self.model.clone(),
query_instruction: self.query_instruction.clone(),
embedding_dim: self.embedding_dim,
requests_per_minute: self.requests_per_minute,
max_concurrent_requests: self.max_concurrent_requests,
tokens_per_minute: self.tokens_per_minute,
})
}
}
impl Reranker {
pub fn to_reranker_config(&self) -> Result<RerankerConfig> {
validate_positive_u64(self.timeout_seconds, "reranker.timeout_seconds")?;
validate_positive_usize(self.requests_per_minute, "reranker.requests_per_minute")?;
validate_positive_usize(
self.max_concurrent_requests,
"reranker.max_concurrent_requests",
)?;
validate_positive_u32(self.tokens_per_minute, "reranker.tokens_per_minute")?;
Ok(RerankerConfig {
api_key: self.api_key.clone(),
base_url: self.url.clone(),
timeout: Duration::from_secs(self.timeout_seconds),
dialect: self.dialect,
model_family: self.model_family,
model: self.model.clone(),
instruction: self.instruction.clone(),
requests_per_minute: self.requests_per_minute,
max_concurrent_requests: self.max_concurrent_requests,
tokens_per_minute: self.tokens_per_minute,
})
}
}
fn default_context_length() -> usize {
32_768
}
fn default_max_batch_size() -> usize {
15
}
fn default_embedding_workers() -> usize {
5
}
fn default_reranker_requests_per_minute() -> usize {
1000
}
fn default_reranker_max_concurrent_requests() -> usize {
50
}
fn default_reranker_tokens_per_minute() -> u32 {
1_000_000
}
fn validate_positive_usize(value: usize, field: &'static str) -> Result<()> {
if value == 0 {
return Err(crate::Error::InvalidConfiguration {
message: format!("{field} must be greater than zero"),
});
}
Ok(())
}
fn validate_positive_u64(value: u64, field: &'static str) -> Result<()> {
if value == 0 {
return Err(crate::Error::InvalidConfiguration {
message: format!("{field} must be greater than zero"),
});
}
Ok(())
}
fn validate_positive_u32(value: u32, field: &'static str) -> Result<()> {
if value == 0 {
return Err(crate::Error::InvalidConfiguration {
message: format!("{field} must be greater than zero"),
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embedding_config_accepts_llama_dot_cpp_alias() {
let config: Embedding = serde_json::from_value(serde_json::json!({
"url": "",
"model": "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf",
"tokenizer": "",
"dialect": "llama.cpp",
"model_family": "gemma",
"query_instruction": "retrieve docs",
"timeout_seconds": 30,
"embedding_dim": 768,
"context_length": default_context_length(),
"max_batch_size": default_max_batch_size(),
"workers": default_embedding_workers(),
"requests_per_minute": 1,
"max_concurrent_requests": 1,
"tokens_per_minute": 1_000_000
}))
.unwrap();
let converted = config.to_embedder_config().unwrap();
assert_eq!(converted.dialect, Dialect::LlamaCpp);
assert_eq!(converted.model_family, ModelFamily::Gemma);
assert_eq!(
converted.query_instruction.as_deref(),
Some("retrieve docs")
);
}
#[test]
fn reranker_config_accepts_llama_dot_cpp_alias() {
let config: Reranker = serde_json::from_value(serde_json::json!({
"url": "",
"model": "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf",
"dialect": "llama.cpp",
"model_family": "qwen3",
"timeout_seconds": 30,
"instruction": "rank docs",
"requests_per_minute": 1,
"max_concurrent_requests": 1,
"tokens_per_minute": 1_000_000
}))
.unwrap();
let converted = config.to_reranker_config().unwrap();
assert_eq!(converted.dialect, Dialect::LlamaCpp);
assert_eq!(converted.model_family, ModelFamily::Qwen3);
assert_eq!(converted.instruction.as_deref(), Some("rank docs"));
}
#[test]
fn config_conversion_rejects_zero_limits() {
let embedding: Embedding = serde_json::from_value(serde_json::json!({
"url": "",
"model": "Qwen/Qwen3-Embedding-0.6B",
"dialect": "deepinfra",
"model_family": "qwen3",
"timeout_seconds": 0,
"embedding_dim": 1024,
"requests_per_minute": 1,
"max_concurrent_requests": 1,
"tokens_per_minute": 1
}))
.unwrap();
assert!(matches!(
embedding.to_embedder_config(),
Err(crate::Error::InvalidConfiguration { .. })
));
let reranker: Reranker = serde_json::from_value(serde_json::json!({
"url": "",
"model": "Qwen/Qwen3-Reranker-0.6B",
"dialect": "deepinfra",
"model_family": "qwen3",
"timeout_seconds": 10,
"requests_per_minute": 0,
"max_concurrent_requests": 1,
"tokens_per_minute": 1
}))
.unwrap();
assert!(matches!(
reranker.to_reranker_config(),
Err(crate::Error::InvalidConfiguration { .. })
));
}
}