Skip to main content

alpine/
client.rs

1use std::sync::Arc;
2
3use crate::error::ProviderError;
4use crate::middleware::{Middleware, Next};
5use crate::provider::Provider;
6use crate::types::{Request, Response, StreamResponse};
7
8pub struct AlpineClient {
9    provider: Arc<dyn Provider>,
10    middleware: Vec<Arc<dyn Middleware>>,
11}
12
13impl AlpineClient {
14    pub fn new(provider: impl Provider + 'static) -> Self {
15        Self {
16            provider: Arc::new(provider),
17            middleware: Vec::new(),
18        }
19    }
20
21    pub fn with_middleware(mut self, m: impl Middleware + 'static) -> Self {
22        self.middleware.push(Arc::new(m));
23        self
24    }
25
26    /// Run the request through the middleware chain, then the provider.
27    pub async fn complete(&self, req: Request) -> Result<Response, ProviderError> {
28        let provider = Arc::clone(&self.provider);
29
30        let core: Next = Box::new(move |r| Box::pin(async move { provider.complete(&r).await }));
31
32        let chain = self.middleware.iter().rev().fold(core, |next, mw| {
33            let mw = Arc::clone(mw);
34            Box::new(move |r| mw.handle(r, next))
35        });
36
37        chain(req).await
38    }
39
40    /// Stream bypasses middleware for now — middleware is request/response
41    /// oriented. Streaming middleware is a separate concern.
42    pub async fn stream(&self, req: &Request) -> Result<StreamResponse<'_>, ProviderError> {
43        self.provider.stream(req).await
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50    use crate::error::ProviderError;
51    use crate::middleware::Middleware;
52    use crate::provider::Provider;
53    use crate::types::*;
54    use async_trait::async_trait;
55    use futures::StreamExt;
56    use std::future::Future;
57    use std::pin::Pin;
58    use std::time::Duration;
59
60    // -- Stub provider ---------------------------------------------------------
61
62    struct StubProvider {
63        model: ModelId,
64        content: String,
65    }
66
67    impl StubProvider {
68        fn new(content: &str) -> Self {
69            Self {
70                model: ModelId::new("stub"),
71                content: content.into(),
72            }
73        }
74    }
75
76    #[async_trait]
77    impl Provider for StubProvider {
78        async fn complete(&self, _req: &Request) -> Result<Response, ProviderError> {
79            Ok(Response {
80                content: self.content.clone(),
81                usage: Usage {
82                    input_tokens: 1,
83                    output_tokens: 2,
84                },
85                model: self.model.clone(),
86                finish_reason: FinishReason::Stop,
87                latency: Duration::ZERO,
88                raw: serde_json::Value::Null,
89            })
90        }
91
92        async fn stream(&self, _req: &Request) -> Result<StreamResponse<'_>, ProviderError> {
93            let chunks = vec![
94                StreamChunk::Delta("hello".into()),
95                StreamChunk::Done { usage: None },
96            ];
97            Ok(Box::pin(futures::stream::iter(chunks)))
98        }
99
100        fn model_id(&self) -> &ModelId {
101            &self.model
102        }
103    }
104
105    // -- Stub middleware --------------------------------------------------------
106
107    struct AppendMiddleware {
108        suffix: String,
109    }
110
111    impl AppendMiddleware {
112        fn new(suffix: &str) -> Self {
113            Self {
114                suffix: suffix.into(),
115            }
116        }
117    }
118
119    impl Middleware for AppendMiddleware {
120        fn handle(
121            self: Arc<Self>,
122            req: Request,
123            next: crate::middleware::Next,
124        ) -> Pin<Box<dyn Future<Output = Result<Response, ProviderError>> + Send>> {
125            Box::pin(async move {
126                let mut resp = next(req).await?;
127                resp.content.push_str(&self.suffix);
128                Ok(resp)
129            })
130        }
131    }
132
133    struct ErrorMiddleware;
134
135    impl Middleware for ErrorMiddleware {
136        fn handle(
137            self: Arc<Self>,
138            _req: Request,
139            _next: crate::middleware::Next,
140        ) -> Pin<Box<dyn Future<Output = Result<Response, ProviderError>> + Send>> {
141            Box::pin(async { Err(ProviderError::Other("middleware error".into())) })
142        }
143    }
144
145    // -- Tests -----------------------------------------------------------------
146
147    #[tokio::test]
148    async fn client_new() {
149        let _client = AlpineClient::new(StubProvider::new("x"));
150    }
151
152    #[tokio::test]
153    async fn complete_no_middleware() {
154        let client = AlpineClient::new(StubProvider::new("hello"));
155        let resp = client.complete(Request::default()).await.unwrap();
156        assert_eq!(resp.content, "hello");
157        assert_eq!(resp.usage.input_tokens, 1);
158        assert_eq!(resp.usage.output_tokens, 2);
159    }
160
161    #[tokio::test]
162    async fn complete_with_one_middleware() {
163        let client = AlpineClient::new(StubProvider::new("base"))
164            .with_middleware(AppendMiddleware::new(" [m1]"));
165        let resp = client.complete(Request::default()).await.unwrap();
166        assert_eq!(resp.content, "base [m1]");
167    }
168
169    #[tokio::test]
170    async fn complete_with_two_middleware() {
171        let client = AlpineClient::new(StubProvider::new("base"))
172            .with_middleware(AppendMiddleware::new(" [first]"))
173            .with_middleware(AppendMiddleware::new(" [second]"));
174        let resp = client.complete(Request::default()).await.unwrap();
175        // Middleware wraps onion-style: first added is outermost.
176        // Inner (second) runs first on response, then outer (first).
177        assert_eq!(resp.content, "base [second] [first]");
178    }
179
180    #[tokio::test]
181    async fn complete_middleware_error() {
182        let client = AlpineClient::new(StubProvider::new("x")).with_middleware(ErrorMiddleware);
183        let err = client.complete(Request::default()).await.unwrap_err();
184        assert!(err.to_string().contains("middleware error"));
185    }
186
187    #[tokio::test]
188    async fn stream_bypasses_middleware() {
189        let client = AlpineClient::new(StubProvider::new("x"))
190            .with_middleware(AppendMiddleware::new(" [mod]"));
191        let mut stream = client.stream(&Request::default()).await.unwrap();
192
193        let first = stream.next().await.unwrap();
194        match first {
195            StreamChunk::Delta(text) => assert_eq!(text, "hello"),
196            other => panic!("expected Delta, got: {other:?}"),
197        }
198
199        let second = stream.next().await.unwrap();
200        match second {
201            StreamChunk::Done { .. } => {}
202            other => panic!("expected Done, got: {other:?}"),
203        }
204
205        assert!(stream.next().await.is_none());
206    }
207}