#[cfg(feature = "use-mimalloc")]
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
pub mod arena;
pub mod batch;
pub mod capabilities;
pub mod context;
pub mod embedding;
pub mod error;
pub mod memory;
pub mod memory_monitor;
pub mod model;
pub mod sampling;
pub mod sampling_simd;
pub mod session;
pub mod sys;
pub mod token;
pub mod vocab;
#[cfg(feature = "async")]
pub mod async_support;
pub mod builder;
pub mod config;
#[cfg(feature = "daemon")]
pub mod daemon;
#[cfg(feature = "format-conversion")]
pub mod format_conversion;
#[cfg(feature = "late-interaction")]
pub mod late_interaction;
#[cfg(feature = "multimodal")]
pub mod multimodal;
#[cfg(feature = "parallel")]
pub mod parallel;
#[cfg(feature = "streaming")]
pub mod streaming;
#[cfg(feature = "streaming-audio")]
pub mod streaming_audio;
#[cfg(feature = "tokio-runtime")]
pub mod tokio_integration;
#[cfg(feature = "web")]
pub mod web;
#[cfg(feature = "websockets")]
pub mod websockets;
pub mod gpu_advanced;
pub mod grammar;
pub mod hf;
pub mod lora;
pub mod modelfile;
pub mod presets;
pub mod speculative;
pub mod structured_output;
pub use hf::{GGUFFile, HFClient, HFModelInfo, ModelSearchFilters, QuantizationType};
pub use modelfile::{
find_modelfile, Message as ModelfileMessage, Modelfile, ModelfileError, ModelfileParser,
ParameterValue,
};
pub use structured_output::{JsonSchemaConverter, StructuredOutputError};
pub use presets::HardwarePreset;
pub fn backend_init() {
unsafe {
sys::llama_backend_init();
}
}
pub fn backend_free() {
unsafe {
sys::llama_backend_free();
}
}
pub fn time_us() -> i64 {
unsafe { sys::llama_time_us() }
}
pub fn max_devices() -> usize {
unsafe { sys::llama_max_devices() }
}
pub fn supports_gpu_offload() -> bool {
unsafe { sys::llama_supports_gpu_offload() }
}
pub fn supports_mmap() -> bool {
unsafe { sys::llama_supports_mmap() }
}
pub fn supports_mlock() -> bool {
unsafe { sys::llama_supports_mlock() }
}
pub fn supports_rpc() -> bool {
unsafe { sys::llama_supports_rpc() }
}
pub fn print_system_info() -> String {
unsafe {
let ptr = sys::llama_print_system_info();
if ptr.is_null() {
String::new()
} else {
std::ffi::CStr::from_ptr(ptr).to_string_lossy().to_string()
}
}
}
pub fn numa_init(strategy: sys::ggml_numa_strategy) {
unsafe {
sys::llama_numa_init(strategy);
}
}
pub fn system_info() -> SystemInfo {
SystemInfo {
max_devices: max_devices(),
supports_gpu_offload: supports_gpu_offload(),
supports_mmap: supports_mmap(),
supports_mlock: supports_mlock(),
supports_rpc: supports_rpc(),
details: print_system_info(),
}
}
#[derive(Debug, Clone)]
pub struct SystemInfo {
pub max_devices: usize,
pub supports_gpu_offload: bool,
pub supports_mmap: bool,
pub supports_mlock: bool,
pub supports_rpc: bool,
pub details: String,
}
pub type LogCallback = extern "C" fn(
level: i32,
text: *const std::os::raw::c_char,
user_data: *mut std::os::raw::c_void,
);
pub unsafe fn log_set(callback: LogCallback, user_data: *mut std::os::raw::c_void) {
sys::llama_log_set(Some(callback), user_data);
}
pub fn batch_get_one(tokens: &[i32]) -> sys::llama_batch {
unsafe { sys::llama_batch_get_one(tokens.as_ptr() as *mut i32, tokens.len() as i32) }
}
pub fn chat_builtin_template_count() -> i32 {
unsafe { sys::llama_chat_builtin_templates(std::ptr::null_mut(), 0) }
}
pub struct VocabInfo {
pub bos_token: i32,
pub eos_token: i32,
pub cls_token: i32,
pub sep_token: i32,
pub nl_token: i32,
pub pad_token: i32,
pub eot_token: i32,
pub add_bos: bool,
pub add_eos: bool,
}
impl Model {
pub fn vocab_info(&self) -> VocabInfo {
let vocab_ptr = unsafe { sys::llama_model_get_vocab(self.as_ptr()) };
VocabInfo {
bos_token: unsafe { sys::llama_vocab_bos(vocab_ptr) },
eos_token: unsafe { sys::llama_vocab_eos(vocab_ptr) },
cls_token: unsafe { sys::llama_vocab_cls(vocab_ptr) },
sep_token: unsafe { sys::llama_vocab_sep(vocab_ptr) },
nl_token: unsafe { sys::llama_vocab_nl(vocab_ptr) },
pad_token: unsafe { sys::llama_vocab_pad(vocab_ptr) },
eot_token: unsafe { sys::llama_vocab_eot(vocab_ptr) },
add_bos: unsafe { sys::llama_vocab_get_add_bos(vocab_ptr) },
add_eos: unsafe { sys::llama_vocab_get_add_eos(vocab_ptr) },
}
}
pub fn vocab_get_text(&self, token: i32) -> Option<String> {
let vocab_ptr = unsafe { sys::llama_model_get_vocab(self.as_ptr()) };
let ptr = unsafe { sys::llama_vocab_get_text(vocab_ptr, token) };
if ptr.is_null() {
None
} else {
Some(unsafe { std::ffi::CStr::from_ptr(ptr).to_string_lossy().to_string() })
}
}
pub fn vocab_get_score(&self, token: i32) -> f32 {
let vocab_ptr = unsafe { sys::llama_model_get_vocab(self.as_ptr()) };
unsafe { sys::llama_vocab_get_score(vocab_ptr, token) }
}
pub fn vocab_get_attr(&self, token: i32) -> sys::llama_token_attr {
let vocab_ptr = unsafe { sys::llama_model_get_vocab(self.as_ptr()) };
unsafe { sys::llama_vocab_get_attr(vocab_ptr, token) }
}
pub fn vocab_is_control(&self, token: i32) -> bool {
let vocab_ptr = unsafe { sys::llama_model_get_vocab(self.as_ptr()) };
unsafe { sys::llama_vocab_is_control(vocab_ptr, token) }
}
pub fn vocab_is_eog(&self, token: i32) -> bool {
let vocab_ptr = unsafe { sys::llama_model_get_vocab(self.as_ptr()) };
unsafe { sys::llama_vocab_is_eog(vocab_ptr, token) }
}
}
pub mod control_vector;
pub mod quantization;
pub use arena::{
with_generation_arena, with_generation_arena_mut, ArenaCandidates, ArenaTokenCandidate,
GenerationArena,
};
pub use batch::Batch;
pub use capabilities::{
detect_capabilities, registry, Capabilities, CapabilityRegistry, ModelFamilyConfig,
ThinkingTokens, TokenConfig, ToolFormat,
};
pub use context::{Context, ContextParams, KvCacheType};
pub use embedding::{EmbeddingUtil, Embeddings};
pub use error::MullamaError;
pub use memory::{
ConstrainedMemoryConfig, MemoryManager, recommend_constrained_config,
};
pub use memory_monitor::{
MemoryConfig, MemoryMonitor, MemoryPressure, MemoryStats, RecoveryManager, RecoveryResult,
RecoveryStrategy,
};
pub use model::{Model, ModelKvOverride, ModelKvOverrideValue, ModelParams, Token};
pub use sampling::{
AlignedTokenData, AlignedTokenDataArray, LogitBias, Sampler, SamplerChain, SamplerChainParams,
SamplerParams, SamplerPerfData, TokenData, TokenDataArray,
};
pub use sampling_simd::{
has_avx2, has_avx512, has_neon, simd_max_f32, simd_softmax, simd_sum_f32, simd_top_k,
SimdCapabilities,
};
pub use session::Session;
pub use token::{GenerationBuffer, Token as TokenStruct, TokenBuffer, TokenId};
pub use vocab::Vocabulary;
#[cfg(feature = "async")]
pub use async_support::{AsyncConfig, AsyncContext, AsyncModel, ModelInfo};
pub use builder::{ContextBuilder, ModelBuilder, SamplerBuilder};
pub use config::{
ContextConfig, CpuOptimizations, GpuOptimizations, LoggingConfig, ModelConfig, MullamaConfig,
PerformanceConfig, SamplingConfig,
};
#[cfg(feature = "format-conversion")]
pub use format_conversion::{
AudioConversionResult, AudioConverter, AudioConverterConfig, ConversionConfig,
ImageConversionResult, ImageConverter, ImageConverterConfig,
};
#[cfg(feature = "late-interaction")]
pub use late_interaction::{
LateInteractionScorer, MultiVectorConfig, MultiVectorEmbedding, MultiVectorGenerator,
};
#[cfg(feature = "multimodal")]
pub use multimodal::{
AudioFeatures, AudioFormat, AudioInput, Bitmap, ChunkType, ImageInput, InputChunk, InputChunks,
MtmdContext, MtmdParams, MultimodalInput, MultimodalOutput, MultimodalProcessor, VideoInput,
};
#[cfg(feature = "parallel")]
pub use parallel::{BatchGenerationConfig, GenerationResult, ParallelProcessor, ThreadPoolConfig};
#[cfg(feature = "streaming")]
pub use streaming::{StreamConfig, TokenData as StreamTokenData, TokenStream};
#[cfg(feature = "streaming-audio")]
pub use streaming_audio::{
AudioChunk, AudioStream, AudioStreamConfig, DevicePreference, StreamingAudioProcessor,
StreamingMetrics,
};
#[cfg(feature = "tokio-runtime")]
pub use tokio_integration::{
ModelPool, MullamaRuntime, MullamaRuntimeBuilder, RuntimeMetrics, TaskManager,
};
#[cfg(feature = "web")]
pub use web::{
ApiMetrics, AppError, AppState, GenerateRequest, GenerateResponse, RouterBuilder,
TokenizeRequest, TokenizeResponse,
};
#[cfg(feature = "websockets")]
pub use websockets::{
AudioProcessor as WSAudioProcessor, ConnectionManager, ServerStats, WSMessage, WebSocketConfig,
WebSocketServer,
};
pub use control_vector::{ControlVector, ControlVectorManager};
pub use gpu_advanced::{AllocationStrategy, GpuDevice, GpuManager};
pub use grammar::{Grammar, GrammarRule};
pub use lora::{LoRAAdapter, LoRAManager};
pub use quantization::{
QuantizationEngine, QuantizationParams, QuantizationType as QuantizationKind,
};
pub use speculative::{SpeculativeConfig, SpeculativeDecoder};
pub use sys::{
ggml_numa_strategy, ggml_type, llama_attention_type, llama_ftype, llama_memory_t,
llama_model_kv_override_type, llama_pooling_type, llama_pos, llama_rope_scaling_type,
llama_rope_type, llama_seq_id, llama_split_mode, llama_token, llama_token_attr,
llama_token_type, llama_vocab_type, LLAMA_DEFAULT_SEED, LLAMA_TOKEN_NULL,
};
pub mod prelude {
pub use crate::{
Batch, Context, ContextBuilder, ContextParams, Model, ModelBuilder, ModelParams,
MullamaConfig, MullamaError, SamplerBuilder, SamplerChain, SamplerParams,
};
pub use crate::{
AllocationStrategy, ControlVector, ControlVectorManager, GpuDevice, GpuManager, Grammar,
GrammarRule, HardwarePreset, LoRAAdapter, LoRAManager, QuantizationEngine,
QuantizationKind, QuantizationParams, SpeculativeConfig, SpeculativeDecoder,
};
#[cfg(feature = "async")]
pub use crate::{AsyncContext, AsyncModel};
#[cfg(feature = "streaming")]
pub use crate::{StreamConfig, StreamTokenData, TokenStream};
#[cfg(feature = "web")]
pub use crate::{AppState, GenerateRequest, GenerateResponse, RouterBuilder};
#[cfg(feature = "tokio-runtime")]
pub use crate::{ModelPool, MullamaRuntime, TaskManager};
#[cfg(feature = "parallel")]
pub use crate::{BatchGenerationConfig, ParallelProcessor};
#[cfg(feature = "late-interaction")]
pub use crate::{
LateInteractionScorer, MultiVectorConfig, MultiVectorEmbedding, MultiVectorGenerator,
};
#[cfg(feature = "websockets")]
pub use crate::{WSMessage, WebSocketServer};
#[cfg(feature = "streaming-audio")]
pub use crate::{AudioChunk, AudioStreamConfig, StreamingAudioProcessor};
#[cfg(feature = "format-conversion")]
pub use crate::{AudioConverter, ImageConverter};
#[cfg(feature = "multimodal")]
pub use crate::{AudioInput, ImageInput, MultimodalInput, MultimodalProcessor};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_initialization() {
unsafe {
sys::llama_backend_init();
sys::llama_backend_free();
}
assert_eq!(2 + 2, 4);
}
#[test]
fn test_model_params_default() {
let params = model::ModelParams::default();
assert_eq!(params.n_gpu_layers, 0);
assert!(params.use_mmap);
assert!(!params.use_mlock);
}
#[test]
fn test_context_params_default() {
let params = context::ContextParams::default();
assert_eq!(params.n_ctx, 0);
assert!(params.n_batch > 0);
assert!(params.n_threads > 0);
}
#[test]
fn test_token_structure() {
let token = token::Token {
id: 1234,
text: "test".to_string(),
score: 0.5,
};
assert_eq!(token.id, 1234);
assert_eq!(token.text, "test");
assert_eq!(token.score, 0.5);
}
#[test]
fn test_batch_structure() {
let batch = batch::Batch::default();
assert!(batch.is_empty());
}
#[test]
fn test_session_structure() {
let session = session::Session { data: vec![] };
assert!(session.data.is_empty());
}
#[test]
fn test_sampling_structure() {
let _sampler = sampling::Sampler::new().expect("Failed to create sampler");
let params = sampling::SamplerParams::default();
assert_eq!(params.temperature, 0.8);
assert_eq!(params.top_p, 0.95);
assert_eq!(params.top_k, 40);
}
#[test]
fn test_embedding_structure() {
let embeddings = embedding::Embeddings::new(vec![0.1, 0.2, 0.3], 3);
assert_eq!(embeddings.len(), 1);
assert_eq!(embeddings.dimension, 3);
}
#[test]
fn test_memory_manager_structure() {
let memory_manager = memory::MemoryManager::new();
assert!(!memory_manager.is_valid());
}
#[test]
fn test_vocabulary_structure() {
let vocab = vocab::Vocabulary::new();
assert_eq!(vocab._placeholder, 0);
}
}