pub mod chat_completion;
pub mod completion;
pub mod embedding;
pub mod hf_tokenizer;
pub mod infill;
pub mod rerank;
pub mod tokenizer;
#[cfg(feature = "hf-tokenizer")]
#[cfg_attr(docsrs, doc(cfg(feature = "hf-tokenizer")))]
pub use self::hf_tokenizer::HfTokenizer;
pub use self::tokenizer::{FimTokens, LlamaTokenizer, Tokenizer};
use std::path::{Path, PathBuf};
use crate::backend::LlamaBackend;
use crate::context::{LlamaContext, LlamaContextParams};
use crate::error::Result;
use crate::model::LlamaModel;
use crate::model::params::LlamaModelParams;
pub use self::chat_completion::{create_chat_completion, ChatMessage};
pub use self::completion::{create_completion, Completion, StopReason};
#[derive(Debug)]
pub struct Llama {
_backend: LlamaBackend,
model: LlamaModel,
context: LlamaContext<'static>,
_not_send_sync: std::marker::PhantomData<*mut ()>,
}
impl Llama {
pub fn load(params: LlamaParams) -> Result<Self> {
let backend = LlamaBackend::init()?;
let model = LlamaModel::load_from_file(&backend, ¶ms.model_path, ¶ms.model)?;
let ctx = model.new_context(&backend, params.context.clone())?;
let ctx: LlamaContext<'static> =
unsafe { std::mem::transmute::<LlamaContext<'_>, LlamaContext<'static>>(ctx) };
Ok(Self {
_backend: backend,
model,
context: ctx,
_not_send_sync: std::marker::PhantomData,
})
}
#[must_use]
pub const fn model(&self) -> &LlamaModel {
&self.model
}
#[must_use]
pub const fn context(&mut self) -> &mut LlamaContext<'static> {
&mut self.context
}
pub fn create_completion(&mut self, prompt: &str, max_tokens: usize) -> Result<Completion> {
create_completion(self, prompt, max_tokens)
}
pub fn create_chat_completion(
&mut self,
messages: &[ChatMessage],
max_tokens: usize,
) -> Result<ChatMessage> {
create_chat_completion(self, messages, max_tokens)
}
}
#[derive(Debug, Clone)]
pub struct LlamaParams {
pub model_path: PathBuf,
pub model: LlamaModelParams,
pub context: LlamaContextParams,
}
impl LlamaParams {
#[must_use]
pub fn new(model_path: impl AsRef<Path>) -> Self {
Self {
model_path: model_path.as_ref().to_path_buf(),
model: LlamaModelParams::default(),
context: LlamaContextParams::default(),
}
}
#[must_use]
pub fn with_model_path(mut self, p: impl AsRef<Path>) -> Self {
self.model_path = p.as_ref().to_path_buf();
self
}
#[must_use]
pub fn with_n_gpu_layers(mut self, n: i32) -> Self {
self.model = self.model.with_n_gpu_layers(n);
self
}
#[must_use]
pub fn with_use_mmap(mut self, yes: bool) -> Self {
self.model = self.model.with_use_mmap(yes);
self
}
#[must_use]
pub fn with_n_ctx(mut self, n: u32) -> Self {
self.context = self.context.with_n_ctx(n);
self
}
#[must_use]
pub fn with_embeddings(mut self, yes: bool) -> Self {
self.context = self.context.with_embeddings(yes);
self
}
#[must_use]
pub fn with_n_threads(mut self, n: i32) -> Self {
self.context = self.context.with_n_threads(n);
self
}
#[must_use]
pub fn with_n_threads_batch(mut self, n: i32) -> Self {
self.context = self.context.with_n_threads_batch(n);
self
}
}
impl Default for LlamaParams {
fn default() -> Self {
Self {
model_path: PathBuf::new(),
model: LlamaModelParams::default(),
context: LlamaContextParams::default(),
}
}
}
#[doc(inline)]
pub use StopReason as _StopReasonShim;