#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
pub mod error;
pub mod params;
#[cfg(feature = "audio")]
use endpoints::audio::{transcription::TranscriptionObject, translation::TranslationObject};
#[cfg(feature = "image")]
use endpoints::images::{ImageCreateRequestBuilder, ImageObject, ListImagesResponse};
#[cfg(feature = "rag")]
use endpoints::{
chat::ChatCompletionRequestBuilder,
embeddings::{ChunksRequest, ChunksResponse},
rag::RetrieveObject,
};
use endpoints::{
chat::{
ChatCompletionObject, ChatCompletionRequest, ChatCompletionRequestMessage, StreamOptions,
},
embeddings::{EmbeddingRequest, EmbeddingsResponse, InputText},
files::FileObject,
models::{ListModelsResponse, Model},
};
use error::LlamaEdgeError;
use futures::{stream::TryStream, StreamExt};
#[cfg(feature = "rag")]
use params::RagChatParams;
use params::{ChatParams, EmbeddingsParams};
#[cfg(feature = "image")]
use params::{ImageCreateParams, ImageEditParams};
#[cfg(feature = "audio")]
use params::{TranscriptionParams, TranslationParams};
use reqwest::multipart;
use std::path::Path;
use url::Url;
pub struct Client {
server_base_url: Url,
}
impl Client {
pub fn new(server_base_url: impl AsRef<str>) -> Result<Self, LlamaEdgeError> {
let url_str = server_base_url.as_ref().trim_end_matches('/');
match Url::parse(url_str) {
Ok(url) => Ok(Self {
server_base_url: url,
}),
Err(e) => Err(LlamaEdgeError::UrlParse(e)),
}
}
pub fn server_base_url(&self) -> &Url {
&self.server_base_url
}
pub async fn chat(
&self,
chat_history: &[ChatCompletionRequestMessage],
params: &ChatParams,
) -> Result<String, LlamaEdgeError> {
if chat_history.is_empty() {
return Err(LlamaEdgeError::InvalidArgument(
"chat_history cannot be empty".to_string(),
));
}
let request = ChatCompletionRequest {
messages: chat_history.to_vec(),
model: params.model.clone(),
temperature: params.temperature,
top_p: params.top_p,
n_choice: params.n_choice,
stop: params.stop.clone(),
max_completion_tokens: params.max_completion_tokens,
presence_penalty: params.presence_penalty,
frequency_penalty: params.frequency_penalty,
user: params.user.clone(),
response_format: params.response_format.clone(),
tools: params.tools.clone(),
tool_choice: params.tool_choice.clone(),
..Default::default()
};
let url = self.server_base_url.join("/v1/chat/completions")?;
let response = reqwest::Client::new()
.post(url)
.json(&request)
.send()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let response_body = response
.json::<ChatCompletionObject>()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
match &response_body.choices[0].message.content {
Some(content) => Ok(content.clone()),
None => Ok("".to_string()),
}
}
pub async fn chat_stream(
&self,
chat_history: &[ChatCompletionRequestMessage],
params: &ChatParams,
) -> Result<
impl TryStream<Item = Result<String, LlamaEdgeError>, Error = LlamaEdgeError>,
LlamaEdgeError,
> {
if chat_history.is_empty() {
return Err(LlamaEdgeError::InvalidArgument(
"chat_history cannot be empty".to_string(),
));
}
let request = ChatCompletionRequest {
messages: chat_history.to_vec(),
model: params.model.clone(),
temperature: params.temperature,
top_p: params.top_p,
n_choice: params.n_choice,
stop: params.stop.clone(),
max_completion_tokens: params.max_completion_tokens,
presence_penalty: params.presence_penalty,
frequency_penalty: params.frequency_penalty,
user: params.user.clone(),
response_format: params.response_format.clone(),
tools: params.tools.clone(),
tool_choice: params.tool_choice.clone(),
stream: Some(true),
stream_options: Some(StreamOptions {
include_usage: Some(true),
}),
..Default::default()
};
let url = self.server_base_url.join("/v1/chat/completions")?;
let response = reqwest::Client::new()
.post(url)
.json(&request)
.send()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let stream = response.bytes_stream().map(|r| match r {
Ok(bytes) => Ok(String::from_utf8_lossy(&bytes).to_string()),
Err(e) => Err(LlamaEdgeError::Operation(e.to_string())),
});
Ok(stream)
}
pub async fn upload_file(&self, file: impl AsRef<Path>) -> Result<FileObject, LlamaEdgeError> {
let abs_file_path = if file.as_ref().is_absolute() {
file.as_ref().to_path_buf()
} else {
std::env::current_dir().unwrap().join(file.as_ref())
};
if !abs_file_path.exists() {
return Err(LlamaEdgeError::InvalidArgument(
"The file does not exist".to_string(),
));
}
let filename = abs_file_path
.file_name()
.unwrap()
.to_str()
.unwrap()
.to_string();
let file_extension = abs_file_path
.extension()
.unwrap()
.to_str()
.unwrap()
.to_string();
let file = tokio::fs::read(abs_file_path)
.await
.map_err(|e| LlamaEdgeError::Operation(format!("Failed to read audio file: {}", e)))?;
let file_part = multipart::Part::bytes(file)
.file_name(filename)
.mime_str(&format!("audio/{}", file_extension))
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let form = multipart::Form::new().part("file", file_part);
let url = self.server_base_url.join("/v1/files")?;
let response = reqwest::Client::new()
.post(url)
.multipart(form)
.send()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let file_object = response
.json::<FileObject>()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
Ok(file_object)
}
pub async fn models(&self) -> Result<Vec<Model>, LlamaEdgeError> {
let url = self.server_base_url.join("/v1/models")?;
let response = reqwest::Client::new()
.get(url)
.send()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let list_models_response = response
.json::<ListModelsResponse>()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
Ok(list_models_response.data)
}
pub async fn embeddings(
&self,
input: InputText,
params: EmbeddingsParams,
) -> Result<EmbeddingsResponse, LlamaEdgeError> {
let url = self.server_base_url.join("/v1/embeddings")?;
let request = EmbeddingRequest {
input,
model: params.model,
encoding_format: Some(params.encoding_format),
user: params.user,
#[cfg(feature = "rag")]
vdb_server_url: params.vdb_server_url,
#[cfg(feature = "rag")]
vdb_collection_name: params.vdb_collection_name,
#[cfg(feature = "rag")]
vdb_api_key: params.vdb_api_key,
};
let response = reqwest::Client::new()
.post(url)
.json(&request)
.send()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let embeddings_response = response
.json::<EmbeddingsResponse>()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
Ok(embeddings_response)
}
#[cfg(feature = "audio")]
pub async fn transcribe(
&self,
audio_file: impl AsRef<Path>,
spoken_language: impl AsRef<str>,
params: TranscriptionParams,
) -> Result<TranscriptionObject, LlamaEdgeError> {
let abs_file_path = if audio_file.as_ref().is_absolute() {
audio_file.as_ref().to_path_buf()
} else {
std::env::current_dir().unwrap().join(audio_file.as_ref())
};
if !abs_file_path.exists() {
let error_message =
format!("The audio file does not exist: {}", abs_file_path.display());
return Err(LlamaEdgeError::InvalidArgument(error_message));
}
let filename = abs_file_path
.file_name()
.unwrap()
.to_str()
.unwrap()
.to_string();
let file_extension = abs_file_path
.extension()
.unwrap()
.to_str()
.unwrap()
.to_string();
let file = tokio::fs::read(abs_file_path).await.map_err(|e| {
LlamaEdgeError::Operation(format!("Failed to read the audio file: {}", e))
})?;
let form = {
let file_part = multipart::Part::bytes(file)
.file_name(filename)
.mime_str(&format!("audio/{}", file_extension))
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let language = if spoken_language.as_ref().is_empty() {
"en".to_string()
} else {
spoken_language.as_ref().to_string()
};
let language_part = multipart::Part::text(language)
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let response_format_part = multipart::Part::text(params.response_format)
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let temperature_part = multipart::Part::text(params.temperature.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let detect_language_part = multipart::Part::text(params.detect_language.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let offset_time_part = multipart::Part::text(params.offset_time.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let duration_part = multipart::Part::text(params.duration.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let max_context_part = multipart::Part::text(params.max_context.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let max_len_part = multipart::Part::text(params.max_len.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let split_on_word_part = multipart::Part::text(params.split_on_word.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let use_new_context_part = multipart::Part::text(params.use_new_context.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let mut form = multipart::Form::new()
.part("file", file_part)
.part("language", language_part)
.part("response_format", response_format_part)
.part("temperature", temperature_part)
.part("detect_language", detect_language_part)
.part("offset_time", offset_time_part)
.part("duration", duration_part)
.part("max_context", max_context_part)
.part("max_len", max_len_part)
.part("split_on_word", split_on_word_part)
.part("use_new_context", use_new_context_part);
if let Some(model) = ¶ms.model {
let model_part = multipart::Part::text(model.clone())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
form = form.part("model", model_part);
}
if let Some(prompt) = ¶ms.prompt {
let prompt_part = multipart::Part::text(prompt.clone())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
form = form.part("prompt", prompt_part);
}
form
};
let url = self.server_base_url.join("/v1/audio/transcriptions")?;
let response = reqwest::Client::new()
.post(url)
.multipart(form)
.send()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let transcription_object = response
.json::<TranscriptionObject>()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
Ok(transcription_object)
}
#[cfg(feature = "audio")]
pub async fn translate(
&self,
audio_file: impl AsRef<Path>,
spoken_language: impl AsRef<str>,
params: TranslationParams,
) -> Result<TranslationObject, LlamaEdgeError> {
let abs_file_path = if audio_file.as_ref().is_absolute() {
audio_file.as_ref().to_path_buf()
} else {
std::env::current_dir().unwrap().join(audio_file.as_ref())
};
if !abs_file_path.exists() {
let error_message =
format!("The audio file does not exist: {}", abs_file_path.display());
return Err(LlamaEdgeError::InvalidArgument(error_message));
}
let filename = abs_file_path
.file_name()
.unwrap()
.to_str()
.unwrap()
.to_string();
let file_extension = abs_file_path
.extension()
.unwrap()
.to_str()
.unwrap()
.to_string();
let file = tokio::fs::read(abs_file_path)
.await
.map_err(|e| LlamaEdgeError::Operation(format!("Failed to read audio file: {}", e)))?;
let form = {
let file_part = multipart::Part::bytes(file)
.file_name(filename)
.mime_str(&format!("audio/{}", file_extension))
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let response_format_part = multipart::Part::text(params.response_format)
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let language = if spoken_language.as_ref().is_empty() {
"en".to_string()
} else {
spoken_language.as_ref().to_string()
};
let language_part = multipart::Part::text(language)
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let temperature_part = multipart::Part::text(params.temperature.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let detect_language_part = multipart::Part::text(params.detect_language.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let offset_time_part = multipart::Part::text(params.offset_time.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let duration_part = multipart::Part::text(params.duration.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let max_context_part = multipart::Part::text(params.max_context.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let max_len_part = multipart::Part::text(params.max_len.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let split_on_word_part = multipart::Part::text(params.split_on_word.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let use_new_context_part = multipart::Part::text(params.use_new_context.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let mut form = multipart::Form::new()
.part("file", file_part)
.part("response_format", response_format_part)
.part("language", language_part)
.part("temperature", temperature_part)
.part("detect_language", detect_language_part)
.part("offset_time", offset_time_part)
.part("duration", duration_part)
.part("max_context", max_context_part)
.part("max_len", max_len_part)
.part("split_on_word", split_on_word_part)
.part("use_new_context", use_new_context_part);
if let Some(model) = ¶ms.model {
let model_part = multipart::Part::text(model.clone())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
form = form.part("model", model_part);
}
if let Some(prompt) = ¶ms.prompt {
let prompt_part = multipart::Part::text(prompt.clone())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
form = form.part("prompt", prompt_part);
}
form
};
let url = self.server_base_url.join("/v1/audio/translations")?;
let response = reqwest::Client::new()
.post(url)
.multipart(form)
.send()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let translation_object = response
.json::<TranslationObject>()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
Ok(translation_object)
}
#[cfg(feature = "image")]
pub async fn create_image(
&self,
prompt: impl AsRef<str>,
params: ImageCreateParams,
) -> Result<Vec<ImageObject>, LlamaEdgeError> {
let url = self.server_base_url.join("/v1/images/generations")?;
let mut builder = ImageCreateRequestBuilder::new(params.model, prompt.as_ref())
.with_number_of_images(params.n)
.with_response_format(params.response_format)
.with_cfg_scale(params.cfg_scale)
.with_sample_method(params.sample_method)
.with_steps(params.steps)
.with_image_size(params.height, params.width)
.with_control_strength(params.control_strength)
.with_seed(params.seed)
.with_strength(params.strength)
.with_scheduler(params.scheduler)
.apply_canny_preprocessor(params.apply_canny_preprocessor)
.with_style_ratio(params.style_ratio);
if let Some(negative_prompt) = params.negative_prompt {
builder = builder.with_negative_prompt(negative_prompt);
}
if let Some(user) = params.user {
builder = builder.with_user(user);
}
if let Some(control_image) = params.control_image {
builder = builder.with_control_image(control_image);
}
let request = builder.build();
let response = reqwest::Client::new()
.post(url)
.json(&request)
.send()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let list_images_response = response
.json::<ListImagesResponse>()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
Ok(list_images_response.data)
}
#[cfg(feature = "image")]
pub async fn edit_image(
&self,
image: impl AsRef<Path>,
prompt: impl AsRef<str>,
params: ImageEditParams,
) -> Result<Vec<ImageObject>, LlamaEdgeError> {
let abs_file_path = if image.as_ref().is_absolute() {
image.as_ref().to_path_buf()
} else {
std::env::current_dir().unwrap().join(image.as_ref())
};
if !abs_file_path.exists() {
let error_message =
format!("The image file does not exist: {}", abs_file_path.display());
return Err(LlamaEdgeError::InvalidArgument(error_message));
}
let filename = abs_file_path
.file_name()
.unwrap()
.to_str()
.unwrap()
.to_string();
let file_extension = abs_file_path
.extension()
.unwrap()
.to_str()
.unwrap()
.to_string();
let file = tokio::fs::read(abs_file_path).await.map_err(|e| {
LlamaEdgeError::Operation(format!("Failed to read the image file: {}", e))
})?;
let form = {
let file_part = multipart::Part::bytes(file)
.file_name(filename)
.mime_str(&format!("image/{}", file_extension))
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let prompt_part = multipart::Part::text(prompt.as_ref().to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let model_part = multipart::Part::text(params.model.clone())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let n_part = multipart::Part::text(params.n.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let response_format_part = multipart::Part::text(params.response_format.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let cfg_scale_part = multipart::Part::text(params.cfg_scale.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let sample_method_part = multipart::Part::text(params.sample_method.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let steps_part = multipart::Part::text(params.steps.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let height_part = multipart::Part::text(params.height.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let width_part = multipart::Part::text(params.width.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let control_strength_part = multipart::Part::text(params.control_strength.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let seed_part = multipart::Part::text(params.seed.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let strength_part = multipart::Part::text(params.strength.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let scheduler_part = multipart::Part::text(params.scheduler.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let apply_canny_preprocessor_part =
multipart::Part::text(params.apply_canny_preprocessor.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let style_ratio_part = multipart::Part::text(params.style_ratio.to_string())
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let mut form = multipart::Form::new()
.part("file", file_part)
.part("prompt", prompt_part)
.part("model", model_part)
.part("n", n_part)
.part("response_format", response_format_part)
.part("cfg_scale", cfg_scale_part)
.part("sample_method", sample_method_part)
.part("steps", steps_part)
.part("height", height_part)
.part("width", width_part)
.part("control_strength", control_strength_part)
.part("seed", seed_part)
.part("strength", strength_part)
.part("scheduler", scheduler_part)
.part("apply_canny_preprocessor", apply_canny_preprocessor_part)
.part("style_ratio", style_ratio_part);
if let Some(user) = params.user {
let user_part = multipart::Part::text(user)
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
form = form.part("user", user_part);
}
if let Some(negative_prompt) = params.negative_prompt {
let negative_prompt_part = multipart::Part::text(negative_prompt)
.mime_str("text/plain")
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
form = form.part("negative_prompt", negative_prompt_part);
}
if let Some(mask) = params.mask {
let abs_mask_file_path = if mask.is_absolute() {
mask.to_path_buf()
} else {
std::env::current_dir().unwrap().join(mask)
};
if !abs_mask_file_path.exists() {
let error_message = format!(
"The mask image file does not exist: {}",
abs_mask_file_path.display()
);
return Err(LlamaEdgeError::InvalidArgument(error_message));
}
let mask_filename = abs_mask_file_path
.file_name()
.unwrap()
.to_str()
.unwrap()
.to_string();
let mask_file_extension = abs_mask_file_path
.extension()
.unwrap()
.to_str()
.unwrap()
.to_string();
let mask_file = tokio::fs::read(abs_mask_file_path).await.map_err(|e| {
LlamaEdgeError::Operation(format!("Failed to read the image file: {}", e))
})?;
let mask_file_part = multipart::Part::bytes(mask_file)
.file_name(mask_filename)
.mime_str(&format!("image/{}", mask_file_extension))
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
form = form.part("mask", mask_file_part);
}
if let Some(control_image) = params.control_image {
let abs_control_image_file_path = if control_image.is_absolute() {
control_image.to_path_buf()
} else {
std::env::current_dir().unwrap().join(control_image)
};
if !abs_control_image_file_path.exists() {
let error_message = format!(
"The control image file does not exist: {}",
abs_control_image_file_path.display()
);
return Err(LlamaEdgeError::InvalidArgument(error_message));
}
let control_image_filename = abs_control_image_file_path
.file_name()
.unwrap()
.to_str()
.unwrap()
.to_string();
let control_image_file_extension = abs_control_image_file_path
.extension()
.unwrap()
.to_str()
.unwrap()
.to_string();
let control_image_file = tokio::fs::read(abs_control_image_file_path)
.await
.map_err(|e| {
LlamaEdgeError::Operation(format!("Failed to read the image file: {}", e))
})?;
let control_image_file_part = multipart::Part::bytes(control_image_file)
.file_name(control_image_filename)
.mime_str(&format!("image/{}", control_image_file_extension))
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
form = form.part("control_image", control_image_file_part);
}
form
};
let url = self.server_base_url.join("/v1/images/edits")?;
let response = reqwest::Client::new()
.post(url)
.multipart(form)
.send()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let list_images_response = response
.json::<ListImagesResponse>()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
Ok(list_images_response.data)
}
#[cfg(feature = "rag")]
pub async fn rag_retrieve_context(
&self,
chat_history: &[ChatCompletionRequestMessage],
params: RagChatParams,
) -> Result<Vec<RetrieveObject>, LlamaEdgeError> {
let url = self.server_base_url.join("/v1/retrieve")?;
let mut builder = ChatCompletionRequestBuilder::new(chat_history)
.with_n_choices(params.n_choice)
.with_max_completion_tokens(params.max_completion_tokens)
.with_presence_penalty(params.presence_penalty)
.with_frequency_penalty(params.frequency_penalty)
.with_rag_context_window(params.context_window);
if let Some(model) = params.model {
builder = builder.with_model(model);
}
if let Some(user) = params.user {
builder = builder.with_user(user);
}
if let Some(response_format) = params.response_format {
builder = builder.with_reponse_format(response_format);
}
if let Some(tools) = params.tools {
builder = builder.with_tools(tools);
}
if let Some(tool_choice) = params.tool_choice {
builder = builder.with_tool_choice(tool_choice);
}
if let Some(vdb_config) = params.vdb_config {
builder = builder.with_rag_vdb_settings(
vdb_config.server_url,
vdb_config.collection_name,
vdb_config.limit,
vdb_config.score_threshold,
vdb_config.api_key,
);
}
let mut request = builder.build();
request.temperature = Some(params.temperature);
request.top_p = Some(params.top_p);
let response = reqwest::Client::new()
.post(url)
.json(&request)
.send()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let rag_context_response = response
.json::<Vec<RetrieveObject>>()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
Ok(rag_context_response)
}
#[cfg(feature = "rag")]
pub async fn rag_chunk_file(
&self,
file_path: impl AsRef<Path>,
chunk_capacity: usize,
) -> Result<ChunksResponse, LlamaEdgeError> {
let url = self.server_base_url.join("/v1/chunks")?;
let fo = self.upload_file(file_path.as_ref()).await?;
let chunks_request = ChunksRequest {
id: fo.id,
filename: fo.filename,
chunk_capacity,
};
let response = reqwest::Client::new()
.post(url)
.json(&chunks_request)
.send()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
let chunks_response = response
.json::<ChunksResponse>()
.await
.map_err(|e| LlamaEdgeError::Operation(e.to_string()))?;
Ok(chunks_response)
}
}