use std::sync::Arc;
use crate::{chat::{chat::ChatApi, prompter::Prompter, ChatError}, chunking::Chunker, convert::{ConversionError, Converter}, embedding::{Embedder, EmbeddingError}, extension::{ActiveExtension, Extension, F11y, UseExtensionError}, storage::{self, Content, Document, Folder, Scope}};
use mime::Mime;
use thiserror::Error;
use tokio::{io::AsyncRead, sync::mpsc::{error::SendError, UnboundedReceiver, UnboundedSender}, task::JoinHandle};
use tracing::instrument;
pub mod search;
mod chat;
pub use chat::{chat, simple_rag};
pub struct Job<T, F: AsyncFnOnce(&mut Assets) -> Result<T, RunJobError> + Send> {
callback: F,
documents: Vec<Document>,
scopes: Vec<Scope>,
asset_channel: Option<(AssetSender, UnboundedReceiver<AssetItem>)>,
extensions: Arc<Vec<ActiveExtension>>,
}
impl<T, F: AsyncFnOnce(&mut Assets) -> Result<T, RunJobError> + Send> Job<T, F> {
pub fn new(callback: F) -> Self {
Self {
callback,
documents: Vec::new(),
scopes: Vec::new(),
asset_channel: None,
extensions: Arc::new(Vec::new()),
}
}
pub fn and_then<T2, C: AsyncFnOnce(&mut Assets, T) -> Result<T2, RunJobError> + Send>(self, callback: C) -> Job<T2, impl AsyncFnOnce(&mut Assets) -> Result<T2, RunJobError> + Send> {
let callback0 = self.callback;
Job {
callback: async move |assets| {
let result0 = callback0(assets).await?;
callback(assets, result0).await
},
documents: self.documents,
scopes: self.scopes,
asset_channel: self.asset_channel,
extensions: self.extensions,
}
}
pub fn and_chain<T2, F2: AsyncFnOnce(&mut Assets) -> Result<T2, RunJobError> + Send, C: FnOnce(T) -> Job<T2, F2> + Send>(self, callback: C) -> Job<T2, impl AsyncFnOnce(&mut Assets) -> Result<T2, RunJobError> + Send> {
let callback0 = self.callback;
Job {
callback: async move |assets| {
let result0 = callback0(assets).await?;
let mut next_job = callback(result0)
.with_extensions(assets.extensions.iter().cloned());
for doc in assets.documents.drain(..) {
next_job.add_document(doc);
}
next_job.run().await
},
documents: self.documents,
scopes: self.scopes,
asset_channel: self.asset_channel,
extensions: self.extensions,
}
}
pub fn and_chain_async<T2, F2: AsyncFnOnce(&mut Assets) -> Result<T2, RunJobError> + Send, C: AsyncFnOnce(T) -> Job<T2, F2> + Send>(self, callback: C) -> Job<T2, impl AsyncFnOnce(&mut Assets) -> Result<T2, RunJobError> + Send> {
let callback0 = self.callback;
Job {
callback: async move |assets| {
let result0 = callback0(assets).await?;
let mut next_job = callback(result0).await
.with_extensions(assets.extensions.iter().cloned());
for doc in assets.documents.drain(..) {
next_job.add_document(doc);
}
next_job.run().await
},
documents: self.documents,
scopes: self.scopes,
asset_channel: self.asset_channel,
extensions: self.extensions,
}
}
pub fn with_extensions<I: Iterator<Item = ActiveExtension>>(self, extensions: I) -> Self {
Job {
callback: self.callback,
documents: self.documents,
scopes: self.scopes,
asset_channel: None,
extensions: Arc::new(extensions.collect()),
}
}
pub fn add_document(&mut self, document: Document) -> &mut Self {
self.documents.push(document);
self
}
pub async fn add_folder(&mut self, folder: Folder) -> Result<&mut Self, storage::Error> {
for doc in folder.list_documents().await? {
self.add_document(doc);
}
for folder in folder.list_folders().await? {
Box::pin(self.add_folder(folder)).await?;
}
Ok(self)
}
pub fn documents(&self) -> &[Document] {
&self.documents
}
pub fn scopes(&self) -> &[Scope] {
&self.scopes
}
pub fn asset_sender(&mut self) -> AssetSender {
if let Some(sender) = &self.asset_channel {
return sender.0.clone();
}
let (sender, receiver) = tokio::sync::mpsc::unbounded_channel::<AssetItem>();
let sender = AssetSender {
inner: sender,
};
self.asset_channel = Some((sender.clone(), receiver));
sender
}
pub async fn run(mut self) -> Result<T, RunJobError> {
let mut assets = Assets {
documents: self.documents,
scopes: self.scopes,
extensions: self.extensions,
asset_channel: self.asset_channel.take(),
};
(self.callback)(&mut assets).await
}
}
pub struct Assets {
documents: Vec<Document>,
scopes: Vec<Scope>,
extensions: Arc<Vec<ActiveExtension>>,
asset_channel: Option<(AssetSender, UnboundedReceiver<AssetItem>)>,
}
impl Assets {
#[instrument(skip(self))]
pub fn refresh(&mut self) -> Refresh<'_> {
let doc_idx = self.documents.len();
let scope_idx = self.scopes.len();
if let Some((_, receiver)) = &mut self.asset_channel {
while let Ok(item) = receiver.try_recv() {
match item {
AssetItem::Document(document) => self.documents.push(document),
AssetItem::Scope(scope) => self.scopes.push(scope),
}
}
}
tracing::debug!("Added {} documents", self.documents.len() - doc_idx);
tracing::debug!("Added {} scopes", self.scopes.len() - scope_idx);
Refresh {
assets: self,
doc_idx,
scope_idx,
}
}
pub fn documents(&self) -> &[Document] {
&self.documents
}
pub fn scopes(&self) -> &[Scope] {
&self.scopes
}
pub fn extensions(&self) -> &Vec<ActiveExtension> {
&self.extensions
}
pub fn asset_sender(&mut self) -> AssetSender {
if let Some(sender) = &self.asset_channel {
return sender.0.clone();
}
let (sender, receiver) = tokio::sync::mpsc::unbounded_channel::<AssetItem>();
let sender = AssetSender {
inner: sender,
};
self.asset_channel = Some((sender.clone(), receiver));
sender
}
pub async fn convert(&self, input: Content, output_type: Mime) -> Result<Vec<Box<dyn AsyncRead + Unpin>>, ConversionError> {
tracing::debug!("Converting content to {}", output_type);
let converters = self.extensions.iter()
.filter_map(|ext|
if let Some(converter) = ext.converters().nth(0) {
Some(converter)
} else {
None
}
)
.collect::<Vec<_>>();
tracing::debug!("Found {} converters", converters.len());
for c in converters {
match c.convert(input.clone(), output_type.clone()).await {
Ok(result) => return Ok(result),
Err(e) => match e {
ConversionError::UnsupportedMimeType(_) => continue,
_ => return Err(e),
}
}
}
Err(ConversionError::UnsupportedMimeType(output_type))
}
pub async fn chat_model(&self, model: Option<String>) -> Result<F11y<dyn ChatApi>, ChatError> {
tracing::debug!("Getting chat model");
for ext in self.extensions.iter() {
tracing::debug!("Checking extension {}", ext.name());
if let Some(chat_client) = ext.chat_providers().nth(0) {
tracing::debug!("Found chat model in extension {}", ext.name());
if let Some(requested_model) = &model {
tracing::debug!("Looking for model {}", requested_model);
for model in chat_client.list_models().await.map_err(|e| ChatError::Provider(Box::new(e)))? {
if *model.id == *requested_model {
tracing::debug!("Found model {}", requested_model);
return Ok(chat_client);
}
}
} else {
tracing::debug!("No model specified, returning default model");
return Ok(chat_client);
}
}
}
Err(ChatError::Provider("No chat model found".into()))
}
pub fn embedders(&self) -> Vec<F11y<dyn Embedder>> {
tracing::debug!("Getting embedders");
let mut embedders = Vec::new();
for ext in self.extensions.iter() {
if let Some(embedder) = ext.embedders().nth(0) {
embedders.push(embedder);
}
}
embedders
}
pub fn chunkers(&self) -> Vec<F11y<dyn Chunker>> {
tracing::debug!("Getting chunkers");
let mut chunkers = Vec::new();
for ext in self.extensions.iter() {
if let Some(chunker) = ext.chunkers().nth(0) {
chunkers.push(chunker);
}
}
chunkers
}
#[instrument(skip(self))]
pub fn prompters(&self) -> Vec<F11y<dyn Prompter>> {
tracing::debug!("Getting prompters");
let mut prompters = Vec::new();
for ext in self.extensions.iter() {
let len_before = prompters.len();
prompters.extend(ext.prompters());
let len_after = prompters.len();
tracing::debug!("Found {} prompters in extension {}", len_after - len_before, ext.name());
}
prompters
}
}
pub struct Refresh<'a> {
assets: &'a Assets,
doc_idx: usize,
scope_idx: usize,
}
impl<'a> Refresh<'a> {
pub fn documents(self) -> impl Iterator<Item = Document> {
self.filter_map(|item| {
if let AssetItem::Document(doc) = item {
Some(doc)
} else {
None
}
})
}
pub fn scopes(self) -> impl Iterator<Item = Scope> {
self.filter_map(|item| {
if let AssetItem::Scope(scope) = item {
Some(scope)
} else {
None
}
})
}
}
impl<'a> Iterator for Refresh<'a> {
type Item = AssetItem;
fn next(&mut self) -> Option<Self::Item> {
if self.doc_idx < self.assets.documents.len() {
let item = AssetItem::Document(self.assets.documents[self.doc_idx].clone());
self.doc_idx += 1;
Some(item)
} else if self.scope_idx < self.assets.scopes.len() {
let item = AssetItem::Scope(self.assets.scopes[self.scope_idx].clone());
self.scope_idx += 1;
Some(item)
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct AssetSender {
inner: UnboundedSender<AssetItem>,
}
impl AssetSender {
pub fn send_document(&self, document: Document) -> Result<(), SendError<Document>> {
self.inner.send(AssetItem::Document(document)).map_err(|e| match e.0 {
AssetItem::Document(document) => SendError(document),
_ => unreachable!(),
})
}
pub fn send_scope(&self, scope: Scope) -> Result<(), SendError<Scope>> {
self.inner.send(AssetItem::Scope(scope)).map_err(|e| match e.0 {
AssetItem::Scope(folder) => SendError(folder),
_ => unreachable!(),
})
}
}
pub enum AssetItem {
Document(Document),
Scope(Scope),
}
#[derive(Debug, Error)]
pub enum RunJobError {
#[error("Job failed due to extension error: {0}")]
Extension(#[from] UseExtensionError),
#[error("Job failed due to chat error: {0}")]
Chat(#[from] ChatError),
#[error("Job failed due to embedding error: {0}")]
Embedding(#[from] EmbeddingError),
#[error("Job failed due to conversion error: {0}")]
Conversion(#[from] ConversionError),
#[error("Job failed due to prompt error: {0}")]
Prompt(#[from] crate::chat::prompter::PromptError),
#[error("Job failed due to storage error: {0}")]
Storage(#[from] storage::Error),
#[error("Job failed: {0}")]
Other(Box<dyn std::error::Error + Send + Sync>),
}