use std::sync::Arc;
use async_trait::async_trait;
use futures::stream::{BoxStream, StreamExt};
use crate::{
Chunk, CompletionRequest, LlmProvider,
error::{LlmError, LlmErrorKind},
};
#[derive(Debug)]
pub struct BoxError(
Box<dyn std::error::Error + Send + Sync + 'static>,
LlmErrorKind,
);
impl BoxError {
#[must_use]
pub fn new<E: LlmError>(err: E) -> Self {
let kind = err.kind();
Self(Box::new(err), kind)
}
}
impl LlmError for BoxError {
fn kind(&self) -> LlmErrorKind {
self.1
}
}
impl std::fmt::Display for BoxError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::error::Error for BoxError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&*self.0)
}
}
pub struct ErasedProvider<P>(P);
#[async_trait]
impl<P: LlmProvider> LlmProvider for ErasedProvider<P> {
type Error = BoxError;
async fn complete(
&self,
req: CompletionRequest,
) -> Result<BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error> {
let stream = self.0.complete(req).await.map_err(BoxError::new)?;
Ok(stream.map(|item| item.map_err(BoxError::new)).boxed())
}
}
pub type DynProvider = dyn LlmProvider<Error = BoxError>;
#[must_use]
pub fn into_dyn<P: LlmProvider>(provider: P) -> Arc<DynProvider> {
Arc::new(ErasedProvider(provider))
}
#[cfg(test)]
mod tests {
#![allow(clippy::pedantic, clippy::nursery, missing_docs)]
use futures::{StreamExt, stream};
use super::{BoxError, DynProvider, into_dyn};
use crate::{Chunk, CompletionRequest, LlmProvider, StopReason, Usage, error::DummyError};
struct DummyProvider;
#[async_trait::async_trait]
impl LlmProvider for DummyProvider {
type Error = DummyError;
async fn complete(
&self,
req: CompletionRequest,
) -> Result<futures::stream::BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error>
{
if req.messages.is_empty() {
return Err(DummyError::Other("no messages".to_owned()));
}
let chunks = vec![
Ok(Chunk::text_delta("hi")),
Ok(Chunk::Usage(Usage {
input_tokens: 1,
output_tokens: 1,
})),
Ok(Chunk::Stop(StopReason::EndTurn)),
];
Ok(stream::iter(chunks).boxed())
}
}
#[tokio::test]
async fn erased_provider_streams_to_completion() {
let provider: std::sync::Arc<DynProvider> = into_dyn(DummyProvider);
let mut req = CompletionRequest::new("m");
req.messages.push(crate::Message::user("yo"));
let stream = provider.complete(req).await.expect("stream opens");
let n = stream.count().await;
assert_eq!(n, 3);
}
#[tokio::test]
async fn erased_pre_stream_error_is_preserved() {
let provider: std::sync::Arc<DynProvider> = into_dyn(DummyProvider);
let req = CompletionRequest::new("m");
let Err(err) = provider.complete(req).await else {
panic!("expected pre-stream rejection");
};
assert_eq!(format!("{err}"), "other: no messages");
let src = std::error::Error::source(&err).expect("source present");
assert_eq!(format!("{src}"), "other: no messages");
}
#[test]
fn box_error_satisfies_llm_error() {
fn require_llm_error<E: crate::error::LlmError>() {}
require_llm_error::<BoxError>();
}
#[test]
fn box_error_preserves_kind_through_erasure() {
use crate::error::{DummyError, LlmError, LlmErrorKind};
let boxed = BoxError::new(DummyError::Provider {
status: 429,
body: String::new(),
});
assert_eq!(boxed.kind(), LlmErrorKind::RateLimit);
let other = BoxError::new(DummyError::Other("x".to_owned()));
assert_eq!(other.kind(), LlmErrorKind::Other);
}
}