use std::rc::Rc;
use serde_json::{Value, json};
use crate::{
error::{Error, InvariantViolationPayload, ParsePayload, Result},
lm::{
cache::{CacheConfig, KvCache, make_prompt_cache, save_prompt_cache},
generate::{FinishReason, GenConfig, GenerationResponse, Generator, build_generator},
model::Model,
speculative::{DraftConfig, SpeculativeStream, speculative_stream_generate},
},
tokenizer::{StreamingDetokenizer as _, Tokenizer, wrapper::BoxedDetokenizer},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, derive_more::Display, derive_more::IsVariant)]
#[display("{}", self.as_str())]
#[non_exhaustive]
pub enum Role {
System,
User,
Assistant,
Tool,
}
impl Role {
pub const fn as_str(self) -> &'static str {
match self {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ChatMessage {
pub role: Role,
content: String,
}
impl ChatMessage {
pub fn new(role: Role, content: impl Into<String>) -> Self {
Self {
role,
content: content.into(),
}
}
#[inline(always)]
pub fn content(&self) -> &str {
&self.content
}
pub fn system(content: impl Into<String>) -> Self {
Self::new(Role::System, content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new(Role::User, content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new(Role::Assistant, content)
}
pub fn tool(content: impl Into<String>) -> Self {
Self::new(Role::Tool, content)
}
}
pub struct SpeculativeDecodingConfig {
pub draft_model: Rc<dyn Model>,
pub num_draft_tokens: usize,
pub draft_cache_config: CacheConfig,
}
impl SpeculativeDecodingConfig {
pub const DEFAULT_NUM_DRAFT_TOKENS: usize = 5;
pub fn new(draft_model: Rc<dyn Model>, draft_cache_config: CacheConfig) -> Self {
Self {
draft_model,
num_draft_tokens: Self::DEFAULT_NUM_DRAFT_TOKENS,
draft_cache_config,
}
}
}
type KvCaches = Vec<Box<dyn KvCache>>;
#[derive(Clone)]
struct CachedTokens {
opaque_len: usize,
known: Vec<u32>,
}
impl CachedTokens {
fn empty() -> Self {
Self {
opaque_len: 0,
known: Vec::new(),
}
}
fn opaque(offset: usize) -> Self {
Self {
opaque_len: offset,
known: Vec::new(),
}
}
}
enum CacheSlot {
Empty,
Realised {
cache: KvCaches,
draft_cache: Option<KvCaches>,
cached: CachedTokens,
},
History(Vec<ChatMessage>),
SpeculativeSpent,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ChatSessionError {
NoCacheAvailable,
SpeculativeCacheUnsupported,
SpeculativeCacheRestoreUnsupported,
}
impl std::fmt::Display for ChatSessionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ChatSessionError::NoCacheAvailable => f.write_str(
"no KV cache is available: call respond() / stream_respond() before save_cache()",
),
ChatSessionError::SpeculativeCacheUnsupported => f.write_str(
"speculative-decoding sessions do not support cache save: the speculative \
generator consumes its KV caches and rebuilds them each turn, so there is \
no advanced cache that encodes the conversation to persist",
),
ChatSessionError::SpeculativeCacheRestoreUnsupported => f.write_str(
"ChatSessionBuilder::cache() combined with ChatSessionBuilder::speculative() \
is unsupported: speculative_stream_generate consumes its KV caches and does \
not return them, so the restored opaque prefix would be used on the first \
turn and silently lost on every subsequent turn (the opaque prefix's token \
ids are unknown, so the session cannot re-prefill it). Drop .cache(..) to \
run speculative decoding without a restored prefix, or drop .speculative(..) \
to use the standard path (which reuses its KV cache across turns and \
preserves the restored prefix)",
),
}
}
}
impl std::error::Error for ChatSessionError {}
impl From<ChatSessionError> for Error {
fn from(e: ChatSessionError) -> Self {
match e {
ChatSessionError::NoCacheAvailable => {
Error::InvariantViolation(InvariantViolationPayload::new(
"ChatSession::save_cache",
"no KV cache is available: call respond() / stream_respond() before save_cache()",
))
}
ChatSessionError::SpeculativeCacheUnsupported => {
Error::InvariantViolation(InvariantViolationPayload::new(
"ChatSession::save_cache",
"speculative-decoding sessions do not support cache save",
))
}
ChatSessionError::SpeculativeCacheRestoreUnsupported => {
Error::InvariantViolation(InvariantViolationPayload::new(
"ChatSessionBuilder::build",
"cache() and speculative() are mutually exclusive; build with only .cache() or only .speculative()",
))
}
}
}
}
pub struct ChatSessionBuilder {
model: Box<dyn Model>,
tokenizer: Tokenizer,
cache_config: CacheConfig,
instructions: Option<String>,
generate_params: GenConfig,
speculative: Option<SpeculativeDecodingConfig>,
initial: CacheSlot,
}
impl ChatSessionBuilder {
pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
self.instructions = Some(instructions.into());
self
}
pub fn generate_params(mut self, params: GenConfig) -> Self {
self.generate_params = params;
self
}
pub fn speculative(mut self, config: SpeculativeDecodingConfig) -> Self {
self.speculative = Some(config);
self
}
pub fn history(mut self, history: Vec<ChatMessage>) -> Self {
self.initial = CacheSlot::History(history);
self
}
pub fn cache(mut self, cache: Vec<Box<dyn KvCache>>) -> Self {
let offset = cache.first().map(|c| c.offset()).unwrap_or(0);
self.initial = CacheSlot::Realised {
cache,
draft_cache: None,
cached: CachedTokens::opaque(offset),
};
self
}
pub fn build(self) -> Result<ChatSession> {
if self.speculative.is_some() && matches!(self.initial, CacheSlot::Realised { .. }) {
return Err(ChatSessionError::SpeculativeCacheRestoreUnsupported.into());
}
Ok(ChatSession {
model: self.model,
tokenizer: self.tokenizer,
cache_config: self.cache_config,
instructions: self.instructions,
generate_params: self.generate_params,
speculative: self.speculative,
cache: self.initial,
history: Vec::new(),
})
}
}
pub struct ChatSession {
model: Box<dyn Model>,
tokenizer: Tokenizer,
cache_config: CacheConfig,
instructions: Option<String>,
generate_params: GenConfig,
speculative: Option<SpeculativeDecodingConfig>,
cache: CacheSlot,
history: Vec<ChatMessage>,
}
impl ChatSession {
pub fn builder(
model: Box<dyn Model>,
tokenizer: Tokenizer,
cache_config: CacheConfig,
) -> ChatSessionBuilder {
ChatSessionBuilder {
model,
tokenizer,
cache_config,
instructions: None,
generate_params: GenConfig::default(),
speculative: None,
initial: CacheSlot::Empty,
}
}
pub fn instructions(&self) -> Option<&str> {
self.instructions.as_deref()
}
pub fn set_instructions(&mut self, instructions: Option<String>) {
self.instructions = instructions;
}
pub fn generate_params(&self) -> &GenConfig {
&self.generate_params
}
pub fn generate_params_mut(&mut self) -> &mut GenConfig {
&mut self.generate_params
}
pub fn history(&self) -> &[ChatMessage] {
&self.history
}
pub fn has_cache(&self) -> bool {
matches!(self.cache, CacheSlot::Realised { .. })
}
pub fn current_cache(&self) -> Option<&[Box<dyn KvCache>]> {
match &self.cache {
CacheSlot::Realised { cache, .. } => Some(cache),
_ => None,
}
}
pub fn clear(&mut self) {
self.cache = CacheSlot::Empty;
self.history.clear();
}
pub fn save_cache(&self, path: &std::path::Path) -> Result<()> {
match &self.cache {
CacheSlot::Realised { cache, .. } => {
save_prompt_cache(path, cache, &std::collections::HashMap::new())
}
CacheSlot::SpeculativeSpent => Err(ChatSessionError::SpeculativeCacheUnsupported.into()),
_ => Err(ChatSessionError::NoCacheAvailable.into()),
}
}
fn build_turn_prompt(&self, prompt: &str, role: Role) -> Result<(Vec<u32>, Vec<ChatMessage>)> {
let mut messages: Vec<ChatMessage> = Vec::new();
if let Some(instructions) = &self.instructions {
messages.push(ChatMessage::system(instructions.clone()));
}
let replayed: Vec<ChatMessage> = match &self.cache {
CacheSlot::History(h) => h.clone(),
_ => Vec::new(),
};
messages.extend(replayed.iter().cloned());
messages.extend(self.history.iter().cloned());
messages.push(ChatMessage::new(role, prompt));
let json_messages = Value::Array(
messages
.iter()
.map(|m| json!({ "role": m.role.as_str(), "content": m.content() }))
.collect(),
);
let prompt_ids = self
.tokenizer
.apply_chat_template_ids(&json_messages, None, true, false, None)
.map_err(|e| {
Error::Parse(ParsePayload::new(
"ChatSession: apply_chat_template",
"chat template",
std::io::Error::other(e.to_string()),
))
})?;
Ok((prompt_ids, replayed))
}
fn take_cache(&mut self) -> (KvCaches, Option<KvCaches>, CachedTokens) {
let slot = std::mem::replace(&mut self.cache, CacheSlot::Empty);
match slot {
CacheSlot::Realised {
cache,
draft_cache,
cached,
} => {
let draft = match (&self.speculative, draft_cache) {
(Some(spec), None) => Some(make_prompt_cache(&spec.draft_cache_config)),
(_, existing) => existing,
};
(cache, draft, cached)
}
CacheSlot::Empty | CacheSlot::History(_) | CacheSlot::SpeculativeSpent => {
let cache = make_prompt_cache(&self.cache_config);
let draft = self
.speculative
.as_ref()
.map(|spec| make_prompt_cache(&spec.draft_cache_config));
(cache, draft, CachedTokens::empty())
}
}
}
fn decide_prefill(
&self,
prompt_ids: &[u32],
cache: KvCaches,
cached: &CachedTokens,
) -> (KvCaches, usize, usize) {
let extends = prompt_ids.len() > cached.known.len() && prompt_ids.starts_with(&cached.known);
if extends {
(cache, cached.known.len(), cached.opaque_len)
} else {
drop(cache);
(make_prompt_cache(&self.cache_config), 0, 0)
}
}
pub fn respond(&mut self, prompt: &str) -> Result<String> {
self.respond_as(prompt, Role::User)
}
pub fn respond_as(&mut self, prompt: &str, role: Role) -> Result<String> {
let mut output = String::new();
let mut last_err: Option<Error> = None;
{
let mut stream = self.stream_respond_as(prompt, role)?;
for response in &mut stream {
match response {
Ok(r) => output.push_str(&r.text),
Err(e) => {
last_err = Some(e);
break;
}
}
}
}
match last_err {
Some(e) => Err(e),
None => Ok(output),
}
}
pub fn stream_respond(&mut self, prompt: &str) -> Result<ChatResponseStream<'_>> {
self.stream_respond_as(prompt, Role::User)
}
pub fn stream_respond_as(&mut self, prompt: &str, role: Role) -> Result<ChatResponseStream<'_>> {
let (prompt_ids, replayed) = self.build_turn_prompt(prompt, role)?;
let (cache, draft_cache, cached) = self.take_cache();
self.history.extend(replayed);
self.history.push(ChatMessage::new(role, prompt));
let detok = self.tokenizer.detokenizer();
let eos: Vec<u32> = self.tokenizer.eos_token_ids_iter().collect();
let mut cfg = self.generate_params.clone();
cfg.eos = eos.clone();
let max_tokens = cfg.max_tokens;
let (std_cache, prefill_start, opaque_len) = if self.speculative.is_none() {
self.decide_prefill(&prompt_ids, cache, &cached)
} else {
(cache, 0, 0)
};
let ChatSession {
model,
tokenizer,
speculative,
cache: cache_slot,
history,
..
} = self;
let model: &dyn Model = &**model;
let driver = match speculative.as_ref() {
Some(spec) => {
let draft_cache =
draft_cache.unwrap_or_else(|| make_prompt_cache(&spec.draft_cache_config));
Driver::Speculative(Box::new(SpeculativeTurn::new(
model,
tokenizer,
&prompt_ids,
std_cache,
draft_cache,
DraftConfig {
draft_model: Box::new(RcModel(Rc::clone(&spec.draft_model))),
n_draft_tokens: spec.num_draft_tokens,
},
cfg,
)))
}
None => Driver::Standard(Box::new(StandardTurn {
generator: build_generator(model, &prompt_ids[prefill_start..], std_cache, cfg),
draft_cache,
})),
};
Ok(ChatResponseStream {
cache_slot,
history,
driver: Some(driver),
detok,
eos,
max_tokens,
prompt_tokens: prompt_ids.len(),
produced: 0,
reply: String::new(),
prompt_ids,
opaque_len,
generated: Vec::new(),
finished: false,
detok_finalized: false,
committed: false,
})
}
}
struct RcModel(Rc<dyn Model>);
impl Model for RcModel {
fn forward(
&self,
tokens: &crate::array::Array,
cache: &mut [Box<dyn KvCache>],
) -> Result<crate::array::Array> {
self.0.forward(tokens, cache)
}
fn forward_embeddings(
&self,
embeddings: &crate::array::Array,
cache: &mut [Box<dyn KvCache>],
) -> Result<crate::array::Array> {
self.0.forward_embeddings(embeddings, cache)
}
}
enum Driver<'a> {
Standard(Box<StandardTurn<'a>>),
Speculative(Box<SpeculativeTurn<'a>>),
}
struct StandardTurn<'a> {
generator: Generator<'a, dyn Model>,
draft_cache: Option<KvCaches>,
}
struct SpeculativeTurn<'a> {
iter: SpeculativeStream<'a>,
}
impl<'a> SpeculativeTurn<'a> {
#[allow(clippy::too_many_arguments)]
fn new(
target: &'a dyn Model,
tokenizer: &'a Tokenizer,
prompt: &[u32],
cache: Vec<Box<dyn KvCache>>,
draft_cache: Vec<Box<dyn KvCache>>,
draft_cfg: DraftConfig,
cfg: GenConfig,
) -> Self {
let iter = speculative_stream_generate(
target,
tokenizer,
prompt,
cache,
draft_cache,
draft_cfg,
cfg,
);
Self { iter }
}
}
pub struct ChatResponseStream<'s> {
cache_slot: &'s mut CacheSlot,
history: &'s mut Vec<ChatMessage>,
driver: Option<Driver<'s>>,
detok: BoxedDetokenizer,
eos: Vec<u32>,
max_tokens: usize,
prompt_tokens: usize,
produced: usize,
reply: String,
prompt_ids: Vec<u32>,
opaque_len: usize,
generated: Vec<u32>,
finished: bool,
detok_finalized: bool,
committed: bool,
}
impl ChatResponseStream<'_> {
fn commit(&mut self) {
if self.committed {
return;
}
self.committed = true;
match self.driver.take() {
Some(Driver::Standard(turn)) => {
if !self.detok_finalized {
self.detok.finalize();
self.detok_finalized = true;
let tail = self.detok.last_segment();
self.reply.push_str(&tail);
}
let cache = turn.generator.into_cache();
let offset = cache.first().map(|c| c.offset()).unwrap_or(0);
let mut logical: Vec<u32> =
Vec::with_capacity(self.prompt_ids.len() + self.generated.len());
logical.extend_from_slice(&self.prompt_ids);
logical.extend_from_slice(&self.generated);
let opaque_len = self.opaque_len.min(offset);
let known_len = offset - opaque_len;
let cached = if known_len <= logical.len() {
logical.truncate(known_len);
CachedTokens {
opaque_len,
known: logical,
}
} else {
CachedTokens::opaque(offset)
};
*self.cache_slot = CacheSlot::Realised {
cache,
draft_cache: turn.draft_cache,
cached,
};
}
Some(Driver::Speculative(mut turn)) => {
if !self.detok_finalized {
let tail = turn.iter.finalize_tail();
self.detok_finalized = true;
self.reply.push_str(&tail);
}
*self.cache_slot = CacheSlot::SpeculativeSpent;
}
None => return,
}
self
.history
.push(ChatMessage::assistant(std::mem::take(&mut self.reply)));
}
}
impl Iterator for ChatResponseStream<'_> {
type Item = Result<GenerationResponse>;
fn next(&mut self) -> Option<Self::Item> {
if self.finished {
return None;
}
match self.driver.as_mut() {
Some(Driver::Speculative(turn)) => match turn.iter.next() {
Some(Ok(spec)) => {
let response = spec.response;
self.produced = response.generation_tokens;
if response.text.is_empty() {
} else {
self.reply.push_str(&response.text);
}
if response.finish_reason.is_some() {
self.finished = true;
self.detok_finalized = true;
}
Some(Ok(response))
}
Some(Err(e)) => {
self.finished = true;
Some(Err(e))
}
None => {
self.finished = true;
self.detok_finalized = true;
None
}
},
Some(Driver::Standard(turn)) => {
let step = match turn.generator.next() {
Some(Ok(step)) => step,
Some(Err(e)) => {
self.finished = true;
return Some(Err(e));
}
None => {
self.finished = true;
return None;
}
};
let token = step.token;
self.generated.push(token);
if self.eos.contains(&token) {
self.finished = true;
self.detok.finalize();
self.detok_finalized = true;
let text = self.detok.last_segment();
self.reply.push_str(&text);
return Some(Ok(GenerationResponse {
text,
token,
logprobs: step.logprobs,
prompt_tokens: self.prompt_tokens,
prompt_tps: 0.0,
generation_tokens: self.produced + 1,
generation_tps: 0.0,
peak_memory_bytes: crate::memory::peak_memory().ok(),
finish_reason: Some(FinishReason::Eos),
}));
}
self.detok.add_token(token);
self.produced += 1;
let text = self.detok.last_segment();
self.reply.push_str(&text);
if self.produced >= self.max_tokens {
self.finished = true;
self.detok.finalize();
self.detok_finalized = true;
let tail = self.detok.last_segment();
self.reply.push_str(&tail);
return Some(Ok(GenerationResponse {
text: format!("{text}{tail}"),
token,
logprobs: step.logprobs,
prompt_tokens: self.prompt_tokens,
prompt_tps: 0.0,
generation_tokens: self.produced,
generation_tps: 0.0,
peak_memory_bytes: crate::memory::peak_memory().ok(),
finish_reason: Some(FinishReason::Length),
}));
}
Some(Ok(GenerationResponse {
text,
token,
logprobs: step.logprobs,
prompt_tokens: self.prompt_tokens,
prompt_tps: 0.0,
generation_tokens: self.produced,
generation_tps: 0.0,
peak_memory_bytes: crate::memory::peak_memory().ok(),
finish_reason: None,
}))
}
None => {
self.finished = true;
None
}
}
}
}
impl Drop for ChatResponseStream<'_> {
fn drop(&mut self) {
self.commit();
}
}
#[cfg(test)]
mod tests;