1use std::sync::Arc;
2
3use crate::error::ProviderError;
4use crate::middleware::{Middleware, Next};
5use crate::provider::Provider;
6use crate::types::{Request, Response, StreamResponse};
7
8pub struct AlpineClient {
9 provider: Arc<dyn Provider>,
10 middleware: Vec<Arc<dyn Middleware>>,
11}
12
13impl AlpineClient {
14 pub fn new(provider: impl Provider + 'static) -> Self {
15 Self {
16 provider: Arc::new(provider),
17 middleware: Vec::new(),
18 }
19 }
20
21 pub fn with_middleware(mut self, m: impl Middleware + 'static) -> Self {
22 self.middleware.push(Arc::new(m));
23 self
24 }
25
26 pub async fn complete(&self, req: Request) -> Result<Response, ProviderError> {
28 let provider = Arc::clone(&self.provider);
29
30 let core: Next = Box::new(move |r| Box::pin(async move { provider.complete(&r).await }));
31
32 let chain = self.middleware.iter().rev().fold(core, |next, mw| {
33 let mw = Arc::clone(mw);
34 Box::new(move |r| mw.handle(r, next))
35 });
36
37 chain(req).await
38 }
39
40 pub async fn stream(&self, req: &Request) -> Result<StreamResponse<'_>, ProviderError> {
43 self.provider.stream(req).await
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50 use crate::error::ProviderError;
51 use crate::middleware::Middleware;
52 use crate::provider::Provider;
53 use crate::types::*;
54 use async_trait::async_trait;
55 use futures::StreamExt;
56 use std::future::Future;
57 use std::pin::Pin;
58 use std::time::Duration;
59
60 struct StubProvider {
63 model: ModelId,
64 content: String,
65 }
66
67 impl StubProvider {
68 fn new(content: &str) -> Self {
69 Self {
70 model: ModelId::new("stub"),
71 content: content.into(),
72 }
73 }
74 }
75
76 #[async_trait]
77 impl Provider for StubProvider {
78 async fn complete(&self, _req: &Request) -> Result<Response, ProviderError> {
79 Ok(Response {
80 content: self.content.clone(),
81 usage: Usage {
82 input_tokens: 1,
83 output_tokens: 2,
84 },
85 model: self.model.clone(),
86 finish_reason: FinishReason::Stop,
87 latency: Duration::ZERO,
88 raw: serde_json::Value::Null,
89 })
90 }
91
92 async fn stream(&self, _req: &Request) -> Result<StreamResponse<'_>, ProviderError> {
93 let chunks = vec![
94 StreamChunk::Delta("hello".into()),
95 StreamChunk::Done { usage: None },
96 ];
97 Ok(Box::pin(futures::stream::iter(chunks)))
98 }
99
100 fn model_id(&self) -> &ModelId {
101 &self.model
102 }
103 }
104
105 struct AppendMiddleware {
108 suffix: String,
109 }
110
111 impl AppendMiddleware {
112 fn new(suffix: &str) -> Self {
113 Self {
114 suffix: suffix.into(),
115 }
116 }
117 }
118
119 impl Middleware for AppendMiddleware {
120 fn handle(
121 self: Arc<Self>,
122 req: Request,
123 next: crate::middleware::Next,
124 ) -> Pin<Box<dyn Future<Output = Result<Response, ProviderError>> + Send>> {
125 Box::pin(async move {
126 let mut resp = next(req).await?;
127 resp.content.push_str(&self.suffix);
128 Ok(resp)
129 })
130 }
131 }
132
133 struct ErrorMiddleware;
134
135 impl Middleware for ErrorMiddleware {
136 fn handle(
137 self: Arc<Self>,
138 _req: Request,
139 _next: crate::middleware::Next,
140 ) -> Pin<Box<dyn Future<Output = Result<Response, ProviderError>> + Send>> {
141 Box::pin(async { Err(ProviderError::Other("middleware error".into())) })
142 }
143 }
144
145 #[tokio::test]
148 async fn client_new() {
149 let _client = AlpineClient::new(StubProvider::new("x"));
150 }
151
152 #[tokio::test]
153 async fn complete_no_middleware() {
154 let client = AlpineClient::new(StubProvider::new("hello"));
155 let resp = client.complete(Request::default()).await.unwrap();
156 assert_eq!(resp.content, "hello");
157 assert_eq!(resp.usage.input_tokens, 1);
158 assert_eq!(resp.usage.output_tokens, 2);
159 }
160
161 #[tokio::test]
162 async fn complete_with_one_middleware() {
163 let client = AlpineClient::new(StubProvider::new("base"))
164 .with_middleware(AppendMiddleware::new(" [m1]"));
165 let resp = client.complete(Request::default()).await.unwrap();
166 assert_eq!(resp.content, "base [m1]");
167 }
168
169 #[tokio::test]
170 async fn complete_with_two_middleware() {
171 let client = AlpineClient::new(StubProvider::new("base"))
172 .with_middleware(AppendMiddleware::new(" [first]"))
173 .with_middleware(AppendMiddleware::new(" [second]"));
174 let resp = client.complete(Request::default()).await.unwrap();
175 assert_eq!(resp.content, "base [second] [first]");
178 }
179
180 #[tokio::test]
181 async fn complete_middleware_error() {
182 let client = AlpineClient::new(StubProvider::new("x")).with_middleware(ErrorMiddleware);
183 let err = client.complete(Request::default()).await.unwrap_err();
184 assert!(err.to_string().contains("middleware error"));
185 }
186
187 #[tokio::test]
188 async fn stream_bypasses_middleware() {
189 let client = AlpineClient::new(StubProvider::new("x"))
190 .with_middleware(AppendMiddleware::new(" [mod]"));
191 let mut stream = client.stream(&Request::default()).await.unwrap();
192
193 let first = stream.next().await.unwrap();
194 match first {
195 StreamChunk::Delta(text) => assert_eq!(text, "hello"),
196 other => panic!("expected Delta, got: {other:?}"),
197 }
198
199 let second = stream.next().await.unwrap();
200 match second {
201 StreamChunk::Done { .. } => {}
202 other => panic!("expected Done, got: {other:?}"),
203 }
204
205 assert!(stream.next().await.is_none());
206 }
207}