Skip to main content

polyc_llm/
erased.rs

1//! Type erasure for [`LlmProvider`] so the control plane can hold a single
2//! `Arc<dyn LlmProvider>` regardless of which backend is configured.
3//!
4//! The trait keeps a per-provider associated [`LlmError`] type (each backend
5//! ships its own concrete error). A trait object must fix that associated
6//! type, so this module supplies one uniform error — [`BoxError`] — and an
7//! [`ErasedProvider`] adapter that maps any provider's error into it. The
8//! result is [`DynProvider`], the single trait-object type callers store and
9//! dispatch through.
10//!
11//! Adding a backend then costs one trait impl plus one [`into_dyn`] call at the
12//! wiring boundary — no change to the dispatch site or the planner.
13
14use std::sync::Arc;
15
16use async_trait::async_trait;
17use futures::stream::{BoxStream, StreamExt};
18
19use crate::{
20    Chunk, CompletionRequest, LlmProvider,
21    error::{LlmError, LlmErrorKind},
22};
23
24/// A provider error erased to one concrete type, so backends with differing
25/// associated `Error`s can be stored behind a single trait object.
26///
27/// Transparent wrapper: [`Display`](std::fmt::Display) delegates to the inner
28/// error and [`source`](std::error::Error::source) exposes it, so logs and
29/// error chains read exactly as the un-erased error did. The second field
30/// preserves the original [`LlmErrorKind`] across erasure (the boxed `dyn Error`
31/// alone could not be re-classified).
32#[derive(Debug)]
33pub struct BoxError(
34    Box<dyn std::error::Error + Send + Sync + 'static>,
35    LlmErrorKind,
36);
37
38impl BoxError {
39    /// Erase any [`LlmError`] into a `BoxError`, capturing its
40    /// [`kind`](LlmError::kind) so the classification survives erasure.
41    #[must_use]
42    pub fn new<E: LlmError>(err: E) -> Self {
43        let kind = err.kind();
44        Self(Box::new(err), kind)
45    }
46}
47
48impl LlmError for BoxError {
49    fn kind(&self) -> LlmErrorKind {
50        self.1
51    }
52}
53
54impl std::fmt::Display for BoxError {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        std::fmt::Display::fmt(&self.0, f)
57    }
58}
59
60impl std::error::Error for BoxError {
61    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
62        Some(&*self.0)
63    }
64}
65
66/// Adapter that wraps a concrete [`LlmProvider`] and erases its associated
67/// error to [`BoxError`], so the wrapped value coerces to [`DynProvider`].
68///
69/// The completion stream is mapped lazily — each item's error is boxed as it
70/// arrives, preserving the bytes-as-they-arrive latency of the inner provider.
71pub struct ErasedProvider<P>(P);
72
73#[async_trait]
74impl<P: LlmProvider> LlmProvider for ErasedProvider<P> {
75    type Error = BoxError;
76
77    async fn complete(
78        &self,
79        req: CompletionRequest,
80    ) -> Result<BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error> {
81        let stream = self.0.complete(req).await.map_err(BoxError::new)?;
82        Ok(stream.map(|item| item.map_err(BoxError::new)).boxed())
83    }
84}
85
86/// The single trait-object provider type the control plane stores. Every
87/// concrete backend is erased to this via [`into_dyn`].
88pub type DynProvider = dyn LlmProvider<Error = BoxError>;
89
90/// Erase a concrete provider and wrap it in an `Arc` as a [`DynProvider`].
91///
92/// The one wiring-boundary call that lets a concrete backend be dispatched
93/// behind the runtime-swappable trait object.
94#[must_use]
95pub fn into_dyn<P: LlmProvider>(provider: P) -> Arc<DynProvider> {
96    Arc::new(ErasedProvider(provider))
97}
98
99#[cfg(test)]
100mod tests {
101    #![allow(clippy::pedantic, clippy::nursery, missing_docs)]
102
103    use futures::{StreamExt, stream};
104
105    use super::{BoxError, DynProvider, into_dyn};
106    use crate::{Chunk, CompletionRequest, LlmProvider, StopReason, Usage, error::DummyError};
107
108    /// Provider whose `Error` is `DummyError` — a different concrete type than
109    /// `BoxError`, so erasing it actually exercises the conversion.
110    struct DummyProvider;
111
112    #[async_trait::async_trait]
113    impl LlmProvider for DummyProvider {
114        type Error = DummyError;
115
116        async fn complete(
117            &self,
118            req: CompletionRequest,
119        ) -> Result<futures::stream::BoxStream<'static, Result<Chunk, Self::Error>>, Self::Error>
120        {
121            if req.messages.is_empty() {
122                return Err(DummyError::Other("no messages".to_owned()));
123            }
124            let chunks = vec![
125                Ok(Chunk::text_delta("hi")),
126                Ok(Chunk::Usage(Usage {
127                    input_tokens: 1,
128                    output_tokens: 1,
129                })),
130                Ok(Chunk::Stop(StopReason::EndTurn)),
131            ];
132            Ok(stream::iter(chunks).boxed())
133        }
134    }
135
136    #[tokio::test]
137    async fn erased_provider_streams_to_completion() {
138        let provider: std::sync::Arc<DynProvider> = into_dyn(DummyProvider);
139        let mut req = CompletionRequest::new("m");
140        req.messages.push(crate::Message::user("yo"));
141
142        let stream = provider.complete(req).await.expect("stream opens");
143        let n = stream.count().await;
144        assert_eq!(n, 3);
145    }
146
147    #[tokio::test]
148    async fn erased_pre_stream_error_is_preserved() {
149        let provider: std::sync::Arc<DynProvider> = into_dyn(DummyProvider);
150        let req = CompletionRequest::new("m"); // no messages → pre-stream error
151
152        // `Ok` here is a stream (not `Debug`), so match rather than `expect_err`.
153        let Err(err) = provider.complete(req).await else {
154            panic!("expected pre-stream rejection");
155        };
156        // Display delegates to the inner DummyError's message.
157        assert_eq!(format!("{err}"), "other: no messages");
158        // The original error is exposed as the source of the chain.
159        let src = std::error::Error::source(&err).expect("source present");
160        assert_eq!(format!("{src}"), "other: no messages");
161    }
162
163    #[test]
164    fn box_error_satisfies_llm_error() {
165        fn require_llm_error<E: crate::error::LlmError>() {}
166        require_llm_error::<BoxError>();
167    }
168
169    #[test]
170    fn box_error_preserves_kind_through_erasure() {
171        use crate::error::{DummyError, LlmError, LlmErrorKind};
172        // A 429 provider error keeps its RateLimit kind after erasure.
173        let boxed = BoxError::new(DummyError::Provider {
174            status: 429,
175            body: String::new(),
176        });
177        assert_eq!(boxed.kind(), LlmErrorKind::RateLimit);
178        // The default (unclassified) kind also round-trips.
179        let other = BoxError::new(DummyError::Other("x".to_owned()));
180        assert_eq!(other.kind(), LlmErrorKind::Other);
181    }
182}