1use serde::{Deserialize, Serialize};
8
9pub const MAX_INSTRUCTION_CHARS: usize = 400;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
20pub enum RerankParam {
21 #[default]
23 #[serde(rename = "rerank-2.5")]
24 Rerank25,
25 #[serde(rename = "rerank-2.5-lite")]
27 Rerank25Lite,
28 #[serde(rename = "none")]
30 None,
31}
32
33impl RerankParam {
34 #[must_use]
36 pub const fn model_name(self) -> Option<&'static str> {
37 match self {
38 Self::Rerank25 => Some("rerank-2.5"),
39 Self::Rerank25Lite => Some("rerank-2.5-lite"),
40 Self::None => None,
41 }
42 }
43
44 #[must_use]
48 pub const fn billed_tokens(self, total_tokens: u64) -> u64 {
49 match self {
50 Self::Rerank25Lite => total_tokens.div_ceil(2),
51 _ => total_tokens,
52 }
53 }
54}
55
56pub const RERANK_REASONS: &[&str] = &[
63 "not_requested",
64 "token_budget_exhausted",
65 "provider_error",
66 "disabled",
67];
68
69#[must_use]
79pub fn known_reason(raw: &str) -> Option<&'static str> {
80 RERANK_REASONS.iter().copied().find(|&r| r == raw)
81}
82
83pub fn validate_instruction(instruction: &str) -> Result<(), String> {
90 let n = instruction.chars().count();
91 if n > MAX_INSTRUCTION_CHARS {
92 return Err(format!(
93 "rerank_instructions is {n} characters; the cap is {MAX_INSTRUCTION_CHARS}. \
94 Shorter instructions also cost fewer tokens (the instruction is \
95 multiplied by the candidate-pool size)."
96 ));
97 }
98 Ok(())
99}
100
101#[must_use]
108pub fn default_instruction(code_exclusive: bool, version: Option<(&str, &str)>) -> Option<String> {
109 let mut parts: Vec<String> = Vec::new();
110 if code_exclusive {
111 parts.push(
112 "Prioritize chunks containing code examples, function signatures, and API usage \
113 over prose."
114 .to_owned(),
115 );
116 }
117 if let Some((name, ver)) = version {
118 parts.push(format!(
119 "Prefer content applying to {name} version {ver}; deprioritize other versions."
120 ));
121 }
122 if parts.is_empty() {
123 None
124 } else {
125 Some(parts.join(" "))
126 }
127}
128
129#[must_use]
136pub fn compose_rerank_query(query: &str, instruction: Option<&str>) -> String {
137 match instruction.map(str::trim) {
138 Some(i) if !i.is_empty() => format!("{query}\nInstructions: {i}"),
139 _ => query.to_owned(),
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::*;
146
147 #[test]
148 fn rerank_param_wire_values_round_trip() {
149 for (variant, wire) in [
150 (RerankParam::Rerank25, "\"rerank-2.5\""),
151 (RerankParam::Rerank25Lite, "\"rerank-2.5-lite\""),
152 (RerankParam::None, "\"none\""),
153 ] {
154 assert_eq!(serde_json::to_string(&variant).unwrap(), wire);
155 let back: RerankParam = serde_json::from_str(wire).unwrap();
156 assert_eq!(back, variant);
157 }
158 assert_eq!(RerankParam::default(), RerankParam::Rerank25);
160 }
161
162 #[test]
163 fn model_name_is_none_only_for_none() {
164 assert_eq!(RerankParam::Rerank25.model_name(), Some("rerank-2.5"));
165 assert_eq!(RerankParam::Rerank25Lite.model_name(), Some("rerank-2.5-lite"));
166 assert_eq!(RerankParam::None.model_name(), None);
167 }
168
169 #[test]
170 fn lite_bills_half_rounded_up() {
171 assert_eq!(RerankParam::Rerank25.billed_tokens(1001), 1001);
173 assert_eq!(RerankParam::Rerank25Lite.billed_tokens(1000), 500);
174 assert_eq!(RerankParam::Rerank25Lite.billed_tokens(1001), 501);
175 assert_eq!(RerankParam::Rerank25Lite.billed_tokens(0), 0);
176 assert_eq!(RerankParam::Rerank25Lite.billed_tokens(1), 1);
177 assert_eq!(RerankParam::None.billed_tokens(10), 10);
179 }
180
181 #[test]
182 fn instruction_cap_is_400_chars() {
183 assert!(validate_instruction(&"x".repeat(400)).is_ok());
184 let err = validate_instruction(&"x".repeat(401)).unwrap_err();
185 assert!(err.contains("400"), "error should name the cap: {err}");
186 assert!(validate_instruction(&"é".repeat(400)).is_ok());
188 }
189
190 #[test]
191 fn default_instruction_rule_table() {
192 assert_eq!(default_instruction(false, None), None);
194 let code = default_instruction(true, None).unwrap();
196 assert!(code.contains("code examples"));
197 let ver = default_instruction(false, Some(("compact", "0.31"))).unwrap();
199 assert!(ver.contains("compact") && ver.contains("0.31"));
200 let both = default_instruction(true, Some(("compact", "0.31"))).unwrap();
202 assert!(both.contains("code examples") && both.contains("0.31"));
203 }
204
205 #[test]
206 fn known_reason_passes_closed_set_and_drops_others() {
207 for r in RERANK_REASONS {
209 assert_eq!(known_reason(r), Some(*r));
210 }
211 assert_eq!(known_reason(""), None);
213 assert_eq!(known_reason("applied"), None);
214 assert_eq!(known_reason("Not_Requested"), None); assert_eq!(known_reason("rate limited: token=eyJhbGci"), None);
216 }
217
218 #[test]
219 fn compose_appends_instruction_to_query() {
220 assert_eq!(compose_rerank_query("how do circuits work", None), "how do circuits work");
221 assert_eq!(compose_rerank_query("q", Some(" ")), "q");
222 let composed = compose_rerank_query("q", Some("Prioritize code."));
223 assert_eq!(composed, "q\nInstructions: Prioritize code.");
224 }
225}