use serde::{Deserialize, Serialize};
pub const MAX_INSTRUCTION_CHARS: usize = 400;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum RerankParam {
#[default]
#[serde(rename = "rerank-2.5")]
Rerank25,
#[serde(rename = "rerank-2.5-lite")]
Rerank25Lite,
#[serde(rename = "none")]
None,
}
impl RerankParam {
#[must_use]
pub const fn model_name(self) -> Option<&'static str> {
match self {
Self::Rerank25 => Some("rerank-2.5"),
Self::Rerank25Lite => Some("rerank-2.5-lite"),
Self::None => None,
}
}
#[must_use]
pub const fn billed_tokens(self, total_tokens: u64) -> u64 {
match self {
Self::Rerank25Lite => total_tokens.div_ceil(2),
_ => total_tokens,
}
}
}
pub const RERANK_REASONS: &[&str] = &[
"not_requested",
"token_budget_exhausted",
"provider_error",
"disabled",
];
#[must_use]
pub fn known_reason(raw: &str) -> Option<&'static str> {
RERANK_REASONS.iter().copied().find(|&r| r == raw)
}
pub fn validate_instruction(instruction: &str) -> Result<(), String> {
let n = instruction.chars().count();
if n > MAX_INSTRUCTION_CHARS {
return Err(format!(
"rerank_instructions is {n} characters; the cap is {MAX_INSTRUCTION_CHARS}. \
Shorter instructions also cost fewer tokens (the instruction is \
multiplied by the candidate-pool size)."
));
}
Ok(())
}
#[must_use]
pub fn default_instruction(code_exclusive: bool, version: Option<(&str, &str)>) -> Option<String> {
let mut parts: Vec<String> = Vec::new();
if code_exclusive {
parts.push(
"Prioritize chunks containing code examples, function signatures, and API usage \
over prose."
.to_owned(),
);
}
if let Some((name, ver)) = version {
parts.push(format!(
"Prefer content applying to {name} version {ver}; deprioritize other versions."
));
}
if parts.is_empty() {
None
} else {
Some(parts.join(" "))
}
}
#[must_use]
pub fn compose_rerank_query(query: &str, instruction: Option<&str>) -> String {
match instruction.map(str::trim) {
Some(i) if !i.is_empty() => format!("{query}\nInstructions: {i}"),
_ => query.to_owned(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rerank_param_wire_values_round_trip() {
for (variant, wire) in [
(RerankParam::Rerank25, "\"rerank-2.5\""),
(RerankParam::Rerank25Lite, "\"rerank-2.5-lite\""),
(RerankParam::None, "\"none\""),
] {
assert_eq!(serde_json::to_string(&variant).unwrap(), wire);
let back: RerankParam = serde_json::from_str(wire).unwrap();
assert_eq!(back, variant);
}
assert_eq!(RerankParam::default(), RerankParam::Rerank25);
}
#[test]
fn model_name_is_none_only_for_none() {
assert_eq!(RerankParam::Rerank25.model_name(), Some("rerank-2.5"));
assert_eq!(RerankParam::Rerank25Lite.model_name(), Some("rerank-2.5-lite"));
assert_eq!(RerankParam::None.model_name(), None);
}
#[test]
fn lite_bills_half_rounded_up() {
assert_eq!(RerankParam::Rerank25.billed_tokens(1001), 1001);
assert_eq!(RerankParam::Rerank25Lite.billed_tokens(1000), 500);
assert_eq!(RerankParam::Rerank25Lite.billed_tokens(1001), 501);
assert_eq!(RerankParam::Rerank25Lite.billed_tokens(0), 0);
assert_eq!(RerankParam::Rerank25Lite.billed_tokens(1), 1);
assert_eq!(RerankParam::None.billed_tokens(10), 10);
}
#[test]
fn instruction_cap_is_400_chars() {
assert!(validate_instruction(&"x".repeat(400)).is_ok());
let err = validate_instruction(&"x".repeat(401)).unwrap_err();
assert!(err.contains("400"), "error should name the cap: {err}");
assert!(validate_instruction(&"é".repeat(400)).is_ok());
}
#[test]
fn default_instruction_rule_table() {
assert_eq!(default_instruction(false, None), None);
let code = default_instruction(true, None).unwrap();
assert!(code.contains("code examples"));
let ver = default_instruction(false, Some(("compact", "0.31"))).unwrap();
assert!(ver.contains("compact") && ver.contains("0.31"));
let both = default_instruction(true, Some(("compact", "0.31"))).unwrap();
assert!(both.contains("code examples") && both.contains("0.31"));
}
#[test]
fn known_reason_passes_closed_set_and_drops_others() {
for r in RERANK_REASONS {
assert_eq!(known_reason(r), Some(*r));
}
assert_eq!(known_reason(""), None);
assert_eq!(known_reason("applied"), None);
assert_eq!(known_reason("Not_Requested"), None); assert_eq!(known_reason("rate limited: token=eyJhbGci"), None);
}
#[test]
fn compose_appends_instruction_to_query() {
assert_eq!(compose_rerank_query("how do circuits work", None), "how do circuits work");
assert_eq!(compose_rerank_query("q", Some(" ")), "q");
let composed = compose_rerank_query("q", Some("Prioritize code."));
assert_eq!(composed, "q\nInstructions: Prioritize code.");
}
}