Skip to main content

oxide_rs/
lib.rs

1//! Oxide-rs
2//!
3//! Fast AI Inference Library & CLI in Rust - A lightweight, CPU-based LLM inference engine inspired by llama.cpp.
4//!
5//! # Features
6//!
7//! - GGUF model support (LLaMA, LFM2 architectures)
8//! - Full tokenizer compatibility (SPM, BPE, WPM, UGM, RWKV)
9//! - Automatic chat templates from GGUF files
10//! - Streaming token generation
11//! - Multiple sampling strategies (temperature, top-k, top-p)
12//! - Interactive REPL and one-shot modes
13//! - Memory-mapped loading for instant startup
14//!
15//! # Quick Start
16//!
17//! ## CLI Usage
18//!
19//! ```bash
20//! # Install via cargo
21//! cargo install oxide-rs
22//!
23//! # Run interactively
24//! oxide-rs -m model.gguf
25//!
26//! # One-shot generation
27//! oxide-rs -m model.gguf --once --prompt "Hello!"
28//! ```
29//!
30//! ## Library Usage
31//!
32//! ```rust,ignore
33//! use oxide_rs::{generate, GenerateOptions};
34//!
35//! fn main() -> Result<(), Box<dyn std::error::Error>> {
36//!     let result = generate(
37//!         "model.gguf",
38//!         GenerateOptions::default(),
39//!         "Hello, how are you?",
40//!     )?;
41//!     println!("{}", result);
42//!     Ok(())
43//! }
44//! ```
45//!
46//! ## Builder API
47//!
48//! For more control, use the `Model` builder:
49//!
50//! ```rust,ignore
51//! use oxide_rs::Model;
52//!
53//! fn main() -> Result<(), Box<dyn std::error::Error>> {
54//!     let mut model = Model::new("model.gguf")
55//!         .with_options(oxide_rs::GenerateOptions {
56//!             max_tokens: 256,
57//!             temperature: 0.7,
58//!             ..Default::default()
59//!         })
60//!         .load()?;
61//!
62//!     let response = model.generate("What is Rust?")?;
63//!     println!("{}", response);
64//!     Ok(())
65//! }
66//! ```
67//!
68//! # Requirements
69//!
70//! - Rust 1.70+ (2021 edition)
71//! - A GGUF quantized model file with embedded chat template
72//!
73//! # Links
74//!
75//! - [GitHub Repository](https://github.com/theawakener0/oxide-rs)
76//! - [crates.io](https://crates.io/crates/oxide-rs)
77//! - [Documentation](https://docs.rs/oxide-rs)
78
79pub mod cli;
80pub mod inference;
81pub mod model;
82pub mod server;
83pub mod tui;
84
85use std::path::Path;
86use std::path::PathBuf;
87
88pub use inference::{
89    BatchConfig, DynamicBatcher, Generator, PagedAttentionConfig, PagedKvCache, 
90    PrefixCache, PrefixCacheConfig, SimdLevel, StreamEvent,
91    ThreadPinnerConfig, ThreadPinner,
92};
93pub use model::{
94    download, format_size, get_hf_cache_dir, get_model_info, list_models, list_repo_files,
95    register_model, unregister_model, ModelEntry, GgufMetadata, Model as ModelWrapper, 
96    TokenizerWrapper,
97};
98
99/// Configuration options for text generation.
100///
101/// # Example
102///
103/// ```rust,ignore,ignore
104/// use oxide_rs::GenerateOptions;
105///
106/// let options = GenerateOptions {
107///     max_tokens: 512,
108///     temperature: 0.3,
109///     top_p: None,
110///     top_k: None,
111///     repeat_penalty: 1.1,
112///     repeat_last_n: 64,
113///     seed: 299792458,
114///     system_prompt: None,
115/// };
116/// ```
117#[derive(Clone, Debug)]
118pub struct GenerateOptions {
119    /// Maximum number of tokens to generate.
120    ///
121    /// Default: `512`
122    pub max_tokens: usize,
123
124    /// Sampling temperature. Higher values produce more diverse output,
125    /// lower values produce more focused output.
126    ///
127    /// Set to `0.0` for greedy/argmax sampling.
128    ///
129    /// Default: `0.3`
130    pub temperature: f64,
131
132    /// Nucleus sampling (top-p) threshold. Limits sampling to the smallest
133    /// set of tokens whose cumulative probability exceeds this threshold.
134    ///
135    /// Default: `None`
136    pub top_p: Option<f64>,
137
138    /// Top-k sampling. Limits sampling to the k most likely tokens.
139    ///
140    /// Default: `None`
141    pub top_k: Option<usize>,
142
143    /// Penalty applied to repeated tokens. Values > 1.0 reduce repetition.
144    ///
145    /// Default: `1.1`
146    pub repeat_penalty: f32,
147
148    /// Number of previous tokens to consider for repeat penalty.
149    ///
150    /// Default: `64`
151    pub repeat_last_n: usize,
152
153    /// Batch size for warmup/prefill.
154    ///
155    /// Default: `128`
156    pub batch_size: usize,
157
158    /// Random seed for reproducibility. Same seed + same input = same output.
159    ///
160    /// Default: `299792458`
161    pub seed: u64,
162
163    /// System prompt to prepend to the conversation.
164    ///
165    /// Default: `None`
166    pub system_prompt: Option<String>,
167
168    /// Maximum batch size for dynamic batching.
169    ///
170    /// Default: `4`
171    pub max_batch_size: usize,
172
173    /// Time window (in ms) to wait for batching requests.
174    ///
175    /// Default: `1`
176    pub batch_window_ms: u64,
177
178    /// Enable prefix caching for faster TTFT.
179    ///
180    /// Default: `true`
181    pub enable_prefix_cache: bool,
182
183    /// Memory budget for prefix cache (in MB).
184    ///
185    /// Default: `512`
186    pub cache_memory_mb: usize,
187
188    /// Number of CPU threads (0 = auto-detect, use n-1).
189    ///
190    /// Default: `0` (auto)
191    pub cpu_threads: usize,
192
193    /// Number of cores to reserve for OS.
194    ///
195    /// Default: `0`
196    pub reserve_cores: usize,
197
198    /// SIMD level (auto, avx512, avx2, neon, scalar).
199    ///
200    /// Default: `auto`
201    pub simd_level: String,
202}
203
204impl Default for GenerateOptions {
205    fn default() -> Self {
206        Self {
207            max_tokens: 512,
208            temperature: 0.3,
209            top_p: None,
210            top_k: None,
211            repeat_penalty: 1.1,
212            repeat_last_n: 64,
213            batch_size: 128,
214            seed: 299792458,
215            system_prompt: None,
216            max_batch_size: 4,
217            batch_window_ms: 1,
218            enable_prefix_cache: true,
219            cache_memory_mb: 512,
220            cpu_threads: 0,
221            reserve_cores: 0,
222            simd_level: "auto".to_string(),
223        }
224    }
225}
226
227/// High-level model wrapper with builder pattern for text generation.
228///
229/// Use this when you need to:
230/// - Generate multiple times with the same model
231/// - Use streaming callbacks
232/// - Maintain conversation history
233/// - Access model metadata
234///
235/// # Example
236///
237/// ```rust,ignore,ignore
238/// use oxide_rs::Model;
239///
240/// let mut model = Model::new("model.gguf")?
241///     .with_options(oxide_rs::GenerateOptions {
242///         max_tokens: 256,
243///         temperature: 0.7,
244///         ..Default::default()
245///     })
246///     .load()?;
247///
248/// let response = model.generate("Hello!")?;
249/// println!("{}", response);
250/// ```
251pub struct Model {
252    generator: Option<Generator>,
253    model_path: PathBuf,
254    tokenizer_path: Option<PathBuf>,
255    options: GenerateOptions,
256}
257
258impl Model {
259    /// Create a new Model instance.
260    ///
261    /// This only creates the Model struct - use `load()` to actually load the model.
262    ///
263    /// # Arguments
264    ///
265    /// * `model_path` - Path to a GGUF model file
266    ///
267    /// # Example
268    ///
269    /// ```rust,ignore
270    /// let model = Model::new("model.gguf")?;
271    /// ```
272    pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self, Box<dyn std::error::Error>> {
273        Ok(Self {
274            generator: None,
275            model_path: model_path.as_ref().to_path_buf(),
276            tokenizer_path: None,
277            options: GenerateOptions::default(),
278        })
279    }
280
281    /// Set generation options.
282    ///
283    /// # Example
284    ///
285    /// ```rust,ignore
286    /// let model = Model::new("model.gguf")
287    ///     .with_options(GenerateOptions {
288    ///         max_tokens: 256,
289    ///         temperature: 0.8,
290    ///         ..Default::default()
291    ///     });
292    /// ```
293    pub fn with_options(mut self, options: GenerateOptions) -> Self {
294        self.options = options;
295        self
296    }
297
298    /// Set a custom tokenizer path.
299    ///
300    /// If not provided, the tokenizer will be extracted from the GGUF file.
301    ///
302    /// # Example
303    ///
304    /// ```rust,ignore
305    /// let model = Model::new("model.gguf")
306    ///     .with_tokenizer("tokenizer.json");
307    /// ```
308    pub fn with_tokenizer<P: AsRef<Path>>(mut self, tokenizer_path: P) -> Self {
309        self.tokenizer_path = Some(tokenizer_path.as_ref().to_path_buf());
310        self
311    }
312
313    /// Load the model into memory.
314    ///
315    /// This must be called before `generate()`.
316    ///
317    /// # Example
318    ///
319    /// ```rust,ignore
320    /// let mut model = Model::new("model.gguf")?.load()?;
321    /// ```
322    pub fn load(&mut self) -> Result<(), Box<dyn std::error::Error>> {
323        let generator = Generator::new(
324            &self.model_path,
325            self.tokenizer_path.as_ref(),
326            self.options.temperature,
327            self.options.top_p,
328            self.options.top_k,
329            self.options.seed,
330            self.options.system_prompt.clone(),
331            self.options.batch_size,
332        )?;
333        self.generator = Some(generator);
334        Ok(())
335    }
336
337    /// Generate text from a prompt.
338    ///
339    /// Requires `load()` to be called first.
340    ///
341    /// # Arguments
342    ///
343    /// * `prompt` - The input prompt
344    ///
345    /// # Example
346    ///
347    /// ```rust,ignore
348    /// let response = model.generate("What is Rust?")?;
349    /// println!("{}", response);
350    /// ```
351    pub fn generate(&mut self, prompt: &str) -> Result<String, Box<dyn std::error::Error>> {
352        let generator = self
353            .generator
354            .as_mut()
355            .ok_or("Model not loaded. Call load() first.")?;
356
357        let result = generator.generate(
358            prompt,
359            self.options.max_tokens,
360            self.options.repeat_penalty,
361            self.options.repeat_last_n,
362            |_event| {},
363        )?;
364
365        Ok(result)
366    }
367
368    /// Generate text with streaming callback.
369    ///
370    /// Tokens are passed to the callback as they're generated, enabling
371    /// real-time output display.
372    ///
373    /// Requires `load()` to be called first.
374    ///
375    /// # Arguments
376    ///
377    /// * `prompt` - The input prompt
378    /// * `callback` - Function called for each generated token
379    ///
380    /// # Example
381    ///
382    /// ```rust,ignore
383    /// model.generate_stream("Tell me a story", |token| {
384    ///     print!("{}", token);
385    /// })?;
386    /// ```
387    pub fn generate_stream<F>(
388        &mut self,
389        prompt: &str,
390        mut callback: F,
391    ) -> Result<String, Box<dyn std::error::Error>>
392    where
393        F: FnMut(String),
394    {
395        let generator = self
396            .generator
397            .as_mut()
398            .ok_or("Model not loaded. Call load() first.")?;
399
400        let mut output = String::new();
401        generator.generate(
402            prompt,
403            self.options.max_tokens,
404            self.options.repeat_penalty,
405            self.options.repeat_last_n,
406            |event| match event {
407                StreamEvent::Token(t) => {
408                    output.push_str(&t);
409                    callback(t);
410                }
411                StreamEvent::Done => {}
412                StreamEvent::PrefillStatus(_) => {}
413            },
414        )?;
415
416        Ok(output)
417    }
418
419    /// Generate text from multiple prompts in batch.
420    ///
421    /// Processes multiple prompts sequentially, sharing the loaded model for efficiency.
422    /// Each prompt generates independently with its own output.
423    ///
424    /// Requires `load()` to be called first.
425    ///
426    /// # Arguments
427    ///
428    /// * `prompts` - Vector of input prompts
429    ///
430    /// # Example
431    ///
432    /// ```rust,ignore
433    /// let prompts = vec!["Hello!", "How are you?", "What's up?"];
434    /// let results = model.generate_batch(prompts)?;
435    /// for result in results {
436    ///     println!("{}", result);
437    /// }
438    /// ```
439    pub fn generate_batch(
440        &mut self,
441        prompts: Vec<String>,
442    ) -> Result<Vec<String>, Box<dyn std::error::Error>> {
443        let generator = self
444            .generator
445            .as_mut()
446            .ok_or("Model not loaded. Call load() first.")?;
447
448        let result = generator.generate_batch(
449            prompts,
450            self.options.max_tokens,
451            self.options.repeat_penalty,
452            self.options.repeat_last_n,
453        )?;
454
455        Ok(result)
456    }
457
458    /// Pre-compile compute kernels for faster first-token generation.
459    ///
460    /// Call this after `load()` to warm up the model before first use.
461    ///
462    /// # Arguments
463    ///
464    /// * `num_tokens` - Number of tokens to use for warmup (default: 128)
465    ///
466    /// # Example
467    ///
468    /// ```rust,ignore
469    /// model.load()?;
470    /// model.warmup(128)?;
471    /// // First generation will be faster
472    /// ```
473    pub fn warmup(&mut self, num_tokens: usize) -> Result<(), Box<dyn std::error::Error>> {
474        let generator = self
475            .generator
476            .as_mut()
477            .ok_or("Model not loaded. Call load() first.")?;
478        generator.warmup(num_tokens)?;
479        Ok(())
480    }
481
482    /// Clear conversation history.
483    ///
484    /// Removes all previous messages from the conversation context.
485    ///
486    /// # Example
487    ///
488    /// ```rust,ignore
489    /// model.generate("Hello")?;
490    /// model.clear_history();
491    /// ```
492    pub fn clear_history(&mut self) {
493        if let Some(ref mut generator) = self.generator {
494            generator.clear_history();
495        }
496    }
497
498    /// Get model metadata.
499    ///
500    /// Returns information about the loaded model including name,
501    /// architecture, layer count, embedding size, etc.
502    ///
503    /// # Example
504    ///
505    /// ```rust,ignore
506    /// if let Some(meta) = model.metadata() {
507    ///     println!("Model: {}", meta.name);
508    ///     println!("Architecture: {}", meta.architecture);
509    /// }
510    /// ```
511    pub fn metadata(&self) -> Option<&GgufMetadata> {
512        self.generator.as_ref().map(|g| g.metadata())
513    }
514
515    /// Get current context usage.
516    ///
517    /// Returns the number of tokens currently in the context.
518    ///
519    /// # Example
520    ///
521    /// ```rust,ignore
522    /// println!("Using {} tokens", model.context_used());
523    /// ```
524    pub fn context_used(&self) -> Option<usize> {
525        self.generator.as_ref().map(|g| g.context_used())
526    }
527
528    /// Get context limit.
529    ///
530    /// Returns the maximum context window size.
531    ///
532    /// # Example
533    ///
534    /// ```rust,ignore
535    /// println!("Context limit: {} tokens", model.context_limit());
536    /// ```
537    pub fn context_limit(&self) -> Option<usize> {
538        self.generator.as_ref().map(|g| g.context_limit())
539    }
540
541    /// Get context usage percentage.
542    ///
543    /// Returns the percentage of context used (0.0 - 100.0).
544    ///
545    /// # Example
546    ///
547    /// ```rust,ignore
548    /// println!("{:.1}% context used", model.context_percentage());
549    /// ```
550    pub fn context_percentage(&self) -> Option<f32> {
551        self.generator.as_ref().map(|g| g.context_percentage())
552    }
553}
554
555/// Simple one-shot text generation function.
556///
557/// This is the easiest way to generate text - just provide the model path,
558/// options, and prompt. The model is loaded and used in a single call.
559///
560/// For multiple generations, use [`Model`] instead to avoid reloading.
561///
562/// # Arguments
563///
564/// * `model_path` - Path to GGUF model file
565/// * `options` - Generation configuration
566/// * `prompt` - Input prompt
567///
568/// # Returns
569///
570/// Generated text string
571///
572/// # Example
573///
574/// ```rust,ignore,ignore
575/// use oxide_rs::{generate, GenerateOptions};
576///
577/// let result = generate(
578///     "model.gguf",
579///     GenerateOptions::default(),
580///     "Hello, how are you?",
581/// )?;
582/// println!("{}", result);
583/// ```
584pub fn generate<P: AsRef<Path>>(
585    model_path: P,
586    options: GenerateOptions,
587    prompt: &str,
588) -> Result<String, Box<dyn std::error::Error>> {
589    let mut model = Model::new(model_path)?.with_options(options);
590    model.load()?;
591    model.generate(prompt)
592}