Skip to main content

cognis_llm/
client.rs

1//! User-facing `Client`. Holds an `Arc<dyn LLMProvider>` and dispatches
2//! through it. Implements `Runnable<Vec<Message>, Message>` so it composes
3//! inside graphs.
4
5use std::str::FromStr;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9
10use cognis_core::{CognisError, Result, Runnable, RunnableConfig, RunnableStream};
11
12use crate::chat::{ChatOptions, ChatResponse, StreamChunk};
13use crate::provider::{LLMProvider, Provider};
14use crate::tools::ToolDefinition;
15use crate::Message;
16
17/// Top-level LLM client. Cheap to clone.
18#[derive(Clone)]
19pub struct Client {
20    provider: Arc<dyn LLMProvider>,
21}
22
23impl std::fmt::Debug for Client {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        f.debug_struct("Client")
26            .field("provider", &self.provider.name())
27            .finish()
28    }
29}
30
31impl Client {
32    /// Wrap any `LLMProvider`.
33    pub fn new(provider: Arc<dyn LLMProvider>) -> Self {
34        Self { provider }
35    }
36
37    /// Fluent builder.
38    pub fn builder() -> ClientBuilder {
39        ClientBuilder::default()
40    }
41
42    /// Build from env vars. Provider-namespaced with fallback:
43    /// `COGNIS_OPENAI_API_KEY` overrides `COGNIS_API_KEY`.
44    pub fn from_env() -> Result<Self> {
45        let provider_str = std::env::var("COGNIS_PROVIDER")
46            .map_err(|_| CognisError::Configuration("COGNIS_PROVIDER not set".into()))?;
47        let provider = Provider::from_str(&provider_str)?;
48        let mut b = Self::builder().provider(provider);
49
50        let key = std::env::var(format!(
51            "COGNIS_{}_API_KEY",
52            provider.to_string().to_uppercase()
53        ))
54        .or_else(|_| std::env::var("COGNIS_API_KEY"))
55        .ok();
56        if let Some(k) = key {
57            b = b.api_key(k);
58        }
59
60        let url = std::env::var(format!(
61            "COGNIS_{}_BASE_URL",
62            provider.to_string().to_uppercase()
63        ))
64        .or_else(|_| std::env::var("COGNIS_BASE_URL"))
65        .ok();
66        if let Some(u) = url {
67            b = b.base_url(u);
68        }
69
70        let model = std::env::var(format!(
71            "COGNIS_{}_MODEL",
72            provider.to_string().to_uppercase()
73        ))
74        .or_else(|_| std::env::var("COGNIS_MODEL"))
75        .ok();
76        if let Some(m) = model {
77            b = b.model(m);
78        }
79
80        b.build()
81    }
82
83    /// One-shot chat completion (no tools).
84    pub async fn invoke(&self, messages: Vec<Message>) -> Result<Message> {
85        Ok(self
86            .provider
87            .chat_completion(messages, ChatOptions::default())
88            .await?
89            .message)
90    }
91
92    /// Streaming chat completion.
93    pub async fn stream(&self, messages: Vec<Message>) -> Result<RunnableStream<StreamChunk>> {
94        self.provider
95            .chat_completion_stream(messages, ChatOptions::default())
96            .await
97    }
98
99    /// Chat completion with tool definitions.
100    pub async fn invoke_with_tools(
101        &self,
102        messages: Vec<Message>,
103        tools: &[Arc<dyn crate::tools::Tool>],
104    ) -> Result<Message> {
105        let defs: Vec<ToolDefinition> = tools
106            .iter()
107            .map(|t| ToolDefinition::from_tool(t.as_ref()))
108            .collect();
109        Ok(self
110            .provider
111            .chat_completion_with_tools(messages, defs, ChatOptions::default())
112            .await?
113            .message)
114    }
115
116    /// Provider-level full chat completion (with all options).
117    pub async fn chat(&self, messages: Vec<Message>, opts: ChatOptions) -> Result<ChatResponse> {
118        self.provider.chat_completion(messages, opts).await
119    }
120
121    /// Underlying provider.
122    pub fn provider(&self) -> &Arc<dyn LLMProvider> {
123        &self.provider
124    }
125}
126
127#[async_trait]
128impl Runnable<Vec<Message>, Message> for Client {
129    async fn invoke(&self, input: Vec<Message>, _: RunnableConfig) -> Result<Message> {
130        Client::invoke(self, input).await
131    }
132    fn name(&self) -> &str {
133        "Client"
134    }
135}
136
137/// Fluent builder for `Client`.
138#[derive(Default)]
139pub struct ClientBuilder {
140    provider: Option<Provider>,
141    api_key: Option<String>,
142    base_url: Option<String>,
143    model: Option<String>,
144    timeout_secs: Option<u64>,
145    organization: Option<String>,
146    azure_endpoint: Option<String>,
147    azure_deployment: Option<String>,
148    azure_api_version: Option<String>,
149}
150
151impl ClientBuilder {
152    /// Provider variant.
153    pub fn provider(mut self, p: Provider) -> Self {
154        self.provider = Some(p);
155        self
156    }
157    /// API key.
158    pub fn api_key(mut self, k: impl Into<String>) -> Self {
159        self.api_key = Some(k.into());
160        self
161    }
162    /// Base URL override.
163    pub fn base_url(mut self, u: impl Into<String>) -> Self {
164        self.base_url = Some(u.into());
165        self
166    }
167    /// Model.
168    pub fn model(mut self, m: impl Into<String>) -> Self {
169        self.model = Some(m.into());
170        self
171    }
172    /// Timeout in seconds.
173    pub fn timeout_secs(mut self, s: u64) -> Self {
174        self.timeout_secs = Some(s);
175        self
176    }
177    /// OpenAI organization (only used for OpenAI provider).
178    pub fn organization(mut self, o: impl Into<String>) -> Self {
179        self.organization = Some(o.into());
180        self
181    }
182    /// Azure resource endpoint (e.g. `https://my-resource.openai.azure.com/`).
183    pub fn azure_endpoint(mut self, e: impl Into<String>) -> Self {
184        self.azure_endpoint = Some(e.into());
185        self
186    }
187    /// Azure deployment name.
188    pub fn azure_deployment(mut self, d: impl Into<String>) -> Self {
189        self.azure_deployment = Some(d.into());
190        self
191    }
192    /// Azure API version.
193    pub fn azure_api_version(mut self, v: impl Into<String>) -> Self {
194        self.azure_api_version = Some(v.into());
195        self
196    }
197    /// Construct the Client.
198    pub fn build(self) -> Result<Client> {
199        let provider = self
200            .provider
201            .ok_or_else(|| CognisError::Configuration("Client: provider required".into()))?;
202        let arc_provider: Arc<dyn LLMProvider> = match provider {
203            #[cfg(feature = "openai")]
204            Provider::OpenAI => {
205                use crate::provider::openai::OpenAIBuilder;
206                let mut b = OpenAIBuilder::default();
207                if let Some(k) = self.api_key {
208                    b = b.api_key(k);
209                }
210                if let Some(u) = self.base_url {
211                    b = b.base_url(u);
212                }
213                if let Some(m) = self.model {
214                    b = b.model(m);
215                }
216                if let Some(t) = self.timeout_secs {
217                    b = b.timeout_secs(t);
218                }
219                if let Some(o) = self.organization {
220                    b = b.organization(o);
221                }
222                Arc::new(b.build()?)
223            }
224            #[cfg(feature = "openai")]
225            Provider::OpenRouter => {
226                use crate::provider::openai::OpenAIBuilder;
227                let mut b = OpenAIBuilder::default()
228                    .base_url(Provider::OpenRouter.default_base_url())
229                    .model(Provider::OpenRouter.default_model());
230                if let Some(k) = self.api_key {
231                    b = b.api_key(k);
232                }
233                if let Some(u) = self.base_url {
234                    b = b.base_url(u);
235                }
236                if let Some(m) = self.model {
237                    b = b.model(m);
238                }
239                if let Some(t) = self.timeout_secs {
240                    b = b.timeout_secs(t);
241                }
242                Arc::new(b.build()?)
243            }
244            #[cfg(feature = "ollama")]
245            Provider::Ollama => {
246                use crate::provider::ollama::OllamaBuilder;
247                let mut b = OllamaBuilder::default();
248                if let Some(u) = self.base_url {
249                    b = b.base_url(u);
250                }
251                if let Some(m) = self.model {
252                    b = b.model(m);
253                }
254                if let Some(t) = self.timeout_secs {
255                    b = b.timeout_secs(t);
256                }
257                Arc::new(b.build()?)
258            }
259            #[cfg(feature = "anthropic")]
260            Provider::Anthropic => {
261                use crate::provider::anthropic::AnthropicBuilder;
262                let mut b = AnthropicBuilder::default();
263                if let Some(k) = self.api_key {
264                    b = b.api_key(k);
265                }
266                if let Some(u) = self.base_url {
267                    b = b.base_url(u);
268                }
269                if let Some(m) = self.model {
270                    b = b.model(m);
271                }
272                if let Some(t) = self.timeout_secs {
273                    b = b.timeout_secs(t);
274                }
275                Arc::new(b.build()?)
276            }
277            #[cfg(feature = "google")]
278            Provider::Google => {
279                use crate::provider::google::GoogleBuilder;
280                let mut b = GoogleBuilder::default();
281                if let Some(k) = self.api_key {
282                    b = b.api_key(k);
283                }
284                if let Some(u) = self.base_url {
285                    b = b.base_url(u);
286                }
287                if let Some(m) = self.model {
288                    b = b.model(m);
289                }
290                if let Some(t) = self.timeout_secs {
291                    b = b.timeout_secs(t);
292                }
293                Arc::new(b.build()?)
294            }
295            #[cfg(feature = "azure")]
296            Provider::Azure => {
297                use crate::provider::azure::AzureBuilder;
298                let mut b = AzureBuilder::default();
299                if let Some(e) = self.azure_endpoint {
300                    b = b.endpoint(e);
301                }
302                if let Some(d) = self.azure_deployment {
303                    b = b.deployment(d);
304                }
305                if let Some(v) = self.azure_api_version {
306                    b = b.api_version(v);
307                }
308                if let Some(k) = self.api_key {
309                    b = b.api_key(k);
310                }
311                if let Some(t) = self.timeout_secs {
312                    b = b.timeout_secs(t);
313                }
314                Arc::new(b.build()?)
315            }
316            #[allow(unreachable_patterns)]
317            other => {
318                return Err(CognisError::Configuration(format!(
319                    "provider `{other}` not compiled in (enable the matching feature flag)"
320                )))
321            }
322        };
323        Ok(Client {
324            provider: arc_provider,
325        })
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[cfg(feature = "openai")]
334    #[test]
335    fn openai_builder_round_trip() {
336        let c = ClientBuilder::default()
337            .provider(Provider::OpenAI)
338            .api_key("sk-test")
339            .model("gpt-4o")
340            .build()
341            .unwrap();
342        assert_eq!(c.provider().name(), "openai");
343    }
344
345    #[cfg(feature = "ollama")]
346    #[test]
347    fn ollama_builder_round_trip() {
348        let c = ClientBuilder::default()
349            .provider(Provider::Ollama)
350            .model("llama3.2")
351            .build()
352            .unwrap();
353        assert_eq!(c.provider().name(), "ollama");
354    }
355
356    #[test]
357    fn missing_provider_errors() {
358        let err = ClientBuilder::default().build().unwrap_err();
359        assert!(format!("{err}").contains("provider required"));
360    }
361
362    #[cfg(feature = "anthropic")]
363    #[test]
364    fn anthropic_builder_round_trip() {
365        let c = ClientBuilder::default()
366            .provider(Provider::Anthropic)
367            .api_key("sk-ant-test")
368            .build()
369            .unwrap();
370        assert_eq!(c.provider().name(), "anthropic");
371    }
372
373    #[cfg(feature = "google")]
374    #[test]
375    fn google_builder_round_trip() {
376        let c = ClientBuilder::default()
377            .provider(Provider::Google)
378            .api_key("AIza-test")
379            .build()
380            .unwrap();
381        assert_eq!(c.provider().name(), "google");
382    }
383
384    #[cfg(feature = "azure")]
385    #[test]
386    fn azure_builder_round_trip() {
387        let c = ClientBuilder::default()
388            .provider(Provider::Azure)
389            .azure_endpoint("https://r.openai.azure.com/")
390            .azure_deployment("gpt-4o")
391            .api_key("k")
392            .build()
393            .unwrap();
394        assert_eq!(c.provider().name(), "azure");
395    }
396}