#![allow(clippy::doc_markdown)]
use std::fmt::Debug;
use thiserror::Error;
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum LlmError {
#[error("network error: {0}")]
Network(String),
#[error("authentication failed: {0}")]
Auth(String),
#[error("rate limited: {0}")]
RateLimited(String),
#[error("bad request ({status}): {body}")]
BadRequest {
status: u16,
body: String,
},
#[error("server error ({status}): {body}")]
Server {
status: u16,
body: String,
},
#[error("decode error: {0}")]
Decode(String),
#[error("config error: {0}")]
Config(String),
#[error("empty completion")]
EmptyCompletion,
}
#[derive(Debug, Clone)]
pub struct GenOptions {
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub stop: Option<Vec<String>>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub seed: Option<u64>,
pub n: usize,
pub system: Option<String>,
}
impl GenOptions {
#[must_use]
pub fn single() -> Self {
Self::default()
}
}
impl Default for GenOptions {
fn default() -> Self {
Self {
max_tokens: None,
temperature: None,
top_p: None,
top_k: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
seed: None,
n: 1,
system: None,
}
}
}
pub trait TextGenerator: Send + Sync + Debug {
fn model(&self) -> &str;
fn generate(&self, prompt: &str, opts: &GenOptions) -> Result<Vec<String>, LlmError>;
}
#[derive(Debug, Clone)]
pub struct MockTextGenerator {
model: String,
completions: Vec<String>,
}
impl MockTextGenerator {
#[must_use]
pub fn new(model: impl Into<String>, completions: Vec<String>) -> Self {
Self {
model: model.into(),
completions,
}
}
}
impl Default for MockTextGenerator {
fn default() -> Self {
Self::new("mock:echo", vec!["(mock completion)".to_string()])
}
}
impl TextGenerator for MockTextGenerator {
fn model(&self) -> &str {
&self.model
}
fn generate(&self, _prompt: &str, opts: &GenOptions) -> Result<Vec<String>, LlmError> {
if self.completions.is_empty() {
return Err(LlmError::EmptyCompletion);
}
let mut out = Vec::with_capacity(opts.n);
for i in 0..opts.n {
let idx = i.min(self.completions.len() - 1);
out.push(self.completions[idx].clone());
}
Ok(out)
}
}
#[derive(Debug, Clone, Default)]
pub struct AlwaysFailGenerator;
impl TextGenerator for AlwaysFailGenerator {
fn model(&self) -> &str {
"mock:always-fail"
}
fn generate(&self, _prompt: &str, _opts: &GenOptions) -> Result<Vec<String>, LlmError> {
Err(LlmError::Network(
"intentional failure for test".to_string(),
))
}
}
pub const HYDE_PROMPT_TEMPLATE: &str =
"Write a short, factual passage (2-4 sentences) that would answer the \
question below. Focus on concrete entities, relations, and attributes \
a note-taking system might have recorded. Omit hedging and meta-talk.
Question: {query}
Passage:";
pub const MULTI_QUERY_PROMPT_TEMPLATE: &str =
"You are rewriting a user's query into search variations for a personal \
knowledge graph. Generate {n} alternative queries that together cover \
different angles of the same intent. Suggested angles:
- a broader/more abstract phrasing
- a narrower/more specific phrasing
- a synonymous rephrasing using different key terms
- an entity-centric phrasing (noun-heavy, no verbs)
Do NOT output minor rewordings. Do NOT repeat the original query.
Output exactly {n} lines, one query per line, no numbering.
Original: {query}";
#[must_use]
pub fn fill_template(template: &str, query: &str) -> String {
template.replace("{query}", query)
}
#[must_use]
pub fn fill_multi_query_template(template: &str, query: &str, n: usize) -> String {
template
.replace("{query}", query)
.replace("{n}", &n.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mock_generates_n_completions() {
let g = MockTextGenerator::new(
"mock:echo",
vec!["a".to_string(), "b".to_string(), "c".to_string()],
);
let out = g
.generate(
"ignored",
&GenOptions {
n: 3,
..Default::default()
},
)
.unwrap();
assert_eq!(out, vec!["a", "b", "c"]);
}
#[test]
fn mock_repeats_last_when_n_exceeds_completion_count() {
let g = MockTextGenerator::new("mock:echo", vec!["only".to_string()]);
let out = g
.generate(
"ignored",
&GenOptions {
n: 4,
..Default::default()
},
)
.unwrap();
assert_eq!(out, vec!["only"; 4]);
}
#[test]
fn mock_empty_completions_errors() {
let g = MockTextGenerator::new("mock:echo", vec![]);
let e = g.generate("q", &GenOptions::default()).unwrap_err();
assert!(matches!(e, LlmError::EmptyCompletion));
}
#[test]
fn always_fail_generator_errors() {
let g = AlwaysFailGenerator;
assert!(g.generate("q", &GenOptions::default()).is_err());
}
#[test]
fn model_id_has_provider_prefix() {
let g = MockTextGenerator::default();
assert!(g.model().contains(':'));
}
#[test]
fn fill_template_substitutes_query() {
let s = fill_template("ask {query} now", "why");
assert_eq!(s, "ask why now");
}
#[test]
fn fill_template_leaves_text_without_placeholder_alone() {
let s = fill_template("no placeholder here", "ignored");
assert_eq!(s, "no placeholder here");
}
#[test]
fn default_hyde_prompt_has_query_placeholder() {
assert!(HYDE_PROMPT_TEMPLATE.contains("{query}"));
}
#[test]
fn default_multi_query_prompt_has_query_placeholder() {
assert!(MULTI_QUERY_PROMPT_TEMPLATE.contains("{query}"));
}
#[test]
fn gen_options_default_is_n1() {
let o = GenOptions::default();
assert_eq!(o.n, 1);
assert!(o.temperature.is_none());
assert!(o.max_tokens.is_none());
assert!(o.system.is_none());
}
}