use serde::de::DeserializeOwned;
use crate::backend::{LLMClient, MediaFile};
use crate::error::Result;
use crate::model::Instructor;
pub struct Request<'a, C: ?Sized> {
client: &'a C,
system: Option<String>,
media: Vec<MediaFile>,
#[cfg(feature = "tools")]
tools: Option<&'a crate::backend::tools::Toolbox>,
#[cfg(feature = "tools")]
max_iterations: usize,
}
impl<'a, C: ?Sized> Request<'a, C> {
fn new(client: &'a C) -> Self {
Self {
client,
system: None,
media: Vec::new(),
#[cfg(feature = "tools")]
tools: None,
#[cfg(feature = "tools")]
max_iterations: crate::backend::tools::DEFAULT_MAX_TOOL_ITERATIONS,
}
}
#[must_use]
pub fn system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
#[must_use]
pub fn media(mut self, media: impl Into<Vec<MediaFile>>) -> Self {
self.media = media.into();
self
}
#[cfg(feature = "tools")]
#[must_use]
pub fn tools(mut self, toolbox: &'a crate::backend::tools::Toolbox) -> Self {
self.tools = Some(toolbox);
self
}
#[cfg(feature = "tools")]
#[must_use]
pub fn max_iterations(mut self, max_iterations: usize) -> Self {
self.max_iterations = max_iterations;
self
}
fn combined(&self, prompt: &str) -> String {
match &self.system {
Some(system) => format!("{system}\n\n{prompt}"),
None => prompt.to_string(),
}
}
}
impl<C: LLMClient + Sync + ?Sized> Request<'_, C> {
pub async fn materialize<T>(self, prompt: &str) -> Result<T>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
let prompt = self.combined(prompt);
if self.media.is_empty() {
self.client.materialize(&prompt).await
} else {
self.client
.materialize_with_media(&prompt, &self.media)
.await
}
}
pub async fn generate(self, prompt: &str) -> Result<String> {
self.client.generate(&self.combined(prompt)).await
}
}
#[cfg(feature = "streaming")]
impl<'a, C: LLMClient + Sync + ?Sized> Request<'a, C> {
pub fn materialize_iter<T>(self, prompt: &str) -> crate::backend::streaming::ItemStream<'a, T>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
use futures_util::StreamExt;
let combined = self.combined(prompt);
let client = self.client;
Box::pin(async_stream::try_stream! {
let mut inner = client.materialize_iter::<T>(&combined);
while let Some(item) = inner.next().await {
yield item?;
}
})
}
pub fn generate_stream(self, prompt: &str) -> crate::backend::streaming::TextStream<'a> {
use futures_util::StreamExt;
let combined = self.combined(prompt);
let client = self.client;
Box::pin(async_stream::try_stream! {
let mut inner = client.generate_stream(&combined);
while let Some(chunk) = inner.next().await {
yield chunk?;
}
})
}
pub fn materialize_stream<T>(
self,
prompt: &str,
) -> crate::backend::streaming::ObjectStream<'a, T>
where
T: Instructor + DeserializeOwned + Send + 'static,
{
use futures_util::StreamExt;
let combined = self.combined(prompt);
let client = self.client;
Box::pin(async_stream::try_stream! {
let mut inner = client.materialize_stream::<T>(&combined);
while let Some(obj) = inner.next().await {
yield obj?;
}
})
}
}
#[cfg(feature = "tools")]
impl<C: crate::backend::tools::ToolRunner + LLMClient + Sync + ?Sized> Request<'_, C> {
pub async fn run(self, prompt: &str) -> Result<String> {
match self.tools {
Some(toolbox) => {
self.client
.run_tool_loop(self.system.as_deref(), prompt, toolbox, self.max_iterations)
.await
}
None => self.client.generate(&self.combined(prompt)).await,
}
}
}
pub trait RequestExt: LLMClient {
fn request(&self) -> Request<'_, Self> {
Request::new(self)
}
fn with_system(&self, system: impl Into<String>) -> Request<'_, Self> {
Request::new(self).system(system)
}
fn with_media<'a>(&'a self, media: &'a [MediaFile]) -> Request<'a, Self> {
Request::new(self).media(media.to_vec())
}
#[cfg(feature = "tools")]
fn with_tools<'a>(&'a self, toolbox: &'a crate::backend::tools::Toolbox) -> Request<'a, Self> {
Request::new(self).tools(toolbox)
}
}
impl<C: LLMClient + ?Sized> RequestExt for C {}