1use crate::{
7 ChatMessage, ChatResponse, ChatStream, ClientConfig, EmbeddingResponse, Provider, RsllmError,
8 RsllmResult,
9};
10
11#[cfg(feature = "openai")]
12use crate::provider::OpenAIProvider;
13
14#[cfg(feature = "ollama")]
15use crate::provider::OllamaProvider;
16
17use crate::provider::LLMProvider;
18use async_trait::async_trait;
19use std::collections::HashMap;
20use std::sync::Arc;
21
22pub struct Client {
24 config: ClientConfig,
26
27 provider: Arc<dyn LLMProvider>,
29
30 metadata: HashMap<String, serde_json::Value>,
32}
33
34impl Client {
35 pub fn new(config: ClientConfig) -> RsllmResult<Self> {
37 config.validate()?;
38
39 let provider = Self::create_provider(&config)?;
40
41 Ok(Self {
42 config,
43 provider,
44 metadata: HashMap::new(),
45 })
46 }
47
48 pub fn builder() -> ClientBuilder {
50 ClientBuilder::new()
51 }
52
53 pub fn from_env() -> RsllmResult<Self> {
55 let config = ClientConfig::from_env()?;
56 Self::new(config)
57 }
58
59 fn create_provider(config: &ClientConfig) -> RsllmResult<Arc<dyn LLMProvider>> {
61 match config.provider.provider {
62 #[cfg(feature = "openai")]
63 Provider::OpenAI => {
64 let api_key = config
65 .provider
66 .api_key
67 .as_ref()
68 .ok_or_else(|| RsllmError::configuration("OpenAI API key required"))?;
69
70 let provider = OpenAIProvider::new(
71 api_key.clone(),
72 config.provider.base_url.clone(),
73 config.provider.organization_id.clone(),
74 )?;
75
76 Ok(Arc::new(provider))
77 }
78
79 #[cfg(feature = "ollama")]
80 Provider::Ollama => {
81 let provider = OllamaProvider::new(config.provider.base_url.clone())?;
82 Ok(Arc::new(provider))
83 }
84
85 #[cfg(feature = "claude")]
86 Provider::Claude => {
87 Err(RsllmError::configuration(
89 "Claude provider not yet implemented",
90 ))
91 }
92
93 #[allow(unreachable_patterns)]
94 _ => Err(RsllmError::configuration(format!(
95 "Provider {:?} not supported in current build",
96 config.provider.provider
97 ))),
98 }
99 }
100
101 pub fn config(&self) -> &ClientConfig {
103 &self.config
104 }
105
106 pub fn provider(&self) -> &Arc<dyn LLMProvider> {
108 &self.provider
109 }
110
111 pub fn add_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
113 self.metadata.insert(key.into(), value);
114 }
115
116 pub fn metadata(&self) -> &HashMap<String, serde_json::Value> {
118 &self.metadata
119 }
120
121 pub async fn health_check(&self) -> RsllmResult<bool> {
123 self.provider.health_check().await
124 }
125
126 pub fn supported_models(&self) -> Vec<String> {
128 self.provider.supported_models()
129 }
130
131 pub async fn chat_completion(&self, messages: Vec<ChatMessage>) -> RsllmResult<ChatResponse> {
133 self.chat_completion_with_options(messages, None, None, None)
134 .await
135 }
136
137 pub async fn chat_completion_with_options(
139 &self,
140 messages: Vec<ChatMessage>,
141 model: Option<&str>,
142 temperature: Option<f32>,
143 max_tokens: Option<u32>,
144 ) -> RsllmResult<ChatResponse> {
145 if messages.is_empty() {
147 return Err(RsllmError::validation(
148 "messages",
149 "Messages cannot be empty",
150 ));
151 }
152
153 let model = model.unwrap_or(&self.config.model.model);
155
156 let temperature = temperature.or(self.config.model.temperature);
158
159 let max_tokens = max_tokens.or(self.config.model.max_tokens);
161
162 self.provider
163 .chat_completion(messages, Some(model), temperature, max_tokens)
164 .await
165 }
166
167 pub async fn chat_completion_with_tools(
169 &self,
170 messages: Vec<ChatMessage>,
171 tools: Vec<crate::tools::ToolDefinition>,
172 ) -> RsllmResult<ChatResponse> {
173 self.chat_completion_with_tools_and_options(messages, tools, None, None, None)
174 .await
175 }
176
177 pub async fn chat_completion_with_tools_and_options(
179 &self,
180 messages: Vec<ChatMessage>,
181 tools: Vec<crate::tools::ToolDefinition>,
182 model: Option<&str>,
183 temperature: Option<f32>,
184 max_tokens: Option<u32>,
185 ) -> RsllmResult<ChatResponse> {
186 if messages.is_empty() {
188 return Err(RsllmError::validation(
189 "messages",
190 "Messages cannot be empty",
191 ));
192 }
193
194 let model = model.unwrap_or(&self.config.model.model);
196
197 let temperature = temperature.or(self.config.model.temperature);
199
200 let max_tokens = max_tokens.or(self.config.model.max_tokens);
202
203 self.provider
204 .chat_completion_with_tools(messages, tools, Some(model), temperature, max_tokens)
205 .await
206 }
207
208 pub async fn chat_completion_stream(
210 &self,
211 messages: Vec<ChatMessage>,
212 ) -> RsllmResult<ChatStream> {
213 self.chat_completion_stream_with_options(messages, None, None, None)
214 .await
215 }
216
217 pub async fn chat_completion_stream_with_options(
219 &self,
220 messages: Vec<ChatMessage>,
221 model: Option<&str>,
222 temperature: Option<f32>,
223 max_tokens: Option<u32>,
224 ) -> RsllmResult<ChatStream> {
225 if messages.is_empty() {
227 return Err(RsllmError::validation(
228 "messages",
229 "Messages cannot be empty",
230 ));
231 }
232
233 let model = model.unwrap_or(&self.config.model.model);
235
236 let temperature = temperature.or(self.config.model.temperature);
238
239 let max_tokens = max_tokens.or(self.config.model.max_tokens);
241
242 let stream = self
243 .provider
244 .chat_completion_stream(messages, Some(model.to_string()), temperature, max_tokens)
245 .await?;
246
247 Ok(Box::pin(stream) as ChatStream)
249 }
250
251 pub async fn complete(&self, prompt: impl Into<String>) -> RsllmResult<String> {
253 let messages = vec![ChatMessage::user(prompt.into())];
254 let response = self.chat_completion(messages).await?;
255 Ok(response.content)
256 }
257
258 pub async fn complete_stream(&self, prompt: impl Into<String>) -> RsllmResult<ChatStream> {
260 let messages = vec![ChatMessage::user(prompt.into())];
261 self.chat_completion_stream(messages).await
262 }
263
264 pub async fn create_embeddings(&self, _inputs: Vec<String>) -> RsllmResult<EmbeddingResponse> {
266 Err(RsllmError::configuration("Embeddings not yet implemented"))
268 }
269
270 pub fn count_tokens(&self, _text: &str) -> RsllmResult<u32> {
272 Err(RsllmError::configuration(
274 "Token counting not yet implemented",
275 ))
276 }
277}
278
279impl std::fmt::Debug for Client {
280 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 f.debug_struct("Client")
282 .field("provider_type", &self.provider.provider_type())
283 .field("model", &self.config.model.model)
284 .finish()
285 }
286}
287
288pub struct ClientBuilder {
290 config: ClientConfig,
291}
292
293impl ClientBuilder {
294 pub fn new() -> Self {
296 Self {
297 config: ClientConfig::default(),
298 }
299 }
300
301 pub fn provider(mut self, provider: Provider) -> Self {
303 self.config.provider.provider = provider;
304 self
305 }
306
307 pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
309 self.config.provider.api_key = Some(api_key.into());
310 self
311 }
312
313 pub fn base_url(mut self, base_url: impl AsRef<str>) -> RsllmResult<Self> {
315 self.config.provider.base_url = Some(base_url.as_ref().parse()?);
316 Ok(self)
317 }
318
319 pub fn organization_id(mut self, org_id: impl Into<String>) -> Self {
321 self.config.provider.organization_id = Some(org_id.into());
322 self
323 }
324
325 pub fn model(mut self, model: impl Into<String>) -> Self {
327 self.config.model.model = model.into();
328 self
329 }
330
331 pub fn temperature(mut self, temperature: f32) -> Self {
333 self.config.model.temperature = Some(temperature);
334 self
335 }
336
337 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
339 self.config.model.max_tokens = Some(max_tokens);
340 self
341 }
342
343 pub fn stream(mut self, stream: bool) -> Self {
345 self.config.model.stream = stream;
346 self
347 }
348
349 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
351 self.config.http.timeout = timeout;
352 self
353 }
354
355 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
357 self.config.headers.insert(key.into(), value.into());
358 self
359 }
360
361 pub fn max_retries(mut self, max_retries: u32) -> Self {
363 self.config.retry.max_retries = max_retries;
364 self
365 }
366
367 pub fn build(self) -> RsllmResult<Client> {
369 Client::new(self.config)
370 }
371}
372
373impl Default for ClientBuilder {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379#[async_trait]
381pub trait AsyncClient: Send + Sync {
382 async fn chat_completion(&self, messages: Vec<ChatMessage>) -> RsllmResult<ChatResponse>;
384
385 async fn chat_completion_stream(&self, messages: Vec<ChatMessage>) -> RsllmResult<ChatStream>;
387
388 async fn health_check(&self) -> RsllmResult<bool>;
390}
391
392#[async_trait]
393impl AsyncClient for Client {
394 async fn chat_completion(&self, messages: Vec<ChatMessage>) -> RsllmResult<ChatResponse> {
395 self.chat_completion(messages).await
396 }
397
398 async fn chat_completion_stream(&self, messages: Vec<ChatMessage>) -> RsllmResult<ChatStream> {
399 self.chat_completion_stream(messages).await
400 }
401
402 async fn health_check(&self) -> RsllmResult<bool> {
403 self.health_check().await
404 }
405}
406
407pub struct ClientPool {
409 clients: HashMap<String, Arc<Client>>,
410 default_client: Option<String>,
411}
412
413impl ClientPool {
414 pub fn new() -> Self {
416 Self {
417 clients: HashMap::new(),
418 default_client: None,
419 }
420 }
421
422 pub fn add_client(&mut self, name: impl Into<String>, client: Client) {
424 let name = name.into();
425 let is_first = self.clients.is_empty();
426
427 self.clients.insert(name.clone(), Arc::new(client));
428
429 if is_first {
430 self.default_client = Some(name);
431 }
432 }
433
434 pub fn get_client(&self, name: &str) -> Option<&Arc<Client>> {
436 self.clients.get(name)
437 }
438
439 pub fn default_client(&self) -> Option<&Arc<Client>> {
441 self.default_client
442 .as_ref()
443 .and_then(|name| self.get_client(name))
444 }
445
446 pub fn set_default(&mut self, name: impl Into<String>) -> RsllmResult<()> {
448 let name = name.into();
449 if self.clients.contains_key(&name) {
450 self.default_client = Some(name);
451 Ok(())
452 } else {
453 Err(RsllmError::not_found(format!("Client '{}'", name)))
454 }
455 }
456
457 pub fn client_names(&self) -> Vec<&String> {
459 self.clients.keys().collect()
460 }
461
462 pub fn remove_client(&mut self, name: &str) -> Option<Arc<Client>> {
464 let removed = self.clients.remove(name);
465
466 if self.default_client.as_deref() == Some(name) {
468 self.default_client = self.clients.keys().next().cloned();
469 }
470
471 removed
472 }
473}
474
475impl Default for ClientPool {
476 fn default() -> Self {
477 Self::new()
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use crate::{MessageRole, Provider};
485
486 #[test]
487 fn test_client_builder() {
488 let config = ClientBuilder::new()
489 .provider(Provider::OpenAI)
490 .model("gpt-4")
491 .temperature(0.7)
492 .max_tokens(1000)
493 .timeout(std::time::Duration::from_secs(30))
494 .header("Custom-Header", "value")
495 .config
496 .clone();
497
498 assert_eq!(config.provider.provider, Provider::OpenAI);
499 assert_eq!(config.model.model, "gpt-4");
500 assert_eq!(config.model.temperature, Some(0.7));
501 assert_eq!(config.model.max_tokens, Some(1000));
502 assert_eq!(config.http.timeout, std::time::Duration::from_secs(30));
503 assert!(config.headers.contains_key("Custom-Header"));
504 }
505
506 #[test]
507 fn test_client_pool() {
508 let mut pool = ClientPool::new();
509
510 assert_eq!(pool.client_names().len(), 0);
513 assert!(pool.default_client().is_none());
514 }
515
516 #[test]
517 fn test_message_validation() {
518 let config = ClientBuilder::new()
519 .provider(Provider::OpenAI)
520 .api_key("test-key")
521 .build();
522
523 assert!(config.is_err() || config.is_ok()); }
526}