Skip to main content

llm_core/
provider.rs

1use async_trait::async_trait;
2
3use crate::error::Result;
4use crate::stream::ResponseStream;
5use crate::types::{ModelInfo, Prompt};
6
7#[cfg(not(target_arch = "wasm32"))]
8#[async_trait]
9pub trait Provider: Send + Sync {
10    fn id(&self) -> &str;
11    fn models(&self) -> Vec<ModelInfo>;
12
13    fn needs_key(&self) -> Option<&str> {
14        None
15    }
16
17    fn key_env_var(&self) -> Option<&str> {
18        None
19    }
20
21    async fn execute(
22        &self,
23        model: &str,
24        prompt: &Prompt,
25        key: Option<&str>,
26        stream: bool,
27    ) -> Result<ResponseStream>;
28}
29
30#[cfg(target_arch = "wasm32")]
31#[async_trait(?Send)]
32pub trait Provider {
33    fn id(&self) -> &str;
34    fn models(&self) -> Vec<ModelInfo>;
35
36    fn needs_key(&self) -> Option<&str> {
37        None
38    }
39
40    fn key_env_var(&self) -> Option<&str> {
41        None
42    }
43
44    async fn execute(
45        &self,
46        model: &str,
47        prompt: &Prompt,
48        key: Option<&str>,
49        stream: bool,
50    ) -> Result<ResponseStream>;
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56    use crate::error::LlmError;
57    use crate::stream::Chunk;
58    use futures::StreamExt;
59
60    // A mock provider for testing the trait contract
61    struct MockProvider;
62
63    #[async_trait]
64    impl Provider for MockProvider {
65        fn id(&self) -> &str {
66            "mock"
67        }
68
69        fn models(&self) -> Vec<ModelInfo> {
70            vec![
71                ModelInfo::new("mock-fast"),
72                ModelInfo {
73                    id: "mock-smart".into(),
74                    can_stream: true,
75                    supports_tools: true,
76                    supports_schema: true,
77                    attachment_types: vec!["image/png".into()],
78                },
79            ]
80        }
81
82        async fn execute(
83            &self,
84            _model: &str,
85            _prompt: &Prompt,
86            _key: Option<&str>,
87            _stream: bool,
88        ) -> Result<ResponseStream> {
89            let chunks = vec![
90                Ok(Chunk::Text("Hello from mock".into())),
91                Ok(Chunk::Done),
92            ];
93            Ok(Box::pin(futures::stream::iter(chunks)))
94        }
95    }
96
97    // A provider that requires a key
98    struct KeyProvider;
99
100    #[async_trait]
101    impl Provider for KeyProvider {
102        fn id(&self) -> &str {
103            "key-provider"
104        }
105
106        fn models(&self) -> Vec<ModelInfo> {
107            vec![ModelInfo::new("key-model")]
108        }
109
110        fn needs_key(&self) -> Option<&str> {
111            Some("test_key")
112        }
113
114        fn key_env_var(&self) -> Option<&str> {
115            Some("TEST_API_KEY")
116        }
117
118        async fn execute(
119            &self,
120            _model: &str,
121            _prompt: &Prompt,
122            key: Option<&str>,
123            _stream: bool,
124        ) -> Result<ResponseStream> {
125            let key = key.ok_or_else(|| LlmError::NeedsKey("test_key required".into()))?;
126            let chunks = vec![
127                Ok(Chunk::Text(format!("key={key}"))),
128                Ok(Chunk::Done),
129            ];
130            Ok(Box::pin(futures::stream::iter(chunks)))
131        }
132    }
133
134    #[test]
135    fn provider_id() {
136        let p = MockProvider;
137        assert_eq!(p.id(), "mock");
138    }
139
140    #[test]
141    fn provider_lists_models() {
142        let p = MockProvider;
143        let models = p.models();
144        assert_eq!(models.len(), 2);
145        assert_eq!(models[0].id, "mock-fast");
146        assert!(models[1].supports_tools);
147    }
148
149    #[test]
150    fn provider_needs_key_defaults_to_none() {
151        let p = MockProvider;
152        assert_eq!(p.needs_key(), None);
153        assert_eq!(p.key_env_var(), None);
154    }
155
156    #[test]
157    fn provider_needs_key_returns_alias() {
158        let p = KeyProvider;
159        assert_eq!(p.needs_key(), Some("test_key"));
160        assert_eq!(p.key_env_var(), Some("TEST_API_KEY"));
161    }
162
163    #[tokio::test]
164    async fn provider_execute_returns_stream() {
165        let p = MockProvider;
166        let prompt = Prompt::new("Hello");
167        let stream = p.execute("mock-fast", &prompt, None, true).await.unwrap();
168        let chunks: Vec<_> = stream.collect().await;
169        assert_eq!(chunks.len(), 2);
170        if let Ok(Chunk::Text(t)) = &chunks[0] {
171            assert_eq!(t, "Hello from mock");
172        } else {
173            panic!("expected Text chunk");
174        }
175    }
176
177    #[tokio::test]
178    async fn provider_execute_with_key() {
179        let p = KeyProvider;
180        let prompt = Prompt::new("Hello");
181        let stream = p
182            .execute("key-model", &prompt, Some("sk-test"), true)
183            .await
184            .unwrap();
185        let chunks: Vec<_> = stream.collect().await;
186        if let Ok(Chunk::Text(t)) = &chunks[0] {
187            assert_eq!(t, "key=sk-test");
188        } else {
189            panic!("expected Text chunk");
190        }
191    }
192
193    #[tokio::test]
194    async fn provider_execute_without_key_errors() {
195        let p = KeyProvider;
196        let prompt = Prompt::new("Hello");
197        let result = p.execute("key-model", &prompt, None, true).await;
198        assert!(result.is_err());
199        if let Err(LlmError::NeedsKey(msg)) = result {
200            assert!(msg.contains("test_key"));
201        } else {
202            panic!("expected NeedsKey error");
203        }
204    }
205}