1use std::sync::Arc;
15
16use async_trait::async_trait;
17use futures::stream::{BoxStream, StreamExt};
18
19use crate::{
20 Chunk, CompletionRequest, LlmProvider,
21 error::{LlmError, LlmErrorKind},
22};
23
24#[derive(Debug)]
33pub struct BoxError(
34 Box<dyn std::error::Error + Send + Sync + 'static>,
35 LlmErrorKind,
36);
37
38impl BoxError {
39 #[must_use]
42 pub fn new<E: LlmError>(err: E) -> Self {
43 let kind = err.kind();
44 Self(Box::new(err), kind)
45 }
46}
47
48impl LlmError for BoxError {
49 fn kind(&self) -> LlmErrorKind {
50 self.1
51 }
52}
53
54impl std::fmt::Display for BoxError {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 std::fmt::Display::fmt(&self.0, f)
57 }
58}
59
60impl std::error::Error for BoxError {
61 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
62 Some(&*self.0)
63 }
64}
65
66pub struct ErasedProvider<P>(P);
72
73#[async_trait]
74impl<P: LlmProvider> LlmProvider for ErasedProvider<P> {
75 type Error = BoxError;
76
77 async fn complete(
78 &self,
79 req: CompletionRequest,
80 ) -> Result<BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error> {
81 let stream = self.0.complete(req).await.map_err(BoxError::new)?;
82 Ok(stream.map(|item| item.map_err(BoxError::new)).boxed())
83 }
84}
85
86pub type DynProvider = dyn LlmProvider<Error = BoxError>;
89
90#[must_use]
95pub fn into_dyn<P: LlmProvider>(provider: P) -> Arc<DynProvider> {
96 Arc::new(ErasedProvider(provider))
97}
98
99#[cfg(test)]
100mod tests {
101 #![allow(clippy::pedantic, clippy::nursery, missing_docs)]
102
103 use futures::{StreamExt, stream};
104
105 use super::{BoxError, DynProvider, into_dyn};
106 use crate::{Chunk, CompletionRequest, LlmProvider, StopReason, Usage, error::DummyError};
107
108 struct DummyProvider;
111
112 #[async_trait::async_trait]
113 impl LlmProvider for DummyProvider {
114 type Error = DummyError;
115
116 async fn complete(
117 &self,
118 req: CompletionRequest,
119 ) -> Result<futures::stream::BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error>
120 {
121 if req.messages.is_empty() {
122 return Err(DummyError::Other("no messages".to_owned()));
123 }
124 let chunks = vec![
125 Ok(Chunk::text_delta("hi")),
126 Ok(Chunk::Usage(Usage {
127 input_tokens: 1,
128 output_tokens: 1,
129 })),
130 Ok(Chunk::Stop(StopReason::EndTurn)),
131 ];
132 Ok(stream::iter(chunks).boxed())
133 }
134 }
135
136 #[tokio::test]
137 async fn erased_provider_streams_to_completion() {
138 let provider: std::sync::Arc<DynProvider> = into_dyn(DummyProvider);
139 let mut req = CompletionRequest::new("m");
140 req.messages.push(crate::Message::user("yo"));
141
142 let stream = provider.complete(req).await.expect("stream opens");
143 let n = stream.count().await;
144 assert_eq!(n, 3);
145 }
146
147 #[tokio::test]
148 async fn erased_pre_stream_error_is_preserved() {
149 let provider: std::sync::Arc<DynProvider> = into_dyn(DummyProvider);
150 let req = CompletionRequest::new("m"); let Err(err) = provider.complete(req).await else {
154 panic!("expected pre-stream rejection");
155 };
156 assert_eq!(format!("{err}"), "other: no messages");
158 let src = std::error::Error::source(&err).expect("source present");
160 assert_eq!(format!("{src}"), "other: no messages");
161 }
162
163 #[test]
164 fn box_error_satisfies_llm_error() {
165 fn require_llm_error<E: crate::error::LlmError>() {}
166 require_llm_error::<BoxError>();
167 }
168
169 #[test]
170 fn box_error_preserves_kind_through_erasure() {
171 use crate::error::{DummyError, LlmError, LlmErrorKind};
172 let boxed = BoxError::new(DummyError::Provider {
174 status: 429,
175 body: String::new(),
176 });
177 assert_eq!(boxed.kind(), LlmErrorKind::RateLimit);
178 let other = BoxError::new(DummyError::Other("x".to_owned()));
180 assert_eq!(other.kind(), LlmErrorKind::Other);
181 }
182}