use crate::error::{Error, Result};
use crate::model::{DefaultMultiscreenModel, ModelInferenceConfig, MultiscreenModelConfig};
use crate::runtime::{default_device, Device};
use std::fs;
use std::path::{Path, PathBuf};
#[derive(Clone, Debug)]
pub struct GenerationConfig {
pub max_new_tokens: usize,
pub pad_token_id: u32,
}
impl Default for GenerationConfig {
fn default() -> Self {
Self {
max_new_tokens: 64,
pad_token_id: 0,
}
}
}
pub struct ChatModel {
model: DefaultMultiscreenModel,
device: Device,
config: MultiscreenModelConfig,
}
impl ChatModel {
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let checkpoint_path = path.as_ref();
let checkpoint_dir = checkpoint_path
.parent()
.ok_or_else(|| Error::Io(format!("cannot determine parent of {:?}", checkpoint_path)))?
.to_path_buf();
let config = match find_file(&[checkpoint_dir.join("config.json")]) {
Ok(config_path) => {
let json = fs::read_to_string(&config_path).map_err(|e| {
Error::Io(format!("failed to read {}: {e}", config_path.display()))
})?;
serde_json::from_str::<MultiscreenModelConfig>(&json).map_err(|e| {
Error::Serialization(format!("failed to parse {}: {e}", config_path.display()))
})?
}
Err(_) => {
MultiscreenModelConfig::preset_10m(8192, 512)
}
};
let device = default_device()?;
let mut model = DefaultMultiscreenModel::new(config.clone(), &device)?;
model.load_parameters(checkpoint_path)?;
Ok(Self {
model,
device,
config,
})
}
pub fn generate(&self, prompt: &[u32], config: GenerationConfig) -> Result<Vec<u32>> {
let inference_config = ModelInferenceConfig {
max_new_tokens: config.max_new_tokens,
pad_token_id: config.pad_token_id,
};
let output = self
.model
.infer_tokens(prompt, &inference_config, &self.device)?;
Ok(output.token_ids)
}
pub fn generate_stream(
&self,
prompt: &[u32],
config: GenerationConfig,
on_token: impl FnMut(u32, usize) -> bool,
) -> Result<Vec<u32>> {
let inference_config = ModelInferenceConfig {
max_new_tokens: config.max_new_tokens,
pad_token_id: config.pad_token_id,
};
let output =
self.model
.infer_tokens_stream(prompt, &inference_config, &self.device, on_token)?;
Ok(output.token_ids)
}
pub fn model(&self) -> &DefaultMultiscreenModel {
&self.model
}
pub fn config(&self) -> &MultiscreenModelConfig {
&self.config
}
pub fn device(&self) -> &Device {
&self.device
}
}
fn find_file(candidates: &[PathBuf]) -> Result<PathBuf> {
for candidate in candidates {
if candidate.exists() {
return Ok(candidate.clone());
}
}
let descriptions = candidates
.iter()
.map(|p| format!(" {}", p.display()))
.collect::<Vec<_>>()
.join("\n");
Err(Error::Io(format!(
"file not found; searched:\n{descriptions}"
)))
}