use crate::{error::LlamaCoreError, utils::set_tensor_data_u8, BaseMetadata};
use wasmedge_wasi_nn::{
Error as WasiNnError, Graph as WasiNnGraph, GraphExecutionContext, TensorType,
};
#[derive(Debug)]
pub struct GraphBuilder<M: BaseMetadata + serde::Serialize + Clone + Default> {
metadata: Option<M>,
wasi_nn_graph_builder: wasmedge_wasi_nn::GraphBuilder,
}
impl<M: BaseMetadata + serde::Serialize + Clone + Default> GraphBuilder<M> {
pub fn new(ty: EngineType) -> Result<Self, LlamaCoreError> {
let encoding = match ty {
EngineType::Ggml => wasmedge_wasi_nn::GraphEncoding::Ggml,
EngineType::Whisper => wasmedge_wasi_nn::GraphEncoding::Whisper,
EngineType::Piper => wasmedge_wasi_nn::GraphEncoding::Piper,
};
let wasi_nn_graph_builder =
wasmedge_wasi_nn::GraphBuilder::new(encoding, wasmedge_wasi_nn::ExecutionTarget::AUTO);
Ok(Self {
metadata: None,
wasi_nn_graph_builder,
})
}
pub fn with_config(mut self, metadata: M) -> Result<Self, LlamaCoreError> {
let config = serde_json::to_string(&metadata).map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.config(config);
self.metadata = Some(metadata.clone());
Ok(self)
}
pub fn use_cpu(mut self) -> Self {
self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.cpu();
self
}
pub fn use_gpu(mut self) -> Self {
self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.gpu();
self
}
pub fn use_tpu(mut self) -> Self {
self.wasi_nn_graph_builder = self.wasi_nn_graph_builder.tpu();
self
}
pub fn build_from_buffer<B>(
self,
bytes_array: impl AsRef<[B]>,
) -> Result<Graph<M>, LlamaCoreError>
where
B: AsRef<[u8]>,
{
let graph = self
.wasi_nn_graph_builder
.build_from_bytes(bytes_array)
.map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let context = graph.init_execution_context().map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let created = std::time::SystemTime::now();
Ok(Graph {
created,
metadata: self.metadata.clone().unwrap_or_default(),
graph,
context,
})
}
pub fn build_from_files<P>(self, files: impl AsRef<[P]>) -> Result<Graph<M>, LlamaCoreError>
where
P: AsRef<std::path::Path>,
{
let graph = self
.wasi_nn_graph_builder
.build_from_files(files)
.map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let context = graph.init_execution_context().map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let created = std::time::SystemTime::now();
Ok(Graph {
created,
metadata: self.metadata.clone().unwrap_or_default(),
graph,
context,
})
}
pub fn build_from_cache(self) -> Result<Graph<M>, LlamaCoreError> {
match &self.metadata {
Some(metadata) => {
let graph = self
.wasi_nn_graph_builder
.build_from_cache(metadata.model_alias())
.map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let context = graph.init_execution_context().map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let created = std::time::SystemTime::now();
Ok(Graph {
created,
metadata: metadata.clone(),
graph,
context,
})
}
None => {
let err_msg =
"Failed to create a Graph from cache. Reason: Metadata is not provided."
.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
Err(LlamaCoreError::Operation(err_msg))
}
}
}
}
#[derive(Debug)]
pub struct Graph<M: BaseMetadata + serde::Serialize + Clone + Default> {
pub created: std::time::SystemTime,
pub metadata: M,
graph: WasiNnGraph,
context: GraphExecutionContext,
}
impl<M: BaseMetadata + serde::Serialize + Clone + Default> Graph<M> {
pub fn new(metadata: M) -> Result<Self, LlamaCoreError> {
let config = serde_json::to_string(&metadata).map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let graph = wasmedge_wasi_nn::GraphBuilder::new(
wasmedge_wasi_nn::GraphEncoding::Ggml,
wasmedge_wasi_nn::ExecutionTarget::AUTO,
)
.config(config)
.build_from_cache(metadata.model_alias())
.map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let context = graph.init_execution_context().map_err(|e| {
let err_msg = e.to_string();
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let created = std::time::SystemTime::now();
Ok(Self {
created,
metadata: metadata.clone(),
graph,
context,
})
}
pub fn name(&self) -> &str {
self.metadata.model_name()
}
pub fn alias(&self) -> &str {
self.metadata.model_alias()
}
pub fn update_metadata(&mut self) -> Result<(), LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Update metadata for the model named {}", self.name());
let config = match serde_json::to_string(&self.metadata) {
Ok(config) => config,
Err(e) => {
let err_msg = format!("Failed to update metadta. Reason: Fail to serialize metadata to a JSON string. {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
return Err(LlamaCoreError::Operation(err_msg));
}
};
let res = set_tensor_data_u8(self, 1, config.as_bytes());
#[cfg(feature = "logging")]
info!(target: "stdout", "Metadata updated successfully.");
res
}
pub fn set_input<T: Sized>(
&mut self,
index: usize,
tensor_type: TensorType,
dimensions: &[usize],
data: impl AsRef<[T]>,
) -> Result<(), WasiNnError> {
self.context.set_input(index, tensor_type, dimensions, data)
}
pub fn compute(&mut self) -> Result<(), WasiNnError> {
self.context.compute()
}
pub fn compute_single(&mut self) -> Result<(), WasiNnError> {
self.context.compute_single()
}
pub fn get_output<T: Sized>(
&self,
index: usize,
out_buffer: &mut [T],
) -> Result<usize, WasiNnError> {
self.context.get_output(index, out_buffer)
}
pub fn get_output_single<T: Sized>(
&self,
index: usize,
out_buffer: &mut [T],
) -> Result<usize, WasiNnError> {
self.context.get_output_single(index, out_buffer)
}
pub fn finish_single(&mut self) -> Result<(), WasiNnError> {
self.context.fini_single()
}
}
impl<M: BaseMetadata + serde::Serialize + Clone + Default> Drop for Graph<M> {
fn drop(&mut self) {
if let Err(e) = self.graph.unload() {
let err_msg = format!("Failed to unload the wasi-nn graph. Reason: {e}");
#[cfg(feature = "logging")]
error!(target: "stdout", "{err_msg}");
#[cfg(not(feature = "logging"))]
eprintln!("{}", err_msg);
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum EngineType {
Ggml,
Whisper,
Piper,
}