Skip to main content

mnem_core/
llm.rs

1// HyDE, OpenAI, Ollama, BEIR, LangChain, Anthropic are well-known
2// external identifiers; backticking every mention in the module doc
3// and field docs would degrade readability in rendered rustdoc for
4// no information gain.
5#![allow(clippy::doc_markdown)]
6
7//! Text-generation trait for HyDE and multi-query retrieval .
8//!
9//! # Why
10//!
11//! Two high-ROI retrieval techniques need an LLM in the read path, but
12//! mnem-core is WASM-clean and tokio-free, so the HTTP call can't live
13//! here:
14//!
15//! - **HyDE** (Gao et al. 2022, [arXiv:2212.10496]): instead of
16//!   embedding the raw query, ask an LLM to generate a hypothetical
17//!   answer and embed THAT. The encoder acts as a lossy compressor
18//!   that filters hallucinated specifics back to real-corpus
19//!   neighborhoods. On BEIR, HyDE beats plain Contriever on every
20//!   dataset we care about.
21//! - **Multi-query / RAG-Fusion** (Raudaschl 2023): ask the LLM to
22//!   generate N paraphrases of the query, retrieve top-K for each, and
23//!   fuse with RRF. Particularly strong when the user's phrasing is
24//!   sharply different from stored phrasing.
25//!
26//! Both techniques share the same primitive: `(prompt) -> completion`.
27//! The [`TextGenerator`] trait is that primitive. Adapter crates
28//! (OpenAI chat completions, Ollama chat, Anthropic messages, local
29//! llama.cpp) live outside `mnem-core` so the core stays tokio-free.
30//!
31//! # What this module provides
32//!
33//! A [`TextGenerator`] trait that adapter crates implement, plus a
34//! [`LlmError`] surface and a deterministic mock for tests.
35//!
36//! # How it plugs in today
37//!
38//! The trait + adapters live here and in `mnem-llm-providers`. HyDE is
39//! wired in the CLI layer (`mnem retrieve --hyde`) rather than as a
40//! `Retriever` builder method: the LLM call produces a hypothetical
41//! passage and the CLI feeds the passage into the embedder input. The
42//! symmetric multi-query variant (generate N variations, retrieve
43//! each, RRF-fuse) is planned. On LLM failure, the graceful-degrade
44//! policy is the same as the rerank pass: fall back to the plain
45//! query.
46//!
47//! [arXiv:2212.10496]: https://arxiv.org/abs/2212.10496
48
49use std::fmt::Debug;
50
51use thiserror::Error;
52
53/// Error surface for text-generation adapters.
54///
55/// Marked `#[non_exhaustive]` so provider crates can grow their own
56/// failure modes without a breaking change here.
57#[derive(Debug, Error)]
58#[non_exhaustive]
59pub enum LlmError {
60    /// TLS / TCP / DNS / timeout failure reaching the provider.
61    #[error("network error: {0}")]
62    Network(String),
63    /// Provider rejected credentials.
64    #[error("authentication failed: {0}")]
65    Auth(String),
66    /// Provider rate-limited the request.
67    #[error("rate limited: {0}")]
68    RateLimited(String),
69    /// 4xx from the provider.
70    #[error("bad request ({status}): {body}")]
71    BadRequest {
72        /// HTTP status code.
73        status: u16,
74        /// Response body or best-effort error string.
75        body: String,
76    },
77    /// 5xx from the provider.
78    #[error("server error ({status}): {body}")]
79    Server {
80        /// HTTP status code.
81        status: u16,
82        /// Response body or best-effort error string.
83        body: String,
84    },
85    /// Response decoder failed (malformed JSON, missing content field, ...).
86    #[error("decode error: {0}")]
87    Decode(String),
88    /// Adapter config invalid (bad URL, missing env var, etc.).
89    #[error("config error: {0}")]
90    Config(String),
91    /// Provider returned an empty completion.
92    #[error("empty completion")]
93    EmptyCompletion,
94}
95
96/// Generation options that the caller supplies per request. Kept
97/// provider-agnostic; adapters map these onto their own APIs and
98/// ignore fields they don't support (all adapters MUST tolerate a
99/// `None` on every optional field without erroring).
100#[derive(Debug, Clone)]
101pub struct GenOptions {
102    /// Maximum tokens in the completion. `None` means adapter default.
103    pub max_tokens: Option<u32>,
104    /// Sampling temperature. `None` means adapter default.
105    pub temperature: Option<f32>,
106    /// Nucleus-sampling probability cutoff. `None` means adapter
107    /// default. OpenAI and most providers accept `top_p`; Ollama maps
108    /// it to its `options.top_p`.
109    pub top_p: Option<f32>,
110    /// Top-K sampling. `None` means adapter default. Not supported by
111    /// every provider (OpenAI v1 chat does NOT accept top_k; Ollama
112    /// does via `options.top_k`). Adapters that can't honour it
113    /// silently drop the field.
114    pub top_k: Option<u32>,
115    /// Stop sequences. Completion halts when any of these strings
116    /// appears. Most providers accept 1-4 stop strings. `None` or
117    /// empty vec means no stop.
118    pub stop: Option<Vec<String>>,
119    /// Presence-penalty (discourage repeating tokens). OpenAI-family.
120    pub presence_penalty: Option<f32>,
121    /// Frequency-penalty (discourage high-frequency tokens).
122    /// OpenAI-family.
123    pub frequency_penalty: Option<f32>,
124    /// Deterministic-sampling seed. OpenAI accepts a seed on the
125    /// chat-completions endpoint; Ollama accepts `options.seed`.
126    /// Useful for reproducing HyDE runs in benchmarks.
127    pub seed: Option<u64>,
128    /// Number of completions to sample. For multi-query this is the
129    /// number of paraphrases; for HyDE this is usually 1 (or small-N
130    /// averaged, per the paper).
131    pub n: usize,
132    /// Optional system prompt / role preamble.
133    pub system: Option<String>,
134}
135
136impl GenOptions {
137    /// Construct options with `n = 1` and everything else None.
138    /// Equivalent to `GenOptions::default()`; provided for
139    /// self-documentation at call sites ("I only want 1 completion").
140    #[must_use]
141    pub fn single() -> Self {
142        Self::default()
143    }
144}
145
146impl Default for GenOptions {
147    fn default() -> Self {
148        Self {
149            max_tokens: None,
150            temperature: None,
151            top_p: None,
152            top_k: None,
153            stop: None,
154            presence_penalty: None,
155            frequency_penalty: None,
156            seed: None,
157            n: 1,
158            system: None,
159        }
160    }
161}
162
163/// Text-generation primitive: given a user prompt (and optional system
164/// preamble), return one or more completions.
165///
166/// The returned `Vec<String>` has length exactly `opts.n`; adapters
167/// that only support `n=1` natively MUST batch-call n times and
168/// surface a coherent error if one sub-call fails. Completion content
169/// is implementation-defined: callers who need structure should
170/// post-parse (e.g. split on newlines for multi-query).
171pub trait TextGenerator: Send + Sync + Debug {
172    /// Provider + model identifier. Lowercase, colon-separated by
173    /// convention (e.g. `"openai:gpt-4o-mini"`, `"ollama:llama3.2:3b"`).
174    fn model(&self) -> &str;
175
176    /// Generate completions for `prompt`.
177    ///
178    /// # Errors
179    ///
180    /// Any [`LlmError`] the adapter surfaces. Callers that use this
181    /// for HyDE / multi-query SHOULD fall back gracefully to the plain
182    /// query on error (same policy as the reranker fallback), so an
183    /// LLM outage does not break retrieval.
184    fn generate(&self, prompt: &str, opts: &GenOptions) -> Result<Vec<String>, LlmError>;
185}
186
187/// Deterministic test-only generator. Returns a configured response
188/// regardless of input. Useful for wiring HyDE / multi-query tests
189/// without a live provider.
190#[derive(Debug, Clone)]
191pub struct MockTextGenerator {
192    model: String,
193    /// Fixed completions to return on every call. If `opts.n` is
194    /// larger than `completions.len()`, the last one is repeated.
195    completions: Vec<String>,
196}
197
198impl MockTextGenerator {
199    /// Construct a mock with the given `(model, completions)`.
200    #[must_use]
201    pub fn new(model: impl Into<String>, completions: Vec<String>) -> Self {
202        Self {
203            model: model.into(),
204            completions,
205        }
206    }
207}
208
209impl Default for MockTextGenerator {
210    fn default() -> Self {
211        Self::new("mock:echo", vec!["(mock completion)".to_string()])
212    }
213}
214
215impl TextGenerator for MockTextGenerator {
216    fn model(&self) -> &str {
217        &self.model
218    }
219
220    fn generate(&self, _prompt: &str, opts: &GenOptions) -> Result<Vec<String>, LlmError> {
221        if self.completions.is_empty() {
222            return Err(LlmError::EmptyCompletion);
223        }
224        let mut out = Vec::with_capacity(opts.n);
225        for i in 0..opts.n {
226            let idx = i.min(self.completions.len() - 1);
227            out.push(self.completions[idx].clone());
228        }
229        Ok(out)
230    }
231}
232
233/// Test-only generator that always errors. Proves the graceful
234/// fallback path in HyDE / multi-query callers.
235#[derive(Debug, Clone, Default)]
236pub struct AlwaysFailGenerator;
237
238impl TextGenerator for AlwaysFailGenerator {
239    fn model(&self) -> &str {
240        "mock:always-fail"
241    }
242
243    fn generate(&self, _prompt: &str, _opts: &GenOptions) -> Result<Vec<String>, LlmError> {
244        Err(LlmError::Network(
245            "intentional failure for test".to_string(),
246        ))
247    }
248}
249
250/// Default HyDE prompt template for short-fact agent memory.
251///
252/// Tuned for mnem's typical payload: short node summaries, concrete
253/// entities/relations/attributes. Avoids the LangChain BEIR-task-tuned
254/// templates because mnem nodes are not BEIR docs.
255pub const HYDE_PROMPT_TEMPLATE: &str =
256    "Write a short, factual passage (2-4 sentences) that would answer the \
257question below. Focus on concrete entities, relations, and attributes \
258a note-taking system might have recorded. Omit hedging and meta-talk.
259
260Question: {query}
261Passage:";
262
263/// Default multi-query prompt template, parameterised by `{n}` and
264/// `{query}`. Both placeholders are replaced at call time via
265/// [`fill_multi_query_template`]. The listed angles (broader,
266/// narrower, synonymous, entity-centric) are suggestions; when
267/// `n > 4` the model is expected to mix and extend them.
268pub const MULTI_QUERY_PROMPT_TEMPLATE: &str =
269    "You are rewriting a user's query into search variations for a personal \
270knowledge graph. Generate {n} alternative queries that together cover \
271different angles of the same intent. Suggested angles:
272  - a broader/more abstract phrasing
273  - a narrower/more specific phrasing
274  - a synonymous rephrasing using different key terms
275  - an entity-centric phrasing (noun-heavy, no verbs)
276
277Do NOT output minor rewordings. Do NOT repeat the original query.
278Output exactly {n} lines, one query per line, no numbering.
279
280Original: {query}";
281
282/// Fill `{query}` in a template with the user's query string.
283///
284/// The substitution is naive: the literal substring `{query}` is
285/// replaced. Templates with more sophisticated placeholders should
286/// use a real templating engine; HyDE and multi-query only need this.
287#[must_use]
288pub fn fill_template(template: &str, query: &str) -> String {
289    template.replace("{query}", query)
290}
291
292/// Fill both `{query}` and `{n}` placeholders in a multi-query
293/// template. Use this when generating N paraphrase variants so the
294/// prompt honours the caller's requested count.
295#[must_use]
296pub fn fill_multi_query_template(template: &str, query: &str, n: usize) -> String {
297    template
298        .replace("{query}", query)
299        .replace("{n}", &n.to_string())
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn mock_generates_n_completions() {
308        let g = MockTextGenerator::new(
309            "mock:echo",
310            vec!["a".to_string(), "b".to_string(), "c".to_string()],
311        );
312        let out = g
313            .generate(
314                "ignored",
315                &GenOptions {
316                    n: 3,
317                    ..Default::default()
318                },
319            )
320            .unwrap();
321        assert_eq!(out, vec!["a", "b", "c"]);
322    }
323
324    #[test]
325    fn mock_repeats_last_when_n_exceeds_completion_count() {
326        let g = MockTextGenerator::new("mock:echo", vec!["only".to_string()]);
327        let out = g
328            .generate(
329                "ignored",
330                &GenOptions {
331                    n: 4,
332                    ..Default::default()
333                },
334            )
335            .unwrap();
336        assert_eq!(out, vec!["only"; 4]);
337    }
338
339    #[test]
340    fn mock_empty_completions_errors() {
341        let g = MockTextGenerator::new("mock:echo", vec![]);
342        let e = g.generate("q", &GenOptions::default()).unwrap_err();
343        assert!(matches!(e, LlmError::EmptyCompletion));
344    }
345
346    #[test]
347    fn always_fail_generator_errors() {
348        let g = AlwaysFailGenerator;
349        assert!(g.generate("q", &GenOptions::default()).is_err());
350    }
351
352    #[test]
353    fn model_id_has_provider_prefix() {
354        let g = MockTextGenerator::default();
355        assert!(g.model().contains(':'));
356    }
357
358    #[test]
359    fn fill_template_substitutes_query() {
360        let s = fill_template("ask {query} now", "why");
361        assert_eq!(s, "ask why now");
362    }
363
364    #[test]
365    fn fill_template_leaves_text_without_placeholder_alone() {
366        let s = fill_template("no placeholder here", "ignored");
367        assert_eq!(s, "no placeholder here");
368    }
369
370    #[test]
371    fn default_hyde_prompt_has_query_placeholder() {
372        assert!(HYDE_PROMPT_TEMPLATE.contains("{query}"));
373    }
374
375    #[test]
376    fn default_multi_query_prompt_has_query_placeholder() {
377        assert!(MULTI_QUERY_PROMPT_TEMPLATE.contains("{query}"));
378    }
379
380    #[test]
381    fn gen_options_default_is_n1() {
382        let o = GenOptions::default();
383        assert_eq!(o.n, 1);
384        assert!(o.temperature.is_none());
385        assert!(o.max_tokens.is_none());
386        assert!(o.system.is_none());
387    }
388}