Skip to main content

alpine/
client.rs

1use std::sync::Arc;
2
3use crate::error::ProviderError;
4use crate::middleware::{Middleware, Next};
5use crate::provider::Provider;
6use crate::types::{Request, Response, StreamResponse};
7
8pub struct AlpineClient {
9    provider: Arc<dyn Provider>,
10    middleware: Vec<Arc<dyn Middleware>>,
11}
12
13impl AlpineClient {
14    pub fn new(provider: impl Provider + 'static) -> Self {
15        Self {
16            provider: Arc::new(provider),
17            middleware: Vec::new(),
18        }
19    }
20
21    pub fn with_middleware(mut self, m: impl Middleware + 'static) -> Self {
22        self.middleware.push(Arc::new(m));
23        self
24    }
25
26    /// Run the request through the middleware chain, then the provider.
27    pub async fn complete(&self, req: Request) -> Result<Response, ProviderError> {
28        let provider = Arc::clone(&self.provider);
29
30        let core: Next = Box::new(move |r| Box::pin(async move { provider.complete(&r).await }));
31
32        let chain = self.middleware.iter().rev().fold(core, |next, mw| {
33            let mw = Arc::clone(mw);
34            Box::new(move |r| mw.handle(r, next))
35        });
36
37        chain(req).await
38    }
39
40    /// Stream bypasses middleware for now — middleware is request/response
41    /// oriented. Streaming middleware is a separate concern.
42    pub async fn stream(&self, req: &Request) -> Result<StreamResponse<'_>, ProviderError> {
43        self.provider.stream(req).await
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50    use crate::error::ProviderError;
51    use crate::middleware::Middleware;
52    use crate::provider::Provider;
53    use crate::types::*;
54    use async_trait::async_trait;
55    use futures::StreamExt;
56    use std::future::Future;
57    use std::pin::Pin;
58    use std::time::Duration;
59
60    // -- Stub provider ---------------------------------------------------------
61
62    struct StubProvider {
63        model: ModelId,
64        content: String,
65    }
66
67    impl StubProvider {
68        fn new(content: &str) -> Self {
69            Self {
70                model: ModelId::new("stub"),
71                content: content.into(),
72            }
73        }
74    }
75
76    #[async_trait]
77    impl Provider for StubProvider {
78        async fn complete(&self, _req: &Request) -> Result<Response, ProviderError> {
79            Ok(Response {
80                content: self.content.clone(),
81                tool_calls: vec![],
82                usage: Usage {
83                    input_tokens: 1,
84                    output_tokens: 2,
85                },
86                model: self.model.clone(),
87                finish_reason: FinishReason::Stop,
88                latency: Duration::ZERO,
89                raw: serde_json::Value::Null,
90            })
91        }
92
93        async fn stream(&self, _req: &Request) -> Result<StreamResponse<'_>, ProviderError> {
94            let chunks = vec![
95                StreamChunk::Delta("hello".into()),
96                StreamChunk::Done { usage: None },
97            ];
98            Ok(Box::pin(futures::stream::iter(chunks)))
99        }
100
101        fn model_id(&self) -> &ModelId {
102            &self.model
103        }
104    }
105
106    // -- Stub middleware --------------------------------------------------------
107
108    struct AppendMiddleware {
109        suffix: String,
110    }
111
112    impl AppendMiddleware {
113        fn new(suffix: &str) -> Self {
114            Self {
115                suffix: suffix.into(),
116            }
117        }
118    }
119
120    impl Middleware for AppendMiddleware {
121        fn handle(
122            self: Arc<Self>,
123            req: Request,
124            next: crate::middleware::Next,
125        ) -> Pin<Box<dyn Future<Output = Result<Response, ProviderError>> + Send>> {
126            Box::pin(async move {
127                let mut resp = next(req).await?;
128                resp.content.push_str(&self.suffix);
129                Ok(resp)
130            })
131        }
132    }
133
134    struct ErrorMiddleware;
135
136    impl Middleware for ErrorMiddleware {
137        fn handle(
138            self: Arc<Self>,
139            _req: Request,
140            _next: crate::middleware::Next,
141        ) -> Pin<Box<dyn Future<Output = Result<Response, ProviderError>> + Send>> {
142            Box::pin(async { Err(ProviderError::Other("middleware error".into())) })
143        }
144    }
145
146    // -- Tests -----------------------------------------------------------------
147
148    #[tokio::test]
149    async fn client_new() {
150        let _client = AlpineClient::new(StubProvider::new("x"));
151    }
152
153    #[tokio::test]
154    async fn complete_no_middleware() {
155        let client = AlpineClient::new(StubProvider::new("hello"));
156        let resp = client.complete(Request::default()).await.unwrap();
157        assert_eq!(resp.content, "hello");
158        assert_eq!(resp.usage.input_tokens, 1);
159        assert_eq!(resp.usage.output_tokens, 2);
160    }
161
162    #[tokio::test]
163    async fn complete_with_one_middleware() {
164        let client = AlpineClient::new(StubProvider::new("base"))
165            .with_middleware(AppendMiddleware::new(" [m1]"));
166        let resp = client.complete(Request::default()).await.unwrap();
167        assert_eq!(resp.content, "base [m1]");
168    }
169
170    #[tokio::test]
171    async fn complete_with_two_middleware() {
172        let client = AlpineClient::new(StubProvider::new("base"))
173            .with_middleware(AppendMiddleware::new(" [first]"))
174            .with_middleware(AppendMiddleware::new(" [second]"));
175        let resp = client.complete(Request::default()).await.unwrap();
176        // Middleware wraps onion-style: first added is outermost.
177        // Inner (second) runs first on response, then outer (first).
178        assert_eq!(resp.content, "base [second] [first]");
179    }
180
181    #[tokio::test]
182    async fn complete_middleware_error() {
183        let client = AlpineClient::new(StubProvider::new("x")).with_middleware(ErrorMiddleware);
184        let err = client.complete(Request::default()).await.unwrap_err();
185        assert!(err.to_string().contains("middleware error"));
186    }
187
188    #[tokio::test]
189    async fn stream_bypasses_middleware() {
190        let client = AlpineClient::new(StubProvider::new("x"))
191            .with_middleware(AppendMiddleware::new(" [mod]"));
192        let mut stream = client.stream(&Request::default()).await.unwrap();
193
194        let first = stream.next().await.unwrap();
195        match first {
196            StreamChunk::Delta(text) => assert_eq!(text, "hello"),
197            other => panic!("expected Delta, got: {other:?}"),
198        }
199
200        let second = stream.next().await.unwrap();
201        match second {
202            StreamChunk::Done { .. } => {}
203            other => panic!("expected Done, got: {other:?}"),
204        }
205
206        assert!(stream.next().await.is_none());
207    }
208}