use serde::{Deserialize, Serialize};
use crate::types::SearchRequest;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SearchParams {
pub top_k: Option<usize>,
pub system_prompt: Option<String>,
pub system_prompt_path: Option<String>,
pub wide_search_top_k: Option<usize>,
pub triplet_distance_penalty: Option<f32>,
pub node_type: Option<String>,
pub node_name: Option<Vec<String>>,
pub node_name_filter_operator: Option<String>,
pub feedback_influence: Option<f32>,
pub max_iter: Option<usize>,
pub context_extension_rounds: Option<usize>,
pub response_schema: Option<serde_json::Value>,
pub neighborhood_depth: Option<usize>,
pub neighborhood_seed_top_k: Option<usize>,
}
impl SearchParams {
pub fn top_k_or(&self, default: usize) -> usize {
self.top_k.unwrap_or(default)
}
pub fn wide_search_top_k_or(&self, default: usize) -> usize {
self.wide_search_top_k.unwrap_or(default)
}
pub fn triplet_distance_penalty_or(&self, default: f32) -> f32 {
self.triplet_distance_penalty.unwrap_or(default)
}
pub fn feedback_influence_or(&self, default: f32) -> f32 {
self.feedback_influence.unwrap_or(default)
}
}
impl From<&SearchRequest> for SearchParams {
fn from(req: &SearchRequest) -> Self {
Self {
top_k: req.top_k,
system_prompt: req.system_prompt.clone(),
system_prompt_path: req.system_prompt_path.clone(),
wide_search_top_k: req.wide_search_top_k,
triplet_distance_penalty: req.triplet_distance_penalty,
node_type: req.node_type.clone(),
node_name: req.node_name.clone(),
node_name_filter_operator: req.node_name_filter_operator.clone(),
feedback_influence: req.feedback_influence,
max_iter: req
.retriever_specific_config
.as_ref()
.and_then(|c| c.get("max_iter"))
.and_then(|v| v.as_u64())
.map(|v| v as usize),
context_extension_rounds: req
.retriever_specific_config
.as_ref()
.and_then(|c| c.get("context_extension_rounds"))
.and_then(|v| v.as_u64())
.map(|v| v as usize),
response_schema: req.response_schema.clone(),
neighborhood_depth: req.neighborhood_depth,
neighborhood_seed_top_k: req.neighborhood_seed_top_k,
}
}
}