1use std::path::PathBuf;
6use std::sync::{Arc, Mutex};
7
8use anyhow::{Context, Result};
9use tokio::sync::mpsc;
10use tokio_stream::wrappers::ReceiverStream;
11use tracing::debug;
12
13use sapient_hub::model_info::{ArchType, ModelInfo};
14use sapient_hub::resolver::ModelFiles;
15use sapient_hub::{tokenizer_fallback_model, HubClient, LoadOptions as HubOptions};
16use sapient_io::GgufLoader;
17use sapient_models::{ForwardEngine, LlmBackendKind};
18use sapient_tokenizers::{
19 chat::{builtin, ChatMessage, ChatTemplate},
20 tokenizer::{SapientTokenizer, TokenizerOptions},
21};
22
23use crate::sampler::{Sampler, SamplingStrategy};
24
25#[derive(Debug, Clone)]
29pub struct GenerationConfig {
30 pub max_new_tokens: usize,
32 pub eos_token_id: Option<u32>,
34 pub strategy: SamplingStrategy,
36 pub stop_sequences: Vec<String>,
38}
39
40impl Default for GenerationConfig {
41 fn default() -> Self {
42 Self {
43 max_new_tokens: 512,
44 eos_token_id: None,
45 strategy: SamplingStrategy::default(),
46 stop_sequences: vec![],
47 }
48 }
49}
50
51#[derive(Debug, Clone, Default)]
55pub struct LoadOptions {
56 pub hub: HubOptions,
58 pub generation: GenerationConfig,
60 pub backend: LlmBackendKind,
62}
63
64pub struct Pipeline {
68 tokenizer: Arc<SapientTokenizer>,
69 chat_template: Option<ChatTemplate>,
70 model_info: ModelInfo,
71 weight_paths: Vec<PathBuf>,
72 engine: Mutex<ForwardEngine>,
73 config: GenerationConfig,
74 backend: LlmBackendKind,
75}
76
77impl Pipeline {
78 pub async fn from_pretrained(model_id: &str) -> Result<Self> {
82 Self::from_pretrained_with_opts(model_id, LoadOptions::default()).await
83 }
84
85 pub async fn from_pretrained_with_opts(model_id: &str, opts: LoadOptions) -> Result<Self> {
87 debug!("Loading model: {model_id}");
88 let backend = opts.backend;
89
90 let mut hub_opts = opts.hub.clone();
91 if hub_opts.formats == LoadOptions::default().hub.formats {
92 hub_opts.formats = vec!["safetensors".into(), "bin".into(), "gguf".into()];
94 }
95
96 let hub = HubClient::with_options(hub_opts)?;
97 let model_files = hub
98 .download(model_id)
99 .await
100 .with_context(|| format!("Failed to download model '{model_id}'"))?;
101
102 ensure_weights_present(&model_files)?;
103
104 let single_gguf = model_files.weight_paths.len() == 1
108 && model_files.weight_paths[0]
109 .extension()
110 .and_then(|e| e.to_str())
111 == Some("gguf");
112 if single_gguf {
113 return Self::from_gguf_with_backend(&model_files.weight_paths[0], backend).await;
114 }
115
116 let model_info = ModelInfo::from_config_file(&model_files.config_path)
117 .context("Failed to parse config.json")?;
118 debug!("Detected architecture: {:?}", model_info.arch);
119
120 if model_info.raw.get("vision_config").is_some() {
121 debug!("Vision tower present — text-only mode (images not supported yet)");
122 }
123
124 let tok_opts = TokenizerOptions {
125 add_bos: true,
126 ..Default::default()
127 };
128 let tokenizer = if let Some(tok_path) = &model_files.tokenizer_path {
129 Arc::new(
130 SapientTokenizer::from_file(tok_path, tok_opts)
131 .context("Failed to load tokenizer")?,
132 )
133 } else if let Some(fallback_id) = tokenizer_fallback_model(model_id) {
134 debug!("No local tokenizer — loading from fallback Hub model '{fallback_id}'");
135 Arc::new(
136 SapientTokenizer::from_pretrained(fallback_id).with_context(|| {
137 format!(
138 "Failed to load tokenizer from fallback model '{fallback_id}' \
139 (GGUF repos often omit tokenizer files)"
140 )
141 })?,
142 )
143 } else {
144 Arc::new(
145 SapientTokenizer::from_pretrained(model_id)
146 .context("Failed to load tokenizer from Hub")?,
147 )
148 };
149
150 let mut builtin_stops: Vec<String> = Vec::new();
153 let chat_template = match model_files
154 .tokenizer_config_path
155 .as_ref()
156 .and_then(|p| ChatTemplate::from_tokenizer_config(p).ok())
157 {
158 Some(tmpl) => Some(tmpl),
159 None => {
160 let (tmpl, stops) =
161 builtin_template_for(&model_info.arch, model_id, &model_info.model_type);
162 builtin_stops = stops;
163 Some(tmpl)
164 }
165 };
166
167 validate_tokenizer_model_compat(model_id, &model_info, &tokenizer)?;
168
169 let engine = ForwardEngine::from_weight_paths_with_backend(
170 model_info.clone(),
171 &model_files.weight_paths,
172 backend,
173 )
174 .context("Failed to initialize inference engine")?;
175
176 let mut config = opts.generation;
177 if config.eos_token_id.is_none() {
178 config.eos_token_id = tokenizer.eos_id;
179 }
180 for s in builtin_stops {
182 if !config.stop_sequences.contains(&s) {
183 config.stop_sequences.push(s);
184 }
185 }
186
187 debug!(
188 "Pipeline ready — vocab_size={} layers={} backend={}",
189 model_info.vocab_size, model_info.num_hidden_layers, backend
190 );
191
192 Ok(Self {
193 tokenizer,
194 chat_template,
195 model_info,
196 weight_paths: model_files.weight_paths.clone(),
197 engine: Mutex::new(engine),
198 config,
199 backend,
200 })
201 }
202
203 pub async fn from_gguf(path: impl Into<PathBuf>) -> Result<Self> {
209 Self::from_gguf_with_backend(path, LlmBackendKind::Auto).await
210 }
211
212 pub async fn from_gguf_with_backend(
213 path: impl Into<PathBuf>,
214 backend: LlmBackendKind,
215 ) -> Result<Self> {
216 let path = path.into();
217 debug!("Loading GGUF: {}", path.display());
218
219 let (metadata, _) = GgufLoader::load_tensors_with_metadata(&path)
221 .with_context(|| format!("failed to load GGUF: {}", path.display()))?;
222
223 let model_info = ModelInfo::from_gguf_metadata(&metadata)
225 .context("failed to build ModelInfo from GGUF metadata")?;
226
227 let engine = ForwardEngine::from_gguf_with_backend(model_info.clone(), &path, backend)
230 .context("failed to initialise ForwardEngine from GGUF")?;
231
232 let model_id = metadata
234 .get("general.name")
235 .and_then(|v| v.as_str())
236 .unwrap_or("");
237 let tokenizer = if let Some(fallback) = tokenizer_fallback_model(model_id)
238 .or_else(|| tokenizer_fallback_model(model_info.model_type.as_str()))
239 {
240 Arc::new(
241 SapientTokenizer::from_pretrained(fallback)
242 .with_context(|| format!("failed to load tokenizer from '{fallback}'"))?,
243 )
244 } else {
245 anyhow::bail!(
246 "Cannot determine tokenizer for GGUF model '{}' (arch: {}). \
247 Load via `Pipeline::from_pretrained` with a registry alias instead.",
248 path.display(),
249 model_info.model_type
250 );
251 };
252
253 let (chat_template, builtin_stops) =
254 builtin_template_for(&model_info.arch, model_id, &model_info.model_type);
255
256 let mut config = GenerationConfig::default();
257 if config.eos_token_id.is_none() {
258 config.eos_token_id = tokenizer.eos_id;
259 }
260 for s in builtin_stops {
261 if !config.stop_sequences.contains(&s) {
262 config.stop_sequences.push(s);
263 }
264 }
265
266 validate_tokenizer_model_compat(model_id, &model_info, &tokenizer)?;
267
268 Ok(Self {
269 tokenizer,
270 chat_template: Some(chat_template),
271 model_info,
272 weight_paths: vec![path],
273 engine: Mutex::new(engine),
274 config,
275 backend,
276 })
277 }
278
279 pub async fn generate(&self, prompt: &str) -> Result<String> {
283 let input_ids = self.tokenizer.encode(prompt)?;
284 let output_ids = self.generate_from_tokens(input_ids).await?;
285 let text = self.tokenizer.decode(&output_ids, true)?;
286 Ok(self.trim_stop_sequences(text))
287 }
288
289 pub async fn generate_with_config(
291 &self,
292 prompt: &str,
293 config: &GenerationConfig,
294 ) -> Result<String> {
295 let input_ids = self.tokenizer.encode(prompt)?;
296 let output_ids = self
297 .generate_from_tokens_with_config(input_ids, config)
298 .await?;
299 let text = self.tokenizer.decode(&output_ids, true)?;
300 Ok(self.trim_stop_sequences(text))
301 }
302
303 fn eos_token_ids(&self) -> Vec<u32> {
307 let mut ids = self.tokenizer.eos_ids.clone();
308 if let Some(e) = self.config.eos_token_id {
309 if !ids.contains(&e) {
310 ids.push(e);
311 }
312 }
313 ids
314 }
315
316 fn trim_stop_sequences(&self, text: String) -> String {
318 match earliest_stop(&text, &self.config.stop_sequences) {
319 Some(idx) => text[..idx].to_string(),
320 None => text,
321 }
322 }
323
324 pub fn format_chat_prompt(&self, messages: &[ChatMessage]) -> Result<String> {
326 if let Some(tmpl) = &self.chat_template {
327 tmpl.render(messages, true)
328 .context("Failed to render chat template")
329 } else {
330 Ok(messages
331 .iter()
332 .map(|m| format!("{}: {}", m.role, m.content))
333 .collect::<Vec<_>>()
334 .join("\n"))
335 }
336 }
337
338 pub async fn chat(&self, messages: &[ChatMessage]) -> Result<String> {
340 let prompt = self.format_chat_prompt(messages)?;
341 self.generate(&prompt).await
342 }
343
344 pub async fn chat_stream(&self, messages: &[ChatMessage]) -> ReceiverStream<String> {
346 match self.format_chat_prompt(messages) {
347 Ok(prompt) => self.generate_stream(&prompt).await,
348 Err(e) => {
349 let (tx, rx) = mpsc::channel(1);
350 let _ = tx.try_send(format!("Error: {e}"));
351 ReceiverStream::new(rx)
352 }
353 }
354 }
355
356 pub async fn generate_stream(&self, prompt: &str) -> ReceiverStream<String> {
358 let (tx, rx) = mpsc::channel(64);
359 let input_ids = self.tokenizer.encode(prompt).unwrap_or_default();
360 let eos_ids = self.eos_token_ids();
361 let max_new = self.config.max_new_tokens;
362 let strategy = self.config.strategy.clone();
363 let stop = self.config.stop_sequences.clone();
364 let tok = Arc::clone(&self.tokenizer);
365 let model_info = self.model_info.clone();
366 let weight_paths = self.weight_paths.clone();
367 let backend = self.configured_backend();
368
369 tokio::task::spawn_blocking(move || {
370 let mut engine = match ForwardEngine::from_weight_paths_with_backend(
371 model_info,
372 &weight_paths,
373 backend,
374 ) {
375 Ok(e) => e,
376 Err(e) => {
377 let _ = tx.blocking_send(format!("Error: {e}"));
378 return;
379 }
380 };
381 let mut sampler = Sampler::new(strategy);
382 let mut all_tokens = input_ids;
383 let mut generated: Vec<u32> = Vec::new();
384 let mut emitted = 0usize;
389 let mut clean_stop = false;
390
391 engine.reset_cache();
392 for step in 0..max_new {
393 let chunk = if step == 0 {
394 all_tokens.clone()
395 } else {
396 vec![*all_tokens.last().unwrap()]
397 };
398 let logits = match engine.forward_logits(&chunk, true) {
399 Ok(v) => v,
400 Err(e) => {
401 let _ = tx.blocking_send(format!("Error: {e}"));
402 break;
403 }
404 };
405
406 let next = match sampler.sample(&logits, &all_tokens) {
407 Ok(t) => t,
408 Err(e) => {
409 let _ = tx.blocking_send(format!("Error: {e}"));
410 break;
411 }
412 };
413
414 generated.push(next);
415 all_tokens.push(next);
416
417 if eos_ids.contains(&next) {
418 clean_stop = true;
419 break;
420 }
421
422 let text = match tok.decode(&generated, true) {
423 Ok(t) => t,
424 Err(_) => continue,
425 };
426
427 if let Some(idx) = earliest_stop(&text, &stop) {
429 if idx > emitted {
430 let _ = tx.blocking_send(text[emitted..idx].to_string());
431 }
432 clean_stop = true;
433 break;
434 }
435
436 let safe = safe_emit_end(&text, &stop);
438 if safe > emitted {
439 if tx.blocking_send(text[emitted..safe].to_string()).is_err() {
440 break;
441 }
442 emitted = safe;
443 }
444 }
445
446 if !clean_stop {
448 if let Ok(text) = tok.decode(&generated, true) {
449 if text.len() > emitted {
450 let _ = tx.blocking_send(text[emitted..].to_string());
451 }
452 }
453 }
454 });
455
456 ReceiverStream::new(rx)
457 }
458
459 pub async fn embed(&self, text: &str) -> Result<Vec<f32>> {
461 let ids = self.tokenizer.encode(text)?;
462 let mut engine = self.engine.lock().map_err(|e| anyhow::anyhow!("{e}"))?;
463 engine.embed(&ids)
464 }
465
466 async fn generate_from_tokens(&self, input_ids: Vec<u32>) -> Result<Vec<u32>> {
469 self.generate_from_tokens_with_config(input_ids, &self.config)
470 .await
471 }
472
473 async fn generate_from_tokens_with_config(
474 &self,
475 input_ids: Vec<u32>,
476 config: &GenerationConfig,
477 ) -> Result<Vec<u32>> {
478 let mut engine = self.engine.lock().map_err(|e| anyhow::anyhow!("{e}"))?;
479 let mut sampler = Sampler::new(config.strategy.clone());
480 let mut generated: Vec<u32> = Vec::new();
481 let mut all_tokens = input_ids;
482 let eos_ids = self.eos_token_ids();
483
484 engine.reset_cache();
485
486 let logits = engine.forward_logits(&all_tokens, true)?;
488 let mut next = sampler.sample(&logits, &all_tokens)?;
489 generated.push(next);
490 all_tokens.push(next);
491
492 if eos_ids.contains(&next) {
493 return Ok(generated);
494 }
495
496 for step in 1..config.max_new_tokens {
497 let logits = engine.forward_logits(&[next], true)?;
498 next = sampler.sample(&logits, &all_tokens)?;
499 generated.push(next);
500 all_tokens.push(next);
501
502 if eos_ids.contains(&next) {
503 debug!("EOS token generated at step {step}");
504 break;
505 }
506
507 if !config.stop_sequences.is_empty() {
508 let decoded = self.tokenizer.decode(&generated, true).unwrap_or_default();
509 if config
510 .stop_sequences
511 .iter()
512 .any(|s| decoded.contains(s.as_str()))
513 {
514 break;
515 }
516 }
517 }
518
519 Ok(generated)
520 }
521
522 pub fn tokenizer(&self) -> &SapientTokenizer {
523 &self.tokenizer
524 }
525 pub fn model_info(&self) -> &ModelInfo {
526 &self.model_info
527 }
528 pub fn arch(&self) -> &ArchType {
529 &self.model_info.arch
530 }
531
532 fn configured_backend(&self) -> LlmBackendKind {
533 self.backend
534 }
535}
536
537fn ensure_weights_present(files: &ModelFiles) -> Result<()> {
538 if files.weight_paths.is_empty() {
539 anyhow::bail!("No weight files found for this model");
540 }
541 Ok(())
542}
543
544fn validate_tokenizer_model_compat(
545 model_id: &str,
546 model_info: &ModelInfo,
547 tokenizer: &SapientTokenizer,
548) -> Result<()> {
549 let tokenizer_vocab = tokenizer.vocab_size();
550 if tokenizer_vocab > model_info.vocab_size {
551 anyhow::bail!(
552 "tokenizer/model vocab mismatch for '{model_id}': tokenizer has {tokenizer_vocab} tokens but model config vocab_size is {}",
553 model_info.vocab_size
554 );
555 }
556
557 if let Some(eos) = tokenizer.eos_id {
558 if eos as usize >= model_info.vocab_size {
559 anyhow::bail!(
560 "tokenizer/model EOS mismatch for '{model_id}': eos_token_id {eos} is outside model vocab_size {}",
561 model_info.vocab_size
562 );
563 }
564 } else {
565 tracing::warn!(
566 model = model_id,
567 "tokenizer has no recognized EOS token; generation will stop only by max_new_tokens or stop strings"
568 );
569 }
570
571 Ok(())
572}
573
574fn earliest_stop(text: &str, stops: &[String]) -> Option<usize> {
576 stops
577 .iter()
578 .filter(|s| !s.is_empty())
579 .filter_map(|s| text.find(s.as_str()))
580 .min()
581}
582
583fn safe_emit_end(text: &str, stops: &[String]) -> usize {
587 let mut hold = 0usize;
588 for s in stops {
589 let max_k = s.len().min(text.len());
590 for k in (1..max_k).rev() {
591 if !s.is_char_boundary(k) {
592 continue;
593 }
594 if text.ends_with(&s[..k]) {
595 hold = hold.max(k);
596 break;
597 }
598 }
599 }
600 text.len() - hold
601}
602
603fn builtin_template_for(
610 arch: &ArchType,
611 model_id: &str,
612 model_type: &str,
613) -> (ChatTemplate, Vec<String>) {
614 let id = model_id.to_ascii_lowercase();
615 let mt = model_type.to_ascii_lowercase();
616 let chatml = || {
617 (
618 ChatTemplate::from_template(builtin::CHATML),
619 vec!["<|im_end|>".to_string()],
620 )
621 };
622 match arch {
623 ArchType::Llama if id.contains("tinyllama") => (
624 ChatTemplate::from_template(builtin::ZEPHYR),
625 vec!["</s>".to_string()],
626 ),
627 ArchType::Llama
628 if id.contains("llama-2")
629 || id.contains("llama2")
630 || (mt.contains("llama") && !id.contains("llama-3") && !id.contains("llama3")) =>
631 {
632 (
633 ChatTemplate::from_template(builtin::LLAMA2),
634 vec!["</s>".to_string()],
635 )
636 }
637 ArchType::Llama => (
638 ChatTemplate::from_template(builtin::LLAMA3),
639 vec!["<|eot_id|>".to_string()],
640 ),
641 ArchType::Gemma => (
642 ChatTemplate::from_template(builtin::GEMMA),
643 vec!["<end_of_turn>".to_string()],
644 ),
645 ArchType::Phi | ArchType::Qwen => chatml(),
646 _ => chatml(),
647 }
648}