Skip to main content

sapient_generate/
pipeline.rs

1//! `Pipeline` — the main user-facing LLM inference API.
2//!
3//! One line to load any HuggingFace model, one line to generate text.
4
5use std::path::PathBuf;
6use std::sync::{Arc, Mutex};
7
8use anyhow::{Context, Result};
9use tokio::sync::mpsc;
10use tokio_stream::wrappers::ReceiverStream;
11use tracing::debug;
12
13use sapient_hub::model_info::{ArchType, ModelInfo};
14use sapient_hub::resolver::ModelFiles;
15use sapient_hub::{tokenizer_fallback_model, HubClient, LoadOptions as HubOptions};
16use sapient_io::GgufLoader;
17use sapient_models::{ForwardEngine, LlmBackendKind};
18use sapient_tokenizers::{
19    chat::{builtin, ChatMessage, ChatTemplate},
20    tokenizer::{SapientTokenizer, TokenizerOptions},
21};
22
23use crate::sampler::{Sampler, SamplingStrategy};
24
25// ── GenerationConfig ──────────────────────────────────────────────────────────
26
27/// Controls how text is generated.
28#[derive(Debug, Clone)]
29pub struct GenerationConfig {
30    /// Maximum number of new tokens to generate.
31    pub max_new_tokens: usize,
32    /// Stop generating when this token ID is produced (usually EOS).
33    pub eos_token_id: Option<u32>,
34    /// Sampling strategy (default: greedy).
35    pub strategy: SamplingStrategy,
36    /// Stop strings — generation ends if any of these appear in output.
37    pub stop_sequences: Vec<String>,
38}
39
40impl Default for GenerationConfig {
41    fn default() -> Self {
42        Self {
43            max_new_tokens: 512,
44            eos_token_id: None,
45            strategy: SamplingStrategy::default(),
46            stop_sequences: vec![],
47        }
48    }
49}
50
51// ── LoadOptions ───────────────────────────────────────────────────────────────
52
53/// Options for loading a model from HuggingFace Hub or local disk.
54#[derive(Debug, Clone, Default)]
55pub struct LoadOptions {
56    /// HuggingFace Hub options.
57    pub hub: HubOptions,
58    /// Override the generation config.
59    pub generation: GenerationConfig,
60    /// Native LLM backend for Hub generation.
61    pub backend: LlmBackendKind,
62}
63
64// ── Pipeline ─────────────────────────────────────────────────────────────────
65
66/// A fully loaded LLM ready for inference.
67pub struct Pipeline {
68    tokenizer: Arc<SapientTokenizer>,
69    chat_template: Option<ChatTemplate>,
70    model_info: ModelInfo,
71    weight_paths: Vec<PathBuf>,
72    engine: Mutex<ForwardEngine>,
73    config: GenerationConfig,
74    backend: LlmBackendKind,
75}
76
77impl Pipeline {
78    // ── Constructors ──────────────────────────────────────────────────────────
79
80    /// Load any model from the HuggingFace Hub by model ID.
81    pub async fn from_pretrained(model_id: &str) -> Result<Self> {
82        Self::from_pretrained_with_opts(model_id, LoadOptions::default()).await
83    }
84
85    /// Load with custom hub and generation options.
86    pub async fn from_pretrained_with_opts(model_id: &str, opts: LoadOptions) -> Result<Self> {
87        debug!("Loading model: {model_id}");
88        let backend = opts.backend;
89
90        let mut hub_opts = opts.hub.clone();
91        if hub_opts.formats == LoadOptions::default().hub.formats {
92            // Prefer full-precision safetensors for native forward passes.
93            hub_opts.formats = vec!["safetensors".into(), "bin".into(), "gguf".into()];
94        }
95
96        let hub = HubClient::with_options(hub_opts)?;
97        let model_files = hub
98            .download(model_id)
99            .await
100            .with_context(|| format!("Failed to download model '{model_id}'"))?;
101
102        ensure_weights_present(&model_files)?;
103
104        // GGUF-only repos: the hub's config_path is a sentinel pointing at the
105        // GGUF file itself.  Route directly to from_gguf_with_backend instead of
106        // trying to parse a config.json that doesn't exist.
107        let single_gguf = model_files.weight_paths.len() == 1
108            && model_files.weight_paths[0]
109                .extension()
110                .and_then(|e| e.to_str())
111                == Some("gguf");
112        if single_gguf {
113            return Self::from_gguf_with_backend(&model_files.weight_paths[0], backend).await;
114        }
115
116        let model_info = ModelInfo::from_config_file(&model_files.config_path)
117            .context("Failed to parse config.json")?;
118        debug!("Detected architecture: {:?}", model_info.arch);
119
120        if model_info.raw.get("vision_config").is_some() {
121            debug!("Vision tower present — text-only mode (images not supported yet)");
122        }
123
124        let tok_opts = TokenizerOptions {
125            add_bos: true,
126            ..Default::default()
127        };
128        let tokenizer = if let Some(tok_path) = &model_files.tokenizer_path {
129            Arc::new(
130                SapientTokenizer::from_file(tok_path, tok_opts)
131                    .context("Failed to load tokenizer")?,
132            )
133        } else if let Some(fallback_id) = tokenizer_fallback_model(model_id) {
134            debug!("No local tokenizer — loading from fallback Hub model '{fallback_id}'");
135            Arc::new(
136                SapientTokenizer::from_pretrained(fallback_id).with_context(|| {
137                    format!(
138                        "Failed to load tokenizer from fallback model '{fallback_id}' \
139                         (GGUF repos often omit tokenizer files)"
140                    )
141                })?,
142            )
143        } else {
144            Arc::new(
145                SapientTokenizer::from_pretrained(model_id)
146                    .context("Failed to load tokenizer from Hub")?,
147            )
148        };
149
150        // Prefer the model's own chat template; otherwise fall back to a builtin
151        // and remember the stop string(s) that builtin uses to end a turn.
152        let mut builtin_stops: Vec<String> = Vec::new();
153        let chat_template = match model_files
154            .tokenizer_config_path
155            .as_ref()
156            .and_then(|p| ChatTemplate::from_tokenizer_config(p).ok())
157        {
158            Some(tmpl) => Some(tmpl),
159            None => {
160                let (tmpl, stops) =
161                    builtin_template_for(&model_info.arch, model_id, &model_info.model_type);
162                builtin_stops = stops;
163                Some(tmpl)
164            }
165        };
166
167        validate_tokenizer_model_compat(model_id, &model_info, &tokenizer)?;
168
169        let engine = ForwardEngine::from_weight_paths_with_backend(
170            model_info.clone(),
171            &model_files.weight_paths,
172            backend,
173        )
174        .context("Failed to initialize inference engine")?;
175
176        let mut config = opts.generation;
177        if config.eos_token_id.is_none() {
178            config.eos_token_id = tokenizer.eos_id;
179        }
180        // Register the builtin template's turn-terminator(s) as stop sequences.
181        for s in builtin_stops {
182            if !config.stop_sequences.contains(&s) {
183                config.stop_sequences.push(s);
184            }
185        }
186
187        debug!(
188            "Pipeline ready — vocab_size={} layers={} backend={}",
189            model_info.vocab_size, model_info.num_hidden_layers, backend
190        );
191
192        Ok(Self {
193            tokenizer,
194            chat_template,
195            model_info,
196            weight_paths: model_files.weight_paths.clone(),
197            engine: Mutex::new(engine),
198            config,
199            backend,
200        })
201    }
202
203    /// Load a GGUF model from a local `.gguf` file.
204    ///
205    /// Weights are kept quantized in memory (Q4_0/Q8_0 as packed blocks, no F32
206    /// expansion) so RAM ≈ file size.  The tokenizer is loaded from the embedded
207    /// GGUF vocabulary; if unavailable, a HuggingFace Hub fallback is fetched.
208    pub async fn from_gguf(path: impl Into<PathBuf>) -> Result<Self> {
209        Self::from_gguf_with_backend(path, LlmBackendKind::Auto).await
210    }
211
212    pub async fn from_gguf_with_backend(
213        path: impl Into<PathBuf>,
214        backend: LlmBackendKind,
215    ) -> Result<Self> {
216        let path = path.into();
217        debug!("Loading GGUF: {}", path.display());
218
219        // Parse metadata + load tensors (quantized types stay quantized).
220        let (metadata, _) = GgufLoader::load_tensors_with_metadata(&path)
221            .with_context(|| format!("failed to load GGUF: {}", path.display()))?;
222
223        // Build ModelInfo from GGUF KV metadata (no config.json needed).
224        let model_info = ModelInfo::from_gguf_metadata(&metadata)
225            .context("failed to build ModelInfo from GGUF metadata")?;
226
227        // Build the forward engine directly from the file (it calls
228        // load_gguf_hf_weights internally, which keeps quantized types intact).
229        let engine = ForwardEngine::from_gguf_with_backend(model_info.clone(), &path, backend)
230            .context("failed to initialise ForwardEngine from GGUF")?;
231
232        // Tokenizer: try the model ID from GGUF metadata, else arch-based fallback.
233        let model_id = metadata
234            .get("general.name")
235            .and_then(|v| v.as_str())
236            .unwrap_or("");
237        let tokenizer = if let Some(fallback) = tokenizer_fallback_model(model_id)
238            .or_else(|| tokenizer_fallback_model(model_info.model_type.as_str()))
239        {
240            Arc::new(
241                SapientTokenizer::from_pretrained(fallback)
242                    .with_context(|| format!("failed to load tokenizer from '{fallback}'"))?,
243            )
244        } else {
245            anyhow::bail!(
246                "Cannot determine tokenizer for GGUF model '{}' (arch: {}). \
247                 Load via `Pipeline::from_pretrained` with a registry alias instead.",
248                path.display(),
249                model_info.model_type
250            );
251        };
252
253        let (chat_template, builtin_stops) =
254            builtin_template_for(&model_info.arch, model_id, &model_info.model_type);
255
256        let mut config = GenerationConfig::default();
257        if config.eos_token_id.is_none() {
258            config.eos_token_id = tokenizer.eos_id;
259        }
260        for s in builtin_stops {
261            if !config.stop_sequences.contains(&s) {
262                config.stop_sequences.push(s);
263            }
264        }
265
266        validate_tokenizer_model_compat(model_id, &model_info, &tokenizer)?;
267
268        Ok(Self {
269            tokenizer,
270            chat_template: Some(chat_template),
271            model_info,
272            weight_paths: vec![path],
273            engine: Mutex::new(engine),
274            config,
275            backend,
276        })
277    }
278
279    // ── Inference API ─────────────────────────────────────────────────────────
280
281    /// Generate a completion for `prompt`.
282    pub async fn generate(&self, prompt: &str) -> Result<String> {
283        let input_ids = self.tokenizer.encode(prompt)?;
284        let output_ids = self.generate_from_tokens(input_ids).await?;
285        let text = self.tokenizer.decode(&output_ids, true)?;
286        Ok(self.trim_stop_sequences(text))
287    }
288
289    /// Generate with a custom generation config.
290    pub async fn generate_with_config(
291        &self,
292        prompt: &str,
293        config: &GenerationConfig,
294    ) -> Result<String> {
295        let input_ids = self.tokenizer.encode(prompt)?;
296        let output_ids = self
297            .generate_from_tokens_with_config(input_ids, config)
298            .await?;
299        let text = self.tokenizer.decode(&output_ids, true)?;
300        Ok(self.trim_stop_sequences(text))
301    }
302
303    /// All token ids that should terminate generation: the configured EOS plus
304    /// every end-of-turn marker the tokenizer knows (e.g. Qwen's `<|im_end|>`,
305    /// which `decode` strips, so it can't be caught as a stop *string*).
306    fn eos_token_ids(&self) -> Vec<u32> {
307        let mut ids = self.tokenizer.eos_ids.clone();
308        if let Some(e) = self.config.eos_token_id {
309            if !ids.contains(&e) {
310                ids.push(e);
311            }
312        }
313        ids
314    }
315
316    /// Cut the reply at the first stop sequence (for non-streaming callers).
317    fn trim_stop_sequences(&self, text: String) -> String {
318        match earliest_stop(&text, &self.config.stop_sequences) {
319            Some(idx) => text[..idx].to_string(),
320            None => text,
321        }
322    }
323
324    /// Render the chat prompt string for a message history.
325    pub fn format_chat_prompt(&self, messages: &[ChatMessage]) -> Result<String> {
326        if let Some(tmpl) = &self.chat_template {
327            tmpl.render(messages, true)
328                .context("Failed to render chat template")
329        } else {
330            Ok(messages
331                .iter()
332                .map(|m| format!("{}: {}", m.role, m.content))
333                .collect::<Vec<_>>()
334                .join("\n"))
335        }
336    }
337
338    /// Chat with the model (for instruct/chat tuned models).
339    pub async fn chat(&self, messages: &[ChatMessage]) -> Result<String> {
340        let prompt = self.format_chat_prompt(messages)?;
341        self.generate(&prompt).await
342    }
343
344    /// Stream a chat reply token-by-token.
345    pub async fn chat_stream(&self, messages: &[ChatMessage]) -> ReceiverStream<String> {
346        match self.format_chat_prompt(messages) {
347            Ok(prompt) => self.generate_stream(&prompt).await,
348            Err(e) => {
349                let (tx, rx) = mpsc::channel(1);
350                let _ = tx.try_send(format!("Error: {e}"));
351                ReceiverStream::new(rx)
352            }
353        }
354    }
355
356    /// Stream tokens as they are generated.
357    pub async fn generate_stream(&self, prompt: &str) -> ReceiverStream<String> {
358        let (tx, rx) = mpsc::channel(64);
359        let input_ids = self.tokenizer.encode(prompt).unwrap_or_default();
360        let eos_ids = self.eos_token_ids();
361        let max_new = self.config.max_new_tokens;
362        let strategy = self.config.strategy.clone();
363        let stop = self.config.stop_sequences.clone();
364        let tok = Arc::clone(&self.tokenizer);
365        let model_info = self.model_info.clone();
366        let weight_paths = self.weight_paths.clone();
367        let backend = self.configured_backend();
368
369        tokio::task::spawn_blocking(move || {
370            let mut engine = match ForwardEngine::from_weight_paths_with_backend(
371                model_info,
372                &weight_paths,
373                backend,
374            ) {
375                Ok(e) => e,
376                Err(e) => {
377                    let _ = tx.blocking_send(format!("Error: {e}"));
378                    return;
379                }
380            };
381            let mut sampler = Sampler::new(strategy);
382            let mut all_tokens = input_ids;
383            let mut generated: Vec<u32> = Vec::new();
384            // Bytes of the decoded reply already streamed to the caller. We decode
385            // the whole reply each step (stable, unlike per-token pieces) and only
386            // emit text that cannot be part of a stop marker, so markers like
387            // `<|im_end|>` never leak even though they span several tokens.
388            let mut emitted = 0usize;
389            let mut clean_stop = false;
390
391            engine.reset_cache();
392            for step in 0..max_new {
393                let chunk = if step == 0 {
394                    all_tokens.clone()
395                } else {
396                    vec![*all_tokens.last().unwrap()]
397                };
398                let logits = match engine.forward_logits(&chunk, true) {
399                    Ok(v) => v,
400                    Err(e) => {
401                        let _ = tx.blocking_send(format!("Error: {e}"));
402                        break;
403                    }
404                };
405
406                let next = match sampler.sample(&logits, &all_tokens) {
407                    Ok(t) => t,
408                    Err(e) => {
409                        let _ = tx.blocking_send(format!("Error: {e}"));
410                        break;
411                    }
412                };
413
414                generated.push(next);
415                all_tokens.push(next);
416
417                if eos_ids.contains(&next) {
418                    clean_stop = true;
419                    break;
420                }
421
422                let text = match tok.decode(&generated, true) {
423                    Ok(t) => t,
424                    Err(_) => continue,
425                };
426
427                // A stop sequence appeared: emit everything before it, then stop.
428                if let Some(idx) = earliest_stop(&text, &stop) {
429                    if idx > emitted {
430                        let _ = tx.blocking_send(text[emitted..idx].to_string());
431                    }
432                    clean_stop = true;
433                    break;
434                }
435
436                // Emit all but a trailing tail that could still grow into a stop.
437                let safe = safe_emit_end(&text, &stop);
438                if safe > emitted {
439                    if tx.blocking_send(text[emitted..safe].to_string()).is_err() {
440                        break;
441                    }
442                    emitted = safe;
443                }
444            }
445
446            // Reached max tokens without hitting a stop: flush the held-back tail.
447            if !clean_stop {
448                if let Ok(text) = tok.decode(&generated, true) {
449                    if text.len() > emitted {
450                        let _ = tx.blocking_send(text[emitted..].to_string());
451                    }
452                }
453            }
454        });
455
456        ReceiverStream::new(rx)
457    }
458
459    /// Compute sentence embeddings via mean-pooled hidden states.
460    pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
461        let ids = self.tokenizer.encode(text)?;
462        let mut engine = self.engine.lock().map_err(|e| anyhow::anyhow!("{e}"))?;
463        engine.embed(&ids)
464    }
465
466    // ── Helpers ───────────────────────────────────────────────────────────────
467
468    async fn generate_from_tokens(&self, input_ids: Vec<u32>) -> Result<Vec<u32>> {
469        self.generate_from_tokens_with_config(input_ids, &self.config)
470            .await
471    }
472
473    async fn generate_from_tokens_with_config(
474        &self,
475        input_ids: Vec<u32>,
476        config: &GenerationConfig,
477    ) -> Result<Vec<u32>> {
478        let mut engine = self.engine.lock().map_err(|e| anyhow::anyhow!("{e}"))?;
479        let mut sampler = Sampler::new(config.strategy.clone());
480        let mut generated: Vec<u32> = Vec::new();
481        let mut all_tokens = input_ids;
482        let eos_ids = self.eos_token_ids();
483
484        engine.reset_cache();
485
486        // Prefill must use the KV cache so decode steps see correct positions and context.
487        let logits = engine.forward_logits(&all_tokens, true)?;
488        let mut next = sampler.sample(&logits, &all_tokens)?;
489        generated.push(next);
490        all_tokens.push(next);
491
492        if eos_ids.contains(&next) {
493            return Ok(generated);
494        }
495
496        for step in 1..config.max_new_tokens {
497            let logits = engine.forward_logits(&[next], true)?;
498            next = sampler.sample(&logits, &all_tokens)?;
499            generated.push(next);
500            all_tokens.push(next);
501
502            if eos_ids.contains(&next) {
503                debug!("EOS token generated at step {step}");
504                break;
505            }
506
507            if !config.stop_sequences.is_empty() {
508                let decoded = self.tokenizer.decode(&generated, true).unwrap_or_default();
509                if config
510                    .stop_sequences
511                    .iter()
512                    .any(|s| decoded.contains(s.as_str()))
513                {
514                    break;
515                }
516            }
517        }
518
519        Ok(generated)
520    }
521
522    pub fn tokenizer(&self) -> &SapientTokenizer {
523        &self.tokenizer
524    }
525    pub fn model_info(&self) -> &ModelInfo {
526        &self.model_info
527    }
528    pub fn arch(&self) -> &ArchType {
529        &self.model_info.arch
530    }
531
532    fn configured_backend(&self) -> LlmBackendKind {
533        self.backend
534    }
535}
536
537fn ensure_weights_present(files: &ModelFiles) -> Result<()> {
538    if files.weight_paths.is_empty() {
539        anyhow::bail!("No weight files found for this model");
540    }
541    Ok(())
542}
543
544fn validate_tokenizer_model_compat(
545    model_id: &str,
546    model_info: &ModelInfo,
547    tokenizer: &SapientTokenizer,
548) -> Result<()> {
549    let tokenizer_vocab = tokenizer.vocab_size();
550    if tokenizer_vocab > model_info.vocab_size {
551        anyhow::bail!(
552            "tokenizer/model vocab mismatch for '{model_id}': tokenizer has {tokenizer_vocab} tokens but model config vocab_size is {}",
553            model_info.vocab_size
554        );
555    }
556
557    if let Some(eos) = tokenizer.eos_id {
558        if eos as usize >= model_info.vocab_size {
559            anyhow::bail!(
560                "tokenizer/model EOS mismatch for '{model_id}': eos_token_id {eos} is outside model vocab_size {}",
561                model_info.vocab_size
562            );
563        }
564    } else {
565        tracing::warn!(
566            model = model_id,
567            "tokenizer has no recognized EOS token; generation will stop only by max_new_tokens or stop strings"
568        );
569    }
570
571    Ok(())
572}
573
574/// Byte index of the earliest stop-sequence occurrence in `text`, if any.
575fn earliest_stop(text: &str, stops: &[String]) -> Option<usize> {
576    stops
577        .iter()
578        .filter(|s| !s.is_empty())
579        .filter_map(|s| text.find(s.as_str()))
580        .min()
581}
582
583/// Largest byte index (a char boundary) up to which `text` is safe to emit
584/// without streaming a partial stop marker. Holds back the longest suffix of
585/// `text` that is a proper prefix of any stop sequence.
586fn safe_emit_end(text: &str, stops: &[String]) -> usize {
587    let mut hold = 0usize;
588    for s in stops {
589        let max_k = s.len().min(text.len());
590        for k in (1..max_k).rev() {
591            if !s.is_char_boundary(k) {
592                continue;
593            }
594            if text.ends_with(&s[..k]) {
595                hold = hold.max(k);
596                break;
597            }
598        }
599    }
600    text.len() - hold
601}
602
603/// Pick a builtin chat template and the stop string(s) that terminate an
604/// assistant turn under that template. When we fall back to a builtin template
605/// (because the model ships no `chat_template`), the model's plain EOS often
606/// isn't what the template trains the turn to end with (e.g. ChatML uses
607/// `<|im_end|>`), so these stops must be registered or the marker leaks into
608/// the output.
609fn builtin_template_for(
610    arch: &ArchType,
611    model_id: &str,
612    model_type: &str,
613) -> (ChatTemplate, Vec<String>) {
614    let id = model_id.to_ascii_lowercase();
615    let mt = model_type.to_ascii_lowercase();
616    let chatml = || {
617        (
618            ChatTemplate::from_template(builtin::CHATML),
619            vec!["<|im_end|>".to_string()],
620        )
621    };
622    match arch {
623        ArchType::Llama if id.contains("tinyllama") => (
624            ChatTemplate::from_template(builtin::ZEPHYR),
625            vec!["</s>".to_string()],
626        ),
627        ArchType::Llama
628            if id.contains("llama-2")
629                || id.contains("llama2")
630                || (mt.contains("llama") && !id.contains("llama-3") && !id.contains("llama3")) =>
631        {
632            (
633                ChatTemplate::from_template(builtin::LLAMA2),
634                vec!["</s>".to_string()],
635            )
636        }
637        ArchType::Llama => (
638            ChatTemplate::from_template(builtin::LLAMA3),
639            vec!["<|eot_id|>".to_string()],
640        ),
641        ArchType::Gemma => (
642            ChatTemplate::from_template(builtin::GEMMA),
643            vec!["<end_of_turn>".to_string()],
644        ),
645        ArchType::Phi | ArchType::Qwen => chatml(),
646        _ => chatml(),
647    }
648}