use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum RerankModel {
#[default]
Rerank,
}
#[derive(Debug, Clone, Serialize)]
pub struct RerankBody {
pub model: RerankModel,
pub query: String,
pub documents: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_n: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub return_documents: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub return_raw_scores: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_id: Option<String>,
}
impl RerankBody {
pub fn new(model: RerankModel, query: impl Into<String>, documents: Vec<String>) -> Self {
Self {
model,
query: query.into(),
documents,
top_n: None,
return_documents: None,
return_raw_scores: None,
request_id: None,
user_id: None,
}
}
pub fn with_top_n(mut self, n: usize) -> Self {
self.top_n = Some(n);
self
}
pub fn with_return_documents(mut self, v: bool) -> Self {
self.return_documents = Some(v);
self
}
pub fn with_return_raw_scores(mut self, v: bool) -> Self {
self.return_raw_scores = Some(v);
self
}
pub fn with_request_id(mut self, v: impl Into<String>) -> Self {
self.request_id = Some(v.into());
self
}
pub fn with_user_id(mut self, v: impl Into<String>) -> Self {
self.user_id = Some(v.into());
self
}
pub fn validate_constraints(&self) -> crate::ZaiResult<()> {
if self.query.chars().count() > 4096 {
return Err(crate::client::error::ZaiError::ApiError {
code: 1200,
message: "query length exceeds 4096 characters".to_string(),
});
}
if self.documents.is_empty() {
return Err(crate::client::error::ZaiError::ApiError {
code: 1200,
message: "documents must not be empty".to_string(),
});
}
if self.documents.len() > 128 {
return Err(crate::client::error::ZaiError::ApiError {
code: 1200,
message: "documents length exceeds 128".to_string(),
});
}
for (i, d) in self.documents.iter().enumerate() {
if d.chars().count() > 4096 {
return Err(crate::client::error::ZaiError::ApiError {
code: 1200,
message: format!("document at index {} exceeds 4096 characters", i),
});
}
}
if let Some(n) = self.top_n
&& n > self.documents.len()
{
return Err(crate::client::error::ZaiError::ApiError {
code: 1200,
message: "top_n cannot exceed documents length".to_string(),
});
}
Ok(())
}
}