use std::path::{Path, PathBuf};
use std::sync::Arc;
use mistralrs::{
EmbeddingModelBuilder, EmbeddingRequest, Model, ModelBuilder, TextMessageRole, TextMessages,
};
use sha2::{Digest, Sha256};
use tokio::runtime::{Handle, Runtime};
use crate::api::{EmbeddingEngine, EngineError, InferenceEngine};
pub struct MistralEngine {
model: Arc<Model>,
model_id: [u8; 32],
source: String,
rt: Runtime,
}
impl MistralEngine {
pub async fn from_hf_repo(repo: impl Into<String>) -> anyhow::Result<Self> {
let source = repo.into();
let model = ModelBuilder::new(&source).build().await?;
Self::wrap(Arc::new(model), source)
}
pub async fn from_model_path(path: impl AsRef<Path>) -> anyhow::Result<Self> {
let path: PathBuf = path.as_ref().to_path_buf();
let source = path.to_string_lossy().into_owned();
let model = ModelBuilder::new(&source).build().await?;
Self::wrap(Arc::new(model), source)
}
pub async fn embedding_from_hf_repo(repo: impl Into<String>) -> anyhow::Result<Self> {
let source = repo.into();
let model = EmbeddingModelBuilder::new(&source).build().await?;
Self::wrap(Arc::new(model), source)
}
pub fn wrap(model: Arc<Model>, source: impl Into<String>) -> anyhow::Result<Self> {
let source = source.into();
let model_id = Self::hash_source(&source);
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.thread_name("hanzo-engine-dispatch")
.build()?;
Ok(Self {
model,
model_id,
source,
rt,
})
}
pub fn hash_source(source: &str) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(source.as_bytes());
hasher.finalize().into()
}
pub fn model_id(&self) -> &[u8; 32] {
&self.model_id
}
pub fn source(&self) -> &str {
&self.source
}
pub fn model(&self) -> &Arc<Model> {
&self.model
}
}
impl MistralEngine {
fn run<F, T>(&self, fut: F) -> T
where
F: std::future::Future<Output = T> + Send + 'static,
T: Send + 'static,
{
match Handle::try_current() {
Err(_) => self.rt.block_on(fut),
Ok(_) => {
let (tx, rx) = std::sync::mpsc::channel();
self.rt.spawn(async move {
let _ = tx.send(fut.await);
});
rx.recv().expect("engine runtime task panicked")
}
}
}
}
impl InferenceEngine for MistralEngine {
fn infer(&self, model_id: &[u8; 32], prompt: &[u8]) -> Result<Vec<u8>, EngineError> {
if model_id != &self.model_id {
return Err(EngineError::ModelNotFound(hex_id(model_id)));
}
let prompt_str = std::str::from_utf8(prompt)
.map_err(|e| EngineError::Other(format!("prompt is not UTF-8: {e}")))?
.to_owned();
let model = Arc::clone(&self.model);
let result = self.run(async move {
let messages = TextMessages::new().add_message(TextMessageRole::User, prompt_str);
let response = model
.send_chat_request(messages)
.await
.map_err(|e| EngineError::Other(format!("chat request failed: {e}")))?;
response
.choices
.into_iter()
.next()
.and_then(|c| c.message.content)
.ok_or_else(|| {
EngineError::Other("model returned no completion content".into())
})
})?;
Ok(result.into_bytes())
}
}
impl EmbeddingEngine for MistralEngine {
fn embed(&self, dim: usize, text: &[u8]) -> Result<Vec<f32>, EngineError> {
let text_str = std::str::from_utf8(text)
.map_err(|e| EngineError::Other(format!("text is not UTF-8: {e}")))?
.to_owned();
let model = Arc::clone(&self.model);
let vec = self.run(async move {
let request = EmbeddingRequest::builder().add_prompt(text_str);
let vecs = model
.generate_embeddings(request)
.await
.map_err(|e| EngineError::Other(format!("embedding request failed: {e}")))?;
vecs.into_iter()
.next()
.ok_or_else(|| EngineError::Other("embedding response was empty".into()))
})?;
if vec.len() != dim {
return Err(EngineError::Other(format!(
"embedding dim mismatch: requested {dim}, model returned {}",
vec.len()
)));
}
Ok(vec)
}
}
fn hex_id(id: &[u8; 32]) -> String {
let mut s = String::with_capacity(64);
for b in id {
use std::fmt::Write;
let _ = write!(&mut s, "{b:02x}");
}
s
}