use crate::embedding::{Embedding, VectorSpace};
use crate::structured::generate_structured;
use crate::TokenOutputStream;
use crate::UnknownVectorSpace;
use futures_util::{Stream, StreamExt};
use kalosm_sample::{Parser, Tokenizer};
use llm_samplers::configure::SamplerChainBuilder;
use llm_samplers::prelude::*;
use llm_samplers::types::Logits;
use std::any::Any;
use std::fmt::Display;
use std::future::IntoFuture;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::Mutex;
use url::Url;
#[async_trait::async_trait]
pub trait Embedder<S: VectorSpace + Send + Sync + 'static>: Send + Sync + 'static {
async fn embed(&mut self, input: &str) -> anyhow::Result<Embedding<S>>;
async fn embed_batch(&mut self, inputs: &[&str]) -> anyhow::Result<Vec<Embedding<S>>> {
let mut embeddings = Vec::with_capacity(inputs.len());
for input in inputs {
embeddings.push(self.embed(input).await?);
}
Ok(embeddings)
}
fn into_any_embedder(self) -> DynEmbedder
where
Self: Sized,
{
Box::new(AnyEmbedder::<S, Self>(self, PhantomData))
}
}
pub type DynEmbedder = Box<dyn Embedder<UnknownVectorSpace>>;
struct AnyEmbedder<S: VectorSpace + Send + Sync + 'static, E: Embedder<S> + Send + Sync + 'static>(
E,
PhantomData<S>,
);
#[async_trait::async_trait]
impl<S: VectorSpace + Send + Sync + 'static, E: Embedder<S> + Send + Sync + 'static>
Embedder<UnknownVectorSpace> for AnyEmbedder<S, E>
{
async fn embed(&mut self, input: &str) -> anyhow::Result<Embedding<UnknownVectorSpace>> {
self.0.embed(input).await.map(|e| e.cast())
}
async fn embed_batch(
&mut self,
inputs: &[&str],
) -> anyhow::Result<Vec<Embedding<UnknownVectorSpace>>> {
self.0
.embed_batch(inputs)
.await
.map(|e| e.into_iter().map(|e| e.cast()).collect())
}
}
#[async_trait::async_trait]
pub trait CreateModel {
async fn start() -> Self;
fn requires_download() -> bool {
false
}
}
pub struct StreamTextBuilder<'a, M: Model> {
self_: &'a M,
prompt: &'a str,
parameters: GenerationParameters,
#[allow(clippy::type_complexity)]
future: fn(
&'a M,
&'a str,
GenerationParameters,
) -> Pin<
Box<dyn std::future::Future<Output = anyhow::Result<M::TextStream>> + Send + 'a>,
>,
}
impl<'a, M: Model> StreamTextBuilder<'a, M> {
#[allow(clippy::type_complexity)]
pub fn new(
prompt: &'a str,
self_: &'a M,
future: fn(
&'a M,
&'a str,
GenerationParameters,
) -> Pin<
Box<dyn std::future::Future<Output = anyhow::Result<M::TextStream>> + Send + 'a>,
>,
) -> Self {
Self {
self_,
prompt,
parameters: GenerationParameters::default(),
future,
}
}
pub fn with_generation_parameters(mut self, parameters: GenerationParameters) -> Self {
self.parameters = parameters;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.parameters.temperature = temperature;
self
}
pub fn with_mu(mut self, mu: f32) -> Self {
self.parameters.mu = mu;
self
}
pub fn with_tau(mut self, tau: f32) -> Self {
self.parameters.tau = tau;
self
}
pub fn with_eta(mut self, eta: f32) -> Self {
self.parameters.eta = eta;
self
}
pub fn with_repetition_penalty(mut self, repetition_penalty: f32) -> Self {
self.parameters.repetition_penalty = repetition_penalty;
self
}
pub fn with_repetition_penalty_range(mut self, repetition_penalty_range: u32) -> Self {
self.parameters.repetition_penalty_range = repetition_penalty_range;
self
}
pub fn with_max_length(mut self, max_length: u32) -> Self {
self.parameters.max_length = max_length;
self
}
pub fn with_stop_on(mut self, stop_on: impl Into<Option<String>>) -> Self {
self.parameters.stop_on = stop_on.into();
self
}
}
impl<'a, M: Model> IntoFuture for StreamTextBuilder<'a, M> {
type Output = anyhow::Result<M::TextStream>;
type IntoFuture = Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
fn into_future(self) -> Self::IntoFuture {
let Self {
self_,
prompt,
parameters,
future,
} = self;
future(self_, prompt, parameters)
}
}
#[allow(clippy::type_complexity)]
pub struct GenerateTextBuilder<'a, M: Model> {
self_: &'a M,
prompt: &'a str,
parameters: GenerationParameters,
future: fn(
&'a M,
&'a str,
GenerationParameters,
)
-> Pin<Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'a>>,
}
impl<'a, M: Model> GenerateTextBuilder<'a, M> {
#[allow(clippy::type_complexity)]
pub fn new(
prompt: &'a str,
self_: &'a M,
future: fn(
&'a M,
&'a str,
GenerationParameters,
) -> Pin<
Box<dyn std::future::Future<Output = anyhow::Result<String>> + Send + 'a>,
>,
) -> Self {
Self {
self_,
prompt,
parameters: GenerationParameters::default(),
future,
}
}
pub fn with_generation_parameters(mut self, parameters: GenerationParameters) -> Self {
self.parameters = parameters;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.parameters.temperature = temperature;
self
}
pub fn with_mu(mut self, mu: f32) -> Self {
self.parameters.mu = mu;
self
}
pub fn with_tau(mut self, tau: f32) -> Self {
self.parameters.tau = tau;
self
}
pub fn with_eta(mut self, eta: f32) -> Self {
self.parameters.eta = eta;
self
}
pub fn with_repetition_penalty(mut self, repetition_penalty: f32) -> Self {
self.parameters.repetition_penalty = repetition_penalty;
self
}
pub fn with_repetition_penalty_range(mut self, repetition_penalty_range: u32) -> Self {
self.parameters.repetition_penalty_range = repetition_penalty_range;
self
}
pub fn with_max_length(mut self, max_length: u32) -> Self {
self.parameters.max_length = max_length;
self
}
pub fn with_stop_on(mut self, stop_on: impl Into<Option<String>>) -> Self {
self.parameters.stop_on = stop_on.into();
self
}
}
impl<'a, M: Model> IntoFuture for GenerateTextBuilder<'a, M> {
type Output = anyhow::Result<String>;
type IntoFuture = Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
fn into_future(self) -> Self::IntoFuture {
let Self {
self_,
prompt,
parameters,
future,
} = self;
future(self_, prompt, parameters)
}
}
#[async_trait::async_trait]
pub trait ModelExt: Model + Send + Sync + 'static {
fn generate_text<'a>(&'a self, prompt: &'a str) -> GenerateTextBuilder<'a, Self>
where
Self: Sized,
{
GenerateTextBuilder::new(prompt, self, |self_, prompt, generation_parameters| {
Box::pin(async move {
self_
.generate_text_inner(prompt, generation_parameters)
.await
})
})
}
fn stream_text<'a>(&'a self, prompt: &'a str) -> StreamTextBuilder<'a, Self>
where
Self: Sized,
{
StreamTextBuilder::new(prompt, self, |self_, prompt, generation_parameters| {
Box::pin(async move { self_.stream_text_inner(prompt, generation_parameters).await })
})
}
fn run_sync(
&self,
f: impl for<'a> FnOnce(
&'a mut Self::SyncModel,
) -> Pin<Box<dyn std::future::Future<Output = ()> + 'a>>
+ Send
+ 'static,
) -> anyhow::Result<()> {
self.run_sync_raw(Box::new(f))
}
async fn stream_structured_text<P>(
&self,
prompt: &str,
parser: P,
) -> anyhow::Result<StructureParserResult<Self::TextStream, P::Output>>
where
Self::TextStream: From<tokio::sync::mpsc::UnboundedReceiver<String>>,
P: kalosm_sample::CreateParserState + Parser + Send + 'static,
P::PartialState: Send + 'static,
P::Output: Send + 'static,
{
let sampler = Arc::new(Mutex::new(GenerationParameters::default().sampler()));
let parser_state = parser.create_parser_state();
self.stream_structured_text_with_sampler(prompt, parser, parser_state, sampler)
.await
}
async fn stream_structured_text_with_sampler<P>(
&self,
prompt: &str,
parser: P,
parser_state: P::PartialState,
sampler: Arc<Mutex<dyn Sampler>>,
) -> anyhow::Result<StructureParserResult<Self::TextStream, P::Output>>
where
Self::TextStream: From<tokio::sync::mpsc::UnboundedReceiver<String>>,
P: Parser + Send + 'static,
P::PartialState: Send + 'static,
P::Output: Send + 'static,
{
let (sender, receiver) = tokio::sync::mpsc::unbounded_channel();
let (result_sender, result_receiver) = tokio::sync::oneshot::channel();
let prompt = prompt.to_string();
self.run_sync(move |llm: &mut Self::SyncModel| {
let mut session = llm.new_session().unwrap();
Box::pin(async move {
let result = llm.generate_structured(
&mut session,
prompt,
parser,
parser_state,
sampler,
|token| Ok(sender.send(token)?),
);
match result_sender.send(result) {
Ok(()) => {}
Err(Ok(_)) => {
log::error!("Error generating structured text: cancelled");
}
Err(Err(err)) => {
log::error!("Error generating structured text: {:?}", err);
}
}
})
})?;
Ok(StructureParserResult::new(
Self::TextStream::from(receiver),
result_receiver,
))
}
}
pub struct StructureParserResult<S: Stream<Item = String> + Send + Unpin + 'static, O> {
stream: S,
result: tokio::sync::oneshot::Receiver<anyhow::Result<O>>,
}
impl<S: Stream<Item = String> + Send + Unpin + 'static, O> Deref for StructureParserResult<S, O> {
type Target = S;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl<S: Stream<Item = String> + Send + Unpin + 'static, O> DerefMut
for StructureParserResult<S, O>
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.stream
}
}
impl<S: Stream<Item = String> + Send + Unpin + 'static, O> StructureParserResult<S, O> {
fn new(stream: S, result: tokio::sync::oneshot::Receiver<anyhow::Result<O>>) -> Self {
Self { stream, result }
}
pub async fn result(self) -> anyhow::Result<O> {
self.result.await.unwrap()
}
pub fn split(self) -> (S, tokio::sync::oneshot::Receiver<anyhow::Result<O>>) {
(self.stream, self.result)
}
}
impl<M: Model + Send + Sync + 'static> ModelExt for M {}
pub trait SyncModel {
type Session: Session;
fn new_session(&self) -> anyhow::Result<Self::Session>;
fn feed_text(&self, session: &mut Self::Session, prompt: &str) -> anyhow::Result<Logits>;
fn feed_tokens(&self, session: &mut Self::Session, tokens: &[u32]) -> anyhow::Result<Logits>;
fn stop_token(&self) -> anyhow::Result<u32>;
fn tokenizer(&self) -> Arc<dyn Tokenizer + Send + Sync>;
}
pub trait Session {
fn save_to(&self, _path: impl AsRef<Path>) -> anyhow::Result<()> {
Err(anyhow::Error::msg("Not implemented"))
}
fn load_from(_path: impl AsRef<Path>) -> anyhow::Result<Self>
where
Self: std::marker::Sized,
{
Err(anyhow::Error::msg("Not implemented"))
}
}
impl Session for () {
fn save_to(&self, _path: impl AsRef<Path>) -> anyhow::Result<()> {
Ok(())
}
fn load_from(_path: impl AsRef<Path>) -> anyhow::Result<()> {
Ok(())
}
}
pub trait SyncModelExt: SyncModel {
fn generate_structured<P: Parser>(
&self,
session: &mut Self::Session,
prompt: impl Display,
parser: P,
parser_state: P::PartialState,
sampler: Arc<Mutex<dyn Sampler>>,
on_token: impl FnMut(String) -> anyhow::Result<()>,
) -> anyhow::Result<P::Output> {
generate_structured(
prompt,
self,
session,
&self.tokenizer(),
parser,
parser_state,
sampler,
on_token,
)
}
#[allow(clippy::too_many_arguments)]
fn stream_text_with_sampler(
&self,
session: &mut Self::Session,
prompt: &str,
max_tokens: Option<u32>,
stop_on: Option<&str>,
mut sampler: Arc<Mutex<dyn Sampler>>,
mut on_token: impl FnMut(String) -> anyhow::Result<ModelFeedback>,
) -> anyhow::Result<()> {
let tokens = self.tokenizer().encode(prompt)?;
let mut text_stream = TokenOutputStream::new(self.tokenizer(), tokens.clone());
let mut logits = self.feed_tokens(session, &tokens)?;
let mut tokens_generated = 0;
let mut text_matching_buffer = String::new();
loop {
let new_token = text_stream.sample_token(&mut sampler, logits, stop_on)?;
if let Some(mut new_text) = text_stream.next_token(new_token)? {
if let Some(stop_on) = stop_on {
text_matching_buffer.push_str(&new_text);
if text_matching_buffer.contains(stop_on) {
if new_text.len() > stop_on.len() {
new_text = new_text.strip_suffix(stop_on).unwrap().to_string();
on_token(new_text)?;
}
break;
}
if text_matching_buffer.len() > stop_on.len() {
let byte_idx = text_matching_buffer.len() - stop_on.len();
if text_matching_buffer.is_char_boundary(byte_idx) {
text_matching_buffer =
text_matching_buffer.split_at(byte_idx).1.to_string();
}
}
}
if let ModelFeedback::Stop = on_token(new_text)? {
break;
}
}
tokens_generated += 1;
if let Some(max_tokens) = max_tokens {
if tokens_generated >= max_tokens {
break;
}
}
logits = self.feed_tokens(session, &[new_token])?;
}
Ok(())
}
}
pub enum ModelFeedback {
Continue,
Stop,
}
impl<M: SyncModel> SyncModelExt for M {}
pub struct SyncModelNotSupported;
impl SyncModel for SyncModelNotSupported {
type Session = ();
fn new_session(&self) -> anyhow::Result<Self::Session> {
Err(anyhow::Error::msg("Not implemented"))
}
fn feed_text(&self, _session: &mut (), _prompt: &str) -> anyhow::Result<Logits> {
Err(anyhow::Error::msg("Not implemented"))
}
fn feed_tokens(&self, _session: &mut (), _tokens: &[u32]) -> anyhow::Result<Logits> {
Err(anyhow::Error::msg("Not implemented"))
}
fn stop_token(&self) -> anyhow::Result<u32> {
Err(anyhow::Error::msg("Not implemented"))
}
fn tokenizer(&self) -> Arc<dyn Tokenizer + Send + Sync> {
unimplemented!()
}
}
#[async_trait::async_trait]
pub trait Model: Send + Sync + 'static {
type TextStream: Stream<Item = String> + Send + Sync + Unpin + 'static;
fn tokenizer(&self) -> Arc<dyn Tokenizer + Send + Sync>;
type SyncModel: SyncModel;
#[allow(clippy::type_complexity)]
fn run_sync_raw(
&self,
_f: Box<
dyn for<'a> FnOnce(
&'a mut Self::SyncModel,
)
-> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + 'a>>
+ Send,
>,
) -> anyhow::Result<()> {
Err(anyhow::Error::msg("Not implemented"))
}
async fn generate_text_with_sampler(
&self,
prompt: &str,
max_tokens: Option<u32>,
stop_on: Option<&str>,
sampler: Arc<Mutex<dyn Sampler>>,
) -> anyhow::Result<String> {
let mut text = String::new();
let mut stream = self
.stream_text_with_sampler(prompt, max_tokens, stop_on, sampler)
.await?;
while let Some(new) = stream.next().await {
text.push_str(&new);
}
Ok(text)
}
async fn generate_text_inner(
&self,
prompt: &str,
parameters: GenerationParameters,
) -> anyhow::Result<String> {
let mut text = String::new();
let mut stream = self.stream_text_inner(prompt, parameters).await?;
while let Some(new) = stream.next().await {
text.push_str(&new);
}
Ok(text)
}
async fn stream_text_with_sampler(
&self,
_prompt: &str,
_max_tokens: Option<u32>,
_stop_on: Option<&str>,
_sampler: Arc<Mutex<dyn Sampler>>,
) -> anyhow::Result<Self::TextStream> {
Err(anyhow::Error::msg("Not implemented"))
}
async fn stream_text_inner(
&self,
prompt: &str,
parameters: GenerationParameters,
) -> anyhow::Result<Self::TextStream>;
fn into_any_model(self) -> DynModel
where
Self: Send + Sync + Sized,
{
Box::new(AnyModel(self, PhantomData))
}
}
pub trait ChatModel: Model {
fn user_marker(&self) -> &str;
fn end_user_marker(&self) -> &str;
fn assistant_marker(&self) -> &str;
fn end_assistant_marker(&self) -> &str;
fn system_prompt_marker(&self) -> &str;
fn end_system_prompt_marker(&self) -> &str;
}
pub type DynModel = Box<
dyn Model<
TextStream = Box<dyn Stream<Item = String> + Send + Sync + Unpin>,
SyncModel = BoxedSyncModel,
> + Send,
>;
#[async_trait::async_trait]
impl Model for DynModel {
type TextStream = Box<dyn Stream<Item = String> + Send + Sync + Unpin>;
type SyncModel = BoxedSyncModel;
fn tokenizer(&self) -> Arc<dyn Tokenizer + Send + Sync> {
let self_ref: &(dyn Model<
TextStream = Box<dyn Stream<Item = String> + Send + Sync + Unpin>,
SyncModel = BoxedSyncModel,
> + Send) = self.as_ref();
self_ref.tokenizer()
}
async fn stream_text_inner(
&self,
prompt: &str,
parameters: GenerationParameters,
) -> anyhow::Result<Self::TextStream> {
let self_ref: &(dyn Model<
TextStream = Box<dyn Stream<Item = String> + Send + Sync + Unpin>,
SyncModel = BoxedSyncModel,
> + Send) = self.as_ref();
self_ref.stream_text_inner(prompt, parameters).await
}
}
pub type BoxedSyncModel = Box<dyn SyncModel<Session = AnySession>>;
trait AnySessionTrait {
fn as_any_mut(&mut self) -> &mut dyn Any;
fn save_to(&self, path: &Path) -> anyhow::Result<()>;
}
impl<S: Any + Session> AnySessionTrait for S {
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn save_to(&self, path: &Path) -> anyhow::Result<()> {
Session::save_to(self, path)
}
}
pub struct AnySession {
session: Box<dyn AnySessionTrait>,
}
impl AnySession {
fn as_any_mut(&mut self) -> &mut dyn Any {
self.session.as_any_mut()
}
}
impl Session for AnySession {
fn save_to(&self, path: impl AsRef<Path>) -> anyhow::Result<()> {
self.session.save_to(path.as_ref())
}
}
impl SyncModel for BoxedSyncModel {
type Session = AnySession;
fn new_session(&self) -> anyhow::Result<Self::Session> {
let self_ref: &(dyn SyncModel<Session = AnySession>) = self.as_ref();
self_ref.new_session()
}
fn feed_text(&self, session: &mut Self::Session, prompt: &str) -> anyhow::Result<Logits> {
let self_ref: &(dyn SyncModel<Session = AnySession>) = self.as_ref();
self_ref.feed_text(session, prompt)
}
fn feed_tokens(&self, session: &mut Self::Session, tokens: &[u32]) -> anyhow::Result<Logits> {
let self_ref: &(dyn SyncModel<Session = AnySession>) = self.as_ref();
self_ref.feed_tokens(session, tokens)
}
fn stop_token(&self) -> anyhow::Result<u32> {
let self_ref: &(dyn SyncModel<Session = AnySession>) = self.as_ref();
self_ref.stop_token()
}
fn tokenizer(&self) -> Arc<dyn Tokenizer + Send + Sync> {
let self_ref: &(dyn SyncModel<Session = AnySession>) = self.as_ref();
self_ref.tokenizer()
}
}
struct AnySyncModel<M: SyncModel<Session = S>, S: Any>(M, PhantomData<S>);
impl<M: SyncModel<Session = S>, S: Session + Any> SyncModel for AnySyncModel<M, S> {
type Session = AnySession;
fn new_session(&self) -> anyhow::Result<Self::Session> {
self.0.new_session().map(|s| AnySession {
session: Box::new(s),
})
}
fn feed_text(&self, session: &mut Self::Session, prompt: &str) -> anyhow::Result<Logits> {
self.0.feed_text(
match session.as_any_mut().downcast_mut() {
Some(s) => s,
None => {
return Err(anyhow::Error::msg(format!(
"Invalid session type expected {:?}",
std::any::type_name::<S>()
)))
}
},
prompt,
)
}
fn feed_tokens(&self, session: &mut Self::Session, tokens: &[u32]) -> anyhow::Result<Logits> {
self.0.feed_tokens(
match session.as_any_mut().downcast_mut() {
Some(s) => s,
None => {
return Err(anyhow::Error::msg(format!(
"Invalid session type expected {:?}",
std::any::type_name::<S>()
)))
}
},
tokens,
)
}
fn stop_token(&self) -> anyhow::Result<u32> {
self.0.stop_token()
}
fn tokenizer(&self) -> Arc<dyn Tokenizer + Send + Sync> {
self.0.tokenizer()
}
}
struct AnyModel<M: Model<TextStream = S>, S: Stream<Item = String> + Send + Sync + Unpin + 'static>(
M,
PhantomData<S>,
);
#[async_trait::async_trait]
impl<M, S> Model for AnyModel<M, S>
where
S: Stream<Item = String> + Send + Sync + Unpin + 'static,
M: Model<TextStream = S> + Send + Sync,
{
type TextStream = Box<dyn Stream<Item = String> + Send + Sync + Unpin>;
type SyncModel = BoxedSyncModel;
fn tokenizer(&self) -> Arc<dyn Tokenizer + Send + Sync> {
self.0.tokenizer()
}
async fn stream_text_inner(
&self,
prompt: &str,
params: GenerationParameters,
) -> anyhow::Result<Self::TextStream> {
self.0
.stream_text_inner(prompt, params)
.await
.map(|s| Box::new(s) as Box<dyn Stream<Item = String> + Send + Sync + Unpin>)
}
async fn stream_text_with_sampler(
&self,
prompt: &str,
max_tokens: Option<u32>,
stop_on: Option<&str>,
sampler: Arc<Mutex<dyn Sampler>>,
) -> anyhow::Result<Self::TextStream> {
self.0
.stream_text_with_sampler(prompt, max_tokens, stop_on, sampler)
.await
.map(|s| Box::new(s) as Box<dyn Stream<Item = String> + Send + Sync + Unpin>)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct GenerationParameters {
pub(crate) temperature: f32,
pub(crate) tau: f32,
pub(crate) eta: f32,
pub(crate) mu: f32,
pub(crate) repetition_penalty: f32,
pub(crate) repetition_penalty_range: u32,
pub(crate) max_length: u32,
pub(crate) stop_on: Option<String>,
}
impl Default for GenerationParameters {
fn default() -> Self {
Self {
temperature: 0.8,
eta: 0.1,
tau: 5.,
mu: 10.,
repetition_penalty: 1.3,
repetition_penalty_range: 64,
max_length: 128,
stop_on: None,
}
}
}
impl crate::model::GenerationParameters {
pub fn sampler(self) -> SamplerChain {
use llm_samplers::configure::SamplerSlot;
let GenerationParameters {
temperature,
tau,
eta,
mu,
repetition_penalty,
repetition_penalty_range,
max_length: _,
stop_on: _,
} = self;
SamplerChainBuilder::from([
(
"repetition",
SamplerSlot::new_static(move || {
Box::new(
SampleRepetition::default()
.penalty(repetition_penalty)
.last_n(repetition_penalty_range as usize),
)
}),
),
(
"freqpresence",
SamplerSlot::new_static(move || Box::new(SampleFreqPresence::default().last_n(64))),
),
(
"seqrepetition",
SamplerSlot::new_static(move || Box::<SampleSeqRepetition>::default()),
),
(
"temperature",
SamplerSlot::new_static(move || {
Box::new(SampleTemperature::default().temperature(temperature))
}),
),
(
"mirostat2",
SamplerSlot::new_static(move || {
Box::new(SampleMirostat2::default().tau(tau).eta(eta).mu(mu))
}),
),
])
.into_chain()
}
pub fn mirostat2_sampler(self) -> SampleMirostat2 {
SampleMirostat2::default()
.tau(self.tau)
.eta(self.eta)
.mu(self.mu)
}
pub fn bias_only_sampler(self) -> SamplerChain {
use llm_samplers::configure::SamplerSlot;
let GenerationParameters {
temperature,
repetition_penalty,
repetition_penalty_range,
..
} = self;
SamplerChainBuilder::from([
(
"repetition",
SamplerSlot::new_static(move || {
Box::new(
SampleRepetition::default()
.penalty(repetition_penalty)
.last_n(repetition_penalty_range as usize),
)
}),
),
(
"freqpresence",
SamplerSlot::new_static(move || Box::new(SampleFreqPresence::default().last_n(64))),
),
(
"seqrepetition",
SamplerSlot::new_static(move || Box::<SampleSeqRepetition>::default()),
),
(
"temperature",
SamplerSlot::new_static(move || {
Box::new(SampleTemperature::default().temperature(temperature))
}),
),
])
.into_chain()
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_tau(mut self, tau: f32) -> Self {
self.tau = tau;
self
}
pub fn with_eta(mut self, eta: f32) -> Self {
self.eta = eta;
self
}
pub fn with_mu(mut self, mu: f32) -> Self {
self.mu = mu;
self
}
pub fn with_repetition_penalty(mut self, repetition_penalty: f32) -> Self {
self.repetition_penalty = repetition_penalty;
self
}
pub fn with_repetition_penalty_range(mut self, repetition_penalty_range: u32) -> Self {
self.repetition_penalty_range = repetition_penalty_range;
self
}
pub fn with_max_length(mut self, max_length: u32) -> Self {
self.max_length = max_length;
self
}
pub fn with_stop_on(mut self, stop_on: impl Into<Option<String>>) -> Self {
self.stop_on = stop_on.into();
self
}
pub fn temperature(&self) -> f32 {
self.temperature
}
pub fn tau(&self) -> f32 {
self.tau
}
pub fn eta(&self) -> f32 {
self.eta
}
pub fn mu(&self) -> f32 {
self.mu
}
pub fn repetition_penalty(&self) -> f32 {
self.repetition_penalty
}
pub fn repetition_penalty_range(&self) -> u32 {
self.repetition_penalty_range
}
pub fn max_length(&self) -> u32 {
self.max_length
}
pub fn stop_on(&self) -> Option<&str> {
self.stop_on.as_deref()
}
}
#[allow(missing_docs)]
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum ModelType {
Mpt(MptType),
GptNeoX(GptNeoXType),
Llama(LlamaType),
}
#[allow(missing_docs)]
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum LlamaType {
Vicuna,
Guanaco,
WizardLm,
Orca,
LlamaSevenChat,
LlamaThirteenChat,
Custom(Url),
}
#[allow(missing_docs)]
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum MptType {
Base,
Story,
Instruct,
Chat,
Custom(Url),
}
#[allow(missing_docs)]
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub enum GptNeoXType {
LargePythia,
TinyPythia,
DollySevenB,
StableLm,
Custom(Url),
}
macro_rules! embedding {
($ty: ident) => {
#[doc = "A vector space for the "]
#[doc = stringify!($ty)]
#[doc = " model."]
pub struct $ty;
impl VectorSpace for $ty {}
};
}
embedding!(VicunaSpace);
embedding!(GuanacoSpace);
embedding!(WizardLmSpace);
embedding!(OrcaSpace);
embedding!(LlamaSevenChatSpace);
embedding!(LlamaThirteenChatSpace);
embedding!(MptBaseSpace);
embedding!(MptStorySpace);
embedding!(MptInstructSpace);
embedding!(MptChatSpace);
embedding!(LargePythiaSpace);
embedding!(TinyPythiaSpace);
embedding!(DollySevenBSpace);
embedding!(StableLmSpace);