Skip to main content

rsllm/
client.rs

1//! # RSLLM Client
2//! 
3//! High-level client interface for RSLLM with multi-provider support.
4//! Provides unified API for chat completions, embeddings, and streaming.
5
6use crate::{
7    RsllmError, RsllmResult, Provider, ClientConfig,
8    ChatMessage, ChatResponse, EmbeddingResponse,
9    ChatStream,
10};
11
12#[cfg(feature = "openai")]
13use crate::provider::OpenAIProvider;
14
15#[cfg(feature = "ollama")]
16use crate::provider::OllamaProvider;
17
18use crate::provider::LLMProvider;
19use async_trait::async_trait;
20use std::sync::Arc;
21use std::collections::HashMap;
22
23/// High-level RSLLM client
24pub struct Client {
25    /// Client configuration
26    config: ClientConfig,
27    
28    /// Provider instance
29    provider: Arc<dyn LLMProvider>,
30    
31    /// Client metadata
32    metadata: HashMap<String, serde_json::Value>,
33}
34
35impl Client {
36    /// Create a new client with configuration
37    pub fn new(config: ClientConfig) -> RsllmResult<Self> {
38        config.validate()?;
39        
40        let provider = Self::create_provider(&config)?;
41        
42        Ok(Self {
43            config,
44            provider,
45            metadata: HashMap::new(),
46        })
47    }
48    
49    /// Create a client builder
50    pub fn builder() -> ClientBuilder {
51        ClientBuilder::new()
52    }
53    
54    /// Create a client from environment variables
55    pub fn from_env() -> RsllmResult<Self> {
56        let config = ClientConfig::from_env()?;
57        Self::new(config)
58    }
59    
60    /// Create provider instance based on configuration
61    fn create_provider(config: &ClientConfig) -> RsllmResult<Arc<dyn LLMProvider>> {
62        match config.provider.provider {
63            #[cfg(feature = "openai")]
64            Provider::OpenAI => {
65                let api_key = config.provider.api_key.as_ref()
66                    .ok_or_else(|| RsllmError::configuration("OpenAI API key required"))?;
67                
68                let provider = OpenAIProvider::new(
69                    api_key.clone(),
70                    config.provider.base_url.clone(),
71                    config.provider.organization_id.clone(),
72                )?;
73                
74                Ok(Arc::new(provider))
75            }
76            
77            #[cfg(feature = "ollama")]
78            Provider::Ollama => {
79                let provider = OllamaProvider::new(config.provider.base_url.clone())?;
80                Ok(Arc::new(provider))
81            }
82            
83            #[cfg(feature = "claude")]
84            Provider::Claude => {
85                // Claude provider implementation would go here
86                Err(RsllmError::configuration("Claude provider not yet implemented"))
87            }
88            
89            #[allow(unreachable_patterns)]
90            _ => Err(RsllmError::configuration(
91                format!("Provider {:?} not supported in current build", config.provider.provider)
92            )),
93        }
94    }
95    
96    /// Get client configuration
97    pub fn config(&self) -> &ClientConfig {
98        &self.config
99    }
100    
101    /// Get provider instance
102    pub fn provider(&self) -> &Arc<dyn LLMProvider> {
103        &self.provider
104    }
105    
106    /// Add client metadata
107    pub fn add_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
108        self.metadata.insert(key.into(), value);
109    }
110    
111    /// Get client metadata
112    pub fn metadata(&self) -> &HashMap<String, serde_json::Value> {
113        &self.metadata
114    }
115    
116    /// Health check for the underlying provider
117    pub async fn health_check(&self) -> RsllmResult<bool> {
118        self.provider.health_check().await
119    }
120    
121    /// Get supported models from the provider
122    pub fn supported_models(&self) -> Vec<String> {
123        self.provider.supported_models()
124    }
125    
126    /// Chat completion (non-streaming)
127    pub async fn chat_completion(
128        &self,
129        messages: Vec<ChatMessage>,
130    ) -> RsllmResult<ChatResponse> {
131        self.chat_completion_with_options(
132            messages,
133            None,
134            None,
135            None,
136        ).await
137    }
138    
139    /// Chat completion with custom options
140    pub async fn chat_completion_with_options(
141        &self,
142        messages: Vec<ChatMessage>,
143        model: Option<&str>,
144        temperature: Option<f32>,
145        max_tokens: Option<u32>,
146    ) -> RsllmResult<ChatResponse> {
147        // Validate messages
148        if messages.is_empty() {
149            return Err(RsllmError::validation("messages", "Messages cannot be empty"));
150        }
151        
152        // Use configured model if not specified
153        let model = model.unwrap_or(&self.config.model.model);
154        
155        // Use configured temperature if not specified
156        let temperature = temperature.or(self.config.model.temperature);
157        
158        // Use configured max_tokens if not specified
159        let max_tokens = max_tokens.or(self.config.model.max_tokens);
160        
161        self.provider.chat_completion(
162            messages,
163            Some(model),
164            temperature,
165            max_tokens,
166        ).await
167    }
168    
169    /// Chat completion (streaming)
170    pub async fn chat_completion_stream(
171        &self,
172        messages: Vec<ChatMessage>,
173    ) -> RsllmResult<ChatStream> {
174        self.chat_completion_stream_with_options(
175            messages,
176            None,
177            None,
178            None,
179        ).await
180    }
181    
182    /// Chat completion streaming with custom options
183    pub async fn chat_completion_stream_with_options(
184        &self,
185        messages: Vec<ChatMessage>,
186        model: Option<&str>,
187        temperature: Option<f32>,
188        max_tokens: Option<u32>,
189    ) -> RsllmResult<ChatStream> {
190        // Validate messages
191        if messages.is_empty() {
192            return Err(RsllmError::validation("messages", "Messages cannot be empty"));
193        }
194        
195        // Use configured model if not specified
196        let model = model.unwrap_or(&self.config.model.model);
197        
198        // Use configured temperature if not specified
199        let temperature = temperature.or(self.config.model.temperature);
200        
201        // Use configured max_tokens if not specified
202        let max_tokens = max_tokens.or(self.config.model.max_tokens);
203        
204        let stream = self.provider.chat_completion_stream(
205            messages,
206            Some(model.to_string()),
207            temperature,
208            max_tokens,
209        ).await?;
210        
211        // Convert Box<dyn Stream + Unpin> to Pin<Box<dyn Stream>>
212        Ok(Box::pin(stream) as ChatStream)
213    }
214    
215    /// Simple text completion
216    pub async fn complete(&self, prompt: impl Into<String>) -> RsllmResult<String> {
217        let messages = vec![ChatMessage::user(prompt.into())];
218        let response = self.chat_completion(messages).await?;
219        Ok(response.content)
220    }
221    
222    /// Simple streaming text completion
223    pub async fn complete_stream(&self, prompt: impl Into<String>) -> RsllmResult<ChatStream> {
224        let messages = vec![ChatMessage::user(prompt.into())];
225        self.chat_completion_stream(messages).await
226    }
227    
228    /// Create embeddings (placeholder - would need provider support)
229    pub async fn create_embeddings(
230        &self,
231        _inputs: Vec<String>,
232    ) -> RsllmResult<EmbeddingResponse> {
233        // TODO: Implement embeddings support in providers
234        Err(RsllmError::configuration("Embeddings not yet implemented"))
235    }
236    
237    /// Count tokens in text (placeholder - would need tokenizer)
238    pub fn count_tokens(&self, _text: &str) -> RsllmResult<u32> {
239        // TODO: Implement tokenization
240        Err(RsllmError::configuration("Token counting not yet implemented"))
241    }
242}
243
244impl std::fmt::Debug for Client {
245    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246        f.debug_struct("Client")
247            .field("provider_type", &self.provider.provider_type())
248            .field("model", &self.config.model.model)
249            .finish()
250    }
251}
252
253/// Client builder for fluent configuration
254pub struct ClientBuilder {
255    config: ClientConfig,
256}
257
258impl ClientBuilder {
259    /// Create a new client builder
260    pub fn new() -> Self {
261        Self {
262            config: ClientConfig::default(),
263        }
264    }
265    
266    /// Set the provider
267    pub fn provider(mut self, provider: Provider) -> Self {
268        self.config.provider.provider = provider;
269        self
270    }
271    
272    /// Set the API key
273    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
274        self.config.provider.api_key = Some(api_key.into());
275        self
276    }
277    
278    /// Set the base URL
279    pub fn base_url(mut self, base_url: impl AsRef<str>) -> RsllmResult<Self> {
280        self.config.provider.base_url = Some(base_url.as_ref().parse()?);
281        Ok(self)
282    }
283    
284    /// Set the organization ID
285    pub fn organization_id(mut self, org_id: impl Into<String>) -> Self {
286        self.config.provider.organization_id = Some(org_id.into());
287        self
288    }
289    
290    /// Set the model
291    pub fn model(mut self, model: impl Into<String>) -> Self {
292        self.config.model.model = model.into();
293        self
294    }
295    
296    /// Set the temperature
297    pub fn temperature(mut self, temperature: f32) -> Self {
298        self.config.model.temperature = Some(temperature);
299        self
300    }
301    
302    /// Set max tokens
303    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
304        self.config.model.max_tokens = Some(max_tokens);
305        self
306    }
307    
308    /// Enable streaming
309    pub fn stream(mut self, stream: bool) -> Self {
310        self.config.model.stream = stream;
311        self
312    }
313    
314    /// Set timeout
315    pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
316        self.config.http.timeout = timeout;
317        self
318    }
319    
320    /// Add a custom header
321    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
322        self.config.headers.insert(key.into(), value.into());
323        self
324    }
325    
326    /// Set retry configuration
327    pub fn max_retries(mut self, max_retries: u32) -> Self {
328        self.config.retry.max_retries = max_retries;
329        self
330    }
331    
332    /// Build the client
333    pub fn build(self) -> RsllmResult<Client> {
334        Client::new(self.config)
335    }
336}
337
338impl Default for ClientBuilder {
339    fn default() -> Self {
340        Self::new()
341    }
342}
343
344/// Async client trait for custom implementations
345#[async_trait]
346pub trait AsyncClient: Send + Sync {
347    /// Chat completion
348    async fn chat_completion(
349        &self,
350        messages: Vec<ChatMessage>,
351    ) -> RsllmResult<ChatResponse>;
352    
353    /// Chat completion streaming
354    async fn chat_completion_stream(
355        &self,
356        messages: Vec<ChatMessage>,
357    ) -> RsllmResult<ChatStream>;
358    
359    /// Health check
360    async fn health_check(&self) -> RsllmResult<bool>;
361}
362
363#[async_trait]
364impl AsyncClient for Client {
365    async fn chat_completion(
366        &self,
367        messages: Vec<ChatMessage>,
368    ) -> RsllmResult<ChatResponse> {
369        self.chat_completion(messages).await
370    }
371    
372    async fn chat_completion_stream(
373        &self,
374        messages: Vec<ChatMessage>,
375    ) -> RsllmResult<ChatStream> {
376        self.chat_completion_stream(messages).await
377    }
378    
379    async fn health_check(&self) -> RsllmResult<bool> {
380        self.health_check().await
381    }
382}
383
384/// Client pool for managing multiple clients
385pub struct ClientPool {
386    clients: HashMap<String, Arc<Client>>,
387    default_client: Option<String>,
388}
389
390impl ClientPool {
391    /// Create a new client pool
392    pub fn new() -> Self {
393        Self {
394            clients: HashMap::new(),
395            default_client: None,
396        }
397    }
398    
399    /// Add a client to the pool
400    pub fn add_client(&mut self, name: impl Into<String>, client: Client) {
401        let name = name.into();
402        let is_first = self.clients.is_empty();
403        
404        self.clients.insert(name.clone(), Arc::new(client));
405        
406        if is_first {
407            self.default_client = Some(name);
408        }
409    }
410    
411    /// Get a client by name
412    pub fn get_client(&self, name: &str) -> Option<&Arc<Client>> {
413        self.clients.get(name)
414    }
415    
416    /// Get the default client
417    pub fn default_client(&self) -> Option<&Arc<Client>> {
418        self.default_client.as_ref().and_then(|name| self.get_client(name))
419    }
420    
421    /// Set the default client
422    pub fn set_default(&mut self, name: impl Into<String>) -> RsllmResult<()> {
423        let name = name.into();
424        if self.clients.contains_key(&name) {
425            self.default_client = Some(name);
426            Ok(())
427        } else {
428            Err(RsllmError::not_found(format!("Client '{}'", name)))
429        }
430    }
431    
432    /// List all client names
433    pub fn client_names(&self) -> Vec<&String> {
434        self.clients.keys().collect()
435    }
436    
437    /// Remove a client
438    pub fn remove_client(&mut self, name: &str) -> Option<Arc<Client>> {
439        let removed = self.clients.remove(name);
440        
441        // Update default if we removed it
442        if self.default_client.as_deref() == Some(name) {
443            self.default_client = self.clients.keys().next().cloned();
444        }
445        
446        removed
447    }
448}
449
450impl Default for ClientPool {
451    fn default() -> Self {
452        Self::new()
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459    use crate::{Provider, MessageRole};
460    
461    #[test]
462    fn test_client_builder() {
463        let config = ClientBuilder::new()
464            .provider(Provider::OpenAI)
465            .model("gpt-4")
466            .temperature(0.7)
467            .max_tokens(1000)
468            .timeout(std::time::Duration::from_secs(30))
469            .header("Custom-Header", "value")
470            .config
471            .clone();
472            
473        assert_eq!(config.provider.provider, Provider::OpenAI);
474        assert_eq!(config.model.model, "gpt-4");
475        assert_eq!(config.model.temperature, Some(0.7));
476        assert_eq!(config.model.max_tokens, Some(1000));
477        assert_eq!(config.http.timeout, std::time::Duration::from_secs(30));
478        assert!(config.headers.contains_key("Custom-Header"));
479    }
480    
481    #[test]
482    fn test_client_pool() {
483        let mut pool = ClientPool::new();
484        
485        // Note: These clients would fail to build without proper API keys
486        // This is just testing the pool structure
487        assert_eq!(pool.client_names().len(), 0);
488        assert!(pool.default_client().is_none());
489    }
490    
491    #[test]
492    fn test_message_validation() {
493        let config = ClientBuilder::new()
494            .provider(Provider::OpenAI)
495            .api_key("test-key")
496            .build();
497            
498        // This will fail due to missing implementation, but we can test the validation logic
499        assert!(config.is_err() || config.is_ok()); // Either way is fine for structure test
500    }
501}