1use async_trait::async_trait;
2
3use crate::error::Result;
4use crate::stream::ResponseStream;
5use crate::types::{ModelInfo, Prompt};
6
7#[cfg(not(target_arch = "wasm32"))]
8#[async_trait]
9pub trait Provider: Send + Sync {
10 fn id(&self) -> &str;
11 fn models(&self) -> Vec<ModelInfo>;
12
13 fn needs_key(&self) -> Option<&str> {
14 None
15 }
16
17 fn key_env_var(&self) -> Option<&str> {
18 None
19 }
20
21 async fn execute(
22 &self,
23 model: &str,
24 prompt: &Prompt,
25 key: Option<&str>,
26 stream: bool,
27 ) -> Result<ResponseStream>;
28}
29
30#[cfg(target_arch = "wasm32")]
31#[async_trait(?Send)]
32pub trait Provider {
33 fn id(&self) -> &str;
34 fn models(&self) -> Vec<ModelInfo>;
35
36 fn needs_key(&self) -> Option<&str> {
37 None
38 }
39
40 fn key_env_var(&self) -> Option<&str> {
41 None
42 }
43
44 async fn execute(
45 &self,
46 model: &str,
47 prompt: &Prompt,
48 key: Option<&str>,
49 stream: bool,
50 ) -> Result<ResponseStream>;
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56 use crate::error::LlmError;
57 use crate::stream::Chunk;
58 use futures::StreamExt;
59
60 struct MockProvider;
62
63 #[async_trait]
64 impl Provider for MockProvider {
65 fn id(&self) -> &str {
66 "mock"
67 }
68
69 fn models(&self) -> Vec<ModelInfo> {
70 vec![
71 ModelInfo::new("mock-fast"),
72 ModelInfo {
73 id: "mock-smart".into(),
74 can_stream: true,
75 supports_tools: true,
76 supports_schema: true,
77 attachment_types: vec!["image/png".into()],
78 },
79 ]
80 }
81
82 async fn execute(
83 &self,
84 _model: &str,
85 _prompt: &Prompt,
86 _key: Option<&str>,
87 _stream: bool,
88 ) -> Result<ResponseStream> {
89 let chunks = vec![
90 Ok(Chunk::Text("Hello from mock".into())),
91 Ok(Chunk::Done),
92 ];
93 Ok(Box::pin(futures::stream::iter(chunks)))
94 }
95 }
96
97 struct KeyProvider;
99
100 #[async_trait]
101 impl Provider for KeyProvider {
102 fn id(&self) -> &str {
103 "key-provider"
104 }
105
106 fn models(&self) -> Vec<ModelInfo> {
107 vec![ModelInfo::new("key-model")]
108 }
109
110 fn needs_key(&self) -> Option<&str> {
111 Some("test_key")
112 }
113
114 fn key_env_var(&self) -> Option<&str> {
115 Some("TEST_API_KEY")
116 }
117
118 async fn execute(
119 &self,
120 _model: &str,
121 _prompt: &Prompt,
122 key: Option<&str>,
123 _stream: bool,
124 ) -> Result<ResponseStream> {
125 let key = key.ok_or_else(|| LlmError::NeedsKey("test_key required".into()))?;
126 let chunks = vec![
127 Ok(Chunk::Text(format!("key={key}"))),
128 Ok(Chunk::Done),
129 ];
130 Ok(Box::pin(futures::stream::iter(chunks)))
131 }
132 }
133
134 #[test]
135 fn provider_id() {
136 let p = MockProvider;
137 assert_eq!(p.id(), "mock");
138 }
139
140 #[test]
141 fn provider_lists_models() {
142 let p = MockProvider;
143 let models = p.models();
144 assert_eq!(models.len(), 2);
145 assert_eq!(models[0].id, "mock-fast");
146 assert!(models[1].supports_tools);
147 }
148
149 #[test]
150 fn provider_needs_key_defaults_to_none() {
151 let p = MockProvider;
152 assert_eq!(p.needs_key(), None);
153 assert_eq!(p.key_env_var(), None);
154 }
155
156 #[test]
157 fn provider_needs_key_returns_alias() {
158 let p = KeyProvider;
159 assert_eq!(p.needs_key(), Some("test_key"));
160 assert_eq!(p.key_env_var(), Some("TEST_API_KEY"));
161 }
162
163 #[tokio::test]
164 async fn provider_execute_returns_stream() {
165 let p = MockProvider;
166 let prompt = Prompt::new("Hello");
167 let stream = p.execute("mock-fast", &prompt, None, true).await.unwrap();
168 let chunks: Vec<_> = stream.collect().await;
169 assert_eq!(chunks.len(), 2);
170 if let Ok(Chunk::Text(t)) = &chunks[0] {
171 assert_eq!(t, "Hello from mock");
172 } else {
173 panic!("expected Text chunk");
174 }
175 }
176
177 #[tokio::test]
178 async fn provider_execute_with_key() {
179 let p = KeyProvider;
180 let prompt = Prompt::new("Hello");
181 let stream = p
182 .execute("key-model", &prompt, Some("sk-test"), true)
183 .await
184 .unwrap();
185 let chunks: Vec<_> = stream.collect().await;
186 if let Ok(Chunk::Text(t)) = &chunks[0] {
187 assert_eq!(t, "key=sk-test");
188 } else {
189 panic!("expected Text chunk");
190 }
191 }
192
193 #[tokio::test]
194 async fn provider_execute_without_key_errors() {
195 let p = KeyProvider;
196 let prompt = Prompt::new("Hello");
197 let result = p.execute("key-model", &prompt, None, true).await;
198 assert!(result.is_err());
199 if let Err(LlmError::NeedsKey(msg)) = result {
200 assert!(msg.contains("test_key"));
201 } else {
202 panic!("expected NeedsKey error");
203 }
204 }
205}