1use std::sync::Arc;
2
3use crate::error::ProviderError;
4use crate::middleware::{Middleware, Next};
5use crate::provider::Provider;
6use crate::types::{Request, Response, StreamResponse};
7
8pub struct AlpineClient {
9 provider: Arc<dyn Provider>,
10 middleware: Vec<Arc<dyn Middleware>>,
11}
12
13impl AlpineClient {
14 pub fn new(provider: impl Provider + 'static) -> Self {
15 Self {
16 provider: Arc::new(provider),
17 middleware: Vec::new(),
18 }
19 }
20
21 pub fn with_middleware(mut self, m: impl Middleware + 'static) -> Self {
22 self.middleware.push(Arc::new(m));
23 self
24 }
25
26 pub async fn complete(&self, req: Request) -> Result<Response, ProviderError> {
28 let provider = Arc::clone(&self.provider);
29
30 let core: Next = Box::new(move |r| Box::pin(async move { provider.complete(&r).await }));
31
32 let chain = self.middleware.iter().rev().fold(core, |next, mw| {
33 let mw = Arc::clone(mw);
34 Box::new(move |r| mw.handle(r, next))
35 });
36
37 chain(req).await
38 }
39
40 pub async fn stream(&self, req: &Request) -> Result<StreamResponse<'_>, ProviderError> {
43 self.provider.stream(req).await
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use super::*;
50 use crate::error::ProviderError;
51 use crate::middleware::Middleware;
52 use crate::provider::Provider;
53 use crate::types::*;
54 use async_trait::async_trait;
55 use futures::StreamExt;
56 use std::future::Future;
57 use std::pin::Pin;
58 use std::time::Duration;
59
60 struct StubProvider {
63 model: ModelId,
64 content: String,
65 }
66
67 impl StubProvider {
68 fn new(content: &str) -> Self {
69 Self {
70 model: ModelId::new("stub"),
71 content: content.into(),
72 }
73 }
74 }
75
76 #[async_trait]
77 impl Provider for StubProvider {
78 async fn complete(&self, _req: &Request) -> Result<Response, ProviderError> {
79 Ok(Response {
80 content: self.content.clone(),
81 tool_calls: vec![],
82 usage: Usage {
83 input_tokens: 1,
84 output_tokens: 2,
85 },
86 model: self.model.clone(),
87 finish_reason: FinishReason::Stop,
88 latency: Duration::ZERO,
89 raw: serde_json::Value::Null,
90 })
91 }
92
93 async fn stream(&self, _req: &Request) -> Result<StreamResponse<'_>, ProviderError> {
94 let chunks = vec![
95 StreamChunk::Delta("hello".into()),
96 StreamChunk::Done { usage: None },
97 ];
98 Ok(Box::pin(futures::stream::iter(chunks)))
99 }
100
101 fn model_id(&self) -> &ModelId {
102 &self.model
103 }
104 }
105
106 struct AppendMiddleware {
109 suffix: String,
110 }
111
112 impl AppendMiddleware {
113 fn new(suffix: &str) -> Self {
114 Self {
115 suffix: suffix.into(),
116 }
117 }
118 }
119
120 impl Middleware for AppendMiddleware {
121 fn handle(
122 self: Arc<Self>,
123 req: Request,
124 next: crate::middleware::Next,
125 ) -> Pin<Box<dyn Future<Output = Result<Response, ProviderError>> + Send>> {
126 Box::pin(async move {
127 let mut resp = next(req).await?;
128 resp.content.push_str(&self.suffix);
129 Ok(resp)
130 })
131 }
132 }
133
134 struct ErrorMiddleware;
135
136 impl Middleware for ErrorMiddleware {
137 fn handle(
138 self: Arc<Self>,
139 _req: Request,
140 _next: crate::middleware::Next,
141 ) -> Pin<Box<dyn Future<Output = Result<Response, ProviderError>> + Send>> {
142 Box::pin(async { Err(ProviderError::Other("middleware error".into())) })
143 }
144 }
145
146 #[tokio::test]
149 async fn client_new() {
150 let _client = AlpineClient::new(StubProvider::new("x"));
151 }
152
153 #[tokio::test]
154 async fn complete_no_middleware() {
155 let client = AlpineClient::new(StubProvider::new("hello"));
156 let resp = client.complete(Request::default()).await.unwrap();
157 assert_eq!(resp.content, "hello");
158 assert_eq!(resp.usage.input_tokens, 1);
159 assert_eq!(resp.usage.output_tokens, 2);
160 }
161
162 #[tokio::test]
163 async fn complete_with_one_middleware() {
164 let client = AlpineClient::new(StubProvider::new("base"))
165 .with_middleware(AppendMiddleware::new(" [m1]"));
166 let resp = client.complete(Request::default()).await.unwrap();
167 assert_eq!(resp.content, "base [m1]");
168 }
169
170 #[tokio::test]
171 async fn complete_with_two_middleware() {
172 let client = AlpineClient::new(StubProvider::new("base"))
173 .with_middleware(AppendMiddleware::new(" [first]"))
174 .with_middleware(AppendMiddleware::new(" [second]"));
175 let resp = client.complete(Request::default()).await.unwrap();
176 assert_eq!(resp.content, "base [second] [first]");
179 }
180
181 #[tokio::test]
182 async fn complete_middleware_error() {
183 let client = AlpineClient::new(StubProvider::new("x")).with_middleware(ErrorMiddleware);
184 let err = client.complete(Request::default()).await.unwrap_err();
185 assert!(err.to_string().contains("middleware error"));
186 }
187
188 #[tokio::test]
189 async fn stream_bypasses_middleware() {
190 let client = AlpineClient::new(StubProvider::new("x"))
191 .with_middleware(AppendMiddleware::new(" [mod]"));
192 let mut stream = client.stream(&Request::default()).await.unwrap();
193
194 let first = stream.next().await.unwrap();
195 match first {
196 StreamChunk::Delta(text) => assert_eq!(text, "hello"),
197 other => panic!("expected Delta, got: {other:?}"),
198 }
199
200 let second = stream.next().await.unwrap();
201 match second {
202 StreamChunk::Done { .. } => {}
203 other => panic!("expected Done, got: {other:?}"),
204 }
205
206 assert!(stream.next().await.is_none());
207 }
208}