#[cfg(feature = "candle")]
mod candle_backend;
#[cfg(feature = "candle")]
pub use candle_backend::*;
mod coreml_backend;
pub use coreml_backend::{AneCapabilities, ComputeUnits, CoreMLBackend};
#[cfg(feature = "hybrid-ane")]
mod hybrid_pipeline;
#[cfg(feature = "hybrid-ane")]
pub use hybrid_pipeline::{
AcceleratorMetrics, AcceleratorType, AneStrategy, DataFormat, HybridPipeline,
HybridPipelineConfig, HybridTensor, OperationType, RoutingDecision,
};
pub mod gemma2;
pub mod phi3;
pub use gemma2::{
logit_soft_cap, Gemma2Attention, Gemma2Config, Gemma2DecoderLayer, Gemma2MLP, Gemma2Model,
ATTENTION_SOFTCAP, FINAL_LOGIT_SOFTCAP,
};
pub use phi3::{Phi3Attention, Phi3Config, Phi3DecoderLayer, Phi3MLP, Phi3Model};
mod mistral_backend;
pub use mistral_backend::{
IsqConfig, IsqMethod, MistralBackend, MistralBackendConfig, MistralTokenizer,
PagedAttentionConfigExt, XLoraConfig, XLoraManager, XLoraManagerStats, XLoraMixingMode,
};
use crate::error::{Result, RuvLLMError};
use serde::{Deserialize, Serialize};
use std::sync::{mpsc, Arc};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ModelArchitecture {
Mistral,
Llama,
Phi,
Phi3,
Qwen,
Gemma,
Gemma2,
}
impl Default for ModelArchitecture {
fn default() -> Self {
Self::Llama
}
}
impl ModelArchitecture {
pub fn config_name(&self) -> &'static str {
match self {
Self::Mistral => "mistral",
Self::Llama => "llama",
Self::Phi => "phi",
Self::Phi3 => "phi3",
Self::Qwen => "qwen2",
Self::Gemma => "gemma",
Self::Gemma2 => "gemma2",
}
}
pub fn detect_from_model_id(model_id: &str) -> Option<Self> {
let lower = model_id.to_lowercase();
if lower.contains("phi-3") || lower.contains("phi3") {
Some(Self::Phi3)
} else if lower.contains("phi") {
Some(Self::Phi)
} else if lower.contains("gemma-2") || lower.contains("gemma2") {
Some(Self::Gemma2)
} else if lower.contains("gemma") {
Some(Self::Gemma)
} else if lower.contains("mistral") || lower.contains("codestral") {
Some(Self::Mistral)
} else if lower.contains("llama") {
Some(Self::Llama)
} else if lower.contains("qwen") {
Some(Self::Qwen)
} else {
None
}
}
pub fn uses_gqa(&self) -> bool {
matches!(
self,
Self::Mistral | Self::Llama | Self::Gemma | Self::Gemma2 | Self::Qwen
)
}
pub fn uses_sliding_window(&self) -> bool {
matches!(self, Self::Mistral | Self::Phi3 | Self::Gemma2)
}
pub fn default_sliding_window(&self) -> Option<usize> {
match self {
Self::Mistral => Some(4096),
Self::Phi3 => Some(2048),
Self::Gemma2 => Some(4096), _ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Quantization {
None,
F16,
Bf16,
Q8,
Q4K,
Q4,
Q2K,
}
impl Default for Quantization {
fn default() -> Self {
Self::Q4K
}
}
impl Quantization {
pub fn bytes_per_weight(&self) -> f32 {
match self {
Self::None => 4.0,
Self::F16 | Self::Bf16 => 2.0,
Self::Q8 => 1.0,
Self::Q4K | Self::Q4 => 0.5,
Self::Q2K => 0.25,
}
}
pub fn is_gguf(&self) -> bool {
matches!(self, Self::Q8 | Self::Q4K | Self::Q4 | Self::Q2K)
}
}
#[derive(Debug, Clone)]
pub struct ModelConfig {
pub architecture: ModelArchitecture,
pub quantization: Option<Quantization>,
pub use_flash_attention: bool,
pub max_sequence_length: usize,
pub num_kv_heads: Option<usize>,
pub hidden_size: Option<usize>,
pub num_layers: Option<usize>,
pub vocab_size: Option<usize>,
pub rope_theta: Option<f64>,
pub sliding_window: Option<usize>,
pub device: DeviceType,
pub dtype: DType,
}
impl Default for ModelConfig {
fn default() -> Self {
Self {
architecture: ModelArchitecture::default(),
quantization: Some(Quantization::Q4K),
use_flash_attention: true,
max_sequence_length: 4096,
num_kv_heads: None,
hidden_size: None,
num_layers: None,
vocab_size: None,
rope_theta: None,
sliding_window: None,
device: DeviceType::default(),
dtype: DType::default(),
}
}
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Default, serde::Serialize, serde::Deserialize,
)]
pub enum DeviceType {
Cpu,
#[default]
Metal,
Cuda(usize),
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Default, serde::Serialize, serde::Deserialize,
)]
pub enum DType {
F32,
#[default]
F16,
Bf16,
}
#[derive(Debug, Clone)]
pub struct GenerateParams {
pub max_tokens: usize,
pub temperature: f32,
pub top_p: f32,
pub top_k: usize,
pub repetition_penalty: f32,
pub frequency_penalty: f32,
pub presence_penalty: f32,
pub stop_sequences: Vec<String>,
pub seed: Option<u64>,
}
impl Default for GenerateParams {
fn default() -> Self {
Self {
max_tokens: 256,
temperature: 0.7,
top_p: 0.9,
top_k: 40,
repetition_penalty: 1.1,
frequency_penalty: 0.0,
presence_penalty: 0.0,
stop_sequences: Vec::new(),
seed: None,
}
}
}
impl GenerateParams {
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = top_p;
self
}
pub fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = top_k;
self
}
pub fn with_repetition_penalty(mut self, penalty: f32) -> Self {
self.repetition_penalty = penalty;
self
}
pub fn with_stop_sequence(mut self, stop: impl Into<String>) -> Self {
self.stop_sequences.push(stop.into());
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
}
#[derive(Debug, Clone)]
pub struct GeneratedToken {
pub id: u32,
pub text: String,
pub logprob: Option<f32>,
pub is_special: bool,
}
#[derive(Debug, Clone)]
pub enum StreamEvent {
Token(GeneratedToken),
Done {
total_tokens: usize,
duration_ms: u64,
tokens_per_second: f64,
},
Error(String),
}
pub struct TokenStream {
receiver: mpsc::Receiver<StreamEvent>,
finished: bool,
start_time: Instant,
token_count: usize,
}
impl TokenStream {
pub fn new(receiver: mpsc::Receiver<StreamEvent>) -> Self {
Self {
receiver,
finished: false,
start_time: Instant::now(),
token_count: 0,
}
}
pub fn channel() -> (mpsc::Sender<StreamEvent>, Self) {
let (tx, rx) = mpsc::channel();
(tx, Self::new(rx))
}
pub fn is_finished(&self) -> bool {
self.finished
}
pub fn tokens_received(&self) -> usize {
self.token_count
}
pub fn elapsed(&self) -> Duration {
self.start_time.elapsed()
}
pub fn tokens_per_second(&self) -> f64 {
let elapsed = self.elapsed().as_secs_f64();
if elapsed > 0.0 {
self.token_count as f64 / elapsed
} else {
0.0
}
}
pub fn try_next(&mut self) -> Option<Result<StreamEvent>> {
if self.finished {
return None;
}
match self.receiver.try_recv() {
Ok(event) => {
match &event {
StreamEvent::Token(_) => self.token_count += 1,
StreamEvent::Done { .. } => self.finished = true,
StreamEvent::Error(_) => self.finished = true,
}
Some(Ok(event))
}
Err(mpsc::TryRecvError::Empty) => None,
Err(mpsc::TryRecvError::Disconnected) => {
self.finished = true;
None
}
}
}
pub fn recv_timeout(&mut self, timeout: Duration) -> Option<Result<StreamEvent>> {
if self.finished {
return None;
}
match self.receiver.recv_timeout(timeout) {
Ok(event) => {
match &event {
StreamEvent::Token(_) => self.token_count += 1,
StreamEvent::Done { .. } => self.finished = true,
StreamEvent::Error(_) => self.finished = true,
}
Some(Ok(event))
}
Err(mpsc::RecvTimeoutError::Timeout) => None,
Err(mpsc::RecvTimeoutError::Disconnected) => {
self.finished = true;
None
}
}
}
}
impl Iterator for TokenStream {
type Item = Result<StreamEvent>;
fn next(&mut self) -> Option<Self::Item> {
if self.finished {
return None;
}
match self.receiver.recv() {
Ok(event) => {
match &event {
StreamEvent::Token(_) => self.token_count += 1,
StreamEvent::Done { .. } => self.finished = true,
StreamEvent::Error(_) => self.finished = true,
}
Some(Ok(event))
}
Err(_) => {
self.finished = true;
None
}
}
}
}
pub trait LlmBackend: Send + Sync {
fn load_model(&mut self, model_id: &str, config: ModelConfig) -> Result<()>;
fn generate(&self, prompt: &str, params: GenerateParams) -> Result<String>;
fn generate_stream(
&self,
prompt: &str,
params: GenerateParams,
) -> Result<Box<dyn Iterator<Item = Result<GeneratedToken>> + Send + '_>>;
fn generate_stream_v2(&self, prompt: &str, params: GenerateParams) -> Result<TokenStream>;
fn get_embeddings(&self, text: &str) -> Result<Vec<f32>>;
fn tokenizer(&self) -> Option<&dyn Tokenizer>;
fn is_model_loaded(&self) -> bool;
fn model_info(&self) -> Option<ModelInfo>;
fn unload_model(&mut self);
}
pub trait Tokenizer: Send + Sync {
fn encode(&self, text: &str) -> Result<Vec<u32>>;
fn decode(&self, tokens: &[u32]) -> Result<String>;
fn vocab_size(&self) -> usize;
fn special_tokens(&self) -> SpecialTokens;
}
#[derive(Debug, Clone, Default)]
pub struct SpecialTokens {
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
pub pad_token_id: Option<u32>,
pub unk_token_id: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub name: String,
pub architecture: ModelArchitecture,
pub num_parameters: usize,
pub vocab_size: usize,
pub hidden_size: usize,
pub num_layers: usize,
pub max_context_length: usize,
pub quantization: Option<Quantization>,
pub memory_usage: usize,
}
pub struct NoopBackend;
impl LlmBackend for NoopBackend {
fn load_model(&mut self, _model_id: &str, _config: ModelConfig) -> Result<()> {
Err(RuvLLMError::Config(
"No inference backend enabled. Enable 'candle' feature.".to_string(),
))
}
fn generate(&self, _prompt: &str, _params: GenerateParams) -> Result<String> {
Err(RuvLLMError::Config(
"No inference backend enabled.".to_string(),
))
}
fn generate_stream(
&self,
_prompt: &str,
_params: GenerateParams,
) -> Result<Box<dyn Iterator<Item = Result<GeneratedToken>> + Send + '_>> {
Err(RuvLLMError::Config(
"No inference backend enabled.".to_string(),
))
}
fn generate_stream_v2(&self, _prompt: &str, _params: GenerateParams) -> Result<TokenStream> {
Err(RuvLLMError::Config(
"No inference backend enabled.".to_string(),
))
}
fn get_embeddings(&self, _text: &str) -> Result<Vec<f32>> {
Err(RuvLLMError::Config(
"No inference backend enabled.".to_string(),
))
}
fn tokenizer(&self) -> Option<&dyn Tokenizer> {
None
}
fn is_model_loaded(&self) -> bool {
false
}
fn model_info(&self) -> Option<ModelInfo> {
None
}
fn unload_model(&mut self) {}
}
pub fn create_backend() -> Box<dyn LlmBackend> {
#[cfg(feature = "candle")]
{
Box::new(CandleBackend::new().unwrap_or_else(|_| CandleBackend::default()))
}
#[cfg(not(feature = "candle"))]
{
Box::new(NoopBackend)
}
}
pub type SharedBackend = Arc<dyn LlmBackend>;
#[cfg(feature = "async-runtime")]
pub mod async_stream {
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct AsyncTokenStream {
inner: TokenStream,
}
impl AsyncTokenStream {
pub fn new(inner: TokenStream) -> Self {
Self { inner }
}
pub fn is_finished(&self) -> bool {
self.inner.is_finished()
}
pub fn tokens_received(&self) -> usize {
self.inner.tokens_received()
}
pub fn tokens_per_second(&self) -> f64 {
self.inner.tokens_per_second()
}
}
impl futures_core::Stream for AsyncTokenStream {
type Item = Result<StreamEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.inner.try_next() {
Some(result) => Poll::Ready(Some(result)),
None => {
if self.inner.is_finished() {
Poll::Ready(None)
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
}
}
#[async_trait::async_trait]
pub trait LlmBackendAsync: Send + Sync {
async fn generate_stream_async(
&self,
prompt: &str,
params: GenerateParams,
) -> Result<AsyncTokenStream>;
}
#[async_trait::async_trait]
impl<T: LlmBackend + ?Sized> LlmBackendAsync for T {
async fn generate_stream_async(
&self,
prompt: &str,
params: GenerateParams,
) -> Result<AsyncTokenStream> {
let stream = self.generate_stream_v2(prompt, params)?;
Ok(AsyncTokenStream::new(stream))
}
}
}
#[cfg(feature = "async-runtime")]
pub use async_stream::{AsyncTokenStream, LlmBackendAsync};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantization_bytes() {
assert_eq!(Quantization::None.bytes_per_weight(), 4.0);
assert_eq!(Quantization::F16.bytes_per_weight(), 2.0);
assert_eq!(Quantization::Q4K.bytes_per_weight(), 0.5);
}
#[test]
fn test_generate_params_builder() {
let params = GenerateParams::default()
.with_max_tokens(512)
.with_temperature(0.5)
.with_top_p(0.95)
.with_seed(42);
assert_eq!(params.max_tokens, 512);
assert_eq!(params.temperature, 0.5);
assert_eq!(params.top_p, 0.95);
assert_eq!(params.seed, Some(42));
}
#[test]
fn test_model_architecture() {
assert_eq!(ModelArchitecture::Mistral.config_name(), "mistral");
assert_eq!(ModelArchitecture::Llama.config_name(), "llama");
}
}