use crate::{error::LlamaCoreError, utils::set_tensor_data_u8, Metadata};
use chat_prompts::PromptTemplateType;
use wasmedge_wasi_nn::{
Error as WasiNnError, Graph as WasiNnGraph, GraphExecutionContext, TensorType,
};
#[derive(Debug)]
pub struct GraphBuilder {
metadata: Option<Metadata>,
wasi_nn_graph_builder: wasmedge_wasi_nn::GraphBuilder,
}
impl GraphBuilder {
pub fn new(ty: EngineType) -> Result<Self, LlamaCoreError> {
let encoding = match ty {
EngineType::Ggml => wasmedge_wasi_nn::GraphEncoding::Ggml,
EngineType::Whisper => wasmedge_wasi_nn::GraphEncoding::Whisper,
EngineType::Piper => wasmedge_wasi_nn::GraphEncoding::Piper,
};
let wasi_nn_graph_builder =
wasmedge_wasi_nn::GraphBuilder::new(encoding, wasmedge_wasi_nn::ExecutionTarget::AUTO);
Ok(Self {
metadata: None,
wasi_nn_graph_builder,
})
}
pub fn with_config(mut self, metadata: &Metadata) -> Result<Self, LlamaCoreError> {
let config = serde_json::to_string(&metadata).map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.config(config);
self.metadata = Some(metadata.clone());
Ok(self)
}
pub fn use_cpu(mut self) -> Self {
self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.cpu();
self
}
pub fn use_gpu(mut self) -> Self {
self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.gpu();
self
}
pub fn use_tpu(mut self) -> Self {
self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.tpu();
self
}
pub fn build_from_buffer<B>(self, bytes_array: impl AsRef<[B]>) -> Result<Graph, LlamaCoreError>
where
B: AsRef<[u8]>,
{
let graph = self
.wasi_nn_graph_builder
.build_from_bytes(bytes_array)
.map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let context = graph.init_execution_context().map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let created = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
Ok(Graph {
created,
metadata: self.metadata.clone().unwrap_or_default(),
_graph: graph,
context,
})
}
pub fn build_from_files<P>(self, files: impl AsRef<[P]>) -> Result<Graph, LlamaCoreError>
where
P: AsRef<std::path::Path>,
{
let graph = self
.wasi_nn_graph_builder
.build_from_files(files)
.map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let context = graph.init_execution_context().map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let created = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
Ok(Graph {
created,
metadata: self.metadata.clone().unwrap_or_default(),
_graph: graph,
context,
})
}
pub fn build_from_cache(self) -> Result<Graph, LlamaCoreError> {
match &self.metadata {
Some(metadata) => {
let graph = self
.wasi_nn_graph_builder
.build_from_cache(&metadata.model_alias)
.map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let context = graph.init_execution_context().map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let created = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
Ok(Graph {
created,
metadata: metadata.clone(),
_graph: graph,
context,
})
}
None => {
let err_msg =
"Failed to create a Graph from cache. Reason: Metadata is not provided."
.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
Err(LlamaCoreError::Operation(err_msg))
}
}
}
}
#[derive(Debug)]
pub struct Graph {
pub created: std::time::Duration,
pub metadata: Metadata,
_graph: WasiNnGraph,
context: GraphExecutionContext,
}
impl Graph {
pub fn new(metadata: &Metadata) -> Result<Self, LlamaCoreError> {
let config = serde_json::to_string(&metadata).map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let graph = wasmedge_wasi_nn::GraphBuilder::new(
wasmedge_wasi_nn::GraphEncoding::Ggml,
wasmedge_wasi_nn::ExecutionTarget::AUTO,
)
.config(config)
.build_from_cache(&metadata.model_alias)
.map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let context = graph.init_execution_context().map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let created = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
Ok(Self {
created,
metadata: metadata.clone(),
_graph: graph,
context,
})
}
pub fn name(&self) -> &str {
&self.metadata.model_name
}
pub fn alias(&self) -> &str {
&self.metadata.model_alias
}
pub fn prompt_template(&self) -> PromptTemplateType {
self.metadata.prompt_template
}
pub fn update_metadata(&mut self) -> Result<(), LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Update metadata for the model named {}", self.name());
let config = match serde_json::to_string(&self.metadata) {
Ok(config) => config,
Err(e) => {
let err_msg = format!("Failed to update metadta. Reason: Fail to serialize metadata to a JSON string. {}", e);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
return Err(LlamaCoreError::Operation(err_msg));
}
};
let res = set_tensor_data_u8(self, 1, config.as_bytes());
#[cfg(feature = "logging")]
info!(target: "stdout", "Metadata updated successfully.");
res
}
pub fn set_input<T: Sized>(
&mut self,
index: usize,
tensor_type: TensorType,
dimensions: &[usize],
data: impl AsRef<[T]>,
) -> Result<(), WasiNnError> {
self.context.set_input(index, tensor_type, dimensions, data)
}
pub fn compute(&mut self) -> Result<(), WasiNnError> {
self.context.compute()
}
pub fn compute_single(&mut self) -> Result<(), WasiNnError> {
self.context.compute_single()
}
pub fn get_output<T: Sized>(
&self,
index: usize,
out_buffer: &mut [T],
) -> Result<usize, WasiNnError> {
self.context.get_output(index, out_buffer)
}
pub fn get_output_single<T: Sized>(
&self,
index: usize,
out_buffer: &mut [T],
) -> Result<usize, WasiNnError> {
self.context.get_output_single(index, out_buffer)
}
pub fn finish_single(&mut self) -> Result<(), WasiNnError> {
self.context.fini_single()
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum EngineType {
Ggml,
Whisper,
Piper,
}