pub mod hf;
pub use hf::SlmHfModel;
mod answer;
pub mod core;
pub mod errors;
mod formatter;
pub mod inference;
pub mod models;
mod oracle;
use std::path::Path;
use std::result::Result;
pub use answer::SlmAnswer;
use errors::*;
pub use formatter::SlmFormatter;
pub use inference::{SlmBoxedBrakeFn, SlmBrake, SlmInference, SlmSimpleInference};
pub use models::SlmDynamicFormatter;
pub use oracle::{SlmOracle, SlmSimpleOracle};
pub trait SlmToken: Copy {
fn as_i32(&self) -> i32;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SlmRole {
System,
User,
Assistant,
Tool(String),
}
impl SlmRole {
pub fn as_str(&self) -> &'static str {
match self {
SlmRole::System => "system",
SlmRole::User => "user",
SlmRole::Assistant => "assistant",
SlmRole::Tool(_) => "tool",
}
}
pub fn is_tool(&self) -> bool {
matches!(self, SlmRole::Tool(_))
}
pub fn tool_name(&self) -> Option<&str> {
match self {
SlmRole::Tool(name) => Some(name.as_str()),
_ => None,
}
}
pub fn tool(name: &str) -> SlmRole {
SlmRole::Tool(name.to_string())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SlmPos {
pub token_pos: usize,
pub fork_id: usize,
}
#[allow(dead_code)]
impl SlmPos {
fn fork_id(&self) -> usize {
self.fork_id
}
fn token_pos(&self) -> usize {
self.token_pos
}
pub fn new(token_pos: usize, fork_id: usize) -> SlmPos {
Self { token_pos, fork_id }
}
}
pub trait SlmBatch<Token: SlmToken> {
fn add(&mut self, token: Token, pos: SlmPos, logits: bool) -> Result<(), BatchError>;
fn clear(&mut self);
fn n_tokens(&self) -> usize;
fn n_max(&self) -> usize;
}
pub trait SlmModelConfig {
type Context: SlmContext;
type Model: SlmModel<Context = Self::Context>;
fn load_gguf(self, path: impl AsRef<Path>) -> Result<Self::Model, GgufLoaderError>;
}
pub trait SlmModel {
type Context: SlmContext;
fn context(&self) -> impl SlmContextBuilder<Self::Context>;
}
pub enum SlmKvType {
Q4,
Q5,
Q6,
Q8,
RawQ8,
F16,
F32,
}
pub trait SlmContextBuilder<T> {
fn build(self) -> Result<T, ContextBuilderError>;
fn with_sampler(self, temperature: f32, top_k: i32, top_p: f32) -> Self;
fn with_n_ctx(self, n_ctx: usize) -> Self;
fn with_gen_type_kv(self, k: SlmKvType, v: SlmKvType) -> Self;
fn with_n_batch(self, n_batch: usize) -> Self;
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd)]
pub enum SlmEditLevel {
#[default]
DumpRestore = 0,
Cut = 1,
Truncate = 2,
}
pub trait SlmContext {
type Token: SlmToken;
type Batch: SlmBatch<Self::Token>;
fn new_batch(&self, tokens: usize, sequences: usize) -> Result<Self::Batch, BatchError>;
fn max_batch_len(&self) -> usize;
fn decode(&mut self, batch: &mut Self::Batch) -> Result<(), DecodeError>;
fn sample(&mut self, logit_idx: usize) -> Result<Option<Self::Token>, SamplingError>;
fn token_to_bytes(
&self,
token: Self::Token,
buffer_size: usize,
special: bool,
lstrip: Option<usize>,
) -> Result<Vec<u8>, TokenToStringError>;
fn str_to_tokens(
&self,
str: &str,
add_special: bool,
parse_special: bool,
) -> Result<Vec<Self::Token>, StringToTokenError>;
fn clear(&mut self) -> Result<(), ContextError>;
fn drop(&mut self, fork_id: usize) -> Result<(), ContextError>;
fn truncate(&mut self, pos: &SlmPos) -> Result<SlmPos, ContextError>;
fn cut(&mut self, start_pos: &SlmPos, end_pos: &SlmPos) -> Result<SlmPos, ContextError>;
fn dump(&mut self) -> Result<Vec<u8>, ContextError>;
fn restore(&mut self, data: Vec<u8>) -> Result<(), ContextError>;
fn edit_level(&self) -> SlmEditLevel;
}