#![warn(missing_docs)]
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
mod language_model;
mod model;
mod raw;
mod source;
use kalosm_common::accelerated_device_if_available;
use kalosm_common::ModelLoadingProgress;
pub use kalosm_language_model;
use kalosm_language_model::ChatMarkers;
use raw::PhiCache;
pub use source::*;
pub mod prelude {
pub use crate::{Phi, PhiBuilder, PhiSource};
pub use kalosm_language_model::*;
}
use anyhow::Error as E;
use crate::raw::Config;
use crate::raw::MixFormerSequentialForCausalLM as QMixFormer;
use candle_core::Device;
use llm_samplers::prelude::Sampler;
use model::PhiModel;
use std::sync::Arc;
use std::sync::Mutex;
use tokenizers::Tokenizer;
enum Task {
Kill,
Infer {
settings: InferenceSettings,
sender: tokio::sync::mpsc::UnboundedSender<String>,
sampler: Arc<Mutex<dyn Sampler>>,
},
RunSync {
callback: SyncCallback,
},
}
type SyncCallback = Box<
dyn for<'a> FnOnce(
&'a mut PhiModel,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + 'a>>
+ Send,
>;
#[derive(Clone)]
pub struct Phi {
task_sender: tokio::sync::mpsc::UnboundedSender<Task>,
tokenizer: Arc<Tokenizer>,
chat_markers: Arc<Option<ChatMarkers>>,
}
impl Drop for Phi {
fn drop(&mut self) {
if std::sync::Arc::strong_count(&self.chat_markers) == 1 {
self.task_sender.send(Task::Kill).unwrap();
}
}
}
impl Phi {
pub fn builder() -> PhiBuilder {
PhiBuilder::default()
}
pub async fn v2() -> anyhow::Result<Phi> {
Phi::builder().with_source(PhiSource::v2()).build().await
}
pub async fn new_chat() -> anyhow::Result<Phi> {
Phi::builder()
.with_source(PhiSource::dolphin_phi_v2())
.build()
.await
}
#[allow(clippy::too_many_arguments)]
fn new(
model: QMixFormer,
tokenizer: Tokenizer,
device: Device,
cache: PhiCache,
chat_markers: Option<ChatMarkers>,
) -> Self {
let (task_sender, mut task_receiver) = tokio::sync::mpsc::unbounded_channel();
let arc_tokenizer = Arc::new(tokenizer);
std::thread::spawn({
let arc_tokenizer = arc_tokenizer.clone();
move || {
let mut inner = PhiModel::new(model, arc_tokenizer, device, cache);
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap()
.block_on(async move {
while let Some(task) = task_receiver.recv().await {
match task {
Task::Kill => break,
Task::Infer {
settings,
sender,
sampler,
} => {
if let Err(err) = inner._infer(settings, sampler, sender) {
tracing::error!("Error in PhiModel::_infer: {}", err);
}
}
Task::RunSync { callback } => {
callback(&mut inner).await;
}
}
}
})
}
});
Self {
task_sender,
tokenizer: arc_tokenizer,
chat_markers: chat_markers.into(),
}
}
pub(crate) fn get_tokenizer(&self) -> Arc<Tokenizer> {
self.tokenizer.clone()
}
fn run(
&self,
settings: InferenceSettings,
sampler: Arc<Mutex<dyn Sampler>>,
) -> anyhow::Result<tokio::sync::mpsc::UnboundedReceiver<String>> {
let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
self.task_sender
.send(Task::Infer {
settings,
sender,
sampler,
})
.unwrap();
Ok(receiver)
}
}
#[derive(Default)]
pub struct PhiBuilder {
source: source::PhiSource,
}
impl PhiBuilder {
pub fn with_source(mut self, source: source::PhiSource) -> Self {
self.source = source;
self
}
pub async fn build(self) -> anyhow::Result<Phi> {
self.build_with_loading_handler(ModelLoadingProgress::multi_bar_loading_indicator())
.await
}
pub async fn build_with_loading_handler(
self,
mut progress_handler: impl FnMut(ModelLoadingProgress) + Send + Sync + 'static,
) -> anyhow::Result<Phi> {
let PhiSource {
tokenizer, model, ..
} = self.source;
let tokenizer_source = format!("Tokenizer ({})", tokenizer);
let mut create_progress = ModelLoadingProgress::downloading_progress(tokenizer_source);
let tokenizer_filename = tokenizer
.download(|progress| progress_handler(create_progress(progress)))
.await?;
let model_source = format!("Model ({})", model);
let mut create_progress = ModelLoadingProgress::downloading_progress(model_source);
let filename = model
.download(|progress| progress_handler(create_progress(progress)))
.await?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let config = self.source.phi_config;
let device = accelerated_device_if_available()?;
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = if self.source.phi2 {
QMixFormer::new_v2(&config, vb)?
} else {
QMixFormer::new(&config, vb)?
};
let cache = PhiCache::new(&config);
Ok(Phi::new(
model,
tokenizer,
device,
cache,
self.source.chat_markers,
))
}
}
#[derive(Debug)]
pub(crate) struct InferenceSettings {
prompt: String,
sample_len: usize,
stop_on: Option<String>,
}
impl InferenceSettings {
pub fn new(prompt: impl Into<String>) -> Self {
Self {
prompt: prompt.into(),
sample_len: 100,
stop_on: None,
}
}
pub fn with_sample_len(mut self, sample_len: usize) -> Self {
self.sample_len = sample_len;
self
}
pub fn with_stop_on(mut self, stop_on: impl Into<Option<String>>) -> Self {
self.stop_on = stop_on.into();
self
}
}