1use crate::{
7 RsllmError, RsllmResult, Provider, ClientConfig,
8 ChatMessage, ChatResponse, EmbeddingResponse,
9 ChatStream,
10};
11
12#[cfg(feature = "openai")]
13use crate::provider::OpenAIProvider;
14
15#[cfg(feature = "ollama")]
16use crate::provider::OllamaProvider;
17
18use crate::provider::LLMProvider;
19use async_trait::async_trait;
20use std::sync::Arc;
21use std::collections::HashMap;
22
23pub struct Client {
25 config: ClientConfig,
27
28 provider: Arc<dyn LLMProvider>,
30
31 metadata: HashMap<String, serde_json::Value>,
33}
34
35impl Client {
36 pub fn new(config: ClientConfig) -> RsllmResult<Self> {
38 config.validate()?;
39
40 let provider = Self::create_provider(&config)?;
41
42 Ok(Self {
43 config,
44 provider,
45 metadata: HashMap::new(),
46 })
47 }
48
49 pub fn builder() -> ClientBuilder {
51 ClientBuilder::new()
52 }
53
54 pub fn from_env() -> RsllmResult<Self> {
56 let config = ClientConfig::from_env()?;
57 Self::new(config)
58 }
59
60 fn create_provider(config: &ClientConfig) -> RsllmResult<Arc<dyn LLMProvider>> {
62 match config.provider.provider {
63 #[cfg(feature = "openai")]
64 Provider::OpenAI => {
65 let api_key = config.provider.api_key.as_ref()
66 .ok_or_else(|| RsllmError::configuration("OpenAI API key required"))?;
67
68 let provider = OpenAIProvider::new(
69 api_key.clone(),
70 config.provider.base_url.clone(),
71 config.provider.organization_id.clone(),
72 )?;
73
74 Ok(Arc::new(provider))
75 }
76
77 #[cfg(feature = "ollama")]
78 Provider::Ollama => {
79 let provider = OllamaProvider::new(config.provider.base_url.clone())?;
80 Ok(Arc::new(provider))
81 }
82
83 #[cfg(feature = "claude")]
84 Provider::Claude => {
85 Err(RsllmError::configuration("Claude provider not yet implemented"))
87 }
88
89 #[allow(unreachable_patterns)]
90 _ => Err(RsllmError::configuration(
91 format!("Provider {:?} not supported in current build", config.provider.provider)
92 )),
93 }
94 }
95
96 pub fn config(&self) -> &ClientConfig {
98 &self.config
99 }
100
101 pub fn provider(&self) -> &Arc<dyn LLMProvider> {
103 &self.provider
104 }
105
106 pub fn add_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
108 self.metadata.insert(key.into(), value);
109 }
110
111 pub fn metadata(&self) -> &HashMap<String, serde_json::Value> {
113 &self.metadata
114 }
115
116 pub async fn health_check(&self) -> RsllmResult<bool> {
118 self.provider.health_check().await
119 }
120
121 pub fn supported_models(&self) -> Vec<String> {
123 self.provider.supported_models()
124 }
125
126 pub async fn chat_completion(
128 &self,
129 messages: Vec<ChatMessage>,
130 ) -> RsllmResult<ChatResponse> {
131 self.chat_completion_with_options(
132 messages,
133 None,
134 None,
135 None,
136 ).await
137 }
138
139 pub async fn chat_completion_with_options(
141 &self,
142 messages: Vec<ChatMessage>,
143 model: Option<&str>,
144 temperature: Option<f32>,
145 max_tokens: Option<u32>,
146 ) -> RsllmResult<ChatResponse> {
147 if messages.is_empty() {
149 return Err(RsllmError::validation("messages", "Messages cannot be empty"));
150 }
151
152 let model = model.unwrap_or(&self.config.model.model);
154
155 let temperature = temperature.or(self.config.model.temperature);
157
158 let max_tokens = max_tokens.or(self.config.model.max_tokens);
160
161 self.provider.chat_completion(
162 messages,
163 Some(model),
164 temperature,
165 max_tokens,
166 ).await
167 }
168
169 pub async fn chat_completion_stream(
171 &self,
172 messages: Vec<ChatMessage>,
173 ) -> RsllmResult<ChatStream> {
174 self.chat_completion_stream_with_options(
175 messages,
176 None,
177 None,
178 None,
179 ).await
180 }
181
182 pub async fn chat_completion_stream_with_options(
184 &self,
185 messages: Vec<ChatMessage>,
186 model: Option<&str>,
187 temperature: Option<f32>,
188 max_tokens: Option<u32>,
189 ) -> RsllmResult<ChatStream> {
190 if messages.is_empty() {
192 return Err(RsllmError::validation("messages", "Messages cannot be empty"));
193 }
194
195 let model = model.unwrap_or(&self.config.model.model);
197
198 let temperature = temperature.or(self.config.model.temperature);
200
201 let max_tokens = max_tokens.or(self.config.model.max_tokens);
203
204 let stream = self.provider.chat_completion_stream(
205 messages,
206 Some(model.to_string()),
207 temperature,
208 max_tokens,
209 ).await?;
210
211 Ok(Box::pin(stream) as ChatStream)
213 }
214
215 pub async fn complete(&self, prompt: impl Into<String>) -> RsllmResult<String> {
217 let messages = vec![ChatMessage::user(prompt.into())];
218 let response = self.chat_completion(messages).await?;
219 Ok(response.content)
220 }
221
222 pub async fn complete_stream(&self, prompt: impl Into<String>) -> RsllmResult<ChatStream> {
224 let messages = vec![ChatMessage::user(prompt.into())];
225 self.chat_completion_stream(messages).await
226 }
227
228 pub async fn create_embeddings(
230 &self,
231 _inputs: Vec<String>,
232 ) -> RsllmResult<EmbeddingResponse> {
233 Err(RsllmError::configuration("Embeddings not yet implemented"))
235 }
236
237 pub fn count_tokens(&self, _text: &str) -> RsllmResult<u32> {
239 Err(RsllmError::configuration("Token counting not yet implemented"))
241 }
242}
243
244impl std::fmt::Debug for Client {
245 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246 f.debug_struct("Client")
247 .field("provider_type", &self.provider.provider_type())
248 .field("model", &self.config.model.model)
249 .finish()
250 }
251}
252
253pub struct ClientBuilder {
255 config: ClientConfig,
256}
257
258impl ClientBuilder {
259 pub fn new() -> Self {
261 Self {
262 config: ClientConfig::default(),
263 }
264 }
265
266 pub fn provider(mut self, provider: Provider) -> Self {
268 self.config.provider.provider = provider;
269 self
270 }
271
272 pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
274 self.config.provider.api_key = Some(api_key.into());
275 self
276 }
277
278 pub fn base_url(mut self, base_url: impl AsRef<str>) -> RsllmResult<Self> {
280 self.config.provider.base_url = Some(base_url.as_ref().parse()?);
281 Ok(self)
282 }
283
284 pub fn organization_id(mut self, org_id: impl Into<String>) -> Self {
286 self.config.provider.organization_id = Some(org_id.into());
287 self
288 }
289
290 pub fn model(mut self, model: impl Into<String>) -> Self {
292 self.config.model.model = model.into();
293 self
294 }
295
296 pub fn temperature(mut self, temperature: f32) -> Self {
298 self.config.model.temperature = Some(temperature);
299 self
300 }
301
302 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
304 self.config.model.max_tokens = Some(max_tokens);
305 self
306 }
307
308 pub fn stream(mut self, stream: bool) -> Self {
310 self.config.model.stream = stream;
311 self
312 }
313
314 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
316 self.config.http.timeout = timeout;
317 self
318 }
319
320 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
322 self.config.headers.insert(key.into(), value.into());
323 self
324 }
325
326 pub fn max_retries(mut self, max_retries: u32) -> Self {
328 self.config.retry.max_retries = max_retries;
329 self
330 }
331
332 pub fn build(self) -> RsllmResult<Client> {
334 Client::new(self.config)
335 }
336}
337
338impl Default for ClientBuilder {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344#[async_trait]
346pub trait AsyncClient: Send + Sync {
347 async fn chat_completion(
349 &self,
350 messages: Vec<ChatMessage>,
351 ) -> RsllmResult<ChatResponse>;
352
353 async fn chat_completion_stream(
355 &self,
356 messages: Vec<ChatMessage>,
357 ) -> RsllmResult<ChatStream>;
358
359 async fn health_check(&self) -> RsllmResult<bool>;
361}
362
363#[async_trait]
364impl AsyncClient for Client {
365 async fn chat_completion(
366 &self,
367 messages: Vec<ChatMessage>,
368 ) -> RsllmResult<ChatResponse> {
369 self.chat_completion(messages).await
370 }
371
372 async fn chat_completion_stream(
373 &self,
374 messages: Vec<ChatMessage>,
375 ) -> RsllmResult<ChatStream> {
376 self.chat_completion_stream(messages).await
377 }
378
379 async fn health_check(&self) -> RsllmResult<bool> {
380 self.health_check().await
381 }
382}
383
384pub struct ClientPool {
386 clients: HashMap<String, Arc<Client>>,
387 default_client: Option<String>,
388}
389
390impl ClientPool {
391 pub fn new() -> Self {
393 Self {
394 clients: HashMap::new(),
395 default_client: None,
396 }
397 }
398
399 pub fn add_client(&mut self, name: impl Into<String>, client: Client) {
401 let name = name.into();
402 let is_first = self.clients.is_empty();
403
404 self.clients.insert(name.clone(), Arc::new(client));
405
406 if is_first {
407 self.default_client = Some(name);
408 }
409 }
410
411 pub fn get_client(&self, name: &str) -> Option<&Arc<Client>> {
413 self.clients.get(name)
414 }
415
416 pub fn default_client(&self) -> Option<&Arc<Client>> {
418 self.default_client.as_ref().and_then(|name| self.get_client(name))
419 }
420
421 pub fn set_default(&mut self, name: impl Into<String>) -> RsllmResult<()> {
423 let name = name.into();
424 if self.clients.contains_key(&name) {
425 self.default_client = Some(name);
426 Ok(())
427 } else {
428 Err(RsllmError::not_found(format!("Client '{}'", name)))
429 }
430 }
431
432 pub fn client_names(&self) -> Vec<&String> {
434 self.clients.keys().collect()
435 }
436
437 pub fn remove_client(&mut self, name: &str) -> Option<Arc<Client>> {
439 let removed = self.clients.remove(name);
440
441 if self.default_client.as_deref() == Some(name) {
443 self.default_client = self.clients.keys().next().cloned();
444 }
445
446 removed
447 }
448}
449
450impl Default for ClientPool {
451 fn default() -> Self {
452 Self::new()
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459 use crate::{Provider, MessageRole};
460
461 #[test]
462 fn test_client_builder() {
463 let config = ClientBuilder::new()
464 .provider(Provider::OpenAI)
465 .model("gpt-4")
466 .temperature(0.7)
467 .max_tokens(1000)
468 .timeout(std::time::Duration::from_secs(30))
469 .header("Custom-Header", "value")
470 .config
471 .clone();
472
473 assert_eq!(config.provider.provider, Provider::OpenAI);
474 assert_eq!(config.model.model, "gpt-4");
475 assert_eq!(config.model.temperature, Some(0.7));
476 assert_eq!(config.model.max_tokens, Some(1000));
477 assert_eq!(config.http.timeout, std::time::Duration::from_secs(30));
478 assert!(config.headers.contains_key("Custom-Header"));
479 }
480
481 #[test]
482 fn test_client_pool() {
483 let mut pool = ClientPool::new();
484
485 assert_eq!(pool.client_names().len(), 0);
488 assert!(pool.default_client().is_none());
489 }
490
491 #[test]
492 fn test_message_validation() {
493 let config = ClientBuilder::new()
494 .provider(Provider::OpenAI)
495 .api_key("test-key")
496 .build();
497
498 assert!(config.is_err() || config.is_ok()); }
501}