#[cfg(feature = "logging")]
#[macro_use]
extern crate log;
pub mod audio;
pub mod chat;
pub mod completions;
pub mod embeddings;
pub mod error;
pub mod files;
pub mod graph;
pub mod images;
pub mod metadata;
pub mod models;
pub mod rag;
#[cfg(feature = "search")]
pub mod search;
pub mod utils;
pub use error::LlamaCoreError;
pub use graph::{EngineType, Graph, GraphBuilder};
pub use metadata::{
ggml::GgmlMetadata, piper::PiperMetadata, whisper::WhisperMetadata, BaseMetadata,
};
use once_cell::sync::OnceCell;
use std::{
collections::HashMap,
path::Path,
sync::{Mutex, RwLock},
};
use utils::get_output_buffer;
use wasmedge_stable_diffusion::*;
pub(crate) static CHAT_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
OnceCell::new();
pub(crate) static EMBEDDING_GRAPHS: OnceCell<Mutex<HashMap<String, Graph<GgmlMetadata>>>> =
OnceCell::new();
pub(crate) static CACHED_UTF8_ENCODINGS: OnceCell<Mutex<Vec<u8>>> = OnceCell::new();
pub(crate) static RUNNING_MODE: OnceCell<RwLock<RunningMode>> = OnceCell::new();
pub(crate) static SD_TEXT_TO_IMAGE: OnceCell<Mutex<TextToImage>> = OnceCell::new();
pub(crate) static SD_IMAGE_TO_IMAGE: OnceCell<Mutex<ImageToImage>> = OnceCell::new();
pub(crate) static AUDIO_GRAPH: OnceCell<Mutex<Graph<WhisperMetadata>>> = OnceCell::new();
pub(crate) static PIPER_GRAPH: OnceCell<Mutex<Graph<PiperMetadata>>> = OnceCell::new();
pub(crate) const MAX_BUFFER_SIZE: usize = 2usize.pow(14) * 15 + 128;
pub(crate) const OUTPUT_TENSOR: usize = 0;
const PLUGIN_VERSION: usize = 1;
pub const ARCHIVES_DIR: &str = "archives";
pub fn init_ggml_context(
metadata_for_chats: Option<&[GgmlMetadata]>,
metadata_for_embeddings: Option<&[GgmlMetadata]>,
) -> Result<(), LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Initializing the core context");
if metadata_for_chats.is_none() && metadata_for_embeddings.is_none() {
let err_msg = "Failed to initialize the core context. Please set metadata for chat completions and/or embeddings.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::InitContext(err_msg.into()));
}
let mut mode = RunningMode::Embeddings;
if let Some(metadata_chats) = metadata_for_chats {
let mut chat_graphs = HashMap::new();
for metadata in metadata_chats {
let graph = Graph::new(metadata.clone())?;
chat_graphs.insert(graph.name().to_string(), graph);
}
CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg.into())
})?;
mode = RunningMode::Chat
}
if let Some(metadata_embeddings) = metadata_for_embeddings {
let mut embedding_graphs = HashMap::new();
for metadata in metadata_embeddings {
let graph = Graph::new(metadata.clone())?;
embedding_graphs.insert(graph.name().to_string(), graph);
}
EMBEDDING_GRAPHS
.set(Mutex::new(embedding_graphs))
.map_err(|_| {
let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg.into())
})?;
if mode == RunningMode::Chat {
mode = RunningMode::ChatEmbedding;
}
}
#[cfg(feature = "logging")]
info!(target: "stdout", "running mode: {}", mode);
RUNNING_MODE.set(RwLock::new(mode)).map_err(|_| {
let err_msg = "Failed to initialize the core context. Reason: The `RUNNING_MODE` has already been initialized";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg.into())
})?;
#[cfg(feature = "logging")]
info!(target: "stdout", "The core context has been initialized");
Ok(())
}
pub fn init_ggml_rag_context(
metadata_for_chats: &[GgmlMetadata],
metadata_for_embeddings: &[GgmlMetadata],
) -> Result<(), LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Initializing the core context for RAG scenarios");
if metadata_for_chats.is_empty() {
let err_msg = "The metadata for chat models is empty";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::InitContext(err_msg.into()));
}
let mut chat_graphs = HashMap::new();
for metadata in metadata_for_chats {
let graph = Graph::new(metadata.clone())?;
chat_graphs.insert(graph.name().to_string(), graph);
}
CHAT_GRAPHS.set(Mutex::new(chat_graphs)).map_err(|_| {
let err_msg = "Failed to initialize the core context. Reason: The `CHAT_GRAPHS` has already been initialized";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg.into())
})?;
if metadata_for_embeddings.is_empty() {
let err_msg = "The metadata for embeddings is empty";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::InitContext(err_msg.into()));
}
let mut embedding_graphs = HashMap::new();
for metadata in metadata_for_embeddings {
let graph = Graph::new(metadata.clone())?;
embedding_graphs.insert(graph.name().to_string(), graph);
}
EMBEDDING_GRAPHS
.set(Mutex::new(embedding_graphs))
.map_err(|_| {
let err_msg = "Failed to initialize the core context. Reason: The `EMBEDDING_GRAPHS` has already been initialized";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg.into())
})?;
let running_mode = RunningMode::Rag;
#[cfg(feature = "logging")]
info!(target: "stdout", "running mode: {}", running_mode);
RUNNING_MODE.set(RwLock::new(running_mode)).map_err(|_| {
let err_msg = "Failed to initialize the core context. Reason: The `RUNNING_MODE` has already been initialized";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg.into())
})?;
#[cfg(feature = "logging")]
info!(target: "stdout", "The core context for RAG scenarios has been initialized");
Ok(())
}
pub fn get_plugin_info() -> Result<PluginInfo, LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Getting the plugin info");
match running_mode()? {
RunningMode::Embeddings => {
let embedding_graphs = match EMBEDDING_GRAPHS.get() {
Some(embedding_graphs) => embedding_graphs,
None => {
let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::Operation(err_msg.into()));
}
};
let embedding_graphs = embedding_graphs.lock().map_err(|e| {
let err_msg = format!("Fail to acquire the lock of `EMBEDDING_GRAPHS`. {}", e);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let graph = match embedding_graphs.values().next() {
Some(graph) => graph,
None => {
let err_msg = "Fail to get the underlying value of `EMBEDDING_GRAPHS`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::Operation(err_msg.into()));
}
};
get_plugin_info_by_graph(graph)
}
_ => {
let chat_graphs = match CHAT_GRAPHS.get() {
Some(chat_graphs) => chat_graphs,
None => {
let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::Operation(err_msg.into()));
}
};
let chat_graphs = chat_graphs.lock().map_err(|e| {
let err_msg = format!("Fail to acquire the lock of `CHAT_GRAPHS`. {}", e);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let graph = match chat_graphs.values().next() {
Some(graph) => graph,
None => {
let err_msg = "Fail to get the underlying value of `CHAT_GRAPHS`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::Operation(err_msg.into()));
}
};
get_plugin_info_by_graph(graph)
}
}
}
fn get_plugin_info_by_graph<M: BaseMetadata + serde::Serialize + Clone + Default>(
graph: &Graph<M>,
) -> Result<PluginInfo, LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Getting the plugin info by the graph named {}", graph.name());
let output_buffer = get_output_buffer(graph, PLUGIN_VERSION)?;
let metadata: serde_json::Value = serde_json::from_slice(&output_buffer[..]).map_err(|e| {
let err_msg = format!("Fail to deserialize the plugin metadata. {}", e);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::Operation(err_msg)
})?;
let plugin_build_number = match metadata.get("llama_build_number") {
Some(value) => match value.as_u64() {
Some(number) => number,
None => {
let err_msg = "Failed to convert the build number of the plugin to u64";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::Operation(err_msg.into()));
}
},
None => {
let err_msg = "Metadata does not have the field `llama_build_number`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::Operation(err_msg.into()));
}
};
let plugin_commit = match metadata.get("llama_commit") {
Some(value) => match value.as_str() {
Some(commit) => commit,
None => {
let err_msg = "Failed to convert the commit id of the plugin to string";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::Operation(err_msg.into()));
}
},
None => {
let err_msg = "Metadata does not have the field `llama_commit`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::Operation(err_msg.into()));
}
};
#[cfg(feature = "logging")]
info!(target: "stdout", "Plugin info: b{}(commit {})", plugin_build_number, plugin_commit);
Ok(PluginInfo {
build_number: plugin_build_number,
commit_id: plugin_commit.to_string(),
})
}
#[derive(Debug, Clone)]
pub struct PluginInfo {
pub build_number: u64,
pub commit_id: String,
}
impl std::fmt::Display for PluginInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"wasinn-ggml plugin: b{}(commit {})",
self.build_number, self.commit_id
)
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum RunningMode {
Chat,
Embeddings,
ChatEmbedding,
Rag,
}
impl std::fmt::Display for RunningMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RunningMode::Chat => write!(f, "chat"),
RunningMode::Embeddings => write!(f, "embeddings"),
RunningMode::ChatEmbedding => write!(f, "chat-embeddings"),
RunningMode::Rag => write!(f, "rag"),
}
}
}
pub fn running_mode() -> Result<RunningMode, LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Get the running mode.");
let mode = match RUNNING_MODE.get() {
Some(mode) => match mode.read() {
Ok(mode) => mode.to_owned(),
Err(e) => {
let err_msg = format!("Fail to get the underlying value of `RUNNING_MODE`. {}", e);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::Operation(err_msg));
}
},
None => {
let err_msg = "Fail to get the underlying value of `RUNNING_MODE`.";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::Operation(err_msg.into()));
}
};
#[cfg(feature = "logging")]
info!(target: "stdout", "running mode: {}", &mode);
Ok(mode.to_owned())
}
#[allow(clippy::too_many_arguments)]
pub fn init_sd_context_with_full_model(
model_file: impl AsRef<str>,
lora_model_dir: Option<&str>,
controlnet_path: Option<&str>,
controlnet_on_cpu: bool,
clip_on_cpu: bool,
vae_on_cpu: bool,
n_threads: i32,
task: StableDiffusionTask,
) -> Result<(), LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Initializing the stable diffusion context with the full model");
let control_net_on_cpu = match controlnet_path {
Some(path) if !path.is_empty() => controlnet_on_cpu,
_ => false,
};
if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
let sd = SDBuidler::new(Task::TextToImage, model_file.as_ref())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.with_lora_model_dir(lora_model_dir.unwrap_or_default())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.clip_on_cpu(clip_on_cpu)
.vae_on_cpu(vae_on_cpu)
.with_n_threads(n_threads)
.build();
#[cfg(feature = "logging")]
info!(target: "stdout", "sd: {:?}", &sd);
let ctx = sd.create_context().map_err(|e| {
let err_msg = format!("Fail to create the context. {}", e);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::InitContext(err_msg)
})?;
let ctx = match ctx {
Context::TextToImage(ctx) => ctx,
_ => {
let err_msg = "Fail to get the context for the text-to-image task";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::InitContext(err_msg.into()));
}
};
#[cfg(feature = "logging")]
info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg.into())
})?;
#[cfg(feature = "logging")]
info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
}
if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
let sd = SDBuidler::new(Task::ImageToImage, model_file.as_ref())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.with_lora_model_dir(lora_model_dir.unwrap_or_default())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.clip_on_cpu(clip_on_cpu)
.vae_on_cpu(vae_on_cpu)
.with_n_threads(n_threads)
.build();
#[cfg(feature = "logging")]
info!(target: "stdout", "sd: {:?}", &sd);
let ctx = sd.create_context().map_err(|e| {
let err_msg = format!("Fail to create the context. {}", e);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::InitContext(err_msg)
})?;
let ctx = match ctx {
Context::ImageToImage(ctx) => ctx,
_ => {
let err_msg = "Fail to get the context for the image-to-image task";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::InitContext(err_msg.into()));
}
};
#[cfg(feature = "logging")]
info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg.into())
})?;
#[cfg(feature = "logging")]
info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn init_sd_context_with_standalone_model(
model_file: impl AsRef<str>,
vae: impl AsRef<str>,
clip_l: impl AsRef<str>,
t5xxl: impl AsRef<str>,
lora_model_dir: Option<&str>,
controlnet_path: Option<&str>,
controlnet_on_cpu: bool,
clip_on_cpu: bool,
vae_on_cpu: bool,
n_threads: i32,
task: StableDiffusionTask,
) -> Result<(), LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Initializing the stable diffusion context with the standalone diffusion model");
let control_net_on_cpu = match controlnet_path {
Some(path) if !path.is_empty() => controlnet_on_cpu,
_ => false,
};
if task == StableDiffusionTask::Full || task == StableDiffusionTask::TextToImage {
let sd = SDBuidler::new_with_standalone_model(Task::TextToImage, model_file.as_ref())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.with_vae_path(vae.as_ref())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.with_clip_l_path(clip_l.as_ref())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.with_t5xxl_path(t5xxl.as_ref())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.with_lora_model_dir(lora_model_dir.unwrap_or_default())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.clip_on_cpu(clip_on_cpu)
.vae_on_cpu(vae_on_cpu)
.with_n_threads(n_threads)
.build();
#[cfg(feature = "logging")]
info!(target: "stdout", "sd: {:?}", &sd);
let ctx = sd.create_context().map_err(|e| {
let err_msg = format!("Fail to create the context. {}", e);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::InitContext(err_msg)
})?;
let ctx = match ctx {
Context::TextToImage(ctx) => ctx,
_ => {
let err_msg = "Fail to get the context for the text-to-image task";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::InitContext(err_msg.into()));
}
};
#[cfg(feature = "logging")]
info!(target: "stdout", "sd text_to_image context: {:?}", &ctx);
SD_TEXT_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_TEXT_TO_IMAGE` has already been initialized";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg.into())
})?;
#[cfg(feature = "logging")]
info!(target: "stdout", "The stable diffusion text-to-image context has been initialized");
}
if task == StableDiffusionTask::Full || task == StableDiffusionTask::ImageToImage {
let sd = SDBuidler::new_with_standalone_model(Task::ImageToImage, model_file.as_ref())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.with_vae_path(vae.as_ref())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.with_clip_l_path(clip_l.as_ref())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.with_t5xxl_path(t5xxl.as_ref())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.with_lora_model_dir(lora_model_dir.unwrap_or_default())
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.use_control_net(controlnet_path.unwrap_or_default(), control_net_on_cpu)
.map_err(|e| {
let err_msg = format!(
"Failed to initialize the stable diffusion context. Reason: {}",
e
);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg)
})?
.clip_on_cpu(clip_on_cpu)
.vae_on_cpu(vae_on_cpu)
.with_n_threads(n_threads)
.build();
#[cfg(feature = "logging")]
info!(target: "stdout", "sd: {:?}", &sd);
let ctx = sd.create_context().map_err(|e| {
let err_msg = format!("Fail to create the context. {}", e);
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", &err_msg);
LlamaCoreError::InitContext(err_msg)
})?;
let ctx = match ctx {
Context::ImageToImage(ctx) => ctx,
_ => {
let err_msg = "Fail to get the context for the image-to-image task";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
return Err(LlamaCoreError::InitContext(err_msg.into()));
}
};
#[cfg(feature = "logging")]
info!(target: "stdout", "sd image_to_image context: {:?}", &ctx);
SD_IMAGE_TO_IMAGE.set(Mutex::new(ctx)).map_err(|_| {
let err_msg = "Failed to initialize the stable diffusion context. Reason: The `SD_IMAGE_TO_IMAGE` has already been initialized";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg.into())
})?;
#[cfg(feature = "logging")]
info!(target: "stdout", "The stable diffusion image-to-image context has been initialized");
}
Ok(())
}
#[derive(Clone, Debug, Copy, PartialEq, Eq)]
pub enum StableDiffusionTask {
TextToImage,
ImageToImage,
Full,
}
pub fn init_whisper_context(
whisper_metadata: &WhisperMetadata,
model_file: impl AsRef<Path>,
) -> Result<(), LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Initializing the audio context");
let graph = GraphBuilder::new(EngineType::Whisper)?
.with_config(whisper_metadata.clone())?
.use_cpu()
.build_from_files([model_file.as_ref()])?;
AUDIO_GRAPH.set(Mutex::new(graph)).map_err(|_| {
let err_msg = "Failed to initialize the audio context. Reason: The `AUDIO_GRAPH` has already been initialized";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg.into())
})?;
#[cfg(feature = "logging")]
info!(target: "stdout", "The audio context has been initialized");
Ok(())
}
pub fn init_piper_context(
piper_metadata: &PiperMetadata,
voice_model: impl AsRef<Path>,
voice_config: impl AsRef<Path>,
espeak_ng_data: impl AsRef<Path>,
) -> Result<(), LlamaCoreError> {
#[cfg(feature = "logging")]
info!(target: "stdout", "Initializing the piper context");
let config = serde_json::json!({
"model": voice_model.as_ref().to_owned(),
"config": voice_config.as_ref().to_owned(),
"espeak_data": espeak_ng_data.as_ref().to_owned(),
});
let graph = GraphBuilder::new(EngineType::Piper)?
.with_config(piper_metadata.clone())?
.use_cpu()
.build_from_buffer([config.to_string()])?;
PIPER_GRAPH.set(Mutex::new(graph)).map_err(|_| {
let err_msg = "Failed to initialize the piper context. Reason: The `PIPER_GRAPH` has already been initialized";
#[cfg(feature = "logging")]
error!(target: "stdout", "{}", err_msg);
LlamaCoreError::InitContext(err_msg.into())
})?;
#[cfg(feature = "logging")]
info!(target: "stdout", "The piper context has been initialized");
Ok(())
}