1use crate::api::{ChatApi, ChatCompletionChunk};
2use crate::provider::{GeminiAdapter, GenericAdapter, OpenAiAdapter, ProviderConfigs};
3use crate::types::{AiLibError, ChatCompletionRequest, ChatCompletionResponse};
4use futures::stream::Stream;
5use futures::Future;
6use tokio::sync::oneshot;
7
8#[derive(Debug, Clone, Copy)]
10pub enum Provider {
11 Groq,
12 OpenAI,
13 DeepSeek,
14 Gemini,
15 Anthropic,
16 }
19
20pub struct AiClient {
63 provider: Provider,
64 adapter: Box<dyn ChatApi>,
65}
66
67impl AiClient {
68 pub fn new(provider: Provider) -> Result<Self, AiLibError> {
84 let adapter: Box<dyn ChatApi> = match provider {
85 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
87 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
88 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
89 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
91 Provider::Gemini => Box::new(GeminiAdapter::new()?),
92 };
93
94 Ok(Self { provider, adapter })
95 }
96
97 pub async fn chat_completion(
105 &self,
106 request: ChatCompletionRequest,
107 ) -> Result<ChatCompletionResponse, AiLibError> {
108 self.adapter.chat_completion(request).await
109 }
110
111 pub async fn chat_completion_stream(
119 &self,
120 mut request: ChatCompletionRequest,
121 ) -> Result<
122 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
123 AiLibError,
124 > {
125 request.stream = Some(true);
126 self.adapter.chat_completion_stream(request).await
127 }
128
129 pub async fn chat_completion_stream_with_cancel(
137 &self,
138 request: ChatCompletionRequest,
139 ) -> Result<
140 (
141 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
142 CancelHandle,
143 ),
144 AiLibError,
145 > {
146 let (cancel_tx, cancel_rx) = oneshot::channel();
147 let stream = self.chat_completion_stream(request).await?;
148
149 let cancel_handle = CancelHandle {
150 sender: Some(cancel_tx),
151 };
152 let controlled_stream = ControlledStream::new(stream, cancel_rx);
153
154 Ok((Box::new(Box::pin(controlled_stream)), cancel_handle))
155 }
156
157 pub async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
162 self.adapter.list_models().await
163 }
164
165 pub fn switch_provider(&mut self, provider: Provider) -> Result<(), AiLibError> {
183 let new_adapter: Box<dyn ChatApi> = match provider {
184 Provider::Groq => Box::new(GenericAdapter::new(ProviderConfigs::groq())?),
185 Provider::DeepSeek => Box::new(GenericAdapter::new(ProviderConfigs::deepseek())?),
186 Provider::OpenAI => Box::new(OpenAiAdapter::new()?),
187 Provider::Anthropic => Box::new(GenericAdapter::new(ProviderConfigs::anthropic())?),
188 Provider::Gemini => Box::new(GeminiAdapter::new()?),
189 };
190
191 self.provider = provider;
192 self.adapter = new_adapter;
193 Ok(())
194 }
195
196 pub fn current_provider(&self) -> Provider {
198 self.provider
199 }
200}
201
202pub struct CancelHandle {
204 sender: Option<oneshot::Sender<()>>,
205}
206
207impl CancelHandle {
208 pub fn cancel(mut self) {
210 if let Some(sender) = self.sender.take() {
211 let _ = sender.send(());
212 }
213 }
214}
215
216struct ControlledStream {
218 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
219 cancel_rx: Option<oneshot::Receiver<()>>,
220}
221
222impl ControlledStream {
223 fn new(
224 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
225 cancel_rx: oneshot::Receiver<()>,
226 ) -> Self {
227 Self {
228 inner,
229 cancel_rx: Some(cancel_rx),
230 }
231 }
232}
233
234impl Stream for ControlledStream {
235 type Item = Result<ChatCompletionChunk, AiLibError>;
236
237 fn poll_next(
238 mut self: std::pin::Pin<&mut Self>,
239 cx: &mut std::task::Context<'_>,
240 ) -> std::task::Poll<Option<Self::Item>> {
241 use futures::stream::StreamExt;
242 use std::task::Poll;
243
244 if let Some(ref mut cancel_rx) = self.cancel_rx {
246 match Future::poll(std::pin::Pin::new(cancel_rx), cx) {
247 Poll::Ready(_) => {
248 self.cancel_rx = None;
249 return Poll::Ready(Some(Err(AiLibError::ProviderError(
250 "Stream cancelled".to_string(),
251 ))));
252 }
253 Poll::Pending => {}
254 }
255 }
256
257 self.inner.poll_next_unpin(cx)
259 }
260}