rexis_llm/
client.rs

1//! # RSLLM Client
2//!
3//! High-level client interface for RSLLM with multi-provider support.
4//! Provides unified API for chat completions, embeddings, and streaming.
5
6use crate::{
7    ChatMessage, ChatResponse, ChatStream, ClientConfig, EmbeddingResponse, Provider, RsllmError,
8    RsllmResult,
9};
10
11#[cfg(feature = "openai")]
12use crate::provider::OpenAIProvider;
13
14#[cfg(feature = "ollama")]
15use crate::provider::OllamaProvider;
16
17use crate::provider::LLMProvider;
18use async_trait::async_trait;
19use std::collections::HashMap;
20use std::sync::Arc;
21
22/// High-level RSLLM client
23pub struct Client {
24    /// Client configuration
25    config: ClientConfig,
26
27    /// Provider instance
28    provider: Arc<dyn LLMProvider>,
29
30    /// Client metadata
31    metadata: HashMap<String, serde_json::Value>,
32}
33
34impl Client {
35    /// Create a new client with configuration
36    pub fn new(config: ClientConfig) -> RsllmResult<Self> {
37        config.validate()?;
38
39        let provider = Self::create_provider(&config)?;
40
41        Ok(Self {
42            config,
43            provider,
44            metadata: HashMap::new(),
45        })
46    }
47
48    /// Create a client builder
49    pub fn builder() -> ClientBuilder {
50        ClientBuilder::new()
51    }
52
53    /// Create a client from environment variables
54    pub fn from_env() -> RsllmResult<Self> {
55        let config = ClientConfig::from_env()?;
56        Self::new(config)
57    }
58
59    /// Create provider instance based on configuration
60    fn create_provider(config: &ClientConfig) -> RsllmResult<Arc<dyn LLMProvider>> {
61        match config.provider.provider {
62            #[cfg(feature = "openai")]
63            Provider::OpenAI => {
64                let api_key = config
65                    .provider
66                    .api_key
67                    .as_ref()
68                    .ok_or_else(|| RsllmError::configuration("OpenAI API key required"))?;
69
70                let provider = OpenAIProvider::new(
71                    api_key.clone(),
72                    config.provider.base_url.clone(),
73                    config.provider.organization_id.clone(),
74                )?;
75
76                Ok(Arc::new(provider))
77            }
78
79            #[cfg(feature = "ollama")]
80            Provider::Ollama => {
81                let provider = OllamaProvider::new(config.provider.base_url.clone())?;
82                Ok(Arc::new(provider))
83            }
84
85            #[cfg(feature = "claude")]
86            Provider::Claude => {
87                // Claude provider implementation would go here
88                Err(RsllmError::configuration(
89                    "Claude provider not yet implemented",
90                ))
91            }
92
93            #[allow(unreachable_patterns)]
94            _ => Err(RsllmError::configuration(format!(
95                "Provider {:?} not supported in current build",
96                config.provider.provider
97            ))),
98        }
99    }
100
101    /// Get client configuration
102    pub fn config(&self) -> &ClientConfig {
103        &self.config
104    }
105
106    /// Get provider instance
107    pub fn provider(&self) -> &Arc<dyn LLMProvider> {
108        &self.provider
109    }
110
111    /// Add client metadata
112    pub fn add_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
113        self.metadata.insert(key.into(), value);
114    }
115
116    /// Get client metadata
117    pub fn metadata(&self) -> &HashMap<String, serde_json::Value> {
118        &self.metadata
119    }
120
121    /// Health check for the underlying provider
122    pub async fn health_check(&self) -> RsllmResult<bool> {
123        self.provider.health_check().await
124    }
125
126    /// Get supported models from the provider
127    pub fn supported_models(&self) -> Vec<String> {
128        self.provider.supported_models()
129    }
130
131    /// Chat completion (non-streaming)
132    pub async fn chat_completion(&self, messages: Vec<ChatMessage>) -> RsllmResult<ChatResponse> {
133        self.chat_completion_with_options(messages, None, None, None)
134            .await
135    }
136
137    /// Chat completion with custom options
138    pub async fn chat_completion_with_options(
139        &self,
140        messages: Vec<ChatMessage>,
141        model: Option<&str>,
142        temperature: Option<f32>,
143        max_tokens: Option<u32>,
144    ) -> RsllmResult<ChatResponse> {
145        // Validate messages
146        if messages.is_empty() {
147            return Err(RsllmError::validation(
148                "messages",
149                "Messages cannot be empty",
150            ));
151        }
152
153        // Use configured model if not specified
154        let model = model.unwrap_or(&self.config.model.model);
155
156        // Use configured temperature if not specified
157        let temperature = temperature.or(self.config.model.temperature);
158
159        // Use configured max_tokens if not specified
160        let max_tokens = max_tokens.or(self.config.model.max_tokens);
161
162        self.provider
163            .chat_completion(messages, Some(model), temperature, max_tokens)
164            .await
165    }
166
167    /// Chat completion with tool calling support
168    pub async fn chat_completion_with_tools(
169        &self,
170        messages: Vec<ChatMessage>,
171        tools: Vec<crate::tools::ToolDefinition>,
172    ) -> RsllmResult<ChatResponse> {
173        self.chat_completion_with_tools_and_options(messages, tools, None, None, None)
174            .await
175    }
176
177    /// Chat completion with tools and custom options
178    pub async fn chat_completion_with_tools_and_options(
179        &self,
180        messages: Vec<ChatMessage>,
181        tools: Vec<crate::tools::ToolDefinition>,
182        model: Option<&str>,
183        temperature: Option<f32>,
184        max_tokens: Option<u32>,
185    ) -> RsllmResult<ChatResponse> {
186        // Validate messages
187        if messages.is_empty() {
188            return Err(RsllmError::validation(
189                "messages",
190                "Messages cannot be empty",
191            ));
192        }
193
194        // Use configured model if not specified
195        let model = model.unwrap_or(&self.config.model.model);
196
197        // Use configured temperature if not specified
198        let temperature = temperature.or(self.config.model.temperature);
199
200        // Use configured max_tokens if not specified
201        let max_tokens = max_tokens.or(self.config.model.max_tokens);
202
203        self.provider
204            .chat_completion_with_tools(messages, tools, Some(model), temperature, max_tokens)
205            .await
206    }
207
208    /// Chat completion (streaming)
209    pub async fn chat_completion_stream(
210        &self,
211        messages: Vec<ChatMessage>,
212    ) -> RsllmResult<ChatStream> {
213        self.chat_completion_stream_with_options(messages, None, None, None)
214            .await
215    }
216
217    /// Chat completion streaming with custom options
218    pub async fn chat_completion_stream_with_options(
219        &self,
220        messages: Vec<ChatMessage>,
221        model: Option<&str>,
222        temperature: Option<f32>,
223        max_tokens: Option<u32>,
224    ) -> RsllmResult<ChatStream> {
225        // Validate messages
226        if messages.is_empty() {
227            return Err(RsllmError::validation(
228                "messages",
229                "Messages cannot be empty",
230            ));
231        }
232
233        // Use configured model if not specified
234        let model = model.unwrap_or(&self.config.model.model);
235
236        // Use configured temperature if not specified
237        let temperature = temperature.or(self.config.model.temperature);
238
239        // Use configured max_tokens if not specified
240        let max_tokens = max_tokens.or(self.config.model.max_tokens);
241
242        let stream = self
243            .provider
244            .chat_completion_stream(messages, Some(model.to_string()), temperature, max_tokens)
245            .await?;
246
247        // Convert Box<dyn Stream + Unpin> to Pin<Box<dyn Stream>>
248        Ok(Box::pin(stream) as ChatStream)
249    }
250
251    /// Simple text completion
252    pub async fn complete(&self, prompt: impl Into<String>) -> RsllmResult<String> {
253        let messages = vec![ChatMessage::user(prompt.into())];
254        let response = self.chat_completion(messages).await?;
255        Ok(response.content)
256    }
257
258    /// Simple streaming text completion
259    pub async fn complete_stream(&self, prompt: impl Into<String>) -> RsllmResult<ChatStream> {
260        let messages = vec![ChatMessage::user(prompt.into())];
261        self.chat_completion_stream(messages).await
262    }
263
264    /// Create embeddings (placeholder - would need provider support)
265    pub async fn create_embeddings(&self, _inputs: Vec<String>) -> RsllmResult<EmbeddingResponse> {
266        // TODO: Implement embeddings support in providers
267        Err(RsllmError::configuration("Embeddings not yet implemented"))
268    }
269
270    /// Count tokens in text (placeholder - would need tokenizer)
271    pub fn count_tokens(&self, _text: &str) -> RsllmResult<u32> {
272        // TODO: Implement tokenization
273        Err(RsllmError::configuration(
274            "Token counting not yet implemented",
275        ))
276    }
277}
278
279impl std::fmt::Debug for Client {
280    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281        f.debug_struct("Client")
282            .field("provider_type", &self.provider.provider_type())
283            .field("model", &self.config.model.model)
284            .finish()
285    }
286}
287
288/// Client builder for fluent configuration
289pub struct ClientBuilder {
290    config: ClientConfig,
291}
292
293impl ClientBuilder {
294    /// Create a new client builder
295    pub fn new() -> Self {
296        Self {
297            config: ClientConfig::default(),
298        }
299    }
300
301    /// Set the provider
302    pub fn provider(mut self, provider: Provider) -> Self {
303        self.config.provider.provider = provider;
304        self
305    }
306
307    /// Set the API key
308    pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
309        self.config.provider.api_key = Some(api_key.into());
310        self
311    }
312
313    /// Set the base URL
314    pub fn base_url(mut self, base_url: impl AsRef<str>) -> RsllmResult<Self> {
315        self.config.provider.base_url = Some(base_url.as_ref().parse()?);
316        Ok(self)
317    }
318
319    /// Set the organization ID
320    pub fn organization_id(mut self, org_id: impl Into<String>) -> Self {
321        self.config.provider.organization_id = Some(org_id.into());
322        self
323    }
324
325    /// Set the model
326    pub fn model(mut self, model: impl Into<String>) -> Self {
327        self.config.model.model = model.into();
328        self
329    }
330
331    /// Set the temperature
332    pub fn temperature(mut self, temperature: f32) -> Self {
333        self.config.model.temperature = Some(temperature);
334        self
335    }
336
337    /// Set max tokens
338    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
339        self.config.model.max_tokens = Some(max_tokens);
340        self
341    }
342
343    /// Enable streaming
344    pub fn stream(mut self, stream: bool) -> Self {
345        self.config.model.stream = stream;
346        self
347    }
348
349    /// Set timeout
350    pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
351        self.config.http.timeout = timeout;
352        self
353    }
354
355    /// Add a custom header
356    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
357        self.config.headers.insert(key.into(), value.into());
358        self
359    }
360
361    /// Set retry configuration
362    pub fn max_retries(mut self, max_retries: u32) -> Self {
363        self.config.retry.max_retries = max_retries;
364        self
365    }
366
367    /// Build the client
368    pub fn build(self) -> RsllmResult<Client> {
369        Client::new(self.config)
370    }
371}
372
373impl Default for ClientBuilder {
374    fn default() -> Self {
375        Self::new()
376    }
377}
378
379/// Async client trait for custom implementations
380#[async_trait]
381pub trait AsyncClient: Send + Sync {
382    /// Chat completion
383    async fn chat_completion(&self, messages: Vec<ChatMessage>) -> RsllmResult<ChatResponse>;
384
385    /// Chat completion streaming
386    async fn chat_completion_stream(&self, messages: Vec<ChatMessage>) -> RsllmResult<ChatStream>;
387
388    /// Health check
389    async fn health_check(&self) -> RsllmResult<bool>;
390}
391
392#[async_trait]
393impl AsyncClient for Client {
394    async fn chat_completion(&self, messages: Vec<ChatMessage>) -> RsllmResult<ChatResponse> {
395        self.chat_completion(messages).await
396    }
397
398    async fn chat_completion_stream(&self, messages: Vec<ChatMessage>) -> RsllmResult<ChatStream> {
399        self.chat_completion_stream(messages).await
400    }
401
402    async fn health_check(&self) -> RsllmResult<bool> {
403        self.health_check().await
404    }
405}
406
407/// Client pool for managing multiple clients
408pub struct ClientPool {
409    clients: HashMap<String, Arc<Client>>,
410    default_client: Option<String>,
411}
412
413impl ClientPool {
414    /// Create a new client pool
415    pub fn new() -> Self {
416        Self {
417            clients: HashMap::new(),
418            default_client: None,
419        }
420    }
421
422    /// Add a client to the pool
423    pub fn add_client(&mut self, name: impl Into<String>, client: Client) {
424        let name = name.into();
425        let is_first = self.clients.is_empty();
426
427        self.clients.insert(name.clone(), Arc::new(client));
428
429        if is_first {
430            self.default_client = Some(name);
431        }
432    }
433
434    /// Get a client by name
435    pub fn get_client(&self, name: &str) -> Option<&Arc<Client>> {
436        self.clients.get(name)
437    }
438
439    /// Get the default client
440    pub fn default_client(&self) -> Option<&Arc<Client>> {
441        self.default_client
442            .as_ref()
443            .and_then(|name| self.get_client(name))
444    }
445
446    /// Set the default client
447    pub fn set_default(&mut self, name: impl Into<String>) -> RsllmResult<()> {
448        let name = name.into();
449        if self.clients.contains_key(&name) {
450            self.default_client = Some(name);
451            Ok(())
452        } else {
453            Err(RsllmError::not_found(format!("Client '{}'", name)))
454        }
455    }
456
457    /// List all client names
458    pub fn client_names(&self) -> Vec<&String> {
459        self.clients.keys().collect()
460    }
461
462    /// Remove a client
463    pub fn remove_client(&mut self, name: &str) -> Option<Arc<Client>> {
464        let removed = self.clients.remove(name);
465
466        // Update default if we removed it
467        if self.default_client.as_deref() == Some(name) {
468            self.default_client = self.clients.keys().next().cloned();
469        }
470
471        removed
472    }
473}
474
475impl Default for ClientPool {
476    fn default() -> Self {
477        Self::new()
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484    use crate::{MessageRole, Provider};
485
486    #[test]
487    fn test_client_builder() {
488        let config = ClientBuilder::new()
489            .provider(Provider::OpenAI)
490            .model("gpt-4")
491            .temperature(0.7)
492            .max_tokens(1000)
493            .timeout(std::time::Duration::from_secs(30))
494            .header("Custom-Header", "value")
495            .config
496            .clone();
497
498        assert_eq!(config.provider.provider, Provider::OpenAI);
499        assert_eq!(config.model.model, "gpt-4");
500        assert_eq!(config.model.temperature, Some(0.7));
501        assert_eq!(config.model.max_tokens, Some(1000));
502        assert_eq!(config.http.timeout, std::time::Duration::from_secs(30));
503        assert!(config.headers.contains_key("Custom-Header"));
504    }
505
506    #[test]
507    fn test_client_pool() {
508        let mut pool = ClientPool::new();
509
510        // Note: These clients would fail to build without proper API keys
511        // This is just testing the pool structure
512        assert_eq!(pool.client_names().len(), 0);
513        assert!(pool.default_client().is_none());
514    }
515
516    #[test]
517    fn test_message_validation() {
518        let config = ClientBuilder::new()
519            .provider(Provider::OpenAI)
520            .api_key("test-key")
521            .build();
522
523        // This will fail due to missing implementation, but we can test the validation logic
524        assert!(config.is_err() || config.is_ok()); // Either way is fine for structure test
525    }
526}