mnem_core/llm.rs
1// HyDE, OpenAI, Ollama, BEIR, LangChain, Anthropic are well-known
2// external identifiers; backticking every mention in the module doc
3// and field docs would degrade readability in rendered rustdoc for
4// no information gain.
5#![allow(clippy::doc_markdown)]
6
7//! Text-generation trait for HyDE and multi-query retrieval .
8//!
9//! # Why
10//!
11//! Two high-ROI retrieval techniques need an LLM in the read path, but
12//! mnem-core is WASM-clean and tokio-free, so the HTTP call can't live
13//! here:
14//!
15//! - **HyDE** (Gao et al. 2022, [arXiv:2212.10496]): instead of
16//! embedding the raw query, ask an LLM to generate a hypothetical
17//! answer and embed THAT. The encoder acts as a lossy compressor
18//! that filters hallucinated specifics back to real-corpus
19//! neighborhoods. On BEIR, HyDE beats plain Contriever on every
20//! dataset we care about.
21//! - **Multi-query / RAG-Fusion** (Raudaschl 2023): ask the LLM to
22//! generate N paraphrases of the query, retrieve top-K for each, and
23//! fuse with RRF. Particularly strong when the user's phrasing is
24//! sharply different from stored phrasing.
25//!
26//! Both techniques share the same primitive: `(prompt) -> completion`.
27//! The [`TextGenerator`] trait is that primitive. Adapter crates
28//! (OpenAI chat completions, Ollama chat, Anthropic messages, local
29//! llama.cpp) live outside `mnem-core` so the core stays tokio-free.
30//!
31//! # What this module provides
32//!
33//! A [`TextGenerator`] trait that adapter crates implement, plus a
34//! [`LlmError`] surface and a deterministic mock for tests.
35//!
36//! # How it plugs in today
37//!
38//! The trait + adapters live here and in `mnem-llm-providers`. HyDE is
39//! wired in the CLI layer (`mnem retrieve --hyde`) rather than as a
40//! `Retriever` builder method: the LLM call produces a hypothetical
41//! passage and the CLI feeds the passage into the embedder input. The
42//! symmetric multi-query variant (generate N variations, retrieve
43//! each, RRF-fuse) is planned. On LLM failure, the graceful-degrade
44//! policy is the same as the rerank pass: fall back to the plain
45//! query.
46//!
47//! [arXiv:2212.10496]: https://arxiv.org/abs/2212.10496
48
49use std::fmt::Debug;
50
51use thiserror::Error;
52
53/// Error surface for text-generation adapters.
54///
55/// Marked `#[non_exhaustive]` so provider crates can grow their own
56/// failure modes without a breaking change here.
57#[derive(Debug, Error)]
58#[non_exhaustive]
59pub enum LlmError {
60 /// TLS / TCP / DNS / timeout failure reaching the provider.
61 #[error("network error: {0}")]
62 Network(String),
63 /// Provider rejected credentials.
64 #[error("authentication failed: {0}")]
65 Auth(String),
66 /// Provider rate-limited the request.
67 #[error("rate limited: {0}")]
68 RateLimited(String),
69 /// 4xx from the provider.
70 #[error("bad request ({status}): {body}")]
71 BadRequest {
72 /// HTTP status code.
73 status: u16,
74 /// Response body or best-effort error string.
75 body: String,
76 },
77 /// 5xx from the provider.
78 #[error("server error ({status}): {body}")]
79 Server {
80 /// HTTP status code.
81 status: u16,
82 /// Response body or best-effort error string.
83 body: String,
84 },
85 /// Response decoder failed (malformed JSON, missing content field, ...).
86 #[error("decode error: {0}")]
87 Decode(String),
88 /// Adapter config invalid (bad URL, missing env var, etc.).
89 #[error("config error: {0}")]
90 Config(String),
91 /// Provider returned an empty completion.
92 #[error("empty completion")]
93 EmptyCompletion,
94}
95
96/// Generation options that the caller supplies per request. Kept
97/// provider-agnostic; adapters map these onto their own APIs and
98/// ignore fields they don't support (all adapters MUST tolerate a
99/// `None` on every optional field without erroring).
100#[derive(Debug, Clone)]
101pub struct GenOptions {
102 /// Maximum tokens in the completion. `None` means adapter default.
103 pub max_tokens: Option<u32>,
104 /// Sampling temperature. `None` means adapter default.
105 pub temperature: Option<f32>,
106 /// Nucleus-sampling probability cutoff. `None` means adapter
107 /// default. OpenAI and most providers accept `top_p`; Ollama maps
108 /// it to its `options.top_p`.
109 pub top_p: Option<f32>,
110 /// Top-K sampling. `None` means adapter default. Not supported by
111 /// every provider (OpenAI v1 chat does NOT accept top_k; Ollama
112 /// does via `options.top_k`). Adapters that can't honour it
113 /// silently drop the field.
114 pub top_k: Option<u32>,
115 /// Stop sequences. Completion halts when any of these strings
116 /// appears. Most providers accept 1-4 stop strings. `None` or
117 /// empty vec means no stop.
118 pub stop: Option<Vec<String>>,
119 /// Presence-penalty (discourage repeating tokens). OpenAI-family.
120 pub presence_penalty: Option<f32>,
121 /// Frequency-penalty (discourage high-frequency tokens).
122 /// OpenAI-family.
123 pub frequency_penalty: Option<f32>,
124 /// Deterministic-sampling seed. OpenAI accepts a seed on the
125 /// chat-completions endpoint; Ollama accepts `options.seed`.
126 /// Useful for reproducing HyDE runs in benchmarks.
127 pub seed: Option<u64>,
128 /// Number of completions to sample. For multi-query this is the
129 /// number of paraphrases; for HyDE this is usually 1 (or small-N
130 /// averaged, per the paper).
131 pub n: usize,
132 /// Optional system prompt / role preamble.
133 pub system: Option<String>,
134}
135
136impl GenOptions {
137 /// Construct options with `n = 1` and everything else None.
138 /// Equivalent to `GenOptions::default()`; provided for
139 /// self-documentation at call sites ("I only want 1 completion").
140 #[must_use]
141 pub fn single() -> Self {
142 Self::default()
143 }
144}
145
146impl Default for GenOptions {
147 fn default() -> Self {
148 Self {
149 max_tokens: None,
150 temperature: None,
151 top_p: None,
152 top_k: None,
153 stop: None,
154 presence_penalty: None,
155 frequency_penalty: None,
156 seed: None,
157 n: 1,
158 system: None,
159 }
160 }
161}
162
163/// Text-generation primitive: given a user prompt (and optional system
164/// preamble), return one or more completions.
165///
166/// The returned `Vec<String>` has length exactly `opts.n`; adapters
167/// that only support `n=1` natively MUST batch-call n times and
168/// surface a coherent error if one sub-call fails. Completion content
169/// is implementation-defined: callers who need structure should
170/// post-parse (e.g. split on newlines for multi-query).
171pub trait TextGenerator: Send + Sync + Debug {
172 /// Provider + model identifier. Lowercase, colon-separated by
173 /// convention (e.g. `"openai:gpt-4o-mini"`, `"ollama:llama3.2:3b"`).
174 fn model(&self) -> &str;
175
176 /// Generate completions for `prompt`.
177 ///
178 /// # Errors
179 ///
180 /// Any [`LlmError`] the adapter surfaces. Callers that use this
181 /// for HyDE / multi-query SHOULD fall back gracefully to the plain
182 /// query on error (same policy as the reranker fallback), so an
183 /// LLM outage does not break retrieval.
184 fn generate(&self, prompt: &str, opts: &GenOptions) -> Result<Vec<String>, LlmError>;
185}
186
187/// Deterministic test-only generator. Returns a configured response
188/// regardless of input. Useful for wiring HyDE / multi-query tests
189/// without a live provider.
190#[derive(Debug, Clone)]
191pub struct MockTextGenerator {
192 model: String,
193 /// Fixed completions to return on every call. If `opts.n` is
194 /// larger than `completions.len()`, the last one is repeated.
195 completions: Vec<String>,
196}
197
198impl MockTextGenerator {
199 /// Construct a mock with the given `(model, completions)`.
200 #[must_use]
201 pub fn new(model: impl Into<String>, completions: Vec<String>) -> Self {
202 Self {
203 model: model.into(),
204 completions,
205 }
206 }
207}
208
209impl Default for MockTextGenerator {
210 fn default() -> Self {
211 Self::new("mock:echo", vec!["(mock completion)".to_string()])
212 }
213}
214
215impl TextGenerator for MockTextGenerator {
216 fn model(&self) -> &str {
217 &self.model
218 }
219
220 fn generate(&self, _prompt: &str, opts: &GenOptions) -> Result<Vec<String>, LlmError> {
221 if self.completions.is_empty() {
222 return Err(LlmError::EmptyCompletion);
223 }
224 let mut out = Vec::with_capacity(opts.n);
225 for i in 0..opts.n {
226 let idx = i.min(self.completions.len() - 1);
227 out.push(self.completions[idx].clone());
228 }
229 Ok(out)
230 }
231}
232
233/// Test-only generator that always errors. Proves the graceful
234/// fallback path in HyDE / multi-query callers.
235#[derive(Debug, Clone, Default)]
236pub struct AlwaysFailGenerator;
237
238impl TextGenerator for AlwaysFailGenerator {
239 fn model(&self) -> &str {
240 "mock:always-fail"
241 }
242
243 fn generate(&self, _prompt: &str, _opts: &GenOptions) -> Result<Vec<String>, LlmError> {
244 Err(LlmError::Network(
245 "intentional failure for test".to_string(),
246 ))
247 }
248}
249
250/// Default HyDE prompt template for short-fact agent memory.
251///
252/// Tuned for mnem's typical payload: short node summaries, concrete
253/// entities/relations/attributes. Avoids the LangChain BEIR-task-tuned
254/// templates because mnem nodes are not BEIR docs.
255pub const HYDE_PROMPT_TEMPLATE: &str =
256 "Write a short, factual passage (2-4 sentences) that would answer the \
257question below. Focus on concrete entities, relations, and attributes \
258a note-taking system might have recorded. Omit hedging and meta-talk.
259
260Question: {query}
261Passage:";
262
263/// Default multi-query prompt template, parameterised by `{n}` and
264/// `{query}`. Both placeholders are replaced at call time via
265/// [`fill_multi_query_template`]. The listed angles (broader,
266/// narrower, synonymous, entity-centric) are suggestions; when
267/// `n > 4` the model is expected to mix and extend them.
268pub const MULTI_QUERY_PROMPT_TEMPLATE: &str =
269 "You are rewriting a user's query into search variations for a personal \
270knowledge graph. Generate {n} alternative queries that together cover \
271different angles of the same intent. Suggested angles:
272 - a broader/more abstract phrasing
273 - a narrower/more specific phrasing
274 - a synonymous rephrasing using different key terms
275 - an entity-centric phrasing (noun-heavy, no verbs)
276
277Do NOT output minor rewordings. Do NOT repeat the original query.
278Output exactly {n} lines, one query per line, no numbering.
279
280Original: {query}";
281
282/// Fill `{query}` in a template with the user's query string.
283///
284/// The substitution is naive: the literal substring `{query}` is
285/// replaced. Templates with more sophisticated placeholders should
286/// use a real templating engine; HyDE and multi-query only need this.
287#[must_use]
288pub fn fill_template(template: &str, query: &str) -> String {
289 template.replace("{query}", query)
290}
291
292/// Fill both `{query}` and `{n}` placeholders in a multi-query
293/// template. Use this when generating N paraphrase variants so the
294/// prompt honours the caller's requested count.
295#[must_use]
296pub fn fill_multi_query_template(template: &str, query: &str, n: usize) -> String {
297 template
298 .replace("{query}", query)
299 .replace("{n}", &n.to_string())
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn mock_generates_n_completions() {
308 let g = MockTextGenerator::new(
309 "mock:echo",
310 vec!["a".to_string(), "b".to_string(), "c".to_string()],
311 );
312 let out = g
313 .generate(
314 "ignored",
315 &GenOptions {
316 n: 3,
317 ..Default::default()
318 },
319 )
320 .unwrap();
321 assert_eq!(out, vec!["a", "b", "c"]);
322 }
323
324 #[test]
325 fn mock_repeats_last_when_n_exceeds_completion_count() {
326 let g = MockTextGenerator::new("mock:echo", vec!["only".to_string()]);
327 let out = g
328 .generate(
329 "ignored",
330 &GenOptions {
331 n: 4,
332 ..Default::default()
333 },
334 )
335 .unwrap();
336 assert_eq!(out, vec!["only"; 4]);
337 }
338
339 #[test]
340 fn mock_empty_completions_errors() {
341 let g = MockTextGenerator::new("mock:echo", vec![]);
342 let e = g.generate("q", &GenOptions::default()).unwrap_err();
343 assert!(matches!(e, LlmError::EmptyCompletion));
344 }
345
346 #[test]
347 fn always_fail_generator_errors() {
348 let g = AlwaysFailGenerator;
349 assert!(g.generate("q", &GenOptions::default()).is_err());
350 }
351
352 #[test]
353 fn model_id_has_provider_prefix() {
354 let g = MockTextGenerator::default();
355 assert!(g.model().contains(':'));
356 }
357
358 #[test]
359 fn fill_template_substitutes_query() {
360 let s = fill_template("ask {query} now", "why");
361 assert_eq!(s, "ask why now");
362 }
363
364 #[test]
365 fn fill_template_leaves_text_without_placeholder_alone() {
366 let s = fill_template("no placeholder here", "ignored");
367 assert_eq!(s, "no placeholder here");
368 }
369
370 #[test]
371 fn default_hyde_prompt_has_query_placeholder() {
372 assert!(HYDE_PROMPT_TEMPLATE.contains("{query}"));
373 }
374
375 #[test]
376 fn default_multi_query_prompt_has_query_placeholder() {
377 assert!(MULTI_QUERY_PROMPT_TEMPLATE.contains("{query}"));
378 }
379
380 #[test]
381 fn gen_options_default_is_n1() {
382 let o = GenOptions::default();
383 assert_eq!(o.n, 1);
384 assert!(o.temperature.is_none());
385 assert!(o.max_tokens.is_none());
386 assert!(o.system.is_none());
387 }
388}